├── LICENSE.txt ├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── additional_transforms.cpython-36.pyc │ ├── datamgr.cpython-36.pyc │ ├── dataset.cpython-36.pyc │ └── feature_loader.cpython-36.pyc ├── additional_transforms.py ├── datamgr.py ├── dataset.py ├── feature_loader.py └── qmul_loader.py ├── filelists ├── CUB │ ├── download_CUB.sh │ └── write_CUB_filelist.py ├── cifar │ ├── cifar.sh │ └── make_json.py └── miniImagenet │ ├── make.py │ ├── miniImagenet.sh │ ├── test.csv │ ├── train.csv │ └── val.csv ├── methods ├── CSS.py ├── __pycache__ │ ├── CSS.cpython-36.pyc │ ├── SSL_three.cpython-36.pyc │ ├── SSL_two.cpython-36.pyc │ └── meta_template.cpython-36.pyc └── meta_template.py ├── run_css.py └── utils ├── __pycache__ ├── backbone.cpython-36.pyc ├── config.cpython-36.pyc └── utils.cpython-36.pyc ├── backbone.py ├── config.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Self-Supervised Learning for Few-Shot Classification 2 | 3 | Code for "Conditional Self-Supervised Learning for Few-Shot Classification" in IJCAI 2021. 4 | 5 | If you use the code in this repo for your work, please cite the following bib entries: 6 | 7 | ``` 8 | @inproceedings{An2021CSS, 9 | author = {Yuexuan An and 10 | Hui Xue and 11 | Xingyu Zhao and 12 | Lu Zhang}, 13 | title = {Conditional Self-Supervised Learning for Few-Shot Classification}, 14 | booktitle = {Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence, {IJCAI} 2021, Virtual Event / Montreal, Canada, 19-27 August 2021}, 15 | pages = {2140--2146}, 16 | year = {2021}, 17 | } 18 | ``` 19 | 20 | ## Enviroment 21 | 22 | Python3 23 | 24 | Pytorch 25 | 26 | ## Getting started 27 | 28 | ### CIFAR-FS 29 | 30 | - Change directory to `./filelists/cifar` 31 | - Download [CIFAR-FS](https://drive.google.com/file/d/1i4atwczSI9NormW5SynaHa1iVN1IaOcs/view) 32 | - run `bash ./cifar.sh` 33 | 34 | ### CUB 35 | 36 | - Change directory to `./filelists/CUB` 37 | - run `bash ./download_CUB.sh` 38 | 39 | ### mini-ImageNet 40 | 41 | - Change directory to `./filelists/miniImagenet` 42 | - Download [mini-ImageNet](https://drive.google.com/file/d/1hQqDL16HTWv9Jz15SwYh3qq1E4F72UDC/view) 43 | - run `bash ./miniImagenet.sh` 44 | 45 | ## Running 46 | 47 | ``` 48 | python run_css.py 49 | ``` 50 | 51 | ## Acknowledgment 52 | 53 | Our project references the codes and datasets in the following repo and papers. 54 | 55 | [CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) 56 | 57 | Catherine Wah, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. The caltechucsd birds-200-2011 dataset. 2011. 58 | 59 | Luca Bertinetto, João F. Henriques, Philip H. S. Torr, Andrea Vedaldi. Meta-learning with differentiable closed-form solvers. ICLR 2019. 60 | 61 | Oriol Vinyals, Charles Blundell, Tim Lillicrap, Koray Kavukcuoglu, Daan Wierstra. Matching Networks for One Shot Learning. NIPS 2016: 3630-3638. 62 | 63 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | from . import additional_transforms 4 | from . import feature_loader 5 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/additional_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/data/__pycache__/additional_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/datamgr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/data/__pycache__/datamgr.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/feature_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/data/__pycache__/feature_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict = dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, 12 | Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | def __call__(self, img): 20 | out = img 21 | randtensor = torch.rand(len(self.transforms)) 22 | 23 | for i, (transformer, alpha) in enumerate(self.transforms): 24 | r = alpha * (randtensor[i] * 2.0 - 1.0) + 1 25 | out = transformer(out).enhance(r).convert('RGB') 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | import data.additional_transforms as add_transforms 7 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 8 | from abc import abstractmethod 9 | 10 | 11 | def _init_fn(worker_id): 12 | np.random.seed(0) 13 | 14 | 15 | class TransformLoader: 16 | def __init__(self, image_size, 17 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 18 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 19 | self.image_size = image_size 20 | self.normalize_param = normalize_param 21 | self.jitter_param = jitter_param 22 | 23 | def parse_transform(self, transform_type): 24 | if transform_type == 'ImageJitter': 25 | method = add_transforms.ImageJitter(self.jitter_param) 26 | return method 27 | method = getattr(transforms, transform_type) 28 | if transform_type == 'RandomSizedCrop': 29 | return method(self.image_size) 30 | elif transform_type == 'CenterCrop': 31 | return method(self.image_size) 32 | elif transform_type == 'Scale': 33 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)]) 34 | elif transform_type == 'Normalize': 35 | return method(**self.normalize_param) 36 | else: 37 | return method() 38 | 39 | def get_composed_transform(self, aug=False): 40 | if aug: 41 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 42 | else: 43 | transform_list = ['Scale', 'CenterCrop', 'ToTensor', 'Normalize'] 44 | 45 | transform_funcs = [self.parse_transform(x) for x in transform_list] 46 | transform = transforms.Compose(transform_funcs) 47 | return transform 48 | 49 | 50 | class DataManager: 51 | @abstractmethod 52 | def get_data_loader(self, data_file, aug): 53 | pass 54 | 55 | 56 | class SimpleDataManager(DataManager): 57 | def __init__(self, image_size, batch_size): 58 | super(SimpleDataManager, self).__init__() 59 | self.batch_size = batch_size 60 | self.trans_loader = TransformLoader(image_size) 61 | 62 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set 63 | transform = self.trans_loader.get_composed_transform(aug) 64 | dataset = SimpleDataset(data_file, transform) 65 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True) 66 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 67 | return data_loader 68 | 69 | 70 | class SetDataManager(DataManager): 71 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide=100, noise_rate=0., num_workers=4): 72 | super(SetDataManager, self).__init__() 73 | self.image_size = image_size 74 | self.n_way = n_way 75 | self.batch_size = n_support + n_query 76 | self.n_eposide = n_eposide 77 | self.trans_loader = TransformLoader(image_size) 78 | self.noise_rate = noise_rate 79 | self.num_workers = num_workers 80 | 81 | def get_data_loader(self, data_file, aug): # parameters that would change on train/val set 82 | transform = self.trans_loader.get_composed_transform(aug) 83 | dataset = SetDataset(data_file, self.batch_size, transform, noise_rate=self.noise_rate) 84 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide) 85 | data_loader_params = dict(batch_sampler=sampler, num_workers=self.num_workers, pin_memory=True) 86 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 87 | return data_loader 88 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | 10 | identity = lambda x: x 11 | 12 | 13 | class SimpleDataset: 14 | def __init__(self, data_file, transform, target_transform=identity): 15 | with open(data_file, 'r') as f: 16 | self.meta = json.load(f) 17 | self.transform = transform 18 | self.target_transform = target_transform 19 | 20 | def __getitem__(self, i): 21 | image_path = os.path.join(self.meta['image_names'][i]) 22 | img = Image.open(image_path).convert('RGB') 23 | img = self.transform(img) 24 | target = self.target_transform(self.meta['image_labels'][i]) 25 | return img, target 26 | 27 | def __len__(self): 28 | return len(self.meta['image_names']) 29 | 30 | 31 | class SetDataset: 32 | def __init__(self, data_file, batch_size, transform, noise_rate=0.): 33 | with open(data_file, 'r') as f: 34 | self.meta = json.load(f) 35 | 36 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 37 | 38 | self.sub_meta = {} 39 | for cl in self.cl_list: 40 | self.sub_meta[cl] = [] 41 | 42 | for x, y in zip(self.meta['image_names'], self.meta['image_labels']): 43 | self.sub_meta[y].append(x) 44 | 45 | self.sub_dataloader = [] 46 | sub_data_loader_params = dict(batch_size=batch_size, 47 | shuffle=True, 48 | num_workers=0, # use main thread only or may receive multiple batches 49 | pin_memory=False) 50 | for cl in self.cl_list: 51 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform, noise_rate=noise_rate) 52 | self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)) 53 | 54 | def __getitem__(self, i): 55 | return next(iter(self.sub_dataloader[i])) 56 | 57 | def __len__(self): 58 | return len(self.cl_list) 59 | 60 | 61 | class SubDataset: 62 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity, noise_rate=0.): 63 | self.sub_meta = sub_meta 64 | self.cl = cl 65 | self.transform = transform 66 | self.target_transform = target_transform 67 | self.noise_rate = noise_rate 68 | 69 | def __getitem__(self, i): 70 | # print( '%d -%d' %(self.cl,i)) 71 | image_path = os.path.join(self.sub_meta[i]) 72 | img = Image.open(image_path).convert('RGB') 73 | if self.noise_rate > 0.: 74 | if np.random.random() > (1 - self.noise_rate): 75 | # img = np.array(img) 76 | # img = (img + np.random.randint(0, 255, size=img.shape)) // 2 77 | img = np.array(img) 78 | img = np.random.randint(0, 255, size=img.shape) 79 | img = Image.fromarray(img.astype('uint8')) 80 | img = self.transform(img) 81 | target = self.target_transform(self.cl) 82 | return img, target 83 | 84 | def __len__(self): 85 | return len(self.sub_meta) 86 | 87 | 88 | class EpisodicBatchSampler(object): 89 | def __init__(self, n_classes, n_way, n_episodes): 90 | self.n_classes = n_classes 91 | self.n_way = n_way 92 | self.n_episodes = n_episodes 93 | 94 | def __len__(self): 95 | return self.n_episodes 96 | 97 | def __iter__(self): 98 | for i in range(self.n_episodes): 99 | yield torch.randperm(self.n_classes)[:self.n_way] 100 | -------------------------------------------------------------------------------- /data/feature_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | 5 | class SimpleHDF5Dataset: 6 | def __init__(self, file_handle = None): 7 | if file_handle == None: 8 | self.f = '' 9 | self.all_feats_dset = [] 10 | self.all_labels = [] 11 | self.total = 0 12 | else: 13 | self.f = file_handle 14 | self.all_feats_dset = self.f['all_feats'][...] 15 | self.all_labels = self.f['all_labels'][...] 16 | self.total = self.f['count'][0] 17 | # print('here') 18 | def __getitem__(self, i): 19 | return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i]) 20 | 21 | def __len__(self): 22 | return self.total 23 | 24 | def init_loader(filename): 25 | with h5py.File(filename, 'r') as f: 26 | fileset = SimpleHDF5Dataset(f) 27 | 28 | #labels = [ l for l in fileset.all_labels if l != 0] 29 | feats = fileset.all_feats_dset 30 | labels = fileset.all_labels 31 | while np.sum(feats[-1]) == 0: 32 | feats = np.delete(feats,-1,axis = 0) 33 | labels = np.delete(labels,-1,axis = 0) 34 | 35 | class_list = np.unique(np.array(labels)).tolist() 36 | inds = range(len(labels)) 37 | 38 | cl_data_file = {} 39 | for cl in class_list: 40 | cl_data_file[cl] = [] 41 | for ind in inds: 42 | cl_data_file[labels[ind]].append( feats[ind]) 43 | 44 | return cl_data_file 45 | -------------------------------------------------------------------------------- /data/qmul_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | 6 | train_people = ['DennisPNoGlassesGrey', 'JohnGrey', 'SimonBGrey', 'SeanGGrey', 'DanJGrey', 'AdamBGrey', 'JackGrey', 7 | 'RichardHGrey', 'YongminYGrey', 'TomKGrey', 'PaulVGrey', 'DennisPGrey', 'CarlaBGrey', 'JamieSGrey', 8 | 'KateSGrey', 'DerekCGrey', 'KatherineWGrey', 'ColinPGrey', 'SueWGrey', 'GrahamWGrey', 'KrystynaNGrey', 9 | 'SeanGNoGlassesGrey', 'KeithCGrey', 'HeatherLGrey'] 10 | test_people = ['RichardBGrey', 'TasosHGrey', 'SarahLGrey', 'AndreeaVGrey', 'YogeshRGrey'] 11 | 12 | def num_to_str(num): 13 | str_ = '' 14 | if num == 0: 15 | str_ = '000' 16 | elif num < 100: 17 | str_ = '0' + str(int(num)) 18 | else: 19 | str_ = str(int(num)) 20 | return str_ 21 | 22 | 23 | def get_person_at_curve(person, curve, prefix='../filelists/QMUL/images/'): 24 | faces = [] 25 | targets = [] 26 | 27 | train_transforms = transforms.Compose([transforms.ToTensor()]) 28 | for pitch, angle in curve: 29 | fname = prefix + person + '/' + person[:-4] + '_' + num_to_str(pitch) + '_' + num_to_str(angle) + '.jpg' 30 | img = Image.open(fname).convert('RGB') 31 | img = train_transforms(img) 32 | 33 | faces.append(img) 34 | pitch_norm = 2 * ((pitch - 60) / (120 - 60)) - 1 35 | angle_norm = 2 * ((angle - 0) / (180 - 0)) - 1 36 | targets.append(torch.Tensor([pitch_norm])) 37 | 38 | faces = torch.stack(faces) 39 | targets = torch.stack(targets).squeeze() 40 | return faces, targets 41 | 42 | 43 | def get_batch(train_people=train_people, num_samples=19): 44 | ## generate trajectory 45 | amp = np.random.uniform(-3, 3) 46 | phase = np.random.uniform(-5, 5) 47 | wave = [(amp * np.sin(phase + x)) for x in range(num_samples)] 48 | ## map trajectory to angles/pitches 49 | angles = list(range(num_samples)) 50 | angles = [x * 10 for x in angles] 51 | pitches = [int(round(((y + 3) * 10) + 60, -1)) for y in wave] 52 | curve = [(p, a) for p, a in zip(pitches, angles)] 53 | 54 | inputs = [] 55 | targets = [] 56 | for person in train_people: 57 | inps, targs = get_person_at_curve(person, curve) 58 | inputs.append(inps) 59 | targets.append(targs) 60 | return torch.stack(inputs), torch.stack(targets) 61 | -------------------------------------------------------------------------------- /filelists/CUB/download_CUB.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 3 | tar -zxvf CUB_200_2011.tgz 4 | python write_CUB_filelist.py 5 | -------------------------------------------------------------------------------- /filelists/CUB/write_CUB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import random 7 | 8 | cwd = os.getcwd() 9 | data_path = join(cwd,'CUB_200_2011/images') 10 | savedir = './' 11 | dataset_list = ['base','val','novel'] 12 | 13 | #if not os.path.exists(savedir): 14 | # os.makedirs(savedir) 15 | 16 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 17 | folder_list.sort() 18 | label_dict = dict(zip(folder_list,range(0,len(folder_list)))) 19 | 20 | classfile_list_all = [] 21 | 22 | for i, folder in enumerate(folder_list): 23 | folder_path = join(data_path, folder) 24 | classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')]) 25 | random.shuffle(classfile_list_all[i]) 26 | 27 | 28 | for dataset in dataset_list: 29 | file_list = [] 30 | label_list = [] 31 | for i, classfile_list in enumerate(classfile_list_all): 32 | if 'base' in dataset: 33 | if (i%2 == 0): 34 | file_list = file_list + classfile_list 35 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 36 | if 'val' in dataset: 37 | if (i%4 == 1): 38 | file_list = file_list + classfile_list 39 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 40 | if 'novel' in dataset: 41 | if (i%4 == 3): 42 | file_list = file_list + classfile_list 43 | label_list = label_list + np.repeat(i, len(classfile_list)).tolist() 44 | 45 | fo = open(savedir + dataset + ".json", "w") 46 | fo.write('{"label_names": [') 47 | fo.writelines(['"%s",' % item for item in folder_list]) 48 | fo.seek(0, os.SEEK_END) 49 | fo.seek(fo.tell()-1, os.SEEK_SET) 50 | fo.write('],') 51 | 52 | fo.write('"image_names": [') 53 | fo.writelines(['"%s",' % item.replace('\\','/') for item in file_list]) 54 | fo.seek(0, os.SEEK_END) 55 | fo.seek(fo.tell()-1, os.SEEK_SET) 56 | fo.write('],') 57 | 58 | fo.write('"image_labels": [') 59 | fo.writelines(['%d,' % item for item in label_list]) 60 | fo.seek(0, os.SEEK_END) 61 | fo.seek(fo.tell()-1, os.SEEK_SET) 62 | fo.write(']}') 63 | 64 | fo.close() 65 | print("%s -OK" %dataset) 66 | -------------------------------------------------------------------------------- /filelists/cifar/cifar.sh: -------------------------------------------------------------------------------- 1 | unzip cifar.zip 2 | python make_json.py 3 | -------------------------------------------------------------------------------- /filelists/cifar/make_json.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | test = {'label_names': [], 'image_names': [], 'image_labels': []} 6 | pathname = os.getcwd() 7 | pathname = pathname.replace('\\', '/') 8 | # pathname = pathname.split('filelists')[0].replace('\\','/') 9 | print(pathname) 10 | 11 | f = open(pathname + '/cifar/splits/test.txt') 12 | classes = f.readlines() 13 | 14 | count = 80 15 | for each in classes: 16 | each = each.strip() 17 | test['label_names'].append(each) 18 | files = glob.glob(pathname + '/cifar/data/' + each + '/*') 19 | for image_name in files: 20 | test['image_names'].append(image_name.replace('\\', '/')) 21 | test['image_labels'].append(count) 22 | count += 1 23 | 24 | json.dump(test, open('novel.json', 'w'), ensure_ascii=False) 25 | 26 | base = {'label_names': [], 'image_names': [], 'image_labels': []} 27 | f = open(pathname + '/cifar/splits/train.txt') 28 | classes = f.readlines() 29 | 30 | count = 0 31 | for each in classes: 32 | each = each.strip() 33 | base['label_names'].append(each) 34 | files = glob.glob(pathname + '/cifar/data/' + each + '/*') 35 | for image_name in files: 36 | base['image_names'].append(image_name.replace('\\', '/')) 37 | base['image_labels'].append(count) 38 | count += 1 39 | 40 | json.dump(base, open('base.json', 'w'), ensure_ascii=False) 41 | 42 | val = {'label_names': [], 'image_names': [], 'image_labels': []} 43 | f = open(pathname + '/cifar/splits/val.txt') 44 | classes = f.readlines() 45 | 46 | count = 64 47 | for each in classes: 48 | each = each.strip() 49 | val['label_names'].append(each) 50 | files = glob.glob(pathname + '/cifar/data/' + each + '/*') 51 | for image_name in files: 52 | val['image_names'].append(image_name.replace('\\', '/')) 53 | val['image_labels'].append(count) 54 | count += 1 55 | 56 | json.dump(val, open('val.json', 'w'), ensure_ascii=False) 57 | -------------------------------------------------------------------------------- /filelists/miniImagenet/make.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | cwd = os.getcwd() 8 | data_path = os.path.join(cwd, 'miniImagenet') 9 | 10 | all = {} 11 | all['label_names'] = [] 12 | all['image_names'] = [] 13 | all['image_labels'] = [] 14 | 15 | trains = np.array(pd.read_csv(os.path.join(cwd, 'train.csv'))) 16 | base = {} 17 | base['label_names'] = [] 18 | base['image_names'] = [] 19 | base['image_labels'] = [] 20 | for i in tqdm(range(trains.shape[0] // 600)): 21 | all['label_names'].append(trains[600 * i, 1]) 22 | base['label_names'].append(trains[600 * i, 1]) 23 | names = os.listdir(os.path.join(data_path, trains[600 * i, 1])) 24 | for name in names: 25 | all['image_names'].append(os.path.join(data_path, trains[600 * i, 1], name)) 26 | all['image_labels'].append(i) 27 | base['image_names'].append(os.path.join(data_path, trains[600 * i, 1], name)) 28 | base['image_labels'].append(i) 29 | 30 | vals = np.array(pd.read_csv(os.path.join(cwd, 'val.csv'))) 31 | val = {} 32 | val['label_names'] = [] 33 | val['image_names'] = [] 34 | val['image_labels'] = [] 35 | for i in tqdm(range(vals.shape[0] // 600)): 36 | all['label_names'].append(vals[600 * i, 1]) 37 | val['label_names'].append(vals[600 * i, 1]) 38 | names = os.listdir(os.path.join(data_path, vals[600 * i, 1])) 39 | for name in names: 40 | all['image_names'].append(os.path.join(data_path, vals[600 * i, 1], name)) 41 | all['image_labels'].append(i + trains.shape[0] // 600) 42 | val['image_names'].append(os.path.join(data_path, vals[600 * i, 1], name)) 43 | val['image_labels'].append(i + trains.shape[0] // 600) 44 | 45 | tests = np.array(pd.read_csv(os.path.join(cwd, 'test.csv'))) 46 | test = {} 47 | test['label_names'] = [] 48 | test['image_names'] = [] 49 | test['image_labels'] = [] 50 | for i in tqdm(range(tests.shape[0] // 600)): 51 | all['label_names'].append(tests[600 * i, 1]) 52 | test['label_names'].append(tests[600 * i, 1]) 53 | names = os.listdir(os.path.join(data_path, tests[600 * i, 1])) 54 | for name in names: 55 | all['image_names'].append(os.path.join(data_path, tests[600 * i, 1], name)) 56 | all['image_labels'].append(i + (trains.shape[0] + vals.shape[0]) // 600) 57 | test['image_names'].append(os.path.join(data_path, tests[600 * i, 1], name)) 58 | test['image_labels'].append(i + (trains.shape[0] + vals.shape[0]) // 600) 59 | 60 | json.dump(base, open('base.json', 'w')) 61 | json.dump(val, open('val.json', 'w')) 62 | json.dump(test, open('novel.json', 'w')) 63 | json.dump(all, open('all.json', 'w')) 64 | 65 | data = json.load(open('all.json')) 66 | print(data.keys()) 67 | print(len(data['label_names'])) 68 | print(len(data['image_names'])) 69 | print(len(data['image_labels']), np.min(data['image_labels']), np.max(data['image_labels'])) 70 | 71 | data = json.load(open('base.json')) 72 | print(data.keys()) 73 | print(len(data['label_names'])) 74 | print(len(data['image_names'])) 75 | print(len(data['image_labels']), np.min(data['image_labels']), np.max(data['image_labels'])) 76 | 77 | data = json.load(open('val.json')) 78 | print(data.keys()) 79 | print(len(data['label_names'])) 80 | print(len(data['image_names'])) 81 | print(len(data['image_labels']), np.min(data['image_labels']), np.max(data['image_labels'])) 82 | 83 | data = json.load(open('novel.json')) 84 | print(data.keys()) 85 | print(len(data['label_names'])) 86 | print(len(data['image_names'])) 87 | print(len(data['image_labels']), np.min(data['image_labels']), np.max(data['image_labels'])) 88 | -------------------------------------------------------------------------------- /filelists/miniImagenet/miniImagenet.sh: -------------------------------------------------------------------------------- 1 | unzip miniImagenet.zip 2 | python make.py 3 | -------------------------------------------------------------------------------- /methods/CSS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import copy 5 | from methods.meta_template import MetaTemplate 6 | from torchvision import transforms 7 | from PIL import Image 8 | 9 | 10 | class CSS(MetaTemplate): 11 | def __init__(self, model_func, n_way, n_support, use_cuda=True, adaptation=False, image_size=(84, 84), 12 | classification_head='cosine'): 13 | super(CSS, self).__init__(model_func, n_way, n_support, use_cuda=use_cuda, adaptation=adaptation) 14 | self.loss_fn = nn.CrossEntropyLoss() 15 | self.pre_feature_extractor = copy.deepcopy(self.feature_extractor) 16 | self.ssl_feature_extractor = copy.deepcopy(self.feature_extractor) 17 | self.projection_mlp_1 = nn.Sequential( 18 | nn.Linear(self.feature_extractor.final_feat_dim, 2048), 19 | ) 20 | self.projection_mlp_2 = nn.Sequential( 21 | nn.BatchNorm1d(2048), 22 | nn.ReLU(), 23 | nn.Linear(2048, 2048), 24 | nn.BatchNorm1d(2048), 25 | nn.ReLU(), 26 | nn.Linear(2048, 2048), 27 | nn.BatchNorm1d(2048) 28 | ) 29 | self.prediction_mlp = nn.Sequential( 30 | nn.Linear(2048, 512), 31 | nn.BatchNorm1d(512), 32 | nn.ReLU(), 33 | nn.Linear(512, 2048), 34 | ) 35 | self.alpha = nn.Parameter(torch.ones([1])) 36 | self.gamma = nn.Parameter(torch.ones([1]) * 2, requires_grad=False) 37 | self.image_size = image_size 38 | self.classification_head = classification_head 39 | 40 | def cosine_similarity(self, x, y): 41 | # x: m x d 42 | # y: n x d 43 | # return: m x n 44 | assert x.size(1) == y.size(1) 45 | x = torch.nn.functional.normalize(x, dim=1) 46 | y = torch.nn.functional.normalize(y, dim=1) 47 | x = x.unsqueeze(1).expand(x.size(0), y.size(0), x.size(1)) # [m,1*n,d] 48 | y = y.unsqueeze(0).expand(x.shape) # [1*m,n,d] 49 | return (x * y).sum(2) 50 | 51 | def euclidean_dist(self, x, y): 52 | # x: m x d 53 | # y: n x d 54 | # return: m x n 55 | assert x.size(1) == y.size(1) 56 | x = x.unsqueeze(1).expand(x.size(0), y.size(0), x.size(1)) # [m,1*n,d] 57 | y = y.unsqueeze(0).expand(x.shape) # [1*m,n,d] 58 | return torch.pow(x - y, 2).sum(2) 59 | 60 | def set_pre_train_forward(self, x): 61 | z_support, z_query = self.parse_feature(x) 62 | z_support = self.projection_mlp_1(z_support) 63 | z_query = self.projection_mlp_1(z_query) 64 | z_proto = z_support.reshape(self.n_way, self.n_support, -1).mean(1) # [N,d] 65 | z_query = z_query.reshape(self.n_way * self.n_query, -1) # [N*Q,d] 66 | if self.classification_head == 'consine': 67 | return self.cosine_similarity(z_query, z_proto) * 10 68 | else: 69 | return -self.euclidean_dist(z_query, z_proto) 70 | 71 | def set_pre_train_forward_loss(self, x): 72 | y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).long() 73 | if self.use_cuda: 74 | y_query = y_query.cuda() 75 | scores = self.set_pre_train_forward(x) 76 | return self.loss_fn(scores, y_query) 77 | 78 | def pre_train_loop(self, epoch, train_loader, optimizer): 79 | print_freq = 10 80 | avg_loss = 0 81 | for i, (x, _) in enumerate(train_loader): 82 | if self.use_cuda: 83 | x = x.cuda() 84 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 85 | self.n_way = x.size(0) 86 | optimizer.zero_grad() 87 | loss = self.set_pre_train_forward_loss(x) 88 | loss.backward() 89 | optimizer.step() 90 | avg_loss = avg_loss + loss.item() 91 | if self.verbose and (i % print_freq) == 0: 92 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 93 | avg_loss / float(i + 1))) 94 | if not self.verbose: 95 | print('Epoch {:d} | Loss {:f}'.format(epoch, avg_loss / float(i + 1))) 96 | return avg_loss 97 | 98 | def pre_train_test_loop(self, test_loader, record=None, return_std=False): 99 | acc_all = [] 100 | iter_num = len(test_loader) 101 | for i, (x, _) in enumerate(test_loader): 102 | if self.use_cuda: 103 | x = x.cuda() 104 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 105 | self.n_way = x.size(0) 106 | with torch.no_grad(): 107 | x = x.reshape(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 108 | z_all = self.feature_extractor.forward(x) 109 | z_all = z_all.reshape(self.n_way, self.n_support + self.n_query, *z_all.shape[1:]) # [N, S+Q, d] 110 | z_support = z_all[:, :self.n_support] # [N, S, d] 111 | z_query = z_all[:, self.n_support:] # [N, Q, d] 112 | z_proto = z_support.reshape(self.n_way, self.n_support, -1).mean(1) # [N,d] 113 | z_query = z_query.reshape(self.n_way * self.n_query, -1) # [N*Q,d] 114 | scores = self.cosine_similarity(z_query, z_proto) 115 | y_query = np.repeat(range(self.n_way), self.n_query) # [0 0 0 1 1 1 2 2 2 3 3 3 4 4 4] 116 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) # top1, dim=1, largest, sorted 117 | topk_ind = topk_labels.cpu().numpy() # index of topk 118 | acc_all.append(np.sum(topk_ind[:, 0] == y_query) / len(y_query) * 100) 119 | acc_all = np.asarray(acc_all) 120 | acc_mean = np.mean(acc_all) 121 | acc_std = np.std(acc_all) 122 | if self.verbose: 123 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 124 | if return_std: 125 | return acc_mean, acc_std 126 | else: 127 | return acc_mean 128 | 129 | def f(self, x): 130 | # x:[N*(S+Q),n_channel,h,w] 131 | x = self.ssl_feature_extractor(x) 132 | x = self.projection_mlp_1(x) 133 | x = self.projection_mlp_2(x) 134 | return x 135 | 136 | def h(self, x): 137 | # x:[N*(S+Q),2048] 138 | x = self.prediction_mlp(x) 139 | return x 140 | 141 | def D(self, p, z): 142 | z = z.detach() 143 | p = torch.nn.functional.normalize(p, dim=1) 144 | z = torch.nn.functional.normalize(z, dim=1) 145 | return -(p * z).sum(dim=1).mean() 146 | 147 | def data_augmentation(self, img): 148 | # x:[n_channel,h,w], torch.Tensor 149 | x = transforms.RandomResizedCrop(self.image_size, interpolation=Image.BICUBIC)(img) 150 | x = transforms.RandomHorizontalFlip()(x) 151 | if np.random.random() < 0.8: 152 | x = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)(x) 153 | else: 154 | x = transforms.RandomGrayscale(p=1.0)(x) 155 | x = transforms.GaussianBlur((5, 5))(x) 156 | return x 157 | 158 | def contrastive_loss(self, x): 159 | # x:[N*(S+Q),n_channel,h,w] 160 | x1 = x.clone() 161 | x2 = x.clone() 162 | for index in range(x.shape[0]): 163 | x1[index] = self.data_augmentation(x[index]) 164 | x2[index] = self.data_augmentation(x[index]) 165 | z1, z2 = self.f(x1), self.f(x2) 166 | p1, p2 = self.h(z1), self.h(z2) 167 | loss = self.D(p1, z2) / 2 + self.D(p2, z1) / 2 168 | return loss 169 | 170 | def ssl_train_loop(self, epoch, train_loader, optimizer): 171 | self.train() 172 | print_freq = 10 173 | avg_loss = 0 174 | for i, (x, _) in enumerate(train_loader): # x:[N, S+Q, n_channel, h, w] 175 | if self.use_cuda: 176 | x = x.cuda() 177 | x = x.reshape([x.shape[0] * x.shape[1], *x.shape[2:]]) # x:[N*(S+Q),n_channel,h,w] 178 | x_ssl = torch.nn.functional.normalize(self.ssl_feature_extractor(x), dim=1) 179 | x_pre = torch.nn.functional.normalize(self.feature_extractor(x).detach(), dim=1) 180 | loss = self.contrastive_loss(x) - torch.mean(torch.sum((x_ssl * x_pre), dim=1)) 181 | optimizer.zero_grad() 182 | loss.backward() 183 | optimizer.step() 184 | avg_loss = avg_loss + loss.item() 185 | if self.verbose and (i % print_freq) == 0: 186 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 187 | avg_loss / float(i + 1))) 188 | if not self.verbose: 189 | print('Epoch {:d} | Loss {:f}'.format(epoch, avg_loss / float(i + 1))) 190 | 191 | def ssl_test_loop(self, test_loader, record=None, return_std=False): 192 | acc_all = [] 193 | iter_num = len(test_loader) 194 | for i, (x, _) in enumerate(test_loader): 195 | if self.use_cuda: 196 | x = x.cuda() 197 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 198 | self.n_way = x.size(0) 199 | with torch.no_grad(): 200 | x = x.reshape(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 201 | z_all = self.ssl_feature_extractor.forward(x) 202 | z_all = z_all.reshape(self.n_way, self.n_support + self.n_query, *z_all.shape[1:]) # [N, S+Q, d] 203 | z_support = z_all[:, :self.n_support] # [N, S, d] 204 | z_query = z_all[:, self.n_support:] # [N, Q, d] 205 | z_proto = z_support.reshape(self.n_way, self.n_support, -1).mean(1) # [N,d] 206 | z_query = z_query.reshape(self.n_way * self.n_query, -1) # [N*Q,d] 207 | scores = self.cosine_similarity(z_query, z_proto) 208 | y_query = np.repeat(range(self.n_way), self.n_query) # [0 0 0 1 1 1 2 2 2 3 3 3 4 4 4] 209 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) # top1, dim=1, largest, sorted 210 | topk_ind = topk_labels.cpu().numpy() # index of topk 211 | acc_all.append(np.sum(topk_ind[:, 0] == y_query) / len(y_query) * 100) 212 | acc_all = np.asarray(acc_all) 213 | acc_mean = np.mean(acc_all) 214 | acc_std = np.std(acc_all) 215 | if self.verbose: 216 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 217 | if return_std: 218 | return acc_mean, acc_std 219 | else: 220 | return acc_mean 221 | 222 | def set_forward(self, x): 223 | z_support, z_query = self.parse_feature(x) 224 | z_support = self.projection_mlp_1(z_support) 225 | z_query = self.projection_mlp_1(z_query) 226 | z_proto = z_support.reshape(self.n_way, self.n_support, -1).mean(1) # [N,d] 227 | z_query = z_query.reshape(self.n_way * self.n_query, -1) # [N*Q,d] 228 | if self.classification_head == 'consine': 229 | return self.cosine_similarity(z_query, z_proto) * 10 230 | else: 231 | return -self.euclidean_dist(z_query, z_proto) 232 | 233 | def set_forward_loss(self, x): 234 | y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).long() 235 | if self.use_cuda: 236 | y_query = y_query.cuda() 237 | scores = self.set_forward(x) 238 | return self.loss_fn(scores, y_query) 239 | 240 | def meta_train_loop(self, epoch, train_loader, optimizer): 241 | self.train() 242 | print_freq = 10 243 | avg_loss = 0 244 | for i, (x, _) in enumerate(train_loader): # x:[N, S+Q, n_channel, h, w] 245 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 246 | self.n_way = x.size(0) 247 | if self.use_cuda: 248 | x = x.cuda() 249 | xx = x.reshape([x.shape[0] * x.shape[1], *x.shape[2:]]) # x:[N*(S+Q),n_channel,h,w] 250 | with torch.no_grad(): 251 | x_pre = self.pre_feature_extractor(xx) 252 | x_ssl = self.ssl_feature_extractor(xx) 253 | x_aggregation = nn.functional.normalize(torch.cat([x_pre, x_ssl], dim=1), dim=1) 254 | similarity = torch.sum(torch.unsqueeze(x_aggregation, dim=0) * torch.unsqueeze(x_aggregation, dim=1), 255 | dim=2) 256 | for index in range(similarity.shape[0]): 257 | similarity[index, index] = 0 258 | D = torch.diag(torch.sum(similarity, dim=1) ** -0.5) 259 | A = D @ similarity @ D 260 | if self.use_cuda: 261 | augment = (self.alpha * torch.eye(A.shape[0], A.shape[0]).cuda() + A) ** self.gamma 262 | else: 263 | augment = (self.alpha * torch.eye(A.shape[0], A.shape[0]) + A) ** self.gamma 264 | z_all = augment @ self.feature_extractor(xx) 265 | z_all = z_all.reshape(self.n_way, self.n_support + self.n_query, -1) 266 | z_support = z_all[:, :self.n_support] 267 | z_query = z_all[:, self.n_support:] 268 | z_support = self.projection_mlp_1(z_support) 269 | z_query = self.projection_mlp_1(z_query) 270 | z_proto = z_support.reshape(self.n_way, self.n_support, -1).mean(1) # [N,d] 271 | z_query = z_query.reshape(self.n_way * self.n_query, -1) # [N*Q,d] 272 | y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).long() 273 | if self.use_cuda: 274 | y_query = y_query.cuda() 275 | scores = self.cosine_similarity(z_query, z_proto) * 10 276 | loss = self.loss_fn(scores, y_query) + self.set_forward_loss(x) 277 | optimizer.zero_grad() 278 | loss.backward() 279 | optimizer.step() 280 | avg_loss = avg_loss + loss.item() 281 | if self.verbose and (i % print_freq) == 0: 282 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 283 | avg_loss / float(i + 1))) 284 | if not self.verbose: 285 | print('Epoch {:d} | Loss {:f}'.format(epoch, avg_loss / float(i + 1))) 286 | return avg_loss 287 | -------------------------------------------------------------------------------- /methods/__pycache__/CSS.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/methods/__pycache__/CSS.cpython-36.pyc -------------------------------------------------------------------------------- /methods/__pycache__/SSL_three.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/methods/__pycache__/SSL_three.cpython-36.pyc -------------------------------------------------------------------------------- /methods/__pycache__/SSL_two.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/methods/__pycache__/SSL_two.cpython-36.pyc -------------------------------------------------------------------------------- /methods/__pycache__/meta_template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/methods/__pycache__/meta_template.cpython-36.pyc -------------------------------------------------------------------------------- /methods/meta_template.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from abc import abstractmethod 5 | 6 | 7 | class MetaTemplate(nn.Module): 8 | def __init__(self, model_func, n_way, n_support, verbose=False, use_cuda=True, adaptation=False): 9 | super(MetaTemplate, self).__init__() 10 | self.n_way = n_way # N, n_classes 11 | self.n_support = n_support # S, sample num of support set 12 | self.n_query = -1 # Q, sample num of query set(change depends on input) 13 | self.feature_extractor = model_func() # feature extractor 14 | self.feat_dim = self.feature_extractor.final_feat_dim 15 | self.verbose = verbose 16 | self.use_cuda = use_cuda 17 | self.adaptation = adaptation 18 | 19 | @abstractmethod 20 | def set_forward(self, x): 21 | # x -> predicted score 22 | pass 23 | 24 | @abstractmethod 25 | def set_forward_loss(self, x): 26 | # x -> loss value 27 | pass 28 | 29 | def forward(self, x): 30 | # x-> feature embedding 31 | out = self.feature_extractor.forward(x) 32 | return out 33 | 34 | def parse_feature(self, x): 35 | x = x.requires_grad_(True) 36 | x = x.reshape(self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 37 | z_all = self.feature_extractor.forward(x) 38 | z_all = z_all.reshape(self.n_way, self.n_support + self.n_query, *z_all.shape[1:]) # [N, S+Q, d] 39 | z_support = z_all[:, :self.n_support] # [N, S, d] 40 | z_query = z_all[:, self.n_support:] # [N, Q, d] 41 | return z_support, z_query 42 | 43 | def correct(self, x): 44 | if self.adaptation: 45 | scores = self.set_forward_adaptation(x) 46 | else: 47 | scores = self.set_forward(x) 48 | y_query = np.repeat(range(self.n_way), self.n_query) # [0 0 0 1 1 1 2 2 2 3 3 3 4 4 4] 49 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) # top1, dim=1, largest, sorted 50 | topk_ind = topk_labels.cpu().numpy() # index of topk 51 | top1_correct = np.sum(topk_ind[:, 0] == y_query) 52 | return float(top1_correct), len(y_query) 53 | 54 | def train_loop(self, epoch, train_loader, optimizer): 55 | print_freq = 10 56 | avg_loss = 0 57 | for i, (x, _) in enumerate(train_loader): 58 | if self.use_cuda: 59 | x = x.cuda() 60 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 61 | self.n_way = x.size(0) 62 | optimizer.zero_grad() 63 | loss = self.set_forward_loss(x) 64 | loss.backward() 65 | optimizer.step() 66 | avg_loss = avg_loss + loss.item() 67 | if self.verbose and (i % print_freq) == 0: 68 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), 69 | avg_loss / float(i + 1))) 70 | if not self.verbose: 71 | print('Epoch {:d} | Loss {:f}'.format(epoch, avg_loss / float(i + 1))) 72 | return avg_loss 73 | 74 | def test_loop(self, test_loader, record=None, return_std=False): 75 | acc_all = [] 76 | iter_num = len(test_loader) 77 | for i, (x, _) in enumerate(test_loader): 78 | if self.use_cuda: 79 | x = x.cuda() 80 | self.n_query = x.size(1) - self.n_support # x:[N, S+Q, n_channel, h, w] 81 | self.n_way = x.size(0) 82 | correct_this, count_this = self.correct(x) 83 | acc_all.append(correct_this / count_this * 100) 84 | acc_all = np.asarray(acc_all) 85 | acc_mean = np.mean(acc_all) 86 | acc_std = np.std(acc_all) 87 | if self.verbose: 88 | # Confidence Interval 90% -> 1.645 95% -> 1.96 99% -> 2.576 89 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (iter_num, acc_mean, 1.96 * acc_std / np.sqrt(iter_num))) 90 | if return_std: 91 | return acc_mean, acc_std 92 | else: 93 | return acc_mean -------------------------------------------------------------------------------- /run_css.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | from methods.CSS import CSS 3 | import copy 4 | 5 | datasets = ['cifar', 'CUB', 'miniImagenet'] 6 | 7 | classification_head = 'cosine' 8 | 9 | for dataset in datasets: 10 | print(dataset) 11 | # region parameters 12 | algorithm = 'css' # protonet/matchingnet/relationnet 13 | model_name = 'Conv4' # Conv4/Conv6/ResNet10/ResNet18/ResNet34/ResNet50/ResNet101 14 | n_shot = 5 # number of labeled data in each class, same as n_support 15 | stop_epoch = -1 16 | pre_train_epoch = -1 17 | ssl_train_epoch = -1 18 | if_pre_train = True 19 | if_ssl_train = True 20 | if_meta_train = True 21 | if_test = True 22 | if 'Conv' in model_name: 23 | image_resize = (84, 84) 24 | else: 25 | image_resize = (224, 224) 26 | # endregion 27 | 28 | image_size = get_image_size(model_name=model_name, dataset=dataset) 29 | 30 | if stop_epoch == -1: 31 | stop_epoch = get_stop_epoch(algorithm=algorithm, dataset=dataset, n_shot=n_shot) 32 | if pre_train_epoch == -1: 33 | pre_train_epoch = stop_epoch 34 | if ssl_train_epoch == -1: 35 | ssl_train_epoch = stop_epoch 36 | checkpoint_dir = get_checkpoint_dir(algorithm=algorithm, model_name=model_name, dataset=dataset, 37 | train_n_way=train_n_way, n_shot=n_shot, addition='%f' % noise_rate) 38 | base_file, val_file = get_train_files(dataset=dataset) 39 | base_loader, val_loader = get_train_loader(algorithm=algorithm, image_size=image_size, base_file=base_file, 40 | val_file=val_file, train_n_way=train_n_way, test_n_way=test_n_way, 41 | n_shot=n_shot, noise_rate=noise_rate, val_noise=True, 42 | num_workers=num_workers) 43 | 44 | 45 | def pre_train(): 46 | print('Start pre-training!') 47 | model = CSS(model_dict[model_name], n_way=train_n_way, n_support=n_shot, use_cuda=use_cuda, 48 | adaptation=adaptation, image_size=image_resize, classification_head=classification_head) 49 | if use_cuda: 50 | model = model.cuda() 51 | max_acc = 0 52 | optimizer = torch.optim.Adam([{'params': model.feature_extractor.parameters(), 'lr': 1e-3}, 53 | {'params': model.projection_mlp_1.parameters(), 'lr': 1e-6}]) 54 | for pre_epoch in range(0, pre_train_epoch): 55 | model.train() 56 | model.pre_train_loop(pre_epoch, base_loader, optimizer) # model are called by reference, no need to return 57 | if not os.path.isdir(checkpoint_dir): 58 | os.makedirs(checkpoint_dir) 59 | model.eval() 60 | acc = model.pre_train_test_loop(val_loader) 61 | if not os.path.isdir(checkpoint_dir): 62 | os.makedirs(checkpoint_dir) 63 | if acc > max_acc: # for baseline and baseline++, we don't use validation here so we let acc = -1 64 | print('epoch:', pre_epoch, 'pre_train val acc:', acc, 'best!') 65 | max_acc = acc 66 | outfile = os.path.join(checkpoint_dir, 'pre_train_best.tar') 67 | torch.save({'epoch': pre_epoch, 'state': model.state_dict()}, outfile) 68 | if (pre_epoch % save_freq == 0) or (pre_epoch == stop_epoch - 1): 69 | outfile = os.path.join(checkpoint_dir, 'pre_train_{:d}.tar'.format(pre_epoch)) 70 | torch.save({'epoch': pre_epoch, 'state': model.state_dict()}, outfile) 71 | return model 72 | 73 | 74 | def ssl_train(): 75 | print('Start ssl-training!') 76 | model = CSS(model_dict[model_name], n_way=train_n_way, n_support=n_shot, use_cuda=use_cuda, 77 | adaptation=adaptation, image_size=image_resize, classification_head=classification_head) 78 | if use_cuda: 79 | model = model.cuda() 80 | outfile = os.path.join(checkpoint_dir, 'pre_train_best.tar') 81 | tmp = torch.load(outfile) 82 | model.load_state_dict(tmp['state']) 83 | max_acc = 0 84 | optimizer = torch.optim.Adam([{'params': model.ssl_feature_extractor.parameters(), 'lr': 1e-3}, 85 | {'params': model.projection_mlp_1.parameters(), 'lr': 1e-3}, 86 | {'params': model.projection_mlp_2.parameters(), 'lr': 1e-3}, 87 | {'params': model.prediction_mlp.parameters(), 'lr': 1e-3}, ]) 88 | for ssl_epoch in range(0, ssl_train_epoch): 89 | model.train() 90 | model.ssl_train_loop(ssl_epoch, base_loader, optimizer) 91 | model.eval() 92 | acc = model.ssl_test_loop(val_loader) 93 | if acc > max_acc: 94 | print('epoch:', ssl_epoch, 'ssl_train val acc:', acc, 'best!') 95 | max_acc = acc 96 | outfile = os.path.join(checkpoint_dir, 'ssl_train_best.tar') 97 | torch.save({'epoch': ssl_epoch, 'state': model.state_dict()}, outfile) 98 | if (ssl_epoch % save_freq == 0) or (ssl_epoch == ssl_train_epoch - 1): 99 | outfile = os.path.join(checkpoint_dir, 'ssl_train_{:d}.tar'.format(ssl_epoch)) 100 | torch.save({'epoch': ssl_epoch, 'state': model.state_dict()}, outfile) 101 | return model 102 | 103 | 104 | def meta_train(): 105 | print('Start meta-training!') 106 | model = CSS(model_dict[model_name], n_way=train_n_way, n_support=n_shot, use_cuda=use_cuda, 107 | adaptation=adaptation, image_size=image_resize, classification_head=classification_head) 108 | if use_cuda: 109 | model = model.cuda() 110 | model.pre_feature_extractor = copy.deepcopy(model.feature_extractor) 111 | outfile = os.path.join(checkpoint_dir, 'ssl_train_best.tar') 112 | tmp = torch.load(outfile) 113 | model.load_state_dict(tmp['state']) 114 | max_acc = 0 115 | optimizer = torch.optim.Adam([{'params': model.feature_extractor.parameters(), 'lr': 1e-3}, 116 | {'params': model.projection_mlp_1.parameters(), 'lr': 1e-3}, 117 | {'params': model.alpha, 'lr': 1e-3}]) 118 | for epoch in range(start_epoch, stop_epoch): 119 | model.train() 120 | model.meta_train_loop(epoch, base_loader, optimizer) # model are called by reference, no need to return 121 | model.eval() 122 | acc = model.test_loop(val_loader) 123 | if acc > max_acc: # for baseline and baseline++, we don't use validation here so we let acc = -1 124 | print("--> Best model! save...", acc) 125 | max_acc = acc 126 | outfile = os.path.join(checkpoint_dir, 'best_model.tar') 127 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) 128 | if not os.path.isdir(checkpoint_dir): 129 | os.makedirs(checkpoint_dir) 130 | if (epoch % save_freq == 0) or (epoch == stop_epoch - 1): 131 | outfile = os.path.join(checkpoint_dir, '{:d}.tar'.format(epoch)) 132 | torch.save({'epoch': epoch, 'state': model.state_dict()}, outfile) 133 | return model 134 | 135 | 136 | def test(phase='test'): 137 | print('Start testing!') 138 | model = CSS(model_dict[model_name], n_way=train_n_way, n_support=n_shot, use_cuda=use_cuda, 139 | adaptation=adaptation, image_size=image_resize, classification_head=classification_head) 140 | if use_cuda: 141 | model = model.cuda() 142 | if phase == 'pre': 143 | modelfile = os.path.join(checkpoint_dir, 'pre_train_best.tar') 144 | assert modelfile is not None 145 | tmp = torch.load(modelfile) 146 | model.load_state_dict(tmp['state']) 147 | elif phase == 'ssl': 148 | modelfile = os.path.join(checkpoint_dir, 'ssl_train_best.tar') 149 | assert modelfile is not None 150 | tmp = torch.load(modelfile) 151 | model.load_state_dict(tmp['state']) 152 | elif phase == 'test': 153 | modelfile = get_best_file(checkpoint_dir) 154 | assert modelfile is not None 155 | tmp = torch.load(modelfile) 156 | model.load_state_dict(tmp['state']) 157 | 158 | loadfile = get_novel_file(dataset=dataset, split='novel') 159 | datamgr = SetDataManager(image_size, n_eposide=test_iter_num, n_query=15, n_way=test_n_way, n_support=n_shot, 160 | noise_rate=0., num_workers=num_workers) 161 | novel_loader = datamgr.get_data_loader(loadfile, aug=False) 162 | model.eval() 163 | if phase == 'pre': 164 | acc_mean, acc_std = model.pre_train_test_loop(novel_loader, return_std=True) 165 | elif phase == 'ssl': 166 | acc_mean, acc_std = model.ssl_test_loop(novel_loader, return_std=True) 167 | else: 168 | acc_mean, acc_std = model.test_loop(novel_loader, return_std=True) 169 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % (test_iter_num, acc_mean, 1.96 * acc_std / np.sqrt(test_iter_num))) 170 | return model 171 | 172 | 173 | if if_pre_train: 174 | pre_train() 175 | if if_test: 176 | test('pre') 177 | if if_ssl_train: 178 | ssl_train() 179 | if if_test: 180 | test('ssl') 181 | if if_meta_train: 182 | meta_train() 183 | test() 184 | -------------------------------------------------------------------------------- /utils/__pycache__/backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/utils/__pycache__/backbone.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anyuexuan/CSS/2f44f8f359e53eb216331182f6c37c28852a572c/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/backbone.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | from torch.nn.utils.weight_norm import WeightNorm 7 | 8 | # Basic ResNet model 9 | def init_layer(L): 10 | if isinstance(L, nn.Conv2d): 11 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 12 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 13 | elif isinstance(L, nn.BatchNorm2d): 14 | L.weight.data.fill_(1) 15 | L.bias.data.fill_(0) 16 | 17 | 18 | class distLinear(nn.Module): 19 | def __init__(self, indim, outdim): 20 | super(distLinear, self).__init__() 21 | self.L = nn.Linear(indim, outdim, bias=False) 22 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 23 | if self.class_wise_learnable_norm: 24 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 25 | if outdim <= 200: 26 | self.scale_factor = 2 # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 27 | else: 28 | self.scale_factor = 10 # in omniglot, a larger scale factor is required to handle >1000 output classes. 29 | 30 | def forward(self, x): 31 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 32 | x_normalized = x.div(x_norm + 0.00001) 33 | if not self.class_wise_learnable_norm: 34 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 35 | self.L.weight.data = self.L.weight.data / (L_norm + 0.00001) 36 | # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 37 | cos_dist = self.L(x_normalized) 38 | scores = self.scale_factor * cos_dist 39 | return scores 40 | 41 | 42 | class Flatten(nn.Module): 43 | def __init__(self): 44 | super(Flatten, self).__init__() 45 | 46 | def forward(self, x): 47 | return x.view(x.size(0), -1) 48 | 49 | 50 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight 51 | def __init__(self, in_features, out_features): 52 | super(Linear_fw, self).__init__(in_features, out_features) 53 | self.weight.fast = None # Lazy hack to add fast weight link 54 | self.bias.fast = None 55 | 56 | def forward(self, x): 57 | if self.weight.fast is not None and self.bias.fast is not None: 58 | out = F.linear(x, self.weight.fast, 59 | self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight 60 | else: 61 | out = super(Linear_fw, self).forward(x) 62 | return out 63 | 64 | 65 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 67 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 68 | bias=bias) 69 | self.weight.fast = None 70 | if not self.bias is None: 71 | self.bias.fast = None 72 | 73 | def forward(self, x): 74 | if self.bias is None: 75 | if self.weight.fast is not None: 76 | out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding) 77 | else: 78 | out = super(Conv2d_fw, self).forward(x) 79 | else: 80 | if self.weight.fast is not None and self.bias.fast is not None: 81 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding) 82 | else: 83 | out = super(Conv2d_fw, self).forward(x) 84 | return out 85 | 86 | 87 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight 88 | def __init__(self, num_features): 89 | super(BatchNorm2d_fw, self).__init__(num_features) 90 | self.weight.fast = None 91 | self.bias.fast = None 92 | 93 | def forward(self, x): 94 | running_mean = torch.zeros(x.data.size()[1]) 95 | running_var = torch.ones(x.data.size()[1]) 96 | if torch.cuda.is_available(): 97 | running_mean = running_mean.cuda() 98 | running_var = running_var.cuda() 99 | if self.weight.fast is not None and self.bias.fast is not None: 100 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True, 101 | momentum=1) 102 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py 103 | else: 104 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1) 105 | return out 106 | 107 | 108 | # Simple Conv Block 109 | class ConvBlock(nn.Module): 110 | maml = False # Default 111 | 112 | def __init__(self, indim, outdim, pool=True, padding=1): 113 | super(ConvBlock, self).__init__() 114 | self.indim = indim 115 | self.outdim = outdim 116 | if self.maml: 117 | self.C = Conv2d_fw(indim, outdim, 3, padding=padding) 118 | self.BN = BatchNorm2d_fw(outdim) 119 | else: 120 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding) 121 | self.BN = nn.BatchNorm2d(outdim) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.parametrized_layers = [self.C, self.BN, self.relu] 124 | if pool: 125 | self.pool = nn.MaxPool2d(2) 126 | self.parametrized_layers.append(self.pool) 127 | for layer in self.parametrized_layers: 128 | init_layer(layer) 129 | self.trunk = nn.Sequential(*self.parametrized_layers) 130 | 131 | def forward(self, x): 132 | out = self.trunk(x) 133 | return out 134 | 135 | 136 | # Simple ResNet Block 137 | class SimpleBlock(nn.Module): 138 | maml = False # Default 139 | 140 | def __init__(self, indim, outdim, half_res): 141 | super(SimpleBlock, self).__init__() 142 | self.indim = indim 143 | self.outdim = outdim 144 | if self.maml: 145 | self.C1 = Conv2d_fw(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 146 | self.BN1 = BatchNorm2d_fw(outdim) 147 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False) 148 | self.BN2 = BatchNorm2d_fw(outdim) 149 | else: 150 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 151 | self.BN1 = nn.BatchNorm2d(outdim) 152 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False) 153 | self.BN2 = nn.BatchNorm2d(outdim) 154 | self.relu1 = nn.ReLU(inplace=True) 155 | self.relu2 = nn.ReLU(inplace=True) 156 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 157 | self.half_res = half_res 158 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 159 | if indim != outdim: 160 | if self.maml: 161 | self.shortcut = Conv2d_fw(indim, outdim, 1, 2 if half_res else 1, bias=False) 162 | self.BNshortcut = BatchNorm2d_fw(outdim) 163 | else: 164 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 165 | self.BNshortcut = nn.BatchNorm2d(outdim) 166 | self.parametrized_layers.append(self.shortcut) 167 | self.parametrized_layers.append(self.BNshortcut) 168 | self.shortcut_type = '1x1' 169 | else: 170 | self.shortcut_type = 'identity' 171 | for layer in self.parametrized_layers: 172 | init_layer(layer) 173 | 174 | def forward(self, x): 175 | out = self.C1(x) 176 | out = self.BN1(out) 177 | out = self.relu1(out) 178 | out = self.C2(out) 179 | out = self.BN2(out) 180 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 181 | out = out + short_out 182 | out = self.relu2(out) 183 | return out 184 | 185 | 186 | # Bottleneck block 187 | class BottleneckBlock(nn.Module): 188 | maml = False # Default 189 | 190 | def __init__(self, indim, outdim, half_res): 191 | super(BottleneckBlock, self).__init__() 192 | bottleneckdim = int(outdim / 4) 193 | self.indim = indim 194 | self.outdim = outdim 195 | if self.maml: 196 | self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False) 197 | self.BN1 = BatchNorm2d_fw(bottleneckdim) 198 | self.C2 = Conv2d_fw(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1) 199 | self.BN2 = BatchNorm2d_fw(bottleneckdim) 200 | self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False) 201 | self.BN3 = BatchNorm2d_fw(outdim) 202 | else: 203 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False) 204 | self.BN1 = nn.BatchNorm2d(bottleneckdim) 205 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1, padding=1) 206 | self.BN2 = nn.BatchNorm2d(bottleneckdim) 207 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False) 208 | self.BN3 = nn.BatchNorm2d(outdim) 209 | self.relu = nn.ReLU() 210 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3] 211 | self.half_res = half_res 212 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 213 | if indim != outdim: 214 | if self.maml: 215 | self.shortcut = Conv2d_fw(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 216 | else: 217 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 218 | self.parametrized_layers.append(self.shortcut) 219 | self.shortcut_type = '1x1' 220 | else: 221 | self.shortcut_type = 'identity' 222 | for layer in self.parametrized_layers: 223 | init_layer(layer) 224 | 225 | def forward(self, x): 226 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x) 227 | out = self.C1(x) 228 | out = self.BN1(out) 229 | out = self.relu(out) 230 | out = self.C2(out) 231 | out = self.BN2(out) 232 | out = self.relu(out) 233 | out = self.C3(out) 234 | out = self.BN3(out) 235 | out = out + short_out 236 | out = self.relu(out) 237 | return out 238 | 239 | 240 | class ConvNet(nn.Module): 241 | def __init__(self, depth, flatten=True): 242 | super(ConvNet, self).__init__() 243 | trunk = [] 244 | for i in range(depth): 245 | indim = 3 if i == 0 else 64 246 | outdim = 64 247 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 248 | trunk.append(B) 249 | if flatten: 250 | trunk.append(Flatten()) 251 | self.trunk = nn.Sequential(*trunk) 252 | self.final_feat_dim = 1600 253 | 254 | def forward(self, x): 255 | out = self.trunk(x) 256 | return out 257 | 258 | 259 | class ConvNetNopool( 260 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling 261 | def __init__(self, depth): 262 | super(ConvNetNopool, self).__init__() 263 | trunk = [] 264 | for i in range(depth): 265 | indim = 3 if i == 0 else 64 266 | outdim = 64 267 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]), 268 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding 269 | trunk.append(B) 270 | self.trunk = nn.Sequential(*trunk) 271 | self.final_feat_dim = [64, 19, 19] 272 | 273 | def forward(self, x): 274 | out = self.trunk(x) 275 | return out 276 | 277 | 278 | class ConvNetS(nn.Module): # For omniglot, only 1 input channel, output dim is 64 279 | def __init__(self, depth, flatten=True): 280 | super(ConvNetS, self).__init__() 281 | trunk = [] 282 | for i in range(depth): 283 | indim = 1 if i == 0 else 64 284 | outdim = 64 285 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 286 | trunk.append(B) 287 | if flatten: 288 | trunk.append(Flatten()) 289 | # trunk.append(nn.BatchNorm1d(64)) #TODO remove 290 | # trunk.append(nn.ReLU(inplace=True)) #TODO remove 291 | # trunk.append(nn.Linear(64, 64)) #TODO remove 292 | self.trunk = nn.Sequential(*trunk) 293 | self.final_feat_dim = 64 294 | 295 | def forward(self, x): 296 | out = x[:, 0:1, :, :] # only use the first dimension 297 | out = self.trunk(out) 298 | # out = torch.tanh(out) #TODO remove 299 | return out 300 | 301 | 302 | class ConvNetSNopool( 303 | nn.Module): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling. For omniglot, only 1 input channel, output dim is [64,5,5] 304 | def __init__(self, depth): 305 | super(ConvNetSNopool, self).__init__() 306 | trunk = [] 307 | for i in range(depth): 308 | indim = 1 if i == 0 else 64 309 | outdim = 64 310 | B = ConvBlock(indim, outdim, pool=(i in [0, 1]), 311 | padding=0 if i in [0, 1] else 1) # only first two layer has pooling and no padding 312 | trunk.append(B) 313 | self.trunk = nn.Sequential(*trunk) 314 | self.final_feat_dim = [64, 5, 5] 315 | 316 | def forward(self, x): 317 | out = x[:, 0:1, :, :] # only use the first dimension 318 | out = self.trunk(out) 319 | return out 320 | 321 | 322 | class ResNet(nn.Module): 323 | maml = False # Default 324 | 325 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True): 326 | # list_of_num_layers specifies number of layers in each stage 327 | # list_of_out_dims specifies number of output channel for each stage 328 | super(ResNet, self).__init__() 329 | assert len(list_of_num_layers) == 4, 'Can have only four stages' 330 | if self.maml: 331 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, 332 | bias=False) 333 | bn1 = BatchNorm2d_fw(64) 334 | else: 335 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 336 | bias=False) 337 | bn1 = nn.BatchNorm2d(64) 338 | relu = nn.ReLU() 339 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 340 | init_layer(conv1) 341 | init_layer(bn1) 342 | trunk = [conv1, bn1, relu, pool1] 343 | indim = 64 344 | for i in range(4): 345 | for j in range(list_of_num_layers[i]): 346 | half_res = (i >= 1) and (j == 0) 347 | B = block(indim, list_of_out_dims[i], half_res) 348 | trunk.append(B) 349 | indim = list_of_out_dims[i] 350 | if flatten: 351 | avgpool = nn.AvgPool2d(7) 352 | trunk.append(avgpool) 353 | trunk.append(Flatten()) 354 | self.final_feat_dim = indim 355 | else: 356 | self.final_feat_dim = [indim, 7, 7] 357 | self.trunk = nn.Sequential(*trunk) 358 | 359 | def forward(self, x): 360 | out = self.trunk(x) 361 | return out 362 | 363 | 364 | # Backbone for QMUL regression 365 | class Conv3(nn.Module): 366 | def __init__(self): 367 | super(Conv3, self).__init__() 368 | self.layer1 = nn.Conv2d(3, 36, 3, stride=2, dilation=2) 369 | self.layer2 = nn.Conv2d(36, 36, 3, stride=2, dilation=2) 370 | self.layer3 = nn.Conv2d(36, 36, 3, stride=2, dilation=2) 371 | 372 | def return_clones(self): 373 | layer1_w = self.layer1.weight.data.clone().detach() 374 | layer2_w = self.layer2.weight.data.clone().detach() 375 | layer3_w = self.layer3.weight.data.clone().detach() 376 | return [layer1_w, layer2_w, layer3_w] 377 | 378 | def assign_clones(self, weights_list): 379 | self.layer1.weight.data.copy_(weights_list[0]) 380 | self.layer2.weight.data.copy_(weights_list[1]) 381 | self.layer3.weight.data.copy_(weights_list[2]) 382 | 383 | def forward(self, x): 384 | out = F.relu(self.layer1(x)) 385 | out = F.relu(self.layer2(out)) 386 | out = F.relu(self.layer3(out)) 387 | out = out.view(out.size(0), -1) 388 | return out 389 | 390 | 391 | def Conv4(): 392 | return ConvNet(4) 393 | 394 | 395 | def Conv6(): 396 | return ConvNet(6) 397 | 398 | 399 | def Conv4NP(): 400 | return ConvNetNopool(4) 401 | 402 | 403 | def Conv6NP(): 404 | return ConvNetNopool(6) 405 | 406 | 407 | def Conv4S(): 408 | return ConvNetS(4) 409 | 410 | 411 | def Conv4SNP(): 412 | return ConvNetSNopool(4) 413 | 414 | 415 | def ResNet10(flatten=True): 416 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 417 | 418 | 419 | def ResNet18(flatten=True): 420 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten) 421 | 422 | 423 | def ResNet34(flatten=True): 424 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten) 425 | 426 | 427 | def ResNet50(flatten=True): 428 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten) 429 | 430 | 431 | def ResNet101(flatten=True): 432 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten) 433 | 434 | 435 | # --- feature-wise transformation layer --- 436 | class FeatureWiseTransformation2d_fw(nn.BatchNorm2d): 437 | feature_augment = False 438 | 439 | def __init__(self, num_features, momentum=0.1, track_running_stats=True): 440 | super(FeatureWiseTransformation2d_fw, self).__init__(num_features, momentum=momentum, 441 | track_running_stats=track_running_stats) 442 | self.weight.fast = None 443 | self.bias.fast = None 444 | if self.track_running_stats: 445 | self.register_buffer('running_mean', torch.zeros(num_features)) 446 | self.register_buffer('running_var', torch.zeros(num_features)) 447 | if self.feature_augment: # initialize {gamma, beta} with {0.3, 0.5} 448 | self.gamma = torch.nn.Parameter(torch.ones(1, num_features, 1, 1) * 0.3) 449 | self.beta = torch.nn.Parameter(torch.ones(1, num_features, 1, 1) * 0.5) 450 | self.reset_parameters() 451 | 452 | def reset_running_stats(self): 453 | if self.track_running_stats: 454 | self.running_mean.zero_() 455 | self.running_var.fill_(1) 456 | 457 | def forward(self, x, step=0): 458 | if self.weight.fast is not None and self.bias.fast is not None: 459 | weight = self.weight.fast 460 | bias = self.bias.fast 461 | else: 462 | weight = self.weight 463 | bias = self.bias 464 | if self.track_running_stats: 465 | out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, 466 | momentum=self.momentum) 467 | else: 468 | out = F.batch_norm(x, torch.zeros_like(x), torch.ones_like(x), weight, bias, training=True, momentum=1) 469 | 470 | # apply feature-wise transformation 471 | if self.feature_augment and self.training: 472 | gamma = (1 + torch.randn(1, self.num_features, 1, 1, dtype=self.gamma.dtype, 473 | device=self.gamma.device) * nn.functional.softplus(self.gamma, 474 | beta=100)).expand_as(out) 475 | beta = (torch.randn(1, self.num_features, 1, 1, dtype=self.beta.dtype, 476 | device=self.beta.device) * nn.functional.softplus(self.beta, beta=100)).expand_as(out) 477 | out = gamma * out + beta 478 | return out 479 | 480 | # --- LSTMCell module for matchingnet --- 481 | class LSTMCell(nn.Module): 482 | maml = False 483 | 484 | def __init__(self, input_size, hidden_size, bias=True): 485 | super(LSTMCell, self).__init__() 486 | self.input_size = input_size 487 | self.hidden_size = hidden_size 488 | self.bias = bias 489 | if self.maml: 490 | self.x2h = Linear_fw(input_size, 4 * hidden_size, bias=bias) 491 | self.h2h = Linear_fw(hidden_size, 4 * hidden_size, bias=bias) 492 | else: 493 | self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) 494 | self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) 495 | self.reset_parameters() 496 | 497 | def reset_parameters(self): 498 | std = 1.0 / math.sqrt(self.hidden_size) 499 | for w in self.parameters(): 500 | w.data.uniform_(-std, std) 501 | 502 | def forward(self, x, hidden=None): 503 | if hidden is None: 504 | hx = torch.zeors_like(x) 505 | cx = torch.zeros_like(x) 506 | else: 507 | hx, cx = hidden 508 | 509 | gates = self.x2h(x) + self.h2h(hx) 510 | ingate, forgetgate, cellgate, outgate = torch.split(gates, self.hidden_size, dim=1) 511 | 512 | ingate = torch.sigmoid(ingate) 513 | forgetgate = torch.sigmoid(forgetgate) 514 | cellgate = torch.tanh(cellgate) 515 | outgate = torch.sigmoid(outgate) 516 | 517 | cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) 518 | hy = torch.mul(outgate, torch.tanh(cy)) 519 | return (hy, cy) 520 | 521 | 522 | # --- LSTM module for matchingnet --- 523 | class LSTM(nn.Module): 524 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, bidirectional=False): 525 | super(LSTM, self).__init__() 526 | 527 | self.input_size = input_size 528 | self.hidden_size = hidden_size 529 | self.num_layers = num_layers 530 | self.bias = bias 531 | self.batch_first = batch_first 532 | self.num_directions = 2 if bidirectional else 1 533 | assert (self.num_layers == 1) 534 | 535 | self.lstm = LSTMCell(input_size, hidden_size, self.bias) 536 | 537 | def forward(self, x, hidden=None): 538 | # swap axis if batch first 539 | if self.batch_first: 540 | x = x.permute(1, 0, 2) 541 | 542 | # hidden state 543 | if hidden is None: 544 | h0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) 545 | c0 = torch.zeros(self.num_directions, x.size(1), self.hidden_size, dtype=x.dtype, device=x.device) 546 | else: 547 | h0, c0 = hidden 548 | 549 | # forward 550 | outs = [] 551 | hn = h0[0] 552 | cn = c0[0] 553 | for seq in range(x.size(0)): 554 | hn, cn = self.lstm(x[seq], (hn, cn)) 555 | outs.append(hn.unsqueeze(0)) 556 | outs = torch.cat(outs, dim=0) 557 | 558 | # reverse foward 559 | if self.num_directions == 2: 560 | outs_reverse = [] 561 | hn = h0[1] 562 | cn = c0[1] 563 | for seq in range(x.size(0)): 564 | seq = x.size(1) - 1 - seq 565 | hn, cn = self.lstm(x[seq], (hn, cn)) 566 | outs_reverse.append(hn.unsqueeze(0)) 567 | outs_reverse = torch.cat(outs_reverse, dim=0) 568 | outs = torch.cat([outs, outs_reverse], dim=2) 569 | 570 | # swap axis if batch first 571 | if self.batch_first: 572 | outs = outs.permute(1, 0, 2) 573 | return outs 574 | 575 | # --- BatchNorm1d --- 576 | class BatchNorm1d_fw(nn.BatchNorm1d): 577 | def __init__(self, num_features, momentum=0.1, track_running_stats=True): 578 | super(BatchNorm1d_fw, self).__init__(num_features, momentum=momentum, track_running_stats=track_running_stats) 579 | self.weight.fast = None 580 | self.bias.fast = None 581 | if self.track_running_stats: 582 | self.register_buffer('running_mean', torch.zeros(num_features)) 583 | self.register_buffer('running_var', torch.zeros(num_features)) 584 | self.reset_parameters() 585 | 586 | def reset_running_stats(self): 587 | if self.track_running_stats: 588 | self.running_mean.zero_() 589 | self.running_var.fill_(1) 590 | 591 | def forward(self, x, step=0): 592 | if self.weight.fast is not None and self.bias.fast is not None: 593 | weight = self.weight.fast 594 | bias = self.bias.fast 595 | else: 596 | weight = self.weight 597 | bias = self.bias 598 | if self.track_running_stats: 599 | out = F.batch_norm(x, self.running_mean, self.running_var, weight, bias, training=self.training, 600 | momentum=self.momentum) 601 | else: 602 | out = F.batch_norm(x, torch.zeros(x.size(1), dtype=x.dtype, device=x.device), 603 | torch.ones(x.size(1), dtype=x.dtype, device=x.device), weight, bias, training=True, 604 | momentum=1) 605 | return out 606 | 607 | model_dict = dict( 608 | Conv4=Conv4, 609 | Conv4S=Conv4S, 610 | Conv6=Conv6, 611 | ResNet10=ResNet10, 612 | ResNet18=ResNet18, 613 | ResNet34=ResNet34, 614 | ResNet50=ResNet50, 615 | ResNet101=ResNet101) 616 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | data_dir = dict( 4 | CUB=os.path.dirname(__file__) + '/../filelists/CUB/', 5 | miniImagenet=os.path.dirname(__file__) + '/../filelists/miniImagenet/', 6 | omniglot=os.path.dirname(__file__) + '/../filelists/omniglot/', 7 | emnist=os.path.dirname(__file__) + '/../filelists/emnist/', 8 | cifar=os.path.dirname(__file__) + '/../filelists/cifar/', 9 | fc100=os.path.dirname(__file__) + '/../filelists/fc100/', 10 | ) 11 | 12 | 13 | num_workers = 4 14 | test_iter_num = 600 15 | 16 | 17 | def get_stop_epoch(algorithm, dataset, n_shot=5): 18 | if algorithm in ['baseline', 'baseline++']: 19 | if dataset in ['omniglot', 'cross_char']: 20 | stop_epoch = 5 21 | elif dataset in ['CUB']: 22 | stop_epoch = 200 # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting 23 | elif dataset in ['miniImagenet', 'cross']: 24 | stop_epoch = 400 25 | else: 26 | stop_epoch = 400 # default 27 | else: # meta-learning methods 28 | stop_epoch = 400 29 | return stop_epoch -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from . import backbone 4 | from .config import * 5 | import os 6 | import random 7 | import glob 8 | from data.datamgr import SimpleDataManager, SetDataManager 9 | import h5py 10 | 11 | # region common parameters 12 | 13 | base_path = os.path.dirname(__file__).replace('\\', '/') + '/..' 14 | 15 | model_dict = dict( 16 | Conv4=backbone.Conv4, 17 | Conv4S=backbone.Conv4S, 18 | Conv6=backbone.Conv6, 19 | ResNet10=backbone.ResNet10, 20 | ResNet18=backbone.ResNet18, 21 | ResNet34=backbone.ResNet34, 22 | ResNet50=backbone.ResNet50, 23 | ResNet101=backbone.ResNet101 24 | ) 25 | 26 | start_epoch = 0 # Starting epoch 27 | save_freq = 50 # Save frequency 28 | train_n_way = 5 # class num to classify for training 29 | test_n_way = 5 # class num to classify for testing (validation) 30 | adaptation = False 31 | noise_rate = 0. 32 | 33 | if torch.cuda.is_available(): 34 | use_cuda = True 35 | print('GPU detected, running with GPU!') 36 | else: 37 | print('GPU not detected, running with CPU!') 38 | use_cuda = False 39 | 40 | 41 | def set_seed(seed=0): 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = False 49 | 50 | 51 | seed = 0 52 | set_seed(seed) 53 | 54 | 55 | # endregion 56 | 57 | def one_hot(y, num_class): 58 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1).long(), 1) 59 | 60 | 61 | def DBindex(cl_data_file): 62 | class_list = cl_data_file.keys() 63 | cl_num = len(class_list) 64 | cl_means = [] 65 | stds = [] 66 | DBs = [] 67 | for cl in class_list: 68 | cl_means.append(np.mean(cl_data_file[cl], axis=0)) 69 | stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1)))) 70 | 71 | mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1)) 72 | mu_j = np.transpose(mu_i, (1, 0, 2)) 73 | mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2)) 74 | 75 | for i in range(cl_num): 76 | DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i])) 77 | return np.mean(DBs) 78 | 79 | 80 | def sparsity(cl_data_file): 81 | class_list = cl_data_file.keys() 82 | cl_sparsity = [] 83 | for cl in class_list: 84 | cl_sparsity.append(np.mean([np.sum(x != 0) for x in cl_data_file[cl]])) 85 | return np.mean(cl_sparsity) 86 | 87 | 88 | def get_image_size(model_name, dataset): 89 | if 'Conv' in model_name: 90 | if dataset in ['omniglot', 'cross_char']: 91 | image_size = 28 92 | else: 93 | image_size = 84 94 | else: 95 | image_size = 224 96 | return image_size 97 | 98 | 99 | def get_train_files(dataset): 100 | if dataset == 'cross': 101 | base_file = data_dir['miniImagenet'] + 'all.json' 102 | val_file = data_dir['CUB'] + 'val.json' 103 | elif dataset == 'cross_char': 104 | base_file = data_dir['omniglot'] + 'noLatin.json' 105 | val_file = data_dir['emnist'] + 'val.json' 106 | else: 107 | base_file = data_dir[dataset] + 'base.json' 108 | val_file = data_dir[dataset] + 'val.json' 109 | return base_file, val_file 110 | 111 | 112 | def get_train_loader(algorithm, image_size, base_file, val_file, train_n_way, test_n_way, n_shot, noise_rate=0., 113 | val_noise=True, num_workers=4): 114 | if algorithm in ['baseline', 'baseline++']: 115 | base_datamgr = SimpleDataManager(image_size, batch_size=16) 116 | base_loader = base_datamgr.get_data_loader(base_file, aug=True) 117 | val_datamgr = SimpleDataManager(image_size, batch_size=64) 118 | val_loader = val_datamgr.get_data_loader(val_file, aug=False) 119 | else: 120 | n_query = max(1, int( 121 | 16 * test_n_way / train_n_way)) # if test_n_way