├── DAP.png ├── LICENSE ├── NOTICE ├── README.md ├── configs ├── base-dap.yaml └── dap │ ├── chestx.yaml │ ├── cifar.yaml │ ├── cropdisease.yaml │ ├── domainnet.yaml │ ├── eurosat.yaml │ ├── imagenetr.yaml │ ├── isic.yaml │ ├── pets.yaml │ └── resisc45.yaml ├── env_setup.sh ├── launch.py ├── src ├── configs │ ├── config.py │ ├── config_node.py │ └── vit_configs.py ├── data │ ├── base.py │ ├── loader.py │ ├── registry.py │ ├── tf_dataset.py │ ├── tf_load.py │ └── transforms.py ├── engine │ └── trainer.py ├── models │ ├── build_model.py │ ├── build_vit_backbone.py │ ├── mlp.py │ ├── vit_backbones │ │ └── vit.py │ ├── vit_dap │ │ └── vit.py │ └── vit_models.py ├── solver │ ├── losses.py │ ├── lr_scheduler.py │ └── optimizer.py └── utils │ ├── distributed.py │ ├── file_io.py │ ├── io_utils.py │ ├── logger.py │ └── train_utils.py └── train.py /DAP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/dap-cl/12c012bee35d25f52a7385bc7e02d3315283f57c/DAP.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-present NAVER Cloud Corp. 2 | 3 | Attribution-NonCommercial 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More_considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial 4.0 International Public 60 | License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial 4.0 International Public License ("Public 65 | License"). To the extent this Public License may be interpreted as a 66 | contract, You are granted the Licensed Rights in consideration of Your 67 | acceptance of these terms and conditions, and the Licensor grants You 68 | such rights in consideration of benefits the Licensor receives from 69 | making the Licensed Material available under these terms and 70 | conditions. 71 | 72 | Section 1 -- Definitions. 73 | 74 | a. Adapted Material means material subject to Copyright and Similar 75 | Rights that is derived from or based upon the Licensed Material 76 | and in which the Licensed Material is translated, altered, 77 | arranged, transformed, or otherwise modified in a manner requiring 78 | permission under the Copyright and Similar Rights held by the 79 | Licensor. For purposes of this Public License, where the Licensed 80 | Material is a musical work, performance, or sound recording, 81 | Adapted Material is always produced where the Licensed Material is 82 | synched in timed relation with a moving image. 83 | 84 | b. Adapter's License means the license You apply to Your Copyright 85 | and Similar Rights in Your contributions to Adapted Material in 86 | accordance with the terms and conditions of this Public License. 87 | 88 | c. Copyright and Similar Rights means copyright and/or similar rights 89 | closely related to copyright including, without limitation, 90 | performance, broadcast, sound recording, and Sui Generis Database 91 | Rights, without regard to how the rights are labeled or 92 | categorized. For purposes of this Public License, the rights 93 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 94 | Rights. 95 | d. Effective Technological Measures means those measures that, in the 96 | absence of proper authority, may not be circumvented under laws 97 | fulfilling obligations under Article 11 of the WIPO Copyright 98 | Treaty adopted on December 20, 1996, and/or similar international 99 | agreements. 100 | 101 | e. Exceptions and Limitations means fair use, fair dealing, and/or 102 | any other exception or limitation to Copyright and Similar Rights 103 | that applies to Your use of the Licensed Material. 104 | 105 | f. Licensed Material means the artistic or literary work, database, 106 | or other material to which the Licensor applied this Public 107 | License. 108 | 109 | g. Licensed Rights means the rights granted to You subject to the 110 | terms and conditions of this Public License, which are limited to 111 | all Copyright and Similar Rights that apply to Your use of the 112 | Licensed Material and that the Licensor has authority to license. 113 | 114 | h. Licensor means the individual(s) or entity(ies) granting rights 115 | under this Public License. 116 | 117 | i. NonCommercial means not primarily intended for or directed towards 118 | commercial advantage or monetary compensation. For purposes of 119 | this Public License, the exchange of the Licensed Material for 120 | other material subject to Copyright and Similar Rights by digital 121 | file-sharing or similar means is NonCommercial provided there is 122 | no payment of monetary compensation in connection with the 123 | exchange. 124 | 125 | j. Share means to provide material to the public by any means or 126 | process that requires permission under the Licensed Rights, such 127 | as reproduction, public display, public performance, distribution, 128 | dissemination, communication, or importation, and to make material 129 | available to the public including in ways that members of the 130 | public may access the material from a place and at a time 131 | individually chosen by them. 132 | 133 | k. Sui Generis Database Rights means rights other than copyright 134 | resulting from Directive 96/9/EC of the European Parliament and of 135 | the Council of 11 March 1996 on the legal protection of databases, 136 | as amended and/or succeeded, as well as other essentially 137 | equivalent rights anywhere in the world. 138 | 139 | l. You means the individual or entity exercising the Licensed Rights 140 | under this Public License. Your has a corresponding meaning. 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | Section 3 -- License Conditions. 222 | 223 | Your exercise of the Licensed Rights is expressly made subject to the 224 | following conditions. 225 | 226 | a. Attribution. 227 | 228 | 1. If You Share the Licensed Material (including in modified 229 | form), You must: 230 | 231 | a. retain the following if it is supplied by the Licensor 232 | with the Licensed Material: 233 | 234 | i. identification of the creator(s) of the Licensed 235 | Material and any others designated to receive 236 | attribution, in any reasonable manner requested by 237 | the Licensor (including by pseudonym if 238 | designated); 239 | 240 | ii. a copyright notice; 241 | 242 | iii. a notice that refers to this Public License; 243 | 244 | iv. a notice that refers to the disclaimer of 245 | warranties; 246 | 247 | v. a URI or hyperlink to the Licensed Material to the 248 | extent reasonably practicable; 249 | 250 | b. indicate if You modified the Licensed Material and 251 | retain an indication of any previous modifications; and 252 | 253 | c. indicate the Licensed Material is licensed under this 254 | Public License, and include the text of, or the URI or 255 | hyperlink to, this Public License. 256 | 257 | 2. You may satisfy the conditions in Section 3(a)(1) in any 258 | reasonable manner based on the medium, means, and context in 259 | which You Share the Licensed Material. For example, it may be 260 | reasonable to satisfy the conditions by providing a URI or 261 | hyperlink to a resource that includes the required 262 | information. 263 | 264 | 3. If requested by the Licensor, You must remove any of the 265 | information required by Section 3(a)(1)(A) to the extent 266 | reasonably practicable. 267 | 268 | 4. If You Share Adapted Material You produce, the Adapter's 269 | License You apply must not prevent recipients of the Adapted 270 | Material from complying with this Public License. 271 | 272 | Section 4 -- Sui Generis Database Rights. 273 | 274 | Where the Licensed Rights include Sui Generis Database Rights that 275 | apply to Your use of the Licensed Material: 276 | 277 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 278 | to extract, reuse, reproduce, and Share all or a substantial 279 | portion of the contents of the database for NonCommercial purposes 280 | only; 281 | 282 | b. if You include all or a substantial portion of the database 283 | contents in a database in which You have Sui Generis Database 284 | Rights, then the database in which You have Sui Generis Database 285 | Rights (but not its individual contents) is Adapted Material; and 286 | 287 | c. You must comply with the conditions in Section 3(a) if You Share 288 | all or a substantial portion of the contents of the database. 289 | 290 | For the avoidance of doubt, this Section 4 supplements and does not 291 | replace Your obligations under this Public License where the Licensed 292 | Rights include other Copyright and Similar Rights. 293 | 294 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 295 | 296 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 297 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 298 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 299 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 300 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 301 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 302 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 303 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 304 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 305 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 306 | 307 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 308 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 309 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 310 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 311 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 312 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 313 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 314 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 315 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 316 | 317 | c. The disclaimer of warranties and limitation of liability provided 318 | above shall be interpreted in a manner that, to the extent 319 | possible, most closely approximates an absolute disclaimer and 320 | waiver of all liability. 321 | 322 | Section 6 -- Term and Termination. 323 | 324 | a. This Public License applies for the term of the Copyright and 325 | Similar Rights licensed here. However, if You fail to comply with 326 | this Public License, then Your rights under this Public License 327 | terminate automatically. 328 | 329 | b. Where Your right to use the Licensed Material has terminated under 330 | Section 6(a), it reinstates: 331 | 332 | 1. automatically as of the date the violation is cured, provided 333 | it is cured within 30 days of Your discovery of the 334 | violation; or 335 | 336 | 2. upon express reinstatement by the Licensor. 337 | 338 | For the avoidance of doubt, this Section 6(b) does not affect any 339 | right the Licensor may have to seek remedies for Your violations 340 | of this Public License. 341 | 342 | c. For the avoidance of doubt, the Licensor may also offer the 343 | Licensed Material under separate terms or conditions or stop 344 | distributing the Licensed Material at any time; however, doing so 345 | will not terminate this Public License. 346 | 347 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 348 | License. 349 | 350 | Section 7 -- Other Terms and Conditions. 351 | 352 | a. The Licensor shall not be bound by any additional or different 353 | terms or conditions communicated by You unless expressly agreed. 354 | 355 | b. Any arrangements, understandings, or agreements regarding the 356 | Licensed Material not stated herein are separate from and 357 | independent of the terms and conditions of this Public License. 358 | 359 | Section 8 -- Interpretation. 360 | 361 | a. For the avoidance of doubt, this Public License does not, and 362 | shall not be interpreted to, reduce, limit, restrict, or impose 363 | conditions on any use of the Licensed Material that could lawfully 364 | be made without permission under this Public License. 365 | 366 | b. To the extent possible, if any provision of this Public License is 367 | deemed unenforceable, it shall be automatically reformed to the 368 | minimum extent necessary to make it enforceable. If the provision 369 | cannot be reformed, it shall be severed from this Public License 370 | without affecting the enforceability of the remaining terms and 371 | conditions. 372 | 373 | c. No term or condition of this Public License will be waived and no 374 | failure to comply consented to unless expressly agreed to by the 375 | Licensor. 376 | 377 | d. Nothing in this Public License constitutes or may be interpreted 378 | as a limitation upon, or waiver of, any privileges and immunities 379 | that apply to the Licensor or You, including from the legal 380 | processes of any jurisdiction or authority. 381 | 382 | ======================================================================= 383 | 384 | Creative Commons is not a party to its public 385 | licenses. Notwithstanding, Creative Commons may elect to apply one of 386 | its public licenses to material it publishes and in those instances 387 | will be considered the “Licensor.” The text of the Creative Commons 388 | public licenses is dedicated to the public domain under the CC0 Public 389 | Domain Dedication. Except for the limited purpose of indicating that 390 | material is shared under a Creative Commons public license or as 391 | otherwise permitted by the Creative Commons policies published at 392 | creativecommons.org/policies, Creative Commons does not authorize the 393 | use of the trademark "Creative Commons" or any other trademark or logo 394 | of Creative Commons without its prior written consent including, 395 | without limitation, in connection with any unauthorized modifications 396 | to any of its public licenses or any other arrangements, 397 | understandings, or agreements concerning use of licensed material. For 398 | the avoidance of doubt, this paragraph does not form part of the 399 | public licenses. 400 | 401 | Creative Commons may be contacted at creativecommons.org. 402 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | dap-cl 2 | Copyright (c) 2023-present NAVER Cloud Corp. 3 | CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | -------------------------------------------------------------------------------------- 6 | 7 | This project contains subcomponents with separate copyright notices and license terms. 8 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 9 | 10 | ===== 11 | 12 | KMnP/vpt 13 | https://github.com/KMnP/vpt 14 | 15 | 16 | Attribution-NonCommercial 4.0 International 17 | 18 | ======================================================================= 19 | 20 | Creative Commons Corporation ("Creative Commons") is not a law firm and 21 | does not provide legal services or legal advice. Distribution of 22 | Creative Commons public licenses does not create a lawyer-client or 23 | other relationship. Creative Commons makes its licenses and related 24 | information available on an "as-is" basis. Creative Commons gives no 25 | warranties regarding its licenses, any material licensed under their 26 | terms and conditions, or any related information. Creative Commons 27 | disclaims all liability for damages resulting from their use to the 28 | fullest extent possible. 29 | 30 | Using Creative Commons Public Licenses 31 | 32 | Creative Commons public licenses provide a standard set of terms and 33 | conditions that creators and other rights holders may use to share 34 | original works of authorship and other material subject to copyright 35 | and certain other rights specified in the public license below. The 36 | following considerations are for informational purposes only, are not 37 | exhaustive, and do not form part of our licenses. 38 | 39 | Considerations for licensors: Our public licenses are 40 | intended for use by those authorized to give the public 41 | permission to use material in ways otherwise restricted by 42 | copyright and certain other rights. Our licenses are 43 | irrevocable. Licensors should read and understand the terms 44 | and conditions of the license they choose before applying it. 45 | Licensors should also secure all rights necessary before 46 | applying our licenses so that the public can reuse the 47 | material as expected. Licensors should clearly mark any 48 | material not subject to the license. This includes other CC- 49 | licensed material, or material used under an exception or 50 | limitation to copyright. More considerations for licensors: 51 | wiki.creativecommons.org/Considerations_for_licensors 52 | 53 | Considerations for the public: By using one of our public 54 | licenses, a licensor grants the public permission to use the 55 | licensed material under specified terms and conditions. If 56 | the licensor's permission is not necessary for any reason--for 57 | example, because of any applicable exception or limitation to 58 | copyright--then that use is not regulated by the license. Our 59 | licenses grant only permissions under copyright and certain 60 | other rights that a licensor has authority to grant. Use of 61 | the licensed material may still be restricted for other 62 | reasons, including because others have copyright or other 63 | rights in the material. A licensor may make special requests, 64 | such as asking that all changes be marked or described. 65 | Although not required by our licenses, you are encouraged to 66 | respect those requests where reasonable. More_considerations 67 | for the public: 68 | wiki.creativecommons.org/Considerations_for_licensees 69 | 70 | ======================================================================= 71 | 72 | Creative Commons Attribution-NonCommercial 4.0 International Public 73 | License 74 | 75 | By exercising the Licensed Rights (defined below), You accept and agree 76 | to be bound by the terms and conditions of this Creative Commons 77 | Attribution-NonCommercial 4.0 International Public License ("Public 78 | License"). To the extent this Public License may be interpreted as a 79 | contract, You are granted the Licensed Rights in consideration of Your 80 | acceptance of these terms and conditions, and the Licensor grants You 81 | such rights in consideration of benefits the Licensor receives from 82 | making the Licensed Material available under these terms and 83 | conditions. 84 | 85 | Section 1 -- Definitions. 86 | 87 | a. Adapted Material means material subject to Copyright and Similar 88 | Rights that is derived from or based upon the Licensed Material 89 | and in which the Licensed Material is translated, altered, 90 | arranged, transformed, or otherwise modified in a manner requiring 91 | permission under the Copyright and Similar Rights held by the 92 | Licensor. For purposes of this Public License, where the Licensed 93 | Material is a musical work, performance, or sound recording, 94 | Adapted Material is always produced where the Licensed Material is 95 | synched in timed relation with a moving image. 96 | 97 | b. Adapter's License means the license You apply to Your Copyright 98 | and Similar Rights in Your contributions to Adapted Material in 99 | accordance with the terms and conditions of this Public License. 100 | 101 | c. Copyright and Similar Rights means copyright and/or similar rights 102 | closely related to copyright including, without limitation, 103 | performance, broadcast, sound recording, and Sui Generis Database 104 | Rights, without regard to how the rights are labeled or 105 | categorized. For purposes of this Public License, the rights 106 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 107 | Rights. 108 | d. Effective Technological Measures means those measures that, in the 109 | absence of proper authority, may not be circumvented under laws 110 | fulfilling obligations under Article 11 of the WIPO Copyright 111 | Treaty adopted on December 20, 1996, and/or similar international 112 | agreements. 113 | 114 | e. Exceptions and Limitations means fair use, fair dealing, and/or 115 | any other exception or limitation to Copyright and Similar Rights 116 | that applies to Your use of the Licensed Material. 117 | 118 | f. Licensed Material means the artistic or literary work, database, 119 | or other material to which the Licensor applied this Public 120 | License. 121 | 122 | g. Licensed Rights means the rights granted to You subject to the 123 | terms and conditions of this Public License, which are limited to 124 | all Copyright and Similar Rights that apply to Your use of the 125 | Licensed Material and that the Licensor has authority to license. 126 | 127 | h. Licensor means the individual(s) or entity(ies) granting rights 128 | under this Public License. 129 | 130 | i. NonCommercial means not primarily intended for or directed towards 131 | commercial advantage or monetary compensation. For purposes of 132 | this Public License, the exchange of the Licensed Material for 133 | other material subject to Copyright and Similar Rights by digital 134 | file-sharing or similar means is NonCommercial provided there is 135 | no payment of monetary compensation in connection with the 136 | exchange. 137 | 138 | j. Share means to provide material to the public by any means or 139 | process that requires permission under the Licensed Rights, such 140 | as reproduction, public display, public performance, distribution, 141 | dissemination, communication, or importation, and to make material 142 | available to the public including in ways that members of the 143 | public may access the material from a place and at a time 144 | individually chosen by them. 145 | 146 | k. Sui Generis Database Rights means rights other than copyright 147 | resulting from Directive 96/9/EC of the European Parliament and of 148 | the Council of 11 March 1996 on the legal protection of databases, 149 | as amended and/or succeeded, as well as other essentially 150 | equivalent rights anywhere in the world. 151 | 152 | l. You means the individual or entity exercising the Licensed Rights 153 | under this Public License. Your has a corresponding meaning. 154 | 155 | Section 2 -- Scope. 156 | 157 | a. License grant. 158 | 159 | 1. Subject to the terms and conditions of this Public License, 160 | the Licensor hereby grants You a worldwide, royalty-free, 161 | non-sublicensable, non-exclusive, irrevocable license to 162 | exercise the Licensed Rights in the Licensed Material to: 163 | 164 | a. reproduce and Share the Licensed Material, in whole or 165 | in part, for NonCommercial purposes only; and 166 | 167 | b. produce, reproduce, and Share Adapted Material for 168 | NonCommercial purposes only. 169 | 170 | 2. Exceptions and Limitations. For the avoidance of doubt, where 171 | Exceptions and Limitations apply to Your use, this Public 172 | License does not apply, and You do not need to comply with 173 | its terms and conditions. 174 | 175 | 3. Term. The term of this Public License is specified in Section 176 | 6(a). 177 | 178 | 4. Media and formats; technical modifications allowed. The 179 | Licensor authorizes You to exercise the Licensed Rights in 180 | all media and formats whether now known or hereafter created, 181 | and to make technical modifications necessary to do so. The 182 | Licensor waives and/or agrees not to assert any right or 183 | authority to forbid You from making technical modifications 184 | necessary to exercise the Licensed Rights, including 185 | technical modifications necessary to circumvent Effective 186 | Technological Measures. For purposes of this Public License, 187 | simply making modifications authorized by this Section 2(a) 188 | (4) never produces Adapted Material. 189 | 190 | 5. Downstream recipients. 191 | 192 | a. Offer from the Licensor -- Licensed Material. Every 193 | recipient of the Licensed Material automatically 194 | receives an offer from the Licensor to exercise the 195 | Licensed Rights under the terms and conditions of this 196 | Public License. 197 | 198 | b. No downstream restrictions. You may not offer or impose 199 | any additional or different terms or conditions on, or 200 | apply any Effective Technological Measures to, the 201 | Licensed Material if doing so restricts exercise of the 202 | Licensed Rights by any recipient of the Licensed 203 | Material. 204 | 205 | 6. No endorsement. Nothing in this Public License constitutes or 206 | may be construed as permission to assert or imply that You 207 | are, or that Your use of the Licensed Material is, connected 208 | with, or sponsored, endorsed, or granted official status by, 209 | the Licensor or others designated to receive attribution as 210 | provided in Section 3(a)(1)(A)(i). 211 | 212 | b. Other rights. 213 | 214 | 1. Moral rights, such as the right of integrity, are not 215 | licensed under this Public License, nor are publicity, 216 | privacy, and/or other similar personality rights; however, to 217 | the extent possible, the Licensor waives and/or agrees not to 218 | assert any such rights held by the Licensor to the limited 219 | extent necessary to allow You to exercise the Licensed 220 | Rights, but not otherwise. 221 | 222 | 2. Patent and trademark rights are not licensed under this 223 | Public License. 224 | 225 | 3. To the extent possible, the Licensor waives any right to 226 | collect royalties from You for the exercise of the Licensed 227 | Rights, whether directly or through a collecting society 228 | under any voluntary or waivable statutory or compulsory 229 | licensing scheme. In all other cases the Licensor expressly 230 | reserves any right to collect such royalties, including when 231 | the Licensed Material is used other than for NonCommercial 232 | purposes. 233 | 234 | Section 3 -- License Conditions. 235 | 236 | Your exercise of the Licensed Rights is expressly made subject to the 237 | following conditions. 238 | 239 | a. Attribution. 240 | 241 | 1. If You Share the Licensed Material (including in modified 242 | form), You must: 243 | 244 | a. retain the following if it is supplied by the Licensor 245 | with the Licensed Material: 246 | 247 | i. identification of the creator(s) of the Licensed 248 | Material and any others designated to receive 249 | attribution, in any reasonable manner requested by 250 | the Licensor (including by pseudonym if 251 | designated); 252 | 253 | ii. a copyright notice; 254 | 255 | iii. a notice that refers to this Public License; 256 | 257 | iv. a notice that refers to the disclaimer of 258 | warranties; 259 | 260 | v. a URI or hyperlink to the Licensed Material to the 261 | extent reasonably practicable; 262 | 263 | b. indicate if You modified the Licensed Material and 264 | retain an indication of any previous modifications; and 265 | 266 | c. indicate the Licensed Material is licensed under this 267 | Public License, and include the text of, or the URI or 268 | hyperlink to, this Public License. 269 | 270 | 2. You may satisfy the conditions in Section 3(a)(1) in any 271 | reasonable manner based on the medium, means, and context in 272 | which You Share the Licensed Material. For example, it may be 273 | reasonable to satisfy the conditions by providing a URI or 274 | hyperlink to a resource that includes the required 275 | information. 276 | 277 | 3. If requested by the Licensor, You must remove any of the 278 | information required by Section 3(a)(1)(A) to the extent 279 | reasonably practicable. 280 | 281 | 4. If You Share Adapted Material You produce, the Adapter's 282 | License You apply must not prevent recipients of the Adapted 283 | Material from complying with this Public License. 284 | 285 | Section 4 -- Sui Generis Database Rights. 286 | 287 | Where the Licensed Rights include Sui Generis Database Rights that 288 | apply to Your use of the Licensed Material: 289 | 290 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 291 | to extract, reuse, reproduce, and Share all or a substantial 292 | portion of the contents of the database for NonCommercial purposes 293 | only; 294 | 295 | b. if You include all or a substantial portion of the database 296 | contents in a database in which You have Sui Generis Database 297 | Rights, then the database in which You have Sui Generis Database 298 | Rights (but not its individual contents) is Adapted Material; and 299 | 300 | c. You must comply with the conditions in Section 3(a) if You Share 301 | all or a substantial portion of the contents of the database. 302 | 303 | For the avoidance of doubt, this Section 4 supplements and does not 304 | replace Your obligations under this Public License where the Licensed 305 | Rights include other Copyright and Similar Rights. 306 | 307 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 308 | 309 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 310 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 311 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 312 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 313 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 314 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 315 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 316 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 317 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 318 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 319 | 320 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 321 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 322 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 323 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 324 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 325 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 326 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 327 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 328 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 329 | 330 | c. The disclaimer of warranties and limitation of liability provided 331 | above shall be interpreted in a manner that, to the extent 332 | possible, most closely approximates an absolute disclaimer and 333 | waiver of all liability. 334 | 335 | Section 6 -- Term and Termination. 336 | 337 | a. This Public License applies for the term of the Copyright and 338 | Similar Rights licensed here. However, if You fail to comply with 339 | this Public License, then Your rights under this Public License 340 | terminate automatically. 341 | 342 | b. Where Your right to use the Licensed Material has terminated under 343 | Section 6(a), it reinstates: 344 | 345 | 1. automatically as of the date the violation is cured, provided 346 | it is cured within 30 days of Your discovery of the 347 | violation; or 348 | 349 | 2. upon express reinstatement by the Licensor. 350 | 351 | For the avoidance of doubt, this Section 6(b) does not affect any 352 | right the Licensor may have to seek remedies for Your violations 353 | of this Public License. 354 | 355 | c. For the avoidance of doubt, the Licensor may also offer the 356 | Licensed Material under separate terms or conditions or stop 357 | distributing the Licensed Material at any time; however, doing so 358 | will not terminate this Public License. 359 | 360 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 361 | License. 362 | 363 | Section 7 -- Other Terms and Conditions. 364 | 365 | a. The Licensor shall not be bound by any additional or different 366 | terms or conditions communicated by You unless expressly agreed. 367 | 368 | b. Any arrangements, understandings, or agreements regarding the 369 | Licensed Material not stated herein are separate from and 370 | independent of the terms and conditions of this Public License. 371 | 372 | Section 8 -- Interpretation. 373 | 374 | a. For the avoidance of doubt, this Public License does not, and 375 | shall not be interpreted to, reduce, limit, restrict, or impose 376 | conditions on any use of the Licensed Material that could lawfully 377 | be made without permission under this Public License. 378 | 379 | b. To the extent possible, if any provision of this Public License is 380 | deemed unenforceable, it shall be automatically reformed to the 381 | minimum extent necessary to make it enforceable. If the provision 382 | cannot be reformed, it shall be severed from this Public License 383 | without affecting the enforceability of the remaining terms and 384 | conditions. 385 | 386 | c. No term or condition of this Public License will be waived and no 387 | failure to comply consented to unless expressly agreed to by the 388 | Licensor. 389 | 390 | d. Nothing in this Public License constitutes or may be interpreted 391 | as a limitation upon, or waiver of, any privileges and immunities 392 | that apply to the Licensor or You, including from the legal 393 | processes of any jurisdiction or authority. 394 | 395 | ======================================================================= 396 | 397 | Creative Commons is not a party to its public 398 | licenses. Notwithstanding, Creative Commons may elect to apply one of 399 | its public licenses to material it publishes and in those instances 400 | will be considered the “Licensor.” The text of the Creative Commons 401 | public licenses is dedicated to the public domain under the CC0 Public 402 | Domain Dedication. Except for the limited purpose of indicating that 403 | material is shared under a Creative Commons public license or as 404 | otherwise permitted by the Creative Commons policies published at 405 | creativecommons.org/policies, Creative Commons does not authorize the 406 | use of the trademark "Creative Commons" or any other trademark or logo 407 | of Creative Commons without its prior written consent including, 408 | without limitation, in connection with any unauthorized modifications 409 | to any of its public licenses or any other arrangements, 410 | understandings, or agreements concerning use of licensed material. For 411 | the avoidance of doubt, this paragraph does not form part of the 412 | public licenses. 413 | 414 | Creative Commons may be contacted at creativecommons.org. 415 | 416 | ===== 417 | 418 | huggingface/transformers 419 | https://github.com/huggingface/transformers 420 | 421 | 422 | Copyright [yyyy] [name of copyright owner] 423 | 424 | Licensed under the Apache License, Version 2.0 (the "License"); 425 | you may not use this file except in compliance with the License. 426 | You may obtain a copy of the License at 427 | 428 | http://www.apache.org/licenses/LICENSE-2.0 429 | 430 | Unless required by applicable law or agreed to in writing, software 431 | distributed under the License is distributed on an "AS IS" BASIS, 432 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 433 | See the License for the specific language governing permissions and limitations under the License. 434 | 435 | ===== 436 | 437 | jeonsworld/ViT-pytorch 438 | https://github.com/jeonsworld/ViT-pytorch 439 | 440 | 441 | MIT License 442 | 443 | Copyright (c) 2020 jeonsworld 444 | 445 | Permission is hereby granted, free of charge, to any person obtaining a copy 446 | of this software and associated documentation files (the "Software"), to deal 447 | in the Software without restriction, including without limitation the rights 448 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 449 | copies of the Software, and to permit persons to whom the Software is 450 | furnished to do so, subject to the following conditions: 451 | 452 | The above copyright notice and this permission notice shall be included in all 453 | copies or substantial portions of the Software. 454 | 455 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 456 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 457 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 458 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 459 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 460 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 461 | SOFTWARE. 462 | 463 | ===== 464 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating Instance-level Prompts for Rehearsal-free Continual Learning 2 | This is the official PyTorch implementation for "[Generating Instance-level Prompts for Rehearsal-free Continual Learning](https://openaccess.thecvf.com/content/ICCV2023/papers/Jung_Generating_Instance-level_Prompts_for_Rehearsal-free_Continual_Learning_ICCV_2023_paper.pdf)" [ICCV 2023 ORAL]. 3 | 4 | # Abstract 5 | We introduce Domain-Adaptive Prompt (DAP), a novel method for continual learning using Vision Transformers (ViT). Prompt-based continual learning has recently gained attention due to its rehearsal-free nature. Currently, the prompt pool, which is suggested by prompt-based continual learning, is key to effectively exploiting the frozen pre-trained ViT backbone in a sequence of tasks. However, we observe that the use of a prompt pool creates a domain scalability problem between pre-training and continual learning. This problem arises due to the inherent encoding of group-level instructions within the prompt pool. To address this problem, we propose DAP, a pool-free approach that generates a suitable prompt in an instance-level manner at inference time. We optimize an adaptive prompt generator that creates instance-specific fine-grained instructions required for each input, enabling enhanced model plasticity and reduced forgetting. Our experiments on seven datasets with varying degrees of domain similarity to ImageNet demonstrate the superiority of DAP over state-of-the-art prompt-based methods. 6 | 7 | ![DAP](DAP.png) 8 | 9 | # Requirements 10 | 11 | Our experiments are done with: 12 | 13 | - python 3.8 14 | - pytorch 1.11.0 15 | - tensorflow 2.5.0 16 | - numpy 1.23.4 17 | - fvcore 0.1.6 18 | - tensorflow_datasets 4.9.2 19 | - scipy 1.10.1 20 | - ml-collections 0.1.1 21 | 22 | # Environment setup 23 | ``` 24 | conda create -n [ENV_NAME] python=3.8 25 | conda activate [ENV_NAME] 26 | bash env_setup.sh 27 | ``` 28 | 29 | # Data preparation 30 | - The datasets should be located in the 'dataset' folder (CIFAR-100, Oxford-IIIT Pets, EuroSAT, RESISC45, CropDiseases, ISIC, ChestX, ImageNet-R, and DomainNet) 31 | - For [Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/), [CropDiseases](https://www.frontiersin.org/articles/10.3389/fpls.2016.01419/full), [ISIC](https://challenge.isic-archive.com/landing/2018/47/), and [ChestX](https://openaccess.thecvf.com/content_cvpr_2017/papers/Wang_ChestX-ray8_Hospital-Scale_Chest_CVPR_2017_paper.pdf), we transform each dataset into TFDS compatible form following the tutorial in [link](https://www.tensorflow.org/datasets/add_dataset) to cover the CL scenario (see sec. 5 and supp. A for details). 32 | - For the rest of the datasets, you can directly download them from tensorflow_datasets. 33 | - TFDSs of experimented benchmarks can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1bBqS8MuTQXUBV3DXJ_-YZyNOR4ejvm1O?usp=sharing) 34 | 35 | # Pretraiend ViT model 36 | - The pretrained ViT-B/16 checkpoint should be located in the 'model' folder. 37 | - ViT-B/16 model used in the paper can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1bBqS8MuTQXUBV3DXJ_-YZyNOR4ejvm1O?usp=sharing) 38 | 39 | # Run experiments 40 | ``` 41 | CUDA_VISIBLE_DEVICES=0 python train.py [--config-file] 42 | ``` 43 | For example for SplitCIFAR-100, execute 44 | ``` 45 | CUDA_VISIBLE_DEVICES=0 python train.py --config-file configs/dap/cifar.yaml 46 | ``` 47 | 48 | # Acknowledgement 49 | 50 | This repository is based on the official PyTorch implementation of "[Visual Prompt Tuning](https://github.com/KMnP/vpt)" [ECCV 2022]. 51 | 52 | # License 53 | 54 | Licensed under [CC BY-NC 4.0](LICENSE) 55 | 56 | ``` 57 | dap-cl 58 | Copyright (c) 2023-present NAVER Cloud Corp. 59 | CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) 60 | ``` 61 | 62 | # How to cite 63 | ``` 64 | @inproceedings{jung2023generating, 65 | title={Generating Instance-level Prompts for Rehearsal-free Continual Learning}, 66 | author={Jung, Dahuin and Han, Dongyoon and Bang, Jihwan and Song, Hwanjun}, 67 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 68 | pages={11847--11857}, 69 | year={2023} 70 | } 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /configs/base-dap.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TYPE: "vit" 7 | TRANSFER_TYPE: "dap" 8 | MODEL_ROOT: "./model" 9 | DAP: 10 | CURRENT_LAMBDA: 1.0 11 | SIM_LAMBDA: 0.1 12 | NUM_DAP_TOKENS: 10 13 | TASK_EMB: 16 14 | SOLVER: 15 | BASE_LR: 0.01 16 | OPTIMIZER: 'adam' 17 | MOMENTUM: 0.9 18 | BETA1: 0.9 19 | BETA2: 0.9 20 | WEIGHT_DECAY: 0.0 21 | WARMUP_EPOCH: 0 22 | TOTAL_EPOCH: 30 23 | SCHEDULER: "linear" 24 | GRAD_CLIP: 1. 25 | GRAD_CLIP_APPLY: True 26 | PATIENCE: 300 27 | LOSS: "softmax" 28 | LOG_EVERY_N: 100 29 | DATA: 30 | NAME: "" 31 | NUMBER_CLASSES: -1 32 | BATCH_SIZE: 128 33 | DATAPATH: "./dataset" 34 | FEATURE: "sup_vitb16_imagenet21k" 35 | -------------------------------------------------------------------------------- /configs/dap/chestx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 3 4 | INCREMENT: 3 5 | N_TASKS: 2 6 | DATA: 7 | NAME: "chestx" 8 | NUMBER_CLASSES: 6 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 2 12 | PROMPT_POOL: 2 13 | SEED: 42 14 | OUTPUT_DIR: "./results/seed42" 15 | -------------------------------------------------------------------------------- /configs/dap/cifar.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 10 4 | INCREMENT: 10 5 | N_TASKS: 10 6 | DATA: 7 | NAME: "cifar100" 8 | NUMBER_CLASSES: 100 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 10 12 | PROMPT_POOL: 10 13 | SOLVER: 14 | TOTAL_EPOCH: 5 15 | SEED: 42 16 | OUTPUT_DIR: "./results/seed42" 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/dap/cropdisease.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 5 4 | INCREMENT: 5 5 | N_TASKS: 7 6 | DATA: 7 | NAME: "cropdisease" 8 | NUMBER_CLASSES: 35 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 7 12 | PROMPT_POOL: 7 13 | SOLVER: 14 | TOTAL_EPOCH: 5 15 | SEED: 42 16 | OUTPUT_DIR: "./results/seed42" 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/dap/domainnet.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 23 4 | INCREMENT: 23 5 | N_TASKS: 15 6 | DATA: 7 | NAME: "domainnet/real" 8 | NUMBER_CLASSES: 345 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 15 12 | PROMPT_POOL: 15 13 | SOLVER: 14 | TOTAL_EPOCH: 5 15 | SEED: 42 16 | OUTPUT_DIR: "./results/seed42" 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /configs/dap/eurosat.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 2 4 | INCREMENT: 2 5 | N_TASKS: 5 6 | DATA: 7 | NAME: "eurosat/rgb" 8 | NUMBER_CLASSES: 10 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 5 12 | PROMPT_POOL: 5 13 | SOLVER: 14 | TOTAL_EPOCH: 5 15 | SEED: 42 16 | OUTPUT_DIR: "./results/seed42" 17 | -------------------------------------------------------------------------------- /configs/dap/imagenetr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 20 4 | INCREMENT: 20 5 | N_TASKS: 10 6 | DATA: 7 | NAME: "imagenet_r" 8 | NUMBER_CLASSES: 200 9 | MODEL: 10 | DAP: 11 | SIM_LAMBDA: 2.0 12 | NUM_DAP_TOKENS: 30 13 | TASK_EMB: 64 14 | NUM_TASKS_FOR_EMB: 10 15 | PROMPT_POOL: 10 16 | SOLVER: 17 | BASE_LR: 0.25 18 | OPTIMIZER: 'sgd' 19 | TOTAL_EPOCH: 20 20 | SEED: 42 21 | OUTPUT_DIR: "./results/seed42" 22 | 23 | -------------------------------------------------------------------------------- /configs/dap/isic.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 2 4 | INCREMENT: 2 5 | N_TASKS: 3 6 | DATA: 7 | NAME: "isic" 8 | NUMBER_CLASSES: 6 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 3 12 | PROMPT_POOL: 3 13 | SEED: 42 14 | OUTPUT_DIR: "./results/seed42" 15 | -------------------------------------------------------------------------------- /configs/dap/pets.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 5 4 | INCREMENT: 5 5 | N_TASKS: 7 6 | DATA: 7 | NAME: "oxford_iiit_pet" 8 | NUMBER_CLASSES: 35 9 | MODEL: 10 | DAP: 11 | NUM_TASKS_FOR_EMB: 7 12 | PROMPT_POOL: 7 13 | SEED: 42 14 | OUTPUT_DIR: "./results/seed42" 15 | 16 | -------------------------------------------------------------------------------- /configs/dap/resisc45.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-dap.yaml" 2 | CONTINUAL: 3 | INITIAL: 5 4 | INCREMENT: 5 5 | N_TASKS: 9 6 | DATA: 7 | NAME: "resisc45" 8 | NUMBER_CLASSES: 45 9 | MODEL: 10 | DAP: 11 | SIM_LAMBDA: 0.5 12 | NUM_TASKS_FOR_EMB: 9 13 | PROMPT_POOL: 9 14 | SEED: 42 15 | OUTPUT_DIR: "./results/seed42" 16 | 17 | -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install tensorflow==2.5.0 3 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 4 | conda install -c conda-forge cudnn 5 | 6 | pip install git+https://github.com/facebookresearch/fvcore.git 7 | pip install numpy==1.23.4 8 | pip install tensorflow_datasets 9 | pip install scipy 10 | pip install ml-collections 11 | pip install -U --force-reinstall charset-normalizer 12 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | launch helper functions 4 | """ 5 | 6 | import argparse 7 | import os 8 | import sys 9 | import PIL 10 | from collections import defaultdict 11 | from tabulate import tabulate 12 | from typing import Tuple 13 | 14 | import torch 15 | 16 | def collect_torch_env() -> str: 17 | try: 18 | import torch.__config__ 19 | 20 | return torch.__config__.show() 21 | except ImportError: 22 | # compatible with older versions of pytorch 23 | from torch.utils.collect_env import get_pretty_env_info 24 | 25 | return get_pretty_env_info() 26 | 27 | def get_env_module() -> Tuple[str]: 28 | var_name = "ENV_MODULE" 29 | return var_name, os.environ.get(var_name, "") 30 | 31 | 32 | def collect_env_info() -> str: 33 | data = [] 34 | data.append(("Python", sys.version.replace("\n", ""))) 35 | data.append(get_env_module()) 36 | data.append(("PyTorch", torch.__version__)) 37 | data.append(("PyTorch Debug Build", torch.version.debug)) 38 | 39 | has_cuda = torch.cuda.is_available() 40 | data.append(("CUDA available", has_cuda)) 41 | if has_cuda: 42 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 43 | devices = defaultdict(list) 44 | for k in range(torch.cuda.device_count()): 45 | devices[torch.cuda.get_device_name(k)].append(str(k)) 46 | for name, devids in devices.items(): 47 | data.append(("GPU " + ",".join(devids), name)) 48 | data.append(("Pillow", PIL.__version__)) 49 | 50 | try: 51 | import cv2 52 | 53 | data.append(("cv2", cv2.__version__)) 54 | except ImportError: 55 | pass 56 | env_str = tabulate(data) + "\n" 57 | env_str += collect_torch_env() 58 | return env_str 59 | 60 | 61 | def default_argument_parser(): 62 | """ 63 | create a simple parser to wrap around config file 64 | """ 65 | parser = argparse.ArgumentParser(description="dap") 66 | parser.add_argument( 67 | "--config-file", metavar="FILE", help="path to config file") 68 | parser.add_argument( 69 | "opts", 70 | help="Modify config options using the command-line", 71 | default=None, 72 | nargs=argparse.REMAINDER, 73 | ) 74 | 75 | return parser 76 | 77 | -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | config system (based on Detectron's) 4 | """ 5 | 6 | from .config_node import CfgNode 7 | 8 | 9 | # Global config object 10 | _C = CfgNode() 11 | # Example usage: 12 | # from configs.config import cfg 13 | 14 | _C.DBG = False 15 | _C.OUTPUT_DIR = "./output" 16 | _C.RUN_N_TIMES = 5 17 | # Perform benchmarking to select the fastest CUDNN algorithms to use 18 | # Note that this may increase the memory usage and will likely not result 19 | # in overall speedups when variable size inputs are used (e.g. COCO training) 20 | _C.CUDNN_BENCHMARK = False 21 | 22 | # Number of GPUs to use (applies to both training and testing) 23 | _C.NUM_GPUS = 1 24 | _C.NUM_SHARDS = 1 25 | 26 | # Note that non-determinism may still be present due to non-deterministic 27 | # operator implementations in GPU operator libraries 28 | _C.SEED = None 29 | 30 | # --------------------------------------------------------------------- 31 | # Continual options 32 | # --------------------------------------------------------------------- 33 | _C.CONTINUAL = CfgNode() 34 | _C.CONTINUAL.INITIAL = 10 35 | _C.CONTINUAL.INCREMENT = 10 36 | _C.CONTINUAL.N_TASKS = 10 37 | _C.CONTINUAL.COLOR_JITTER = None 38 | _C.CONTINUAL.AA = None 39 | _C.CONTINUAL.REPROB = 0. 40 | _C.CONTINUAL.REMODE = "pixel" 41 | _C.CONTINUAL.RECOUNT = 1 42 | _C.CONTINUAL.RESPLIT = False 43 | # ---------------------------------------------------------------------- 44 | # Model options 45 | # ---------------------------------------------------------------------- 46 | _C.MODEL = CfgNode() 47 | _C.MODEL.TRANSFER_TYPE = "linear" 48 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file 49 | _C.MODEL.SAVE_CKPT = False 50 | _C.MODEL.NUM_HEAD = 12 51 | 52 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights 53 | 54 | _C.MODEL.TYPE = "vit" 55 | _C.MODEL.MLP_NUM = 0 56 | 57 | _C.MODEL.LINEAR = CfgNode() 58 | _C.MODEL.LINEAR.MLP_SIZES = [] 59 | _C.MODEL.LINEAR.DROPOUT = 0.1 60 | 61 | # ---------------------------------------------------------------------- 62 | # dap options 63 | # ---------------------------------------------------------------------- 64 | _C.MODEL.DAP = CfgNode() 65 | _C.MODEL.DAP.DROPOUT = 0.3 66 | _C.MODEL.DAP.INIT = None 67 | 68 | _C.MODEL.DAP.CURRENT_LAMBDA = 1.0 69 | _C.MODEL.DAP.SIM_LAMBDA = 0.5 70 | _C.MODEL.DAP.NUM_DAP_TOKENS = 3 71 | _C.MODEL.DAP.TASK_EMB = 64 72 | _C.MODEL.DAP.NUM_TASKS_FOR_EMB = 10 73 | _C.MODEL.DAP.PROMPT_POOL = 10 74 | 75 | # ---------------------------------------------------------------------- 76 | # Solver options 77 | # ---------------------------------------------------------------------- 78 | _C.SOLVER = CfgNode() 79 | _C.SOLVER.LOSS = "softmax" 80 | _C.SOLVER.LOSS_ALPHA = 0.01 81 | 82 | _C.SOLVER.OPTIMIZER = "sgd" 83 | _C.SOLVER.MOMENTUM = 0.9 84 | _C.SOLVER.BETA1 = 0.9 85 | _C.SOLVER.BETA2 = 0.9 86 | _C.SOLVER.WEIGHT_DECAY = 0.0001 87 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 88 | _C.SOLVER.GRAD_CLIP = 1. 89 | _C.SOLVER.GRAD_CLIP_APPLY = False 90 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 91 | 92 | _C.SOLVER.PATIENCE = 300 93 | 94 | 95 | _C.SOLVER.SCHEDULER = "cosine" 96 | 97 | _C.SOLVER.BASE_LR = 0.01 98 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias 99 | 100 | _C.SOLVER.WARMUP_EPOCH = 5 101 | _C.SOLVER.TOTAL_EPOCH = 30 102 | _C.SOLVER.LOG_EVERY_N = 1000 103 | 104 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params 105 | 106 | # ---------------------------------------------------------------------- 107 | # Dataset options 108 | # ---------------------------------------------------------------------- 109 | _C.DATA = CfgNode() 110 | 111 | _C.DATA.NAME = "" 112 | _C.DATA.DATAPATH = "" 113 | _C.DATA.FEATURE = "" 114 | 115 | _C.DATA.PERCENTAGE = 1.0 116 | _C.DATA.NUMBER_CLASSES = -1 117 | _C.DATA.MULTILABEL = False 118 | _C.DATA.CLASS_WEIGHTS_TYPE = "none" 119 | 120 | _C.DATA.CROPSIZE = 224 # or 384 121 | 122 | _C.DATA.NO_TEST = False 123 | _C.DATA.BATCH_SIZE = 32 124 | # Number of data loader workers per training process 125 | _C.DATA.NUM_WORKERS = 4 126 | # Load data to pinned host memory 127 | _C.DATA.PIN_MEMORY = True 128 | 129 | _C.DIST_BACKEND = "nccl" 130 | _C.DIST_INIT_PATH = "env://" 131 | _C.DIST_INIT_FILE = "" 132 | 133 | 134 | def get_cfg(): 135 | """ 136 | Get a copy of the default config. 137 | """ 138 | return _C.clone() 139 | -------------------------------------------------------------------------------- /src/configs/config_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | config node based on fvcore 4 | """ 5 | 6 | from fvcore.common.config import CfgNode as _CfgNode 7 | from ..utils.file_io import PathManager 8 | 9 | 10 | class CfgNode(_CfgNode): 11 | """ 12 | The same as `fvcore.common.config.CfgNode`, but different in: 13 | support manifold path 14 | """ 15 | @classmethod 16 | def _open_cfg(cls, filename): 17 | return PathManager.open(filename, "r") 18 | 19 | def dump(self, *args, **kwargs): 20 | """ 21 | Returns: 22 | str: a yaml string representation of the config 23 | """ 24 | # to make it show up in docs 25 | return super().dump(*args, **kwargs) 26 | -------------------------------------------------------------------------------- /src/configs/vit_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py 5 | """ 6 | 7 | import ml_collections 8 | 9 | def get_testing(): 10 | """Returns a minimal configuration for testing.""" 11 | config = ml_collections.ConfigDict() 12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 13 | config.hidden_size = 1 14 | config.transformer = ml_collections.ConfigDict() 15 | config.transformer.mlp_dim = 1 16 | config.transformer.num_heads = 1 17 | config.transformer.num_layers = 1 18 | config.transformer.attention_dropout_rate = 0.0 19 | config.transformer.dropout_rate = 0.1 20 | config.classifier = 'token' 21 | config.representation_size = None 22 | return config 23 | 24 | def get_b16_config(): 25 | """Returns the ViT-B/16 configuration.""" 26 | config = ml_collections.ConfigDict() 27 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 28 | config.hidden_size = 768 29 | config.transformer = ml_collections.ConfigDict() 30 | config.transformer.mlp_dim = 3072 31 | config.transformer.num_heads = 12 32 | config.transformer.num_layers = 12 33 | config.transformer.attention_dropout_rate = 0.0 34 | config.transformer.dropout_rate = 0.1 35 | config.classifier = 'token' 36 | config.representation_size = None 37 | return config 38 | 39 | -------------------------------------------------------------------------------- /src/data/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | abstract class for reading the data using tfds 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import abc 25 | import six 26 | import tensorflow.compat.v1 as tf 27 | import tensorflow_datasets as tfds 28 | 29 | 30 | def make_get_tensors_fn(output_tensors): 31 | """Create a function that outputs a collection of tensors from the dataset.""" 32 | 33 | def _get_fn(data): 34 | """Get tensors by name.""" 35 | return {tensor_name: data[tensor_name] for tensor_name in output_tensors} 36 | 37 | return _get_fn 38 | 39 | 40 | def make_get_and_cast_tensors_fn(output_tensors): 41 | """Create a function that gets and casts a set of tensors from the dataset. 42 | 43 | Optionally, you can also rename the tensors. 44 | 45 | Examples: 46 | # This simply gets "image" and "label" tensors without any casting. 47 | # Note that this is equivalent to make_get_tensors_fn(["image", "label"]). 48 | make_get_and_cast_tensors_fn({ 49 | "image": None, 50 | "label": None, 51 | }) 52 | 53 | # This gets the "image" tensor without any type conversion, casts the 54 | # "heatmap" tensor to tf.float32, and renames the tensor "class/label" to 55 | # "label" and casts it to tf.int64. 56 | make_get_and_cast_tensors_fn({ 57 | "image": None, 58 | "heatmap": tf.float32, 59 | "class/label": ("label", tf.int64), 60 | }) 61 | 62 | Args: 63 | output_tensors: dictionary specifying the set of tensors to get and cast 64 | from the dataset. 65 | 66 | Returns: 67 | The function performing the operation. 68 | """ 69 | 70 | def _tensors_to_cast(): 71 | tensors_to_cast = [] # AutoGraph does not support generators. 72 | for tensor_name, tensor_dtype in output_tensors.items(): 73 | if isinstance(tensor_dtype, tuple) and len(tensor_dtype) == 2: 74 | tensors_to_cast.append((tensor_name, tensor_dtype[0], tensor_dtype[1])) 75 | elif tensor_dtype is None or isinstance(tensor_dtype, tf.dtypes.DType): 76 | tensors_to_cast.append((tensor_name, tensor_name, tensor_dtype)) 77 | else: 78 | raise ValueError('Values of the output_tensors dictionary must be ' 79 | 'None, tf.dtypes.DType or 2-tuples.') 80 | return tensors_to_cast 81 | 82 | def _get_and_cast_fn(data): 83 | """Get and cast tensors by name, optionally changing the name too.""" 84 | 85 | return { 86 | new_name: 87 | data[name] if new_dtype is None else tf.cast(data[name], new_dtype) 88 | for name, new_name, new_dtype in _tensors_to_cast() 89 | } 90 | 91 | return _get_and_cast_fn 92 | 93 | 94 | def compose_preprocess_fn(*functions): 95 | """Compose two or more preprocessing functions. 96 | 97 | Args: 98 | *functions: Sequence of preprocess functions to compose. 99 | 100 | Returns: 101 | The composed function. 102 | """ 103 | 104 | def _composed_fn(x): 105 | for fn in functions: 106 | if fn is not None: # Note: If one function is None, equiv. to identity. 107 | x = fn(x) 108 | return x 109 | 110 | return _composed_fn 111 | 112 | 113 | # Note: DO NOT implement any method in this abstract class. 114 | @six.add_metaclass(abc.ABCMeta) 115 | class ImageDataInterface(object): 116 | """Interface to the image data classes.""" 117 | 118 | @property 119 | @abc.abstractmethod 120 | def default_label_key(self): 121 | """Returns the default label key of the dataset.""" 122 | 123 | @property 124 | @abc.abstractmethod 125 | def label_keys(self): 126 | """Returns a tuple with the available label keys of the dataset.""" 127 | 128 | @property 129 | @abc.abstractmethod 130 | def num_channels(self): 131 | """Returns the number of channels of the images in the dataset.""" 132 | 133 | @property 134 | @abc.abstractmethod 135 | def splits(self): 136 | """Returns the splits defined in the dataset.""" 137 | 138 | @abc.abstractmethod 139 | def get_num_samples(self, split_name): 140 | """Returns the number of images in the given split name.""" 141 | 142 | @abc.abstractmethod 143 | def get_num_classes(self, label_key=None): 144 | """Returns the number of classes of the given label_key.""" 145 | 146 | @abc.abstractmethod 147 | def get_tf_data(self, 148 | split_name, 149 | batch_size, 150 | pairwise_mix_fn=None, 151 | preprocess_fn=None, 152 | preprocess_before_filter=None, 153 | epochs=None, 154 | drop_remainder=True, 155 | for_eval=False, 156 | shuffle_buffer_size=None, 157 | prefetch=1, 158 | train_examples=None, 159 | filtered_num_samples=None, 160 | filter_fn=None, 161 | batch_preprocess_fn=None, 162 | ignore_errors=False, 163 | shuffle_files=False): 164 | """Provides preprocessed and batched data. 165 | 166 | Args: 167 | split_name: name of a data split to provide. Can be "train", "val", 168 | "trainval" or "test". 169 | batch_size: batch size. 170 | pairwise_mix_fn: a function for mixing each data with another random one. 171 | preprocess_fn: a function for preprocessing input data. It expects a 172 | dictionary with a key "image" associated with a 3D image tensor. 173 | preprocess_before_filter: a function for preprocessing input data, 174 | before filter_fn. It is only designed for light preprocessing, 175 | i.e. augment with image id. For heavy preprocessing, it's more 176 | efficient to do it after filter_fn. 177 | epochs: number of full passes through the data. If None, the data is 178 | provided indefinitely. 179 | drop_remainder: if True, the last incomplete batch of data is dropped. 180 | Normally, this parameter should be True, otherwise it leads to 181 | the unknown batch dimension, which is not compatible with training 182 | or evaluation on TPUs. 183 | for_eval: get data for evaluation. Disables shuffling. 184 | shuffle_buffer_size: overrides default shuffle buffer size. 185 | prefetch: number of batches to prefetch. 186 | train_examples: optional number of examples to take for training. 187 | If greater than available number of examples, equivalent to None (all). 188 | Ignored with for_eval is True. 189 | filtered_num_samples: required when filter_fn is set, number of 190 | samples after applying filter_fn. 191 | filter_fn: filter function for generating training subset. 192 | batch_preprocess_fn: optional function for preprocessing a full batch of 193 | input data. Analoguous to preprocess_fn with an extra batch-dimension 194 | on all tensors. 195 | ignore_errors: whether to skip images that encountered an error in 196 | decoding *or pre-processing*, the latter is why it is False by default. 197 | shuffle_files: whether to shuffle the dataset files or not. 198 | 199 | Returns: 200 | A tf.data.Dataset object as a dictionary containing the output tensors. 201 | """ 202 | 203 | 204 | class ImageData(ImageDataInterface): 205 | """Abstract data provider class. 206 | 207 | IMPORTANT: You should use ImageTfdsData below whenever is posible. We want 208 | to use as many datasets in TFDS as possible to ensure reproducibility of our 209 | experiments. Your data class should only inherit directly from this if you 210 | are doing experiments while creating a TFDS dataset. 211 | """ 212 | 213 | @abc.abstractmethod 214 | def __init__(self, 215 | num_samples_splits, 216 | shuffle_buffer_size, 217 | num_preprocessing_threads, 218 | num_classes, 219 | default_label_key='label', 220 | base_preprocess_fn=None, 221 | filter_fn=None, 222 | image_decoder=None, 223 | num_channels=3): 224 | """Initializer for the base ImageData class. 225 | 226 | Args: 227 | num_samples_splits: a dictionary, that maps splits ("train", and "test") to the corresponding number of samples. 228 | shuffle_buffer_size: size of a buffer used for shuffling. 229 | num_preprocessing_threads: the number of parallel threads for data 230 | preprocessing. 231 | num_classes: int/dict, number of classes in this dataset for the 232 | `default_label_key` tensor, or dictionary with the number of classes in 233 | each label tensor. 234 | default_label_key: optional, string with the name of the tensor to use 235 | as label. Default is "label". 236 | base_preprocess_fn: optional, base preprocess function to apply in all 237 | cases for this dataset. 238 | filter_fn: optional, function to filter the examples to use in the 239 | dataset. DEPRECATED, soon to be removed. 240 | image_decoder: a function to decode image. 241 | num_channels: number of channels in the dataset image. 242 | """ 243 | self._log_warning_if_direct_inheritance() 244 | self._num_samples_splits = num_samples_splits 245 | self._shuffle_buffer_size = shuffle_buffer_size 246 | self._num_preprocessing_threads = num_preprocessing_threads 247 | self._base_preprocess_fn = base_preprocess_fn 248 | self._default_label_key = default_label_key 249 | self._filter_fn = filter_fn 250 | if self._filter_fn: 251 | tf.logging.warning('Using deprecated filtering mechanism.') 252 | self._image_decoder = image_decoder 253 | self._num_channels = num_channels 254 | 255 | if isinstance(num_classes, dict): 256 | self._num_classes = num_classes 257 | if default_label_key not in num_classes: 258 | raise ValueError( 259 | 'No num_classes was specified for the default_label_key %r' % 260 | default_label_key) 261 | elif isinstance(num_classes, int): 262 | self._num_classes = {default_label_key: num_classes} 263 | else: 264 | raise ValueError( 265 | '"num_classes" must be a int or a dict, but type %r was given' % 266 | type(num_classes)) 267 | 268 | @property 269 | def default_label_key(self): 270 | return self._default_label_key 271 | 272 | @property 273 | def label_keys(self): 274 | return tuple(self._num_classes.keys()) 275 | 276 | @property 277 | def num_channels(self): 278 | return self._num_channels 279 | 280 | @property 281 | def splits(self): 282 | return tuple(self._num_samples_splits.keys()) 283 | 284 | def get_num_samples(self, split_name): 285 | return self._num_samples_splits[split_name] 286 | 287 | def get_num_classes(self, label_key=None): 288 | if label_key is None: 289 | label_key = self._default_label_key 290 | return self._num_classes[label_key] 291 | 292 | def get_version(self): 293 | return NotImplementedError('Version is not supported outside TFDS.') 294 | 295 | def get_tf_data(self, 296 | split_name, 297 | batch_size, 298 | pairwise_mix_fn=None, 299 | preprocess_fn=None, 300 | preprocess_before_filter=None, 301 | epochs=None, 302 | drop_remainder=True, 303 | for_eval=False, 304 | shuffle_buffer_size=None, 305 | prefetch=1, 306 | train_examples=None, 307 | filtered_num_samples=None, 308 | filter_fn=None, 309 | batch_preprocess_fn=None, 310 | ignore_errors=False, 311 | shuffle_files=False): 312 | # Obtains tf.data object. 313 | # We shuffle later when not for eval, it's important to not shuffle before 314 | # a subset of data is retrieved. 315 | data = self._get_dataset_split( 316 | split_name=split_name, 317 | shuffle_files=shuffle_files) 318 | 319 | if preprocess_before_filter is not None: 320 | data = preprocess_before_filter(data) 321 | 322 | 323 | if self._filter_fn and (filter_fn is None): 324 | filter_fn = self._filter_fn 325 | 326 | # Dataset filtering priority: (1) filter_fn; (2) train_examples. 327 | if filter_fn and train_examples: 328 | raise ValueError('You must not set both filter_fn and train_examples.') 329 | 330 | if filter_fn: 331 | tf.logging.warning( 332 | 'You are filtering the dataset. Notice that this may hurt your ' 333 | 'throughput, since examples still need to be decoded, and may ' 334 | 'make the result of get_num_samples() inacurate. ' 335 | 'train_examples is ignored for filtering, but only used for ' 336 | 'calculating training steps.') 337 | data = data.filter(filter_fn) 338 | num_samples = filtered_num_samples 339 | assert num_samples is not None, ( 340 | 'You must set filtered_num_samples if filter_fn is set.') 341 | 342 | elif not for_eval and train_examples: 343 | # Deterministic for same dataset version. 344 | data = data.take(train_examples) 345 | num_samples = train_examples 346 | 347 | else: 348 | num_samples = self.get_num_samples(split_name) 349 | 350 | data = self._cache_data_if_possible( 351 | data, split_name=split_name, num_samples=num_samples, for_eval=for_eval) 352 | 353 | def print_filtered_subset(ex): 354 | """Print filtered subset for debug purpose.""" 355 | if isinstance(ex, dict) and 'id' in ex and 'label' in ex: 356 | print_op = tf.print( 357 | 'filtered_example:', 358 | ex['id'], 359 | ex['label'], 360 | output_stream=tf.logging.error) 361 | with tf.control_dependencies([print_op]): 362 | ex['id'] = tf.identity(ex['id']) 363 | return ex 364 | if not for_eval and filter_fn: 365 | data = data.map(print_filtered_subset) 366 | 367 | # Repeats data `epochs` time or indefinitely if `epochs` is None. 368 | if epochs is None or epochs > 1: 369 | data = data.repeat(epochs) 370 | 371 | shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size 372 | if not for_eval and shuffle_buffer_size > 1: 373 | data = data.shuffle(shuffle_buffer_size) 374 | 375 | data = self._preprocess_and_batch_data(data, batch_size, drop_remainder, 376 | pairwise_mix_fn, preprocess_fn, 377 | ignore_errors) 378 | 379 | if batch_preprocess_fn is not None: 380 | data = data.map(batch_preprocess_fn, self._num_preprocessing_threads) 381 | 382 | if prefetch != 0: 383 | data = data.prefetch(prefetch) 384 | 385 | return data 386 | 387 | @abc.abstractmethod 388 | def _get_dataset_split(self, split_name, shuffle_files=False): 389 | """Return the Dataset object for the given split name. 390 | 391 | Args: 392 | split_name: Name of the dataset split to get. 393 | shuffle_files: Whether or not to shuffle files in the dataset. 394 | 395 | Returns: 396 | A tf.data.Dataset object containing the data for the given split. 397 | """ 398 | 399 | def _log_warning_if_direct_inheritance(self): 400 | tf.logging.warning( 401 | 'You are directly inheriting from ImageData. Please, consider porting ' 402 | 'your dataset to TFDS (go/tfds) and inheriting from ImageTfdsData ' 403 | 'instead.') 404 | 405 | def _preprocess_and_batch_data(self, 406 | data, 407 | batch_size, 408 | drop_remainder=True, 409 | pairwise_mix_fn=None, 410 | preprocess_fn=None, 411 | ignore_errors=False): 412 | """Preprocesses and batches a given tf.Dataset.""" 413 | # Preprocess with basic preprocess functions (e.g. decoding images, parsing 414 | # features etc.). 415 | base_preprocess_fn = compose_preprocess_fn(self._image_decoder, 416 | self._base_preprocess_fn) 417 | # Note: `map_and_batch` is deprecated, and at least when nothing happens 418 | # in-between, automatically gets merged for efficiency. Same below. 419 | data = data.map(base_preprocess_fn, self._num_preprocessing_threads) 420 | 421 | # Mix images pair-wise before other element-wise preprocessing. 422 | # Note: The pairing is implemented by shifting `data` by 1, so the last 423 | # element of `data` will be dropped. 424 | if pairwise_mix_fn is not None: 425 | data = tf.data.Dataset.zip( 426 | (data, data.skip(1))).map(pairwise_mix_fn, 427 | self._num_preprocessing_threads) 428 | 429 | # Preprocess with customized preprocess functions. 430 | if preprocess_fn is not None: 431 | data = data.map(preprocess_fn, self._num_preprocessing_threads) 432 | 433 | if ignore_errors: 434 | tf.logging.info('Ignoring any image with errors.') 435 | data = data.apply(tf.data.experimental.ignore_errors()) 436 | 437 | return data.batch(batch_size, drop_remainder) 438 | 439 | def _cache_data_if_possible(self, data, split_name, num_samples, for_eval): 440 | del split_name 441 | 442 | if not for_eval and num_samples <= 150000: 443 | # Cache the whole dataset if it's smaller than 150K examples. 444 | data = data.cache() 445 | return data 446 | 447 | 448 | class ImageTfdsData(ImageData): 449 | """Abstract data provider class for datasets available in Tensorflow Datasets. 450 | 451 | To add new datasets inherit from this class. This class implements a simple 452 | API that is used throughout the project and provides standardized way of data 453 | preprocessing and batching. 454 | """ 455 | 456 | @abc.abstractmethod 457 | def __init__(self, dataset_builder, tfds_splits, image_key='image', **kwargs): 458 | """Initializer for the base ImageData class. 459 | 460 | Args: 461 | dataset_builder: tfds dataset builder object. 462 | tfds_splits: a dictionary, that maps splits ("train", 463 | and "test") to the corresponding tfds `Split` objects. 464 | image_key: image key. 465 | **kwargs: Additional keyword arguments for the ImageData class. 466 | """ 467 | self._dataset_builder = dataset_builder 468 | self._tfds_splits = tfds_splits 469 | self._image_key = image_key 470 | 471 | # Overwrite image decoder 472 | def _image_decoder(data): 473 | decoder = dataset_builder.info.features[image_key].decode_example 474 | data[image_key] = decoder(data[image_key]) 475 | return data 476 | self._image_decoder = _image_decoder 477 | 478 | kwargs.update({'image_decoder': _image_decoder}) 479 | 480 | super(ImageTfdsData, self).__init__(**kwargs) 481 | 482 | def get_version(self): 483 | return self._dataset_builder.version.__str__() 484 | 485 | def _get_dataset_split(self, split_name, shuffle_files): 486 | dummy_decoder = tfds.decode.SkipDecoding() 487 | #return self._dataset_builder[split_name] 488 | return self._dataset_builder.as_dataset( 489 | split=self._tfds_splits[split_name], shuffle_files=shuffle_files, 490 | decoders={self._image_key: dummy_decoder}) 491 | 492 | def _log_warning_if_direct_inheritance(self): 493 | pass 494 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | data loader 4 | """ 5 | 6 | import torch 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch.utils.data.sampler import RandomSampler 9 | from src.data.tf_dataset import TFDataset 10 | 11 | import numpy as np 12 | from copy import deepcopy 13 | 14 | def _construct_dataset(cfg, split): 15 | """Constructs the data loader for the given dataset.""" 16 | # import the tensorflow here only if needed 17 | dataset = TFDataset(cfg, split) 18 | 19 | return dataset 20 | 21 | def _construct_continual_loader(cfg, dataset, shuffle=False): 22 | sampler = None 23 | 24 | # In the case of ImageNet-R, we observed that runs with a sequential sampler show more consistent results. 25 | # When using a random sampler, the average performance across five runs is still ~70%; however, there is a variance in performance over runs. 26 | if cfg.DATA.NAME == 'imagenet_r': 27 | sampler = torch.utils.data.SequentialSampler(dataset) 28 | 29 | # Create a loader 30 | loader = torch.utils.data.DataLoader( 31 | dataset, 32 | batch_size=cfg.DATA.BATCH_SIZE, 33 | shuffle=(False if sampler else shuffle), 34 | sampler=sampler, 35 | num_workers=cfg.DATA.NUM_WORKERS, 36 | pin_memory=cfg.DATA.PIN_MEMORY, 37 | drop_last=False, 38 | ) 39 | return loader 40 | 41 | def _build_continual_dataset(cfg, dataset): 42 | prev_cls_increment = 0 43 | cls_increment = cfg.CONTINUAL.INITIAL 44 | scenario = [] 45 | 46 | for i in range(cfg.CONTINUAL.N_TASKS): 47 | cls_less = np.where((np.asarray(dataset._targets) < cls_increment) & (prev_cls_increment <= np.asarray(dataset._targets)))[0] 48 | 49 | _labels = [] 50 | _image_files = [] 51 | for j in cls_less: 52 | _labels.append(dataset._targets[j]) 53 | _image_files.append(dataset._image_tensor_list[j]) 54 | 55 | cur_dataset = deepcopy(dataset) 56 | 57 | cur_dataset._targets = _labels 58 | cur_dataset._image_tensor_list = _image_files 59 | cur_dataset._class_ids = dataset._class_ids[prev_cls_increment:cls_increment] 60 | cur_dataset._class_ids_mask = dataset._class_ids_mask[prev_cls_increment:cls_increment] 61 | 62 | prev_cls_increment = cls_increment 63 | cls_increment += cfg.CONTINUAL.INCREMENT 64 | 65 | scenario.append(cur_dataset) 66 | 67 | return scenario 68 | 69 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last): 70 | """Constructs the data loader for the given dataset.""" 71 | # import the tensorflow here only if needed 72 | dataset = TFDataset(cfg, split) 73 | 74 | # Create a sampler for multi-process training 75 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 76 | # Create a loader 77 | loader = torch.utils.data.DataLoader( 78 | dataset, 79 | batch_size=batch_size, 80 | shuffle=(False if sampler else shuffle), 81 | sampler=sampler, 82 | num_workers=cfg.DATA.NUM_WORKERS, 83 | pin_memory=cfg.DATA.PIN_MEMORY, 84 | drop_last=drop_last, 85 | ) 86 | return loader 87 | 88 | 89 | def construct_train_loader(cfg): 90 | """Train loader wrapper.""" 91 | if cfg.NUM_GPUS > 1: 92 | drop_last = True 93 | else: 94 | drop_last = False 95 | return _construct_loader( 96 | cfg=cfg, 97 | split="train", 98 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 99 | shuffle=True, 100 | drop_last=drop_last, 101 | ) 102 | 103 | 104 | def construct_test_loader(cfg): 105 | """Test loader wrapper.""" 106 | return _construct_loader( 107 | cfg=cfg, 108 | split="test", 109 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 110 | shuffle=False, 111 | drop_last=False, 112 | ) 113 | 114 | 115 | def shuffle(loader, cur_epoch): 116 | """"Shuffles the data.""" 117 | assert isinstance( 118 | loader.sampler, (RandomSampler, DistributedSampler) 119 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 120 | # RandomSampler handles shuffling automatically 121 | if isinstance(loader.sampler, DistributedSampler): 122 | # DistributedSampler shuffles data based on epoch 123 | loader.sampler.set_epoch(cur_epoch) 124 | -------------------------------------------------------------------------------- /src/data/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | global Registry for the task adaptation framework 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import ast 25 | import functools 26 | 27 | 28 | def partialclass(cls, *base_args, **base_kwargs): 29 | """Builds a subclass with partial application of the given args and keywords. 30 | 31 | Equivalent to functools.partial performance, base_args are preprended to the 32 | positional arguments given during object initialization and base_kwargs are 33 | updated with the kwargs given later. 34 | 35 | Args: 36 | cls: The base class. 37 | *base_args: Positional arguments to be applied to the subclass. 38 | **base_kwargs: Keyword arguments to be applied to the subclass. 39 | 40 | Returns: 41 | A subclass of the input class. 42 | """ 43 | 44 | class _NewClass(cls): 45 | 46 | def __init__(self, *args, **kwargs): 47 | bound_args = base_args + args 48 | bound_kwargs = base_kwargs.copy() 49 | bound_kwargs.update(kwargs) 50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs) 51 | 52 | return _NewClass 53 | 54 | 55 | def parse_name(string_to_parse): 56 | """Parses input to the registry's lookup function. 57 | 58 | Args: 59 | string_to_parse: can be either an arbitrary name or function call 60 | (optionally with positional and keyword arguments). 61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)". 62 | 63 | Returns: 64 | A tuple of input name and a dctinary with arguments. Examples: 65 | "multiclass" -> ("multiclass", (), {}) 66 | "resnet50_v2(9, filters_factor=4)" -> 67 | ("resnet50_v2", (9,), {"filters_factor": 4}) 68 | """ 69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error 70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): 71 | raise ValueError( 72 | "The given string should be a name or a call, but a {} was parsed from " 73 | "the string {!r}".format(type(expr), string_to_parse)) 74 | 75 | # Notes: 76 | # name="some_name" -> type(expr) = ast.Name 77 | # name="module.some_name" -> type(expr) = ast.Attribute 78 | # name="some_name()" -> type(expr) = ast.Call 79 | # name="module.some_name()" -> type(expr) = ast.Call 80 | 81 | if isinstance(expr, ast.Name): 82 | return string_to_parse, {} 83 | elif isinstance(expr, ast.Attribute): 84 | return string_to_parse, {} 85 | 86 | def _get_func_name(expr): 87 | if isinstance(expr, ast.Attribute): 88 | return _get_func_name(expr.value) + "." + expr.attr 89 | elif isinstance(expr, ast.Name): 90 | return expr.id 91 | else: 92 | raise ValueError( 93 | "Type {!r} is not supported in a function name, the string to parse " 94 | "was {!r}".format(type(expr), string_to_parse)) 95 | 96 | def _get_func_args_and_kwargs(call): 97 | args = tuple([ast.literal_eval(arg) for arg in call.args]) 98 | kwargs = { 99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords 100 | } 101 | return args, kwargs 102 | 103 | func_name = _get_func_name(expr.func) 104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr) 105 | if func_args: 106 | raise ValueError("Positional arguments are not supported here, but these " 107 | "were found: {!r}".format(func_args)) 108 | 109 | return func_name, func_kwargs 110 | 111 | 112 | class Registry(object): 113 | """Implements global Registry.""" 114 | 115 | _GLOBAL_REGISTRY = {} 116 | 117 | @staticmethod 118 | def global_registry(): 119 | return Registry._GLOBAL_REGISTRY 120 | 121 | @staticmethod 122 | def register(name, item_type): 123 | """Creates a function that registers its input.""" 124 | if item_type not in ["function", "class"]: 125 | raise ValueError("Unknown item type: %s" % item_type) 126 | 127 | def _register(item): 128 | if name in Registry.global_registry(): 129 | raise KeyError( 130 | "The name {!r} was already registered in with type {!r}".format( 131 | name, item_type)) 132 | 133 | Registry.global_registry()[name] = (item, item_type) 134 | return item 135 | 136 | return _register 137 | 138 | @staticmethod 139 | def lookup(lookup_string, kwargs_extra=None): 140 | """Lookup a name in the registry.""" 141 | 142 | name, kwargs = parse_name(lookup_string) 143 | if kwargs_extra: 144 | kwargs.update(kwargs_extra) 145 | item, item_type = Registry.global_registry()[name] 146 | if item_type == "function": 147 | return functools.partial(item, **kwargs) 148 | elif item_type == "class": 149 | return partialclass(item, **kwargs) 150 | -------------------------------------------------------------------------------- /src/data/tf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | handle output of tf.data 4 | """ 5 | 6 | import functools 7 | import tensorflow.compat.v1 as tf 8 | import torch 9 | import torch.utils.data 10 | import numpy as np 11 | 12 | from collections import Counter 13 | from torch import Tensor 14 | 15 | from src.data import tf_load 16 | from src.data.registry import Registry 17 | 18 | tf.config.experimental.set_visible_devices([], 'GPU') 19 | 20 | class TFDataset(torch.utils.data.Dataset): 21 | def __init__(self, cfg, split): 22 | self.cfg = cfg 23 | self._split = split 24 | self.name = cfg.DATA.NAME 25 | 26 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 27 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 28 | 29 | self.get_data(cfg, split) 30 | 31 | def get_data(self, cfg, split): 32 | tf_data = build_tf_dataset(cfg, split) 33 | data_list = list(tf_data) 34 | 35 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list] 36 | self._targets = [int(t[1].numpy()[0]) for t in data_list] 37 | 38 | if cfg.DATA.NAME == "imagenet_r": 39 | for index, target in enumerate(self._targets): 40 | self._targets[index] = IR_LABEL_MAP[target] 41 | 42 | self._class_ids = sorted(list(set(self._targets))) 43 | self._class_ids_mask = self._class_ids 44 | 45 | del data_list 46 | del tf_data 47 | 48 | def get_info(self): 49 | num_imgs = len(self._image_tensor_list) 50 | return num_imgs, self.get_class_num() 51 | 52 | def get_class_num(self): 53 | return self.cfg.DATA.NUMBER_CLASSES 54 | 55 | def get_class_weights(self, weight_type): 56 | """get a list of class weight, return a list float""" 57 | if "train" not in self._split: 58 | raise ValueError( 59 | "only getting training class distribution, " + \ 60 | "got split {} instead".format(self._split) 61 | ) 62 | 63 | cls_num = self.get_class_num() 64 | if weight_type == "none": 65 | return [1.0] * cls_num 66 | 67 | id2counts = Counter(self._class_ids) 68 | assert len(id2counts) == cls_num 69 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 70 | 71 | if weight_type == 'inv': 72 | mu = -1.0 73 | elif weight_type == 'inv_sqrt': 74 | mu = -0.5 75 | weight_list = num_per_cls ** mu 76 | weight_list = np.divide( 77 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 78 | return weight_list.tolist() 79 | 80 | def __getitem__(self, index): 81 | # Load the image 82 | label = self._targets[index] 83 | im = to_torch_imgs( 84 | self._image_tensor_list[index], self.img_mean, self.img_std) 85 | 86 | sample = { 87 | "image": im, 88 | "label": label, 89 | } 90 | return sample 91 | 92 | def __len__(self): 93 | return len(self._targets) 94 | 95 | 96 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)): 97 | image = data["image"] 98 | image = tf.image.resize(image, [size, size]) 99 | 100 | image = tf.cast(image, tf.float32) / 255.0 101 | image = image * (input_range[1] - input_range[0]) + input_range[0] 102 | 103 | data["image"] = image 104 | return data 105 | 106 | 107 | def build_tf_dataset(cfg, mode): 108 | """ 109 | Builds a tf data instance, then transform to a list of tensors and labels 110 | """ 111 | data_cls = Registry.lookup("data") 112 | vtab_tf_dataloader = data_cls(data_name=cfg.DATA.NAME, data_dir=cfg.DATA.DATAPATH) 113 | split_name_dict = { 114 | "dataset_train_split_name": 'train', 115 | "dataset_test_split_name": 'test' 116 | } 117 | 118 | def _dict_to_tuple(batch): 119 | return batch['image'], batch['label'] 120 | 121 | return vtab_tf_dataloader.get_tf_data( 122 | batch_size=1, 123 | drop_remainder=False, 124 | split_name=split_name_dict[f"dataset_{mode}_split_name"], 125 | preprocess_fn=functools.partial( 126 | preprocess_fn, 127 | input_range=(0.0, 1.0), 128 | size=cfg.DATA.CROPSIZE, 129 | ), 130 | for_eval=mode != "train", 131 | shuffle_buffer_size=1000, 132 | prefetch=1, 133 | train_examples=None, 134 | epochs=1 135 | ).map(_dict_to_tuple) 136 | 137 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 138 | if len(img.shape) == 2: 139 | img = np.stack((img, img, img), axis=2) 140 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))) 141 | t_img -= mean 142 | t_img /= std 143 | 144 | return t_img 145 | 146 | 147 | # mapping ImageNet-R, which labeled between 0-999, to 0-199 in a randomly shuffled order 148 | IR_LABEL_MAP = { 149 | 1: 162, 150 | 2: 53, 151 | 4: 4, 152 | 6: 155, 153 | 8: 54, 154 | 9: 75, 155 | 11: 79, 156 | 13: 11, 157 | 22: 46, 158 | 23: 145, 159 | 26: 186, 160 | 29: 126, 161 | 31: 29, 162 | 39: 98, 163 | 47: 84, 164 | 63: 104, 165 | 71: 39, 166 | 76: 167, 167 | 79: 7, 168 | 84: 123, 169 | 90: 139, 170 | 94: 192, 171 | 96: 19, 172 | 97: 25, 173 | 99: 164, 174 | 100: 8, 175 | 105: 47, 176 | 107: 41, 177 | 113: 67, 178 | 122: 49, 179 | 125: 35, 180 | 130: 34, 181 | 132: 69, 182 | 144: 103, 183 | 145: 187, 184 | 147: 3, 185 | 148: 22, 186 | 150: 119, 187 | 151: 60, 188 | 155: 142, 189 | 160: 38, 190 | 161: 153, 191 | 162: 9, 192 | 163: 62, 193 | 171: 90, 194 | 172: 109, 195 | 178: 72, 196 | 187: 77, 197 | 195: 0, 198 | 199: 23, 199 | 203: 146, 200 | 207: 122, 201 | 208: 94, 202 | 219: 73, 203 | 231: 16, 204 | 232: 154, 205 | 234: 45, 206 | 235: 176, 207 | 242: 17, 208 | 245: 101, 209 | 247: 143, 210 | 250: 170, 211 | 251: 78, 212 | 254: 120, 213 | 259: 59, 214 | 260: 165, 215 | 263: 86, 216 | 265: 50, 217 | 267: 51, 218 | 269: 195, 219 | 276: 64, 220 | 277: 107, 221 | 281: 111, 222 | 288: 30, 223 | 289: 156, 224 | 291: 43, 225 | 292: 114, 226 | 293: 129, 227 | 296: 74, 228 | 299: 134, 229 | 301: 68, 230 | 308: 110, 231 | 309: 42, 232 | 310: 150, 233 | 311: 161, 234 | 314: 48, 235 | 315: 1, 236 | 319: 132, 237 | 323: 121, 238 | 327: 130, 239 | 330: 85, 240 | 334: 80, 241 | 335: 108, 242 | 337: 183, 243 | 338: 116, 244 | 340: 52, 245 | 341: 168, 246 | 344: 40, 247 | 347: 97, 248 | 353: 100, 249 | 355: 21, 250 | 361: 152, 251 | 362: 157, 252 | 365: 166, 253 | 366: 180, 254 | 367: 102, 255 | 368: 131, 256 | 372: 31, 257 | 388: 44, 258 | 390: 199, 259 | 393: 174, 260 | 397: 163, 261 | 401: 196, 262 | 407: 65, 263 | 413: 6, 264 | 414: 18, 265 | 425: 135, 266 | 428: 5, 267 | 430: 33, 268 | 435: 141, 269 | 437: 99, 270 | 441: 70, 271 | 447: 13, 272 | 448: 14, 273 | 457: 149, 274 | 462: 148, 275 | 463: 198, 276 | 469: 175, 277 | 470: 136, 278 | 471: 118, 279 | 472: 125, 280 | 476: 178, 281 | 483: 159, 282 | 487: 81, 283 | 515: 71, 284 | 546: 140, 285 | 555: 179, 286 | 558: 15, 287 | 570: 158, 288 | 579: 173, 289 | 583: 127, 290 | 587: 61, 291 | 593: 128, 292 | 594: 191, 293 | 596: 193, 294 | 609: 91, 295 | 613: 106, 296 | 617: 185, 297 | 621: 147, 298 | 629: 93, 299 | 637: 105, 300 | 657: 82, 301 | 658: 144, 302 | 701: 117, 303 | 717: 92, 304 | 724: 160, 305 | 763: 32, 306 | 768: 197, 307 | 774: 76, 308 | 776: 115, 309 | 779: 112, 310 | 780: 189, 311 | 787: 184, 312 | 805: 63, 313 | 812: 95, 314 | 815: 57, 315 | 820: 177, 316 | 824: 24, 317 | 833: 27, 318 | 847: 89, 319 | 852: 96, 320 | 866: 58, 321 | 875: 194, 322 | 883: 190, 323 | 889: 55, 324 | 895: 10, 325 | 907: 137, 326 | 928: 37, 327 | 931: 124, 328 | 932: 56, 329 | 933: 66, 330 | 934: 83, 331 | 936: 138, 332 | 937: 2, 333 | 943: 171, 334 | 945: 88, 335 | 947: 188, 336 | 948: 28, 337 | 949: 151, 338 | 951: 172, 339 | 953: 36, 340 | 954: 182, 341 | 957: 169, 342 | 963: 12, 343 | 965: 26, 344 | 967: 181, 345 | 980: 20, 346 | 981: 87, 347 | 983: 113, 348 | 988: 133 349 | } 350 | -------------------------------------------------------------------------------- /src/data/tf_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | load a built tf-dataset which can be downloaded https://drive.google.com/drive/folders/1bBqS8MuTQXUBV3DXJ_-YZyNOR4ejvm1O?usp=sharing 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow_datasets as tfds 25 | 26 | from src.data import base as base 27 | from src.data.registry import Registry 28 | 29 | TRAIN_SPLIT_PERCENT = 80 30 | TEST_SPLIT_PERCENT = 20 31 | 32 | @Registry.register("data", "class") 33 | class Data(base.ImageTfdsData): 34 | def __init__(self, data_name=None, data_dir=None): 35 | dataset_name = data_name + ":*.*.*" 36 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 37 | dataset_builder.download_and_prepare() 38 | 39 | if data_name == "chestx" or data_name == "eurosat/rgb" or data_name == "isic" or data_name == "resisc45": 40 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 41 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 42 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 43 | 44 | tfds_splits = { 45 | "train": 46 | "train[:{}]".format(train_count), 47 | "test": 48 | "train[{}:]".format(train_count), 49 | } 50 | 51 | num_samples_splits = { 52 | "train": train_count, 53 | "test": test_count, 54 | } 55 | elif data_name == "imagenet_r": 56 | num_examples = dataset_builder.info.splits[tfds.Split.TEST].num_examples 57 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 58 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 59 | 60 | tfds_splits = { 61 | "train": 62 | "test[:{}]".format(train_count), 63 | "test": 64 | "test[{}:]".format(train_count), 65 | } 66 | 67 | num_samples_splits = { 68 | "train": train_count, 69 | "test": test_count, 70 | } 71 | else: 72 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 73 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 74 | 75 | num_samples_splits = { 76 | "train": train_count, 77 | "test": test_count, 78 | } 79 | 80 | tfds_splits = { 81 | "train": "train", 82 | "test": "test" 83 | } 84 | 85 | super(Data, self).__init__( 86 | dataset_builder=dataset_builder, 87 | tfds_splits=tfds_splits, 88 | num_samples_splits=num_samples_splits, 89 | num_preprocessing_threads=400, 90 | shuffle_buffer_size=20000, 91 | # Note: Rename tensors but keep their original types. 92 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 93 | "image": ("image", None), 94 | "label": ("label", None), 95 | }), 96 | image_key="image", 97 | num_channels=3, 98 | num_classes=dataset_builder.info.features["label"].num_classes) 99 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | image transformations 4 | """ 5 | 6 | import torchvision as tv 7 | 8 | def get_transforms(split, size): 9 | normalize = tv.transforms.Normalize( 10 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 11 | ) 12 | resize_dim = 256 13 | crop_dim = 224 14 | if split == "train": 15 | transform = tv.transforms.Compose( 16 | [ 17 | tv.transforms.Resize(resize_dim), 18 | tv.transforms.RandomCrop(crop_dim), 19 | tv.transforms.RandomHorizontalFlip(0.5), 20 | tv.transforms.ToTensor(), 21 | normalize, 22 | ] 23 | ) 24 | else: 25 | transform = tv.transforms.Compose( 26 | [ 27 | tv.transforms.Resize(resize_dim), 28 | tv.transforms.CenterCrop(crop_dim), 29 | tv.transforms.ToTensor(), 30 | normalize, 31 | ] 32 | ) 33 | return transform 34 | -------------------------------------------------------------------------------- /src/engine/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | a trainer class 4 | """ 5 | 6 | import datetime 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | import os 11 | import numpy as np 12 | from fvcore.common.config import CfgNode 13 | 14 | from ..solver.lr_scheduler import make_scheduler 15 | from ..solver.optimizer import make_optimizer 16 | from ..solver.losses import build_loss 17 | from ..utils import logger as logger_continual 18 | from ..utils.train_utils import AverageMeter, gpu_mem_usage 19 | 20 | from ..data.loader import _build_continual_dataset, _construct_continual_loader 21 | 22 | class Trainer(): 23 | def __init__( 24 | self, 25 | cfg: CfgNode, 26 | model: nn.Module, 27 | device: torch.device, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.model = model 31 | self.device = device 32 | 33 | # solver related 34 | self.optimizer = make_optimizer([self.model], cfg.SOLVER) 35 | self.scheduler = make_scheduler(self.optimizer, cfg.SOLVER) 36 | self.cls_criterion = build_loss(self.cfg) 37 | 38 | self.cpu_device = torch.device("cpu") 39 | 40 | self.prev_task = -1 41 | self.task_changed = False 42 | 43 | def forward_one_batch(self, inputs, targets, is_train, task_id=None): 44 | """Train a single (full) epoch on the model using the given 45 | data loader. 46 | """ 47 | inputs = inputs.to(self.device, non_blocking=True) 48 | targets = targets.to(self.device, non_blocking=True) 49 | 50 | # forward 51 | with torch.set_grad_enabled(is_train): 52 | outputs, reduce_sim = self.model(inputs, task_id=task_id, is_train=is_train, cfg=self.cfg) 53 | 54 | if self.cls_criterion.is_local() and is_train: 55 | self.model.eval() 56 | loss = self.cls_criterion( 57 | outputs, targets, 58 | self.model, inputs 59 | ) 60 | elif self.cls_criterion.is_local(): 61 | return torch.tensor(1), outputs 62 | elif is_train: 63 | num_total_class = self.cfg.DATA.NUMBER_CLASSES 64 | class_mask = self.dataset_train._class_ids_mask 65 | not_mask = np.setdiff1d(np.arange(num_total_class), class_mask) 66 | outputs[:, not_mask] = -np.inf 67 | loss = self.cfg.MODEL.DAP.CURRENT_LAMBDA * self.cls_criterion( 68 | outputs, targets) 69 | else: 70 | loss = self.cls_criterion( 71 | outputs, targets) 72 | 73 | simloss = self.cfg.MODEL.DAP.SIM_LAMBDA * reduce_sim 74 | loss -= simloss 75 | 76 | # =======backward and optim step only if in training phase... ========= 77 | if is_train: 78 | self.optimizer.zero_grad() 79 | loss.backward() 80 | if self.cfg.SOLVER.GRAD_CLIP_APPLY: 81 | nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.SOLVER.GRAD_CLIP) 82 | self.optimizer.step() 83 | 84 | return loss, outputs 85 | 86 | def get_input(self, data): 87 | if not isinstance(data["image"], torch.Tensor): 88 | for k, v in data.items(): 89 | data[k] = torch.from_numpy(v) 90 | 91 | inputs = data["image"].float() 92 | labels = data["label"] 93 | return inputs, labels 94 | 95 | def train_classifier(self, train_dataset, test_dataset): 96 | """ 97 | Train a classifier using epoch 98 | """ 99 | # save the model prompt if required before training 100 | self.model.eval() 101 | 102 | # setup training epoch params 103 | total_epoch = self.cfg.SOLVER.TOTAL_EPOCH 104 | 105 | self.scenario_train = _build_continual_dataset(self.cfg, train_dataset) 106 | self.scenario_test = _build_continual_dataset(self.cfg, test_dataset) 107 | 108 | self.LOG = logger_continual.logger_all('acc', n_tasks=self.cfg.CONTINUAL.N_TASKS) 109 | 110 | for task_id, dataset_train in enumerate(self.scenario_train): 111 | print(f"Starting task id {task_id}/{len(self.scenario_train) - 1}") 112 | 113 | if task_id == 1: 114 | for k, p in self.model.enc.named_parameters(): 115 | if "dap_downsample" in k: 116 | p.requires_grad = False 117 | 118 | self.dataset_train = dataset_train 119 | 120 | loader_train = _construct_continual_loader(self.cfg, dataset_train, shuffle=True) 121 | 122 | total_data = len(loader_train) 123 | log_interval = self.cfg.SOLVER.LOG_EVERY_N 124 | 125 | losses = AverageMeter('Loss', ':.4e') 126 | batch_time = AverageMeter('Time', ':6.3f') 127 | data_time = AverageMeter('Data', ':6.3f') 128 | 129 | print(f"Start training for {total_epoch} epochs") 130 | for epoch in range(total_epoch): 131 | # reset averagemeters to measure per-epoch results 132 | losses.reset() 133 | batch_time.reset() 134 | data_time.reset() 135 | 136 | lr = self.scheduler.get_lr()[0] 137 | 138 | # Enable training mode 139 | self.model.train() 140 | 141 | end = time.time() 142 | for idx, input_data in enumerate(loader_train): 143 | X, targets = self.get_input(input_data) 144 | data_time.update(time.time() - end) 145 | 146 | train_loss, _ = self.forward_one_batch(X, targets, True, task_id) 147 | 148 | if train_loss == -1: 149 | # continue 150 | return None 151 | 152 | losses.update(train_loss.item(), X.shape[0]) 153 | 154 | # measure elapsed time 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | 158 | # log during one batch 159 | if (idx + 1) % log_interval == 0: 160 | seconds_per_batch = batch_time.val 161 | eta = datetime.timedelta(seconds=int( 162 | seconds_per_batch * (total_data - idx - 1) + seconds_per_batch*total_data*(total_epoch-epoch-1))) 163 | print( 164 | "\tTraining {}/{}. train loss: {:.4f},".format( 165 | idx + 1, 166 | total_data, 167 | train_loss 168 | ) 169 | + "\t{:.4f} s / batch. (data: {:.2e}). ETA={}, ".format( 170 | seconds_per_batch, 171 | data_time.val, 172 | str(eta), 173 | ) 174 | + "max mem: {:.1f} GB ".format(gpu_mem_usage()) 175 | ) 176 | print( 177 | "Epoch {} / {}: ".format(epoch + 1, total_epoch) 178 | + "learning rate: {:.2f}, avg data time: {:.2e}, avg batch time: {:.4f}, ".format( 179 | lr, data_time.avg, batch_time.avg) 180 | + "average train loss: {:.4f}".format(losses.avg)) 181 | self.scheduler.step() 182 | 183 | # Enable eval mode 184 | self.model.eval() 185 | 186 | self.eval_classifier_continual(task_id, self.scenario_test) 187 | 188 | task = self.cfg.CONTINUAL.N_TASKS - 1 189 | final_accs = self.LOG['acc'][:, task] 190 | logger_continual.per_task_summary(self.LOG, 'final_acc', value=np.round(np.mean(final_accs), 5)) 191 | best_acc = np.max(self.LOG['acc'], 1) 192 | final_forgets = best_acc - self.LOG['acc'][:, task] 193 | logger_continual.per_task_summary(self.LOG, 'final_forget', value=np.round(np.mean(final_forgets[:-1]), 5)) 194 | final_la = np.diag(self.LOG['acc']) 195 | logger_continual.per_task_summary(self.LOG, 'final_la', value=np.round(np.mean(final_la), 5)) 196 | 197 | print('\n') 198 | print('final accuracy: {}'.format(final_accs)) 199 | print('average: {}'.format(self.LOG['final_acc'])) 200 | print('final forgetting: {}'.format(final_forgets)) 201 | print('average: {}'.format(self.LOG['final_forget'])) 202 | print('final LA: {}'.format(final_la)) 203 | print('average: {}'.format(self.LOG['final_la'])) 204 | 205 | with open(self.cfg.OUTPUT_DIR + '/final_results.txt', "w") as text_file: 206 | print(self.cfg, file=text_file) 207 | print("\n", file=text_file) 208 | print(self.LOG['acc'], file=text_file) 209 | print('\nFinal {} Accuracy: {:.5f}'.format('test', self.LOG['final_acc']), file=text_file) 210 | print('\nFinal {} Forget: {:.5f}'.format('test', self.LOG['final_forget']), file=text_file) 211 | print('\nFinal {} LA: {:.5f}'.format('test', self.LOG['final_la']), file=text_file) 212 | 213 | @torch.no_grad() 214 | def eval_classifier_continual(self, task_id, scenario_test): 215 | for task_t in range(task_id + 1): 216 | te_dataset = scenario_test[task_t] 217 | loader_te = _construct_continual_loader(self.cfg, te_dataset) 218 | 219 | LOG_eval = logger_continual.logger_eval('acc') 220 | 221 | for idx, input_data in enumerate(loader_te): 222 | X, targets = self.get_input(input_data) 223 | loss, outputs = self.forward_one_batch(X, targets, False) 224 | if loss == -1: 225 | return 226 | 227 | pred = outputs.argmax(dim=1, keepdim=True).cpu() 228 | LOG_eval['acc'] += [pred.eq(targets.view_as(pred)).sum().item() / pred.size(0)] 229 | 230 | logger_continual.per_task_summary(self.LOG, 'acc', task_id, task_t, 231 | np.round(np.mean(LOG_eval['acc']), 5)) 232 | 233 | print(self.LOG['acc']) 234 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | model construction functions 4 | """ 5 | 6 | from tabnanny import verbose 7 | import torch 8 | 9 | from .vit_models import ViT 10 | 11 | # Supported model types 12 | _MODEL_TYPES = { 13 | "vit": ViT, 14 | } 15 | 16 | def build_model(cfg): 17 | """ 18 | build model here 19 | """ 20 | assert ( 21 | cfg.MODEL.TYPE in _MODEL_TYPES.keys() 22 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE) 23 | assert ( 24 | cfg.NUM_GPUS <= torch.cuda.device_count() 25 | ), "Cannot use more GPU devices than available" 26 | 27 | # Construct the model 28 | train_type = cfg.MODEL.TYPE 29 | model = _MODEL_TYPES[train_type](cfg) 30 | 31 | model, device = load_model_to_device(model, cfg) 32 | 33 | return model, device 34 | 35 | 36 | def get_current_device(): 37 | if torch.cuda.is_available(): 38 | # Determine the GPU used by the current process 39 | cur_device = torch.cuda.current_device() 40 | else: 41 | cur_device = torch.device('cpu') 42 | return cur_device 43 | 44 | 45 | def load_model_to_device(model, cfg): 46 | cur_device = get_current_device() 47 | if torch.cuda.is_available(): 48 | # Transfer the model to the current GPU device 49 | model = model.cuda(device=cur_device) 50 | # Use multi-process data parallel model in the multi-gpu setting 51 | if cfg.NUM_GPUS > 1: 52 | # Make model replica operate on the current device 53 | model = torch.nn.parallel.DistributedDataParallel( 54 | module=model, device_ids=[cur_device], output_device=cur_device, 55 | find_unused_parameters=True, 56 | ) 57 | else: 58 | model = model.to(cur_device) 59 | return model, cur_device 60 | -------------------------------------------------------------------------------- /src/models/build_vit_backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | call vit with DAP 4 | """ 5 | 6 | import numpy as np 7 | import os 8 | 9 | from .vit_backbones.vit import VisionTransformer 10 | from .vit_dap.vit import ADPT_VisionTransformer 11 | 12 | 13 | MODEL_ZOO = { 14 | "sup_vitb16_imagenet21k": "imagenet21k_ViT-B_16.npz", 15 | } 16 | 17 | def build_vit_sup_models( 18 | model_type, crop_size, model_root=None, dap_cfg=None, load_pretrain=True, cfg=None, vis=False, transfer_type=None 19 | ): 20 | m2featdim = { 21 | "sup_vitb16_imagenet21k": 768, 22 | } 23 | model = ADPT_VisionTransformer(model_type, crop_size, num_classes=-1, dap_cfg=dap_cfg, model_root=model_root, total_cfg=cfg) 24 | 25 | if load_pretrain: 26 | model.load_from(np.load(os.path.join(model_root, MODEL_ZOO[model_type]))) 27 | 28 | return model, m2featdim[model_type] 29 | 30 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | modified from: fbcode/multimo/models/encoders/mlp.py 4 | """ 5 | 6 | import math 7 | import torch 8 | 9 | from torch import nn 10 | from typing import List, Type 11 | 12 | 13 | class MLP(nn.Module): 14 | def __init__( 15 | self, 16 | input_dim: int, 17 | mlp_dims: List[int], 18 | dropout: float = 0.1, 19 | nonlinearity: Type[nn.Module] = nn.ReLU, 20 | normalization: Type[nn.Module] = nn.BatchNorm1d, 21 | special_bias: bool = False, 22 | add_bn_first: bool = False, 23 | ): 24 | super(MLP, self).__init__() 25 | projection_prev_dim = input_dim 26 | projection_modulelist = [] 27 | last_dim = mlp_dims[-1] 28 | mlp_dims = mlp_dims[:-1] 29 | 30 | if add_bn_first: 31 | if normalization is not None: 32 | projection_modulelist.append(normalization(projection_prev_dim)) 33 | if dropout != 0: 34 | projection_modulelist.append(nn.Dropout(dropout)) 35 | 36 | for idx, mlp_dim in enumerate(mlp_dims): 37 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim) 38 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out') 39 | projection_modulelist.append(fc_layer) 40 | projection_modulelist.append(nonlinearity()) 41 | 42 | if normalization is not None: 43 | projection_modulelist.append(normalization(mlp_dim)) 44 | 45 | if dropout != 0: 46 | projection_modulelist.append(nn.Dropout(dropout)) 47 | projection_prev_dim = mlp_dim 48 | 49 | self.projection = nn.Sequential(*projection_modulelist) 50 | self.last_layer = nn.Linear(projection_prev_dim, last_dim) 51 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out') 52 | if special_bias: 53 | prior_prob = 0.01 54 | bias_value = -math.log((1 - prior_prob) / prior_prob) 55 | torch.nn.init.constant_(self.last_layer.bias, bias_value) 56 | 57 | def forward(self, x: torch.Tensor) -> torch.Tensor: 58 | """ 59 | input_arguments: 60 | @x: torch.FloatTensor 61 | """ 62 | x = self.projection(x) 63 | x = self.last_layer(x) 64 | return x 65 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | models for vits, borrowed from 5 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/modeling_resnet.py 6 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/modeling.py 7 | """ 8 | 9 | import copy 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | import torch.nn.functional as F 17 | import math 18 | 19 | from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm 20 | from torch.nn.modules.utils import _pair 21 | from scipy import ndimage 22 | 23 | from ...configs import vit_configs as configs 24 | 25 | 26 | CONFIGS = { 27 | "sup_vitb16_imagenet21k": configs.get_b16_config(), 28 | } 29 | 30 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 31 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 32 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 33 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 34 | FC_0 = "MlpBlock_3/Dense_0" 35 | FC_1 = "MlpBlock_3/Dense_1" 36 | ATTENTION_NORM = "LayerNorm_0" 37 | MLP_NORM = "LayerNorm_2" 38 | 39 | 40 | def np2th(weights, conv=False): 41 | """Possibly convert HWIO to OIHW.""" 42 | if conv: 43 | weights = weights.transpose([3, 2, 0, 1]) 44 | return torch.from_numpy(weights) 45 | 46 | 47 | def swish(x): 48 | return x * torch.sigmoid(x) 49 | 50 | 51 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, config, vis): 56 | super(Attention, self).__init__() 57 | self.vis = vis 58 | self.num_attention_heads = config.transformer["num_heads"] 59 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 60 | self.all_head_size = self.num_attention_heads * self.attention_head_size 61 | 62 | self.query = Linear(config.hidden_size, self.all_head_size) 63 | self.key = Linear(config.hidden_size, self.all_head_size) 64 | self.value = Linear(config.hidden_size, self.all_head_size) 65 | 66 | self.out = Linear(config.hidden_size, config.hidden_size) 67 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 68 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 69 | 70 | self.softmax = Softmax(dim=-1) 71 | 72 | def transpose_for_scores(self, x): 73 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 74 | x = x.view(*new_x_shape) 75 | return x.permute(0, 2, 1, 3) 76 | 77 | def forward(self, hidden_states): 78 | mixed_query_layer = self.query(hidden_states) # B, num_patches, head_size*num_head 79 | mixed_key_layer = self.key(hidden_states) 80 | mixed_value_layer = self.value(hidden_states) 81 | 82 | query_layer = self.transpose_for_scores(mixed_query_layer) # B, num_head, num_patches, head_size 83 | key_layer = self.transpose_for_scores(mixed_key_layer) 84 | value_layer = self.transpose_for_scores(mixed_value_layer) # B, num_head, num_patches, head_size 85 | 86 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # B, num_head, num_patches, num_patches 87 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 88 | attention_probs = self.softmax(attention_scores) # B, num_head, num_patches(query), num_patches(key) 89 | weights = attention_probs if self.vis else None 90 | attention_probs = self.attn_dropout(attention_probs) 91 | 92 | context_layer = torch.matmul(attention_probs, value_layer) # B, num_head, num_patches, head_size 93 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 94 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 95 | context_layer = context_layer.view(*new_context_layer_shape) 96 | attention_output = self.out(context_layer) 97 | attention_output = self.proj_dropout(attention_output) 98 | return attention_output, weights 99 | 100 | class Mlp(nn.Module): 101 | def __init__(self, config): 102 | super(Mlp, self).__init__() 103 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 104 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 105 | self.act_fn = ACT2FN["gelu"] 106 | self.dropout = Dropout(config.transformer["dropout_rate"]) 107 | 108 | self._init_weights() 109 | 110 | def _init_weights(self): 111 | nn.init.xavier_uniform_(self.fc1.weight) 112 | nn.init.xavier_uniform_(self.fc2.weight) 113 | nn.init.normal_(self.fc1.bias, std=1e-6) 114 | nn.init.normal_(self.fc2.bias, std=1e-6) 115 | 116 | def forward(self, x): 117 | x = self.fc1(x) 118 | x = self.act_fn(x) 119 | x = self.dropout(x) 120 | x = self.fc2(x) 121 | x = self.dropout(x) 122 | return x 123 | 124 | 125 | class Embeddings(nn.Module): 126 | """Construct the embeddings from patch, position embeddings. 127 | """ 128 | def __init__(self, config, img_size, in_channels=3): 129 | super(Embeddings, self).__init__() 130 | self.hybrid = None 131 | img_size = _pair(img_size) 132 | 133 | if config.patches.get("grid") is not None: 134 | grid_size = config.patches["grid"] 135 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 136 | n_patches = (img_size[0] // 16) * (img_size[1] // 16) 137 | self.hybrid = True 138 | else: 139 | patch_size = _pair(config.patches["size"]) 140 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 141 | self.hybrid = False 142 | 143 | self.patch_embeddings = Conv2d(in_channels=in_channels, 144 | out_channels=config.hidden_size, 145 | kernel_size=patch_size, 146 | stride=patch_size) 147 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) 148 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 149 | 150 | self.dropout = Dropout(config.transformer["dropout_rate"]) 151 | 152 | def forward(self, x): 153 | B = x.shape[0] 154 | cls_tokens = self.cls_token.expand(B, -1, -1) 155 | 156 | if self.hybrid: 157 | x = self.hybrid_model(x) 158 | x = self.patch_embeddings(x) 159 | x = x.flatten(2) 160 | x = x.transpose(-1, -2) 161 | x = torch.cat((cls_tokens, x), dim=1) 162 | 163 | embeddings = x + self.position_embeddings 164 | embeddings = self.dropout(embeddings) 165 | return embeddings 166 | 167 | 168 | class Block(nn.Module): 169 | def __init__(self, config, vis): 170 | super(Block, self).__init__() 171 | self.hidden_size = config.hidden_size 172 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 173 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 174 | self.ffn = Mlp(config) 175 | self.attn = Attention(config, vis) 176 | 177 | def forward(self, x): 178 | h = x 179 | x = self.attention_norm(x) 180 | x, weights = self.attn(x) 181 | x = x + h 182 | 183 | h = x 184 | x = self.ffn_norm(x) 185 | x = self.ffn(x) 186 | x = x + h 187 | return x, weights 188 | 189 | def load_from(self, weights, n_block): 190 | ROOT = f"Transformer/encoderblock_{n_block}" 191 | with torch.no_grad(): 192 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 193 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 194 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 195 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 196 | 197 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 198 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 199 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 200 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 201 | 202 | self.attn.query.weight.copy_(query_weight) 203 | self.attn.key.weight.copy_(key_weight) 204 | self.attn.value.weight.copy_(value_weight) 205 | self.attn.out.weight.copy_(out_weight) 206 | self.attn.query.bias.copy_(query_bias) 207 | self.attn.key.bias.copy_(key_bias) 208 | self.attn.value.bias.copy_(value_bias) 209 | self.attn.out.bias.copy_(out_bias) 210 | 211 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 212 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 213 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 214 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 215 | 216 | self.ffn.fc1.weight.copy_(mlp_weight_0) 217 | self.ffn.fc2.weight.copy_(mlp_weight_1) 218 | self.ffn.fc1.bias.copy_(mlp_bias_0) 219 | self.ffn.fc2.bias.copy_(mlp_bias_1) 220 | 221 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 222 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 223 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 224 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 225 | 226 | 227 | class Encoder(nn.Module): 228 | def __init__(self, config, vis): 229 | super(Encoder, self).__init__() 230 | self.vis = vis 231 | self.layer = nn.ModuleList() 232 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 233 | for _ in range(config.transformer["num_layers"]): 234 | layer = Block(config, vis) 235 | self.layer.append(copy.deepcopy(layer)) 236 | 237 | def forward(self, hidden_states): 238 | attn_weights = [] 239 | for layer_block in self.layer: 240 | hidden_states, weights = layer_block(hidden_states) 241 | if self.vis: 242 | attn_weights.append(weights) 243 | encoded = self.encoder_norm(hidden_states) 244 | return encoded, attn_weights 245 | 246 | def forward_cls_layerwise(self, hidden_states): 247 | # hidden_states: B, 1+n_patches, dim 248 | 249 | if hidden_states.size(0) != 1: 250 | raise ValueError('not support batch-wise cls forward yet') 251 | 252 | cls_embeds = [] 253 | cls_embeds.append(hidden_states[0][0]) 254 | for i,layer_block in enumerate(self.layer): 255 | hidden_states, _ = layer_block(hidden_states) 256 | if i < len(self.layer)-1: 257 | cls_embeds.append(hidden_states[0][0]) 258 | encoded = self.encoder_norm(hidden_states) 259 | cls_embeds.append(hidden_states[0][0]) 260 | 261 | cls_embeds = torch.stack(cls_embeds) # 12, dim 262 | return cls_embeds 263 | 264 | 265 | 266 | class Transformer(nn.Module): 267 | def __init__(self, config, img_size, vis): 268 | super(Transformer, self).__init__() 269 | self.embeddings = Embeddings(config, img_size=img_size) 270 | self.encoder = Encoder(config, vis) 271 | 272 | def forward(self, input_ids): 273 | embedding_output = self.embeddings(input_ids) 274 | 275 | encoded, attn_weights = self.encoder(embedding_output) 276 | return encoded, attn_weights 277 | 278 | def forward_cls_layerwise(self, input_ids): 279 | embedding_output = self.embeddings(input_ids) 280 | 281 | cls_embeds = self.encoder.forward_cls_layerwise(embedding_output) 282 | return cls_embeds 283 | 284 | 285 | class VisionTransformer(nn.Module): 286 | def __init__( 287 | self, model_type, 288 | img_size=224, num_classes=21843, vis=False 289 | ): 290 | super(VisionTransformer, self).__init__() 291 | config = CONFIGS[model_type] 292 | self.num_classes = num_classes 293 | self.classifier = config.classifier 294 | 295 | self.transformer = Transformer(config, img_size, vis) 296 | self.head = Linear(config.hidden_size, num_classes) if num_classes > 0 else nn.Identity() 297 | 298 | def forward(self, x, vis=False): 299 | x, attn_weights = self.transformer(x) 300 | logits = self.head(x[:, 0]) 301 | 302 | if not vis: 303 | return logits 304 | return logits, attn_weights # attn_weights: num_layers, B, num_head, num_patches, num_patches 305 | 306 | def forward_last_cls(self, x): 307 | x, attn_weights = self.transformer(x) 308 | return x[:, 0] 309 | 310 | def forward_cls_layerwise(self, x): 311 | cls_embeds = self.transformer.forward_cls_layerwise(x) 312 | return cls_embeds 313 | 314 | def load_from(self, weights): 315 | with torch.no_grad(): 316 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 317 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 318 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 319 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 320 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 321 | 322 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 323 | posemb_new = self.transformer.embeddings.position_embeddings 324 | if posemb.size() == posemb_new.size(): 325 | self.transformer.embeddings.position_embeddings.copy_(posemb) 326 | else: 327 | ntok_new = posemb_new.size(1) 328 | 329 | if self.classifier == "token": 330 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 331 | ntok_new -= 1 332 | else: 333 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 334 | 335 | gs_old = int(np.sqrt(len(posemb_grid))) 336 | gs_new = int(np.sqrt(ntok_new)) 337 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 338 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 339 | 340 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 341 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 342 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 343 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 344 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 345 | 346 | for bname, block in self.transformer.encoder.named_children(): 347 | for uname, unit in block.named_children(): 348 | unit.load_from(weights, n_block=uname) 349 | 350 | if self.transformer.embeddings.hybrid: 351 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) 352 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 353 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 354 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 355 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 356 | 357 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 358 | for uname, unit in block.named_children(): 359 | unit.load_from(weights, n_block=bname, n_unit=uname) 360 | 361 | -------------------------------------------------------------------------------- /src/models/vit_dap/vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit with DAP 4 | """ 5 | 6 | import os 7 | from ..vit_backbones.vit import * 8 | 9 | from functools import reduce 10 | from operator import mul 11 | from torch.nn.modules.utils import _pair 12 | 13 | 14 | MODEL_ZOO = { 15 | "sup_vitb16_imagenet21k": "imagenet21k_ViT-B_16.npz", 16 | } 17 | 18 | class ADPT_Block(nn.Module): 19 | def __init__(self, config, vis, dap_config): 20 | super(ADPT_Block, self).__init__() 21 | self.hidden_size = config.hidden_size 22 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 23 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 24 | self.ffn = Mlp(config) 25 | 26 | self.config = config 27 | self.attn = Attention(config, vis) 28 | 29 | # domain-adaptive prompts 30 | self.dap_config = dap_config 31 | self.dap_downsample = nn.Linear(197, dap_config.NUM_DAP_TOKENS) 32 | nn.init.zeros_(self.dap_downsample.weight) 33 | nn.init.zeros_(self.dap_downsample.bias) 34 | self.dap_film = nn.Linear(dap_config.TASK_EMB, config.hidden_size * 2) 35 | self.dap_norm = LayerNorm(config.hidden_size, eps=1e-6) 36 | 37 | def forward(self, x, task_id_estimated_emb=None, layer_index=None, cfg=None): 38 | if x.shape[1] == 197: # first layer 39 | x_norm = self.dap_norm(x) 40 | x_tran = torch.transpose(x_norm, 2, 1) 41 | down = self.dap_downsample(x_tran) 42 | 43 | film = self.dap_film(task_id_estimated_emb) 44 | gamma4 = film[:, :self.config.hidden_size] 45 | beta4 = film[:, self.config.hidden_size:] 46 | gamma_norm = gamma4.norm(p=2, dim=1, keepdim=True).detach() 47 | beta_norm = beta4.norm(p=2, dim=1, keepdim=True).detach() 48 | 49 | gamma4 = gamma4.div(gamma_norm).view(film.size(0), -1, 1) 50 | beta4 = beta4.div(beta_norm).view(film.size(0), -1, 1) 51 | down = gamma4 * down + beta4 52 | down = torch.transpose(down, 2, 1) 53 | 54 | x = torch.cat(( 55 | x[:, :1, :], 56 | down, 57 | x[:, 1:, :] 58 | ), dim=1) 59 | else: 60 | x = torch.cat(( 61 | x[:, :1, :], 62 | x[:, (1+self.dap_config.NUM_DAP_TOKENS):, :] 63 | ), dim=1) 64 | 65 | x_norm = self.dap_norm(x) 66 | x_tran = torch.transpose(x_norm, 2, 1) 67 | down = self.dap_downsample(x_tran) 68 | 69 | film = self.dap_film(task_id_estimated_emb) 70 | gamma4 = film[:, :self.config.hidden_size] 71 | beta4 = film[:, self.config.hidden_size:] 72 | gamma_norm = gamma4.norm(p=2, dim=1, keepdim=True).detach() 73 | beta_norm = beta4.norm(p=2, dim=1, keepdim=True).detach() 74 | 75 | gamma4 = gamma4.div(gamma_norm).view(film.size(0), -1, 1) 76 | beta4 = beta4.div(beta_norm).view(film.size(0), -1, 1) 77 | down = gamma4 * down + beta4 78 | down = torch.transpose(down, 2, 1) 79 | 80 | if not (layer_index == 11 and cfg.DATA.NAME == 'imagenet_r'): 81 | # for imagenet_r, do not append prompts on the last layer 82 | x = torch.cat(( 83 | x[:, :1, :], 84 | down, 85 | x[:, 1:, :] 86 | ), dim=1) 87 | 88 | h = x 89 | x = self.attention_norm(x) 90 | x, weights = self.attn(x) 91 | x = x + h 92 | 93 | h = x 94 | x = self.ffn_norm(x) 95 | 96 | x = self.ffn(x) 97 | x = x + h 98 | return x, weights 99 | 100 | def load_from(self, weights, n_block): 101 | ROOT = f"Transformer/encoderblock_{n_block}" 102 | with torch.no_grad(): 103 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, 104 | self.hidden_size).t() 105 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 106 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, 107 | self.hidden_size).t() 108 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, 109 | self.hidden_size).t() 110 | 111 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 112 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 113 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 114 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 115 | 116 | self.attn.query.weight.copy_(query_weight) 117 | self.attn.key.weight.copy_(key_weight) 118 | self.attn.value.weight.copy_(value_weight) 119 | self.attn.out.weight.copy_(out_weight) 120 | self.attn.query.bias.copy_(query_bias) 121 | self.attn.key.bias.copy_(key_bias) 122 | self.attn.value.bias.copy_(value_bias) 123 | self.attn.out.bias.copy_(out_bias) 124 | 125 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 126 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 127 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 128 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 129 | 130 | self.ffn.fc1.weight.copy_(mlp_weight_0) 131 | self.ffn.fc2.weight.copy_(mlp_weight_1) 132 | self.ffn.fc1.bias.copy_(mlp_bias_0) 133 | self.ffn.fc2.bias.copy_(mlp_bias_1) 134 | 135 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 136 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 137 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 138 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 139 | 140 | 141 | class ADPT_Encoder(nn.Module): 142 | def __init__(self, config, vis, dap_cfg): 143 | super(ADPT_Encoder, self).__init__() 144 | self.vis = vis 145 | self.layer = nn.ModuleList() 146 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 147 | 148 | self.num_layers = config.transformer["num_layers"] 149 | for _ in range(self.num_layers): 150 | layer = ADPT_Block(config, vis, dap_cfg) 151 | self.layer.append(copy.deepcopy(layer)) 152 | 153 | def forward_films(self, task_id_emb): 154 | films = [] 155 | for layer_block in self.layer: 156 | films.append(layer_block.dap_film(task_id_emb)) 157 | return films 158 | 159 | def forward(self, hidden_states, task_id_estimated_emb=None, cfg=None): 160 | attn_weights = [] 161 | for layer_index, layer_block in enumerate(self.layer): 162 | hidden_states, weights = layer_block(hidden_states, task_id_estimated_emb=task_id_estimated_emb, layer_index=layer_index, cfg=cfg) 163 | if self.vis: 164 | attn_weights.append(weights) 165 | encoded = self.encoder_norm(hidden_states) 166 | return encoded, attn_weights 167 | 168 | def expand_to_batch(x, batch_size, dim=0): 169 | shape = [1 for _ in x.shape] 170 | shape.insert(dim, batch_size) 171 | return torch.tile(torch.unsqueeze(x, dim=dim), shape).cuda() 172 | 173 | class ADPT_Transformer(nn.Module): 174 | def __init__(self, config, img_size, vis, dap_cfg, model_type=None, model_root=None, total_cfg=None): 175 | super(ADPT_Transformer, self).__init__() 176 | self.embeddings = Embeddings(config, img_size=img_size) 177 | self.encoder = ADPT_Encoder(config, vis, dap_cfg) 178 | 179 | self.pretrained_enc = VisionTransformer(model_type, img_size, num_classes=-1, vis=vis) 180 | self.pretrained_enc.load_from(np.load(os.path.join(model_root, MODEL_ZOO[model_type]))) 181 | 182 | self.patch_size = _pair(config.patches["size"]) 183 | self.prompt_dim = config.hidden_size 184 | self.pool_size = total_cfg.MODEL.DAP.PROMPT_POOL 185 | 186 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_size, 1) + self.prompt_dim)) 187 | self.dap_key_embeddings = nn.Parameter(torch.zeros(self.pool_size, self.prompt_dim)) 188 | nn.init.uniform_(self.dap_key_embeddings.data, -val, val) 189 | self.dap_emb = torch.nn.Embedding(dap_cfg.NUM_TASKS_FOR_EMB, dap_cfg.TASK_EMB) 190 | 191 | self.dap_cfg = dap_cfg 192 | self.cfg = total_cfg 193 | self.top_k = 1 194 | 195 | def forward(self, input_ids, task_id=None, is_train=None, cfg=None): 196 | B = input_ids.shape[0] 197 | x_cls_embed = self.pretrained_enc.forward_last_cls(input_ids).detach() 198 | 199 | if is_train: 200 | start = task_id * self.top_k 201 | end = (task_id + 1) * self.top_k 202 | prompt_mask = torch.arange(start, end).cuda() 203 | if end > self.pool_size: 204 | prompt_mask = None 205 | else: 206 | prompt_mask = None 207 | 208 | dap_prompt_key_norm = F.normalize(self.dap_key_embeddings, dim=-1) 209 | 210 | x_embed_norm = F.normalize(x_cls_embed, dim=-1) 211 | sim = torch.matmul(dap_prompt_key_norm, 212 | torch.transpose(x_embed_norm, 1, 0)) 213 | 214 | sim = torch.transpose(sim, 1, 0) 215 | (sim_top_k, idx) = torch.topk(sim, self.top_k) 216 | idx = idx.squeeze(dim=-1) 217 | 218 | prompt_id, id_counts = torch.unique(idx, return_counts=True) 219 | _, major_idx = torch.topk(id_counts, self.top_k) 220 | major_prompt_id = prompt_id[major_idx] 221 | idx = expand_to_batch(major_prompt_id, x_cls_embed.shape[0]).squeeze(dim=-1) 222 | 223 | task_id = major_prompt_id.cpu()[0] 224 | 225 | if prompt_mask is not None: 226 | idx = prompt_mask 227 | task_id = idx.cpu()[0] 228 | idx = expand_to_batch(idx, x_cls_embed.shape[0]).squeeze(dim=-1) 229 | 230 | task_id_estimated_emb = self.dap_emb(idx) 231 | 232 | i = torch.arange(B).reshape(B, 1, 1) 233 | l = torch.arange(self.prompt_dim).reshape(1, 1, self.prompt_dim) 234 | 235 | selected_prompt_key = dap_prompt_key_norm.repeat(B, 1, 1)[ 236 | i, idx.unsqueeze(-1), l] 237 | 238 | x_embed_norm = x_embed_norm.unsqueeze(1) 239 | sim_pull = selected_prompt_key * x_embed_norm 240 | reduce_sim = torch.sum(sim_pull) / x_cls_embed.shape[0] 241 | 242 | embedding_output = self.embeddings(input_ids) 243 | encoded, attn_weights = self.encoder(embedding_output, task_id_estimated_emb=task_id_estimated_emb, cfg=cfg) 244 | 245 | return encoded, attn_weights, reduce_sim, task_id 246 | 247 | 248 | class ADPT_VisionTransformer(nn.Module): 249 | def __init__( 250 | self, model_type, 251 | img_size=224, num_classes=21843, vis=False, dap_cfg=None, model_root=None, total_cfg=None 252 | ): 253 | super(ADPT_VisionTransformer, self).__init__() 254 | config = CONFIGS[model_type] 255 | self.num_classes = num_classes 256 | self.classifier = config.classifier 257 | 258 | self.transformer = ADPT_Transformer(config, img_size, vis, dap_cfg, model_type=model_type, model_root=model_root, total_cfg=total_cfg) 259 | self.head = Linear(config.hidden_size, num_classes) if num_classes > 0 else nn.Identity() 260 | 261 | self.cfg = total_cfg 262 | self.top_k = 1 263 | 264 | def forward(self, x, task_id, vis=False, is_train=None): 265 | x, attn_weights, reduce_sim, task_id_out = self.transformer(x, task_id, is_train=is_train, cfg=self.cfg) 266 | 267 | logits = self.head(x[:, 0]) 268 | 269 | if not vis: 270 | return logits, reduce_sim, task_id_out 271 | return logits, attn_weights, reduce_sim, task_id_out 272 | 273 | def load_from(self, weights): 274 | with torch.no_grad(): 275 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 276 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 277 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 278 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 279 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 280 | 281 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 282 | posemb_new = self.transformer.embeddings.position_embeddings 283 | if posemb.size() == posemb_new.size(): 284 | self.transformer.embeddings.position_embeddings.copy_(posemb) 285 | else: 286 | ntok_new = posemb_new.size(1) 287 | 288 | if self.classifier == "token": 289 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 290 | ntok_new -= 1 291 | else: 292 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 293 | 294 | gs_old = int(np.sqrt(len(posemb_grid))) 295 | gs_new = int(np.sqrt(ntok_new)) 296 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 297 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 298 | 299 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 300 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 301 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 302 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 303 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 304 | 305 | for bname, block in self.transformer.encoder.named_children(): 306 | for uname, unit in block.named_children(): 307 | unit.load_from(weights, n_block=uname) 308 | 309 | if self.transformer.embeddings.hybrid: 310 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_( 311 | np2th(weights["conv_root/kernel"], conv=True)) 312 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 313 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 314 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 315 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 316 | 317 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 318 | for uname, unit in block.named_children(): 319 | unit.load_from(weights, n_block=bname, n_unit=uname) -------------------------------------------------------------------------------- /src/models/vit_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | ViT-related models 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from collections import OrderedDict 10 | from torchvision import models 11 | 12 | from .build_vit_backbone import build_vit_sup_models 13 | from .mlp import MLP 14 | 15 | class ViT(nn.Module): 16 | """ViT-related model.""" 17 | def __init__(self, cfg, load_pretrain=True, vis=False): 18 | super(ViT, self).__init__() 19 | 20 | self.froze_enc = True 21 | 22 | dap_cfg = cfg.MODEL.DAP 23 | 24 | self.build_backbone( 25 | cfg, dap_cfg, load_pretrain, vis=vis) 26 | self.cfg = cfg 27 | self.setup_head(cfg) 28 | 29 | def build_backbone(self, cfg, dap_cfg, load_pretrain, vis): 30 | self.enc, self.feat_dim = build_vit_sup_models( 31 | cfg.DATA.FEATURE, cfg.DATA.CROPSIZE, cfg.MODEL.MODEL_ROOT, dap_cfg, load_pretrain, cfg, vis, cfg.MODEL.TRANSFER_TYPE 32 | ) 33 | 34 | for k, p in self.enc.named_parameters(): 35 | if "dap" not in k: 36 | p.requires_grad = False 37 | 38 | def setup_head(self, cfg): 39 | self.head = MLP( 40 | input_dim=self.feat_dim, 41 | mlp_dims=[self.feat_dim] * self.cfg.MODEL.MLP_NUM + \ 42 | [cfg.DATA.NUMBER_CLASSES], 43 | special_bias=True 44 | ) 45 | 46 | def forward(self, x, return_feature=False, task_id=None, is_train=None, cfg=None): 47 | if self.froze_enc and self.enc.training: 48 | self.enc.eval() 49 | 50 | x, reduce_sim, task_id_out = self.enc(x, task_id=task_id, is_train=is_train) 51 | 52 | if return_feature: 53 | return x, x 54 | x = self.head(x) 55 | 56 | if cfg.DATA.NAME == 'imagenet_r': 57 | # only for imagenet_r 58 | offset1 = task_id_out * cfg.CONTINUAL.INCREMENT 59 | offset2 = (task_id_out + 1) * cfg.CONTINUAL.INCREMENT 60 | if offset1 > 0: 61 | x[:, :offset1].data.fill_(-10e10) 62 | if offset2 < cfg.DATA.NUMBER_CLASSES: 63 | x[:, int(offset2):cfg.DATA.NUMBER_CLASSES].data.fill_(-10e10) 64 | 65 | return x, reduce_sim 66 | 67 | def forward_cls_layerwise(self, x): 68 | cls_embeds = self.enc.forward_cls_layerwise(x) 69 | return cls_embeds 70 | 71 | def get_features(self, x): 72 | """get a (batch_size, self.feat_dim) feature""" 73 | x = self.enc(x) # batch_size x self.feat_dim 74 | return x 75 | -------------------------------------------------------------------------------- /src/solver/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | loss functions 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from typing import Optional 10 | 11 | 12 | class SigmoidLoss(nn.Module): 13 | def __init__(self, cfg=None): 14 | super(SigmoidLoss, self).__init__() 15 | 16 | def is_single(self): 17 | return True 18 | 19 | def is_local(self): 20 | return False 21 | 22 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor: 23 | labels = labels.unsqueeze(1) # (batch_size, 1) 24 | target = torch.zeros( 25 | labels.size(0), nb_classes, device=labels.device 26 | ).scatter_(1, labels, 1.) 27 | # (batch_size, num_classes) 28 | return target 29 | 30 | def loss( 31 | self, logits, targets, per_cls_weights, 32 | multihot_targets: Optional[bool] = False 33 | ): 34 | # targets: 1d-tensor of integer 35 | # Only support single label at this moment 36 | # if len(targets.shape) != 2: 37 | num_classes = logits.shape[1] 38 | targets = self.multi_hot(targets, num_classes) 39 | 40 | loss = F.binary_cross_entropy_with_logits( 41 | logits, targets, reduction="none") 42 | weight = torch.tensor( 43 | per_cls_weights, device=logits.device 44 | ).unsqueeze(0) 45 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32)) 46 | return torch.sum(loss) / targets.shape[0] 47 | 48 | def forward( 49 | self, pred_logits, targets, per_cls_weights=None, multihot_targets=False 50 | ): 51 | loss = self.loss( 52 | pred_logits, targets, per_cls_weights, multihot_targets) 53 | return loss 54 | 55 | 56 | class SoftmaxLoss(SigmoidLoss): 57 | def __init__(self, cfg=None): 58 | super(SoftmaxLoss, self).__init__() 59 | 60 | def loss(self, logits, targets, per_cls_weights, kwargs): 61 | loss = F.cross_entropy(logits, targets, per_cls_weights, reduction="none") 62 | return torch.sum(loss) / targets.shape[0] 63 | 64 | 65 | LOSS = { 66 | "softmax": SoftmaxLoss, 67 | } 68 | 69 | 70 | def build_loss(cfg): 71 | loss_name = cfg.SOLVER.LOSS 72 | assert loss_name in LOSS, \ 73 | f'loss name {loss_name} is not supported' 74 | loss_fn = LOSS[loss_name] 75 | if not loss_fn: 76 | return None 77 | else: 78 | return loss_fn(cfg) 79 | -------------------------------------------------------------------------------- /src/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | learning rate scheduler 4 | """ 5 | 6 | import math 7 | 8 | import torch.optim as optim 9 | from fvcore.common.config import CfgNode 10 | from torch.optim.lr_scheduler import LambdaLR 11 | 12 | def make_scheduler( 13 | optimizer: optim.Optimizer, train_params: CfgNode 14 | ) -> LambdaLR: 15 | warmup = train_params.WARMUP_EPOCH 16 | total_iters = train_params.TOTAL_EPOCH 17 | 18 | if train_params.SCHEDULER == "cosine": 19 | scheduler = WarmupCosineSchedule( 20 | optimizer, 21 | warmup_steps=warmup, 22 | t_total=total_iters 23 | ) 24 | elif train_params.SCHEDULER == "cosine_hardrestart": 25 | scheduler = WarmupCosineWithHardRestartsSchedule( 26 | optimizer, 27 | warmup_steps=warmup, 28 | t_total=total_iters 29 | ) 30 | elif train_params.SCHEDULER == "linear": 31 | scheduler = WarmupConstantSchedule( 32 | optimizer, 33 | warmup_steps=warmup, 34 | ) 35 | elif train_params.SCHEDULER == "plateau": 36 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 37 | optimizer, 38 | "max", 39 | patience=5, 40 | verbose=True, 41 | factor=train_params.LR_DECAY_FACTOR, 42 | ) 43 | else: 44 | scheduler = None 45 | return scheduler 46 | 47 | 48 | class WarmupConstantSchedule(LambdaLR): 49 | """ Linear warmup and then constant. 50 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 51 | Keeps learning rate schedule equal to 1. after warmup_steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 56 | 57 | def lr_lambda(self, step): 58 | if step < self.warmup_steps: 59 | return float(step) / float(max(1.0, self.warmup_steps)) 60 | return 1. 61 | 62 | class WarmupCosineSchedule(LambdaLR): 63 | """ Linear warmup and then cosine decay. 64 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 65 | Decreases learning rate from 1. to 0. over remaining 66 | `t_total - warmup_steps` steps following a cosine curve. 67 | If `cycles` (default=0.5) is different from default, learning rate 68 | follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__( 75 | optimizer, self.lr_lambda, last_epoch=last_epoch) 76 | 77 | def lr_lambda(self, step): 78 | if step < self.warmup_steps: 79 | return float(step) / float(max(1.0, self.warmup_steps)) 80 | # progress after warmup 81 | progress = float(step - self.warmup_steps) / float(max( 82 | 1, self.t_total - self.warmup_steps)) 83 | return max( 84 | 0.0, 85 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 86 | ) 87 | 88 | 89 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 90 | """ Linear warmup and then cosine cycles with hard restarts. 91 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 92 | If `cycles` (default=1.) is different from default, learning rate 93 | follows `cycles` times a cosine decaying learning rate 94 | (with hard restarts). 95 | """ 96 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 97 | self.warmup_steps = warmup_steps 98 | self.t_total = t_total 99 | self.cycles = cycles 100 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 101 | optimizer, self.lr_lambda, last_epoch=last_epoch) 102 | 103 | def lr_lambda(self, step): 104 | if step < self.warmup_steps: 105 | return float(step) / float(max(1, self.warmup_steps)) 106 | # progress after warmup 107 | progress = float(step - self.warmup_steps) / float( 108 | max(1, self.t_total - self.warmup_steps)) 109 | if progress >= 1.0: 110 | return 0.0 111 | return max( 112 | 0.0, 113 | 0.5 * (1. + math.cos( 114 | math.pi * ((float(self.cycles) * progress) % 1.0))) 115 | ) 116 | -------------------------------------------------------------------------------- /src/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | optimizer, ref: https://github.com/huggingface/transformers/blob/master/transformers/optimization.property 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from fvcore.common.config import CfgNode 11 | from torch.optim import Optimizer 12 | import torch.optim as optim 13 | from typing import Any, Callable, Iterable, List, Tuple, Optional 14 | 15 | def make_optimizer( 16 | models: List[Any], train_params: CfgNode 17 | ) -> Optimizer: 18 | params = [] 19 | for model in models: 20 | for key, value in model.named_parameters(): 21 | 22 | if value.requires_grad: 23 | params.append((key, value)) 24 | 25 | if train_params.WEIGHT_DECAY > 0: 26 | if train_params.OPTIMIZER == 'adamw': 27 | 28 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 29 | optimizer_grouped_parameters = [ 30 | {'params': [p for n, p in params 31 | if not any(nd in n for nd in no_decay)], 32 | 'weight_decay': 0.01}, 33 | {'params': [p for n, p in params 34 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 35 | ] 36 | optimizer = AdamW( 37 | optimizer_grouped_parameters, 38 | lr=train_params.BASE_LR, 39 | betas=(train_params.BETA1, train_params.BETA2) 40 | ) 41 | else: 42 | _params = [] 43 | for p in params: 44 | key, value = p 45 | # print(key) 46 | # if not value.requires_grad: 47 | # continue 48 | lr = train_params.BASE_LR 49 | weight_decay = train_params.WEIGHT_DECAY 50 | if "last_layer.bias" in key: 51 | # no regularization (weight decay) for last layer's bias 52 | weight_decay = 0.0 53 | 54 | if train_params.BIAS_MULTIPLIER == 1.: 55 | _params += [{ 56 | "params": [value], 57 | "lr": lr, 58 | "weight_decay": weight_decay 59 | }] 60 | else: 61 | if "bias" in key and "last_layer.bias" not in key: 62 | # use updated lr for this param 63 | lr_value = lr * train_params.BIAS_MULTIPLIER 64 | else: 65 | lr_value = lr 66 | 67 | _params += [{ 68 | "params": [value], 69 | "lr": lr_value, 70 | "weight_decay": weight_decay 71 | }] 72 | 73 | if train_params.OPTIMIZER == 'adam': 74 | print(f"adam is used!") 75 | optimizer = optim.Adam( 76 | _params, 77 | lr=train_params.BASE_LR, 78 | weight_decay=train_params.WEIGHT_DECAY, 79 | betas = (train_params.BETA1, train_params.BETA2), 80 | amsgrad = True 81 | ) 82 | else: 83 | optimizer = optim.SGD( 84 | _params, 85 | train_params.BASE_LR, 86 | momentum=train_params.MOMENTUM, 87 | weight_decay=train_params.WEIGHT_DECAY 88 | ) 89 | return optimizer 90 | else: 91 | if train_params.OPTIMIZER == 'rmsprop': 92 | optimizer = optim.RMSprop( 93 | model.parameters(), 94 | lr=train_params.BASE_LR, 95 | momentum=train_params.MOMENTUM 96 | ) 97 | else: 98 | _params = [] 99 | for p in params: 100 | key, value = p 101 | 102 | lr = train_params.BASE_LR 103 | 104 | if train_params.BIAS_MULTIPLIER == 1.: 105 | _params += [{ 106 | "params": [value], 107 | "lr": lr, 108 | }] 109 | else: 110 | if "bias" in key and "last_layer.bias" not in key: 111 | # use updated lr for this param 112 | lr_value = lr * train_params.BIAS_MULTIPLIER 113 | else: 114 | lr_value = lr 115 | 116 | _params += [{ 117 | "params": [value], 118 | "lr": lr_value, 119 | }] 120 | if train_params.OPTIMIZER == 'sgd': 121 | optimizer = optim.SGD( 122 | _params, 123 | train_params.BASE_LR, 124 | momentum=train_params.MOMENTUM, 125 | ) 126 | elif train_params.OPTIMIZER == 'adam': 127 | #print(f"adam is used!") 128 | optimizer = optim.Adam( 129 | _params, 130 | train_params.BASE_LR, 131 | betas=(train_params.BETA1, train_params.BETA2), 132 | amsgrad=True 133 | ) 134 | elif train_params.OPTIMIZER == 'asgd': 135 | optimizer = optim.ASGD( 136 | _params, 137 | train_params.BASE_LR 138 | ) 139 | return optimizer 140 | 141 | 142 | class AdamW(Optimizer): 143 | """ Implements Adam algorithm with weight decay fix. 144 | Parameters: 145 | lr (float): learning rate. Default 1e-3. 146 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 147 | eps (float): Adams epsilon. Default: 1e-6 148 | weight_decay (float): Weight decay. Default: 0.0 149 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 150 | """ 151 | 152 | def __init__( 153 | self, 154 | params: Iterable, 155 | lr: float = 1e-3, 156 | betas: Tuple[float, float] = (0.9, 0.999), 157 | eps: float = 1e-6, 158 | weight_decay: float = 0.0, 159 | correct_bias: bool = True 160 | ) -> None: 161 | if lr < 0.0: 162 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 163 | if not 0.0 <= betas[0] < 1.0: 164 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 165 | if not 0.0 <= betas[1] < 1.0: 166 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 167 | if not 0.0 <= eps: 168 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 169 | defaults = { 170 | "lr": lr, "betas": betas, "eps": eps, 171 | "weight_decay": weight_decay, "correct_bias": correct_bias 172 | } 173 | super(AdamW, self).__init__(params, defaults) 174 | 175 | def step(self, closure: Optional[Callable] = None) -> Optional[Callable]: 176 | """Performs a single optimization step. 177 | Arguments: 178 | closure (callable, optional): A closure that reevaluates the model 179 | and returns the loss. 180 | """ 181 | loss = None 182 | if closure is not None: 183 | loss = closure() 184 | 185 | for group in self.param_groups: 186 | for p in group['params']: 187 | if p.grad is None: 188 | continue 189 | grad = p.grad.data 190 | if grad.is_sparse: 191 | raise RuntimeError( 192 | "Adam does not support sparse gradients, " 193 | "please consider SparseAdam instead") 194 | 195 | state = self.state[p] 196 | 197 | # State initialization 198 | if len(state) == 0: 199 | state['step'] = 0 200 | # Exponential moving average of gradient values 201 | state['exp_avg'] = torch.zeros_like(p.data) 202 | # Exponential moving average of squared gradient values 203 | state['exp_avg_sq'] = torch.zeros_like(p.data) 204 | 205 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 206 | beta1, beta2 = group['betas'] 207 | 208 | state['step'] += 1 209 | 210 | # Decay the first and second moment running average coefficient 211 | # In-place operations to update the averages at the same time 212 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 213 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 214 | denom = exp_avg_sq.sqrt().add_(group['eps']) 215 | 216 | step_size = group['lr'] 217 | if group['correct_bias']: # No bias correction for Bert 218 | bias_correction1 = 1.0 - beta1 ** state['step'] 219 | bias_correction2 = 1.0 - beta2 ** state['step'] 220 | step_size = step_size * math.sqrt( 221 | bias_correction2) / bias_correction1 222 | 223 | p.data.addcdiv_(-step_size, exp_avg, denom) 224 | 225 | # Just adding the square of the weights to the loss function is *not* 226 | # the correct way of using L2 regularization/weight decay with Adam, 227 | # since that will interact with the m and v parameters in strange ways. 228 | # 229 | # Instead we want to decay the weights in a manner that doesn't interact 230 | # with the m/v parameters. This is equivalent to adding the square 231 | # of the weights to the loss with plain (non-momentum) SGD. 232 | # Add weight decay at the end (fixed version) 233 | if group['weight_decay'] > 0.0: 234 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 235 | 236 | return loss 237 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | distributed helpers 4 | """ 5 | 6 | import torch 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | 10 | 11 | def get_world_size() -> int: 12 | if not dist.is_available(): 13 | return 1 14 | if not dist.is_initialized(): 15 | return 1 16 | return dist.get_world_size() 17 | 18 | 19 | def get_rank() -> int: 20 | if not dist.is_available(): 21 | return 0 22 | if not dist.is_initialized(): 23 | return 0 24 | return dist.get_rank() 25 | 26 | 27 | def is_master_process(num_gpus=8): 28 | """ 29 | Determines if the current process is the master process. 30 | """ 31 | if torch.distributed.is_initialized(): 32 | return dist.get_rank() % num_gpus == 0 33 | else: 34 | return True 35 | 36 | 37 | def run( 38 | local_rank, 39 | num_proc, 40 | func, 41 | init_method, 42 | shard_id, 43 | num_shards, 44 | backend, 45 | cfg, 46 | args, 47 | ): 48 | """ 49 | Runs a function from a child process. 50 | Args: 51 | local_rank (int): rank of the current process on the current machine. 52 | num_proc (int): number of processes per machine. 53 | func (function): function to execute on each of the process. 54 | init_method (string): method to initialize the distributed training. 55 | TCP initialization: equiring a network address reachable from all 56 | processes followed by the port. 57 | Shared file-system initialization: makes use of a file system that 58 | is shared and visible from all machines. The URL should start with 59 | file:// and contain a path to a non-existent file on a shared file 60 | system. 61 | shard_id (int): the rank of the current machine. 62 | num_shards (int): number of overall machines for the distributed 63 | training job. 64 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 65 | supports, each with different capabilities. Details can be found 66 | here: 67 | https://pytorch.org/docs/stable/distributed.html 68 | cfg (CfgNode): configs. Details can be found in 69 | loco/config/defaults.py 70 | """ 71 | # Initialize the process group. 72 | # shard_id = get_rank() 73 | world_size = num_proc * num_shards 74 | rank = shard_id * num_proc + local_rank 75 | 76 | try: 77 | torch.distributed.init_process_group( 78 | backend=backend, 79 | init_method=init_method, 80 | world_size=world_size, 81 | rank=rank, 82 | ) 83 | except Exception as e: 84 | raise e 85 | 86 | torch.cuda.set_device(local_rank) 87 | func(cfg, args) 88 | 89 | 90 | def destroy_process_group(): 91 | """Destroys the default process group.""" 92 | torch.distributed.destroy_process_group() 93 | 94 | 95 | def scaled_all_reduce(cfg, tensors): 96 | """Performs the scaled all_reduce operation on the provided tensors. 97 | 98 | The input tensors are modified in-place. Currently supports only the sum 99 | reduction operator. The reduced values are scaled by the inverse size of 100 | the process group (equivalent to cfg.NUM_GPUS). 101 | """ 102 | # Queue the reductions 103 | reductions = [] 104 | for tensor in tensors: 105 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 106 | reductions.append(reduction) 107 | # Wait for reductions to finish 108 | for reduction in reductions: 109 | reduction.wait() 110 | # Scale the results 111 | for tensor in tensors: 112 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 113 | return tensors 114 | 115 | 116 | def cat_all_gather(tensors): 117 | """Performs the concatenated all_gather operation on the provided tensors. 118 | """ 119 | tensors_gather = [ 120 | torch.ones_like(tensors) 121 | for _ in range(torch.distributed.get_world_size()) 122 | ] 123 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 124 | 125 | output = torch.cat(tensors_gather, dim=0) 126 | return output 127 | 128 | 129 | def local_cat_all_gather(tensors): 130 | """Performs the concatenated all_gather operation on the provided tensors. 131 | """ 132 | tensors_gather = [ 133 | torch.ones_like(tensors) 134 | for _ in range(get_local_size()) 135 | ] 136 | torch.distributed.all_gather( 137 | tensors_gather, 138 | tensors, 139 | async_op=False, 140 | group=_LOCAL_PROCESS_GROUP, 141 | ) 142 | output = torch.cat(tensors_gather, dim=0) 143 | return output 144 | 145 | 146 | def get_local_size(): 147 | """ 148 | Returns: 149 | The size of the per-machine process group, 150 | i.e. the number of processes per machine. 151 | """ 152 | if not dist.is_available(): 153 | return 1 154 | if not dist.is_initialized(): 155 | return 1 156 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 157 | 158 | 159 | def get_local_rank(): 160 | """ 161 | Returns: 162 | The rank of the current process within the local (per-machine) process group. 163 | """ 164 | if not dist.is_available(): 165 | return 0 166 | if not dist.is_initialized(): 167 | return 0 168 | assert _LOCAL_PROCESS_GROUP is not None 169 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 170 | -------------------------------------------------------------------------------- /src/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | project specific pathmanagers for a project as recommended by Detectron2 4 | """ 5 | 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | helper functions for read and write data 4 | """ 5 | 6 | import os 7 | import json 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from typing import Union 12 | from PIL import Image, ImageFile 13 | Image.MAX_IMAGE_PIXELS = None 14 | 15 | 16 | def save_or_append_df(out_path, df): 17 | if os.path.exists(out_path): 18 | previous_df = pd.read_pickle(out_path) 19 | df = pd.concat([previous_df, df], ignore_index=True) 20 | df.to_pickle(out_path) 21 | print(f"Saved output at {out_path}") 22 | 23 | 24 | class JSONEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.ndarray): 27 | return obj.tolist() 28 | elif isinstance(obj, bytes): 29 | return str(obj, encoding='utf-8') 30 | elif isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | # return super(MyEncoder, self).default(obj) 38 | 39 | raise TypeError( 40 | "Unserializable object {} of type {}".format(obj, type(obj)) 41 | ) 42 | 43 | 44 | def write_json(data: Union[list, dict], outfile: str) -> None: 45 | json_dir, _ = os.path.split(outfile) 46 | if json_dir and not os.path.exists(json_dir): 47 | os.makedirs(json_dir) 48 | 49 | with open(outfile, 'w') as f: 50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2) 51 | 52 | 53 | def read_json(filename: str) -> Union[list, dict]: 54 | """read json files""" 55 | with open(filename, "rb") as fin: 56 | data = json.load(fin, encoding="utf-8") 57 | return data 58 | 59 | 60 | def pil_loader(path: str) -> Image.Image: 61 | """load an image from path, and suppress warning""" 62 | # to avoid crashing for truncated (corrupted images) 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | # open path as file to avoid ResourceWarning 65 | # (https://github.com/python-pillow/Pillow/issues/835) 66 | with open(path, 'rb') as f: 67 | img = Image.open(f) 68 | return img.convert('RGB') 69 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | logger 4 | """ 5 | 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | def logger_all(metric, n_tasks=None): 10 | log_metric = OrderedDict() 11 | log_metric[metric] = np.zeros([n_tasks, n_tasks]) 12 | log_metric['final_acc'] = 0. 13 | log_metric['final_forget'] = 0. 14 | log_metric['final_la'] = 0. 15 | return log_metric 16 | 17 | def logger_eval(metric): 18 | log_metric = OrderedDict() 19 | log_metric[metric] = [] 20 | return log_metric 21 | 22 | def per_task_summary(log_metric, metric, task_id=0, task_t=0, value=0): 23 | if metric is 'acc': 24 | log_metric[metric][task_t, task_id] = value 25 | else: 26 | log_metric[metric] = value 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | average meter during training 4 | """ 5 | 6 | import torch 7 | 8 | def gpu_mem_usage(): 9 | """Computes the GPU memory usage for the current device (GB).""" 10 | if not torch.cuda.is_available(): 11 | return 0 12 | # Number of bytes in a megabyte 13 | _B_IN_GB = 1024 * 1024 * 1024 14 | 15 | mem_usage_bytes = torch.cuda.max_memory_allocated() 16 | return mem_usage_bytes / _B_IN_GB 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self, name, fmt=':f'): 22 | self.name = name 23 | self.fmt = fmt 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def __str__(self): 39 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 40 | return fmtstr.format(**self.__dict__) 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | DAP: Generating Instance-level Prompts for Rehearsal-free Continual Learning (ICCV 2023 Oral) 4 | main function 5 | """ 6 | 7 | import os 8 | import torch 9 | import warnings 10 | 11 | import numpy as np 12 | import random 13 | 14 | from src.configs.config import get_cfg 15 | from src.data import loader as data_loader 16 | from src.engine.trainer import Trainer 17 | from src.models.build_model import build_model 18 | from src.utils.file_io import PathManager 19 | 20 | from launch import default_argument_parser 21 | warnings.filterwarnings("ignore") 22 | 23 | def setup(args): 24 | cfg = get_cfg() 25 | cfg.merge_from_file(args.config_file) 26 | cfg.merge_from_list(args.opts) 27 | 28 | output_path = os.path.join(cfg.OUTPUT_DIR, cfg.DATA.NAME) 29 | 30 | if PathManager.exists(output_path): 31 | raise ValueError(f"Already run for {output_path}") 32 | 33 | PathManager.mkdirs(output_path) 34 | cfg.OUTPUT_DIR = output_path 35 | return cfg 36 | 37 | def get_datasets(cfg): 38 | print("Loading training data...") 39 | train_dataset = data_loader._construct_dataset(cfg, split='train') 40 | print("Loading test data...") 41 | test_dataset = data_loader._construct_dataset(cfg, split='test') 42 | return train_dataset, test_dataset 43 | 44 | def train(cfg): 45 | if torch.cuda.is_available(): 46 | torch.cuda.empty_cache() 47 | 48 | if cfg.SEED is not None: 49 | torch.manual_seed(cfg.SEED) 50 | np.random.seed(cfg.SEED) 51 | random.seed(0) 52 | 53 | train_dataset, test_dataset = get_datasets(cfg) 54 | print("Constructing models...") 55 | model, cur_device = build_model(cfg) 56 | 57 | print("Setting up trainer...") 58 | trainer = Trainer(cfg, model, cur_device) 59 | trainer.train_classifier(train_dataset, test_dataset) 60 | 61 | def main(args): 62 | cfg = setup(args) 63 | train(cfg) 64 | 65 | if __name__ == '__main__': 66 | args = default_argument_parser().parse_args() 67 | main(args) 68 | --------------------------------------------------------------------------------