├── CIL.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY-LICENSES ├── configs ├── multi_step │ ├── createyaml.py │ ├── createyaml_fusion.py │ ├── imnet_base_resnet10_500.yml │ ├── imnet_base_resnet18_500.yml │ ├── imnet_delta_fc_resnet10_500_tmpl.yml │ ├── imnet_delta_fc_resnet18_20shot_500_tmpl.yml │ ├── imnet_delta_layer4_resnet10_500_tmpl.yml │ ├── imnet_delta_layer4_resnet18_20shot_500_tmpl.yml │ └── test │ │ ├── resnet10 │ │ ├── createyaml.py │ │ └── deltacls_branches_tmpl.yaml │ │ └── resnet18 │ │ ├── createyaml.py │ │ └── deltacls_branches_tmpl.yaml └── table1 │ ├── imnet_base_resnet10.yml │ └── imnet_base_resnet18.yml ├── prepro ├── data │ └── TODO-link-imagenet-here ├── gen_split500.py ├── gen_split800_40_40.py └── utils.py ├── scripts ├── collectresults50050.py ├── createyaml.py ├── getresults50050.sh └── getresults80040.sh └── src ├── hyper_search.py ├── loader ├── __init__.py ├── baldata.py ├── basedata.py ├── img_flist.py ├── nshotdata.py └── nshotdatanovel.py ├── metrics.py ├── models ├── __init__.py ├── deltacls.py ├── deltapoolcls.py ├── linearcls.py ├── resnet.py └── resnet_2br.py ├── optimizers ├── __init__.py └── scheduler.py ├── parse_result.py ├── test_fusion.py ├── test_fusion_ms.py ├── train_base.py ├── train_fusion.py ├── train_fusion_ms.py ├── train_novel.py ├── train_route.py └── utils.py /CIL.yml: -------------------------------------------------------------------------------- 1 | name: CIL 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.6.24=0 10 | - certifi=2018.8.24=py35_1 11 | - cudatoolkit=10.2.89=hfd86e86_1 12 | - freetype=2.10.2=h5ab3b9f_0 13 | - intel-openmp=2019.4=243 14 | - jpeg=9b=h024ee3a_2 15 | - libedit=3.1.20191231=h14c3975_1 16 | - libffi=3.2.1=hd88cf55_4 17 | - libgcc-ng=9.1.0=hdf63c60_0 18 | - libgfortran-ng=7.3.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=9.1.0=hdf63c60_0 21 | - libtiff=4.1.0=h2733197_1 22 | - lz4-c=1.9.2=he6710b0_0 23 | - mkl=2018.0.3=1 24 | - mkl_fft=1.0.6=py35h7dd41cf_0 25 | - mkl_random=1.0.1=py35h4414c95_1 26 | - ncurses=6.2=he6710b0_1 27 | - ninja=1.8.2=py35h6bb024c_1 28 | - numpy=1.15.2=py35h1d66e8a_0 29 | - numpy-base=1.15.2=py35h81de0dd_0 30 | - olefile=0.46=py_0 31 | - openssl=1.0.2u=h7b6447c_0 32 | - pillow=5.2.0=py35heded4f4_0 33 | - pip=10.0.1=py35_0 34 | - python=3.5.6=hc3d631a_0 35 | - pytorch=1.5.1=py3.5_cuda10.2.89_cudnn7.6.5_0 36 | - readline=7.0=h7b6447c_5 37 | - setuptools=40.2.0=py35_0 38 | - sqlite=3.32.3=h62c20be_0 39 | - tbb=2020.0=hfd86e86_0 40 | - tbb4py=2018.0.5=py35h6bb024c_0 41 | - tk=8.6.10=hbc83047_0 42 | - torchvision=0.6.1=py35_cu102 43 | - wheel=0.31.1=py35_0 44 | - xz=5.2.5=h7b6447c_0 45 | - zlib=1.2.11=h7b6447c_3 46 | - zstd=1.4.5=h0b5b093_0 47 | - pip: 48 | - absl-py==0.13.0 49 | - astor==0.8.1 50 | - future==0.18.2 51 | - gast==0.5.0 52 | - google-pasta==0.2.0 53 | - grpcio==1.38.1 54 | - h5py==2.10.0 55 | - importlib-metadata==2.1.1 56 | - keras-applications==1.0.8 57 | - keras-preprocessing==1.1.2 58 | - markdown==3.2.2 59 | - pandas==0.25.3 60 | - protobuf==3.12.2 61 | - python-dateutil==2.8.1 62 | - pytz==2021.1 63 | - pyyaml==5.1.2 64 | - six==1.15.0 65 | - tensorboard==1.14.0 66 | - tensorboardx==1.8 67 | - tensorflow-estimator==1.14.0 68 | - tensorflow-gpu==1.14.0 69 | - termcolor==1.1.0 70 | - werkzeug==1.0.1 71 | - wrapt==1.12.1 72 | - zipp==1.2.0 73 | prefix: /home/ubuntu/anaconda3/envs/CIL 74 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SP-CIL 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Strongly Pretrained Class Incremental Learning 2 | 3 | This is code base for the following paper: 4 | 5 | ### [Class-Incremental Learning with Strong Pre-trained Models](https://arxiv.org/abs/2204.03634) 6 | Tz-Ying Wu, Gurumurthy Swaminathan, Zhizhong Li, Avinash Ravichandran, Nuno Vasconcelos, Rahul Bhotika, Stefano Soatto 7 | 8 | Please read our paper for details! 9 | 10 | ## Installation 11 | To create a conda environment to run the project, simply run 12 | `conda env create -f CIL.yml`. 13 | 14 | ## Data 15 | Create a soft link of the imagenet folder (the root folder that includes train/val image folders) at `prepro/data/imagenet`. 16 | 17 | ## Experiments 18 | ### 800-40 19 | ```shell 20 | bash scripts/getresults80040.sh -l layer4 -n 10 # for resnet10 21 | bash scripts/getresults80040.sh -l layer4 -n 18 # for resnet18 22 | bash scripts/getresults80040.sh -l fc -n 10 # for resnet10, fc-only 23 | bash scripts/getresults80040.sh -l fc -n 18 # for resnet18, fc-only 24 | ``` 25 | 26 | ### 500-50 27 | ```shell 28 | bash scripts/getresults50050.sh -l layer4 -n 10 29 | bash scripts/getresults50050.sh -l layer4 -n 18 30 | bash scripts/getresults50050.sh -l fc -n 10 31 | bash scripts/getresults50050.sh -l fc -n 18 32 | ``` 33 | 34 | ## Security 35 | 36 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 37 | 38 | ## License 39 | 40 | This project is licensed under the Apache-2.0 License. 41 | -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | The Amazon SP-CIL Product includes the following third-party software/licensing: 2 | 3 | From PyTorch: 4 | 5 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 6 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 7 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 8 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 9 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 10 | Copyright (c) 2011-2013 NYU (Clement Farabet) 11 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 12 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 13 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 14 | From Caffe2: 15 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 16 | All contributions by Facebook: 17 | Copyright (c) 2016 Facebook Inc. 18 | All contributions by Google: 19 | Copyright (c) 2015 Google Inc. 20 | All rights reserved. 21 | All contributions by Yangqing Jia: 22 | Copyright (c) 2015 Yangqing Jia 23 | All rights reserved. 24 | All contributions by Kakao Brain: 25 | Copyright 2019-2020 Kakao Brain 26 | All contributions by Cruise LLC: 27 | Copyright (c) 2022 Cruise LLC. 28 | All rights reserved. 29 | All contributions from Caffe: 30 | Copyright(c) 2013, 2014, 2015, the respective contributors 31 | All rights reserved. 32 | All other contributions: 33 | Copyright(c) 2015, 2016 the respective contributors 34 | All rights reserved. 35 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 36 | copyright over their contributions to Caffe2. The project versioning records 37 | all such contribution and copyright details. If a contributor wants to further 38 | mark their specific copyright on a particular contribution, they should 39 | indicate their copyright solely in the commit message of the change when it is 40 | committed. 41 | All rights reserved. 42 | Redistribution and use in source and binary forms, with or without 43 | modification, are permitted provided that the following conditions are met: 44 | 1. Redistributions of source code must retain the above copyright 45 | notice, this list of conditions and the following disclaimer. 46 | 2. Redistributions in binary form must reproduce the above copyright 47 | notice, this list of conditions and the following disclaimer in the 48 | documentation and/or other materials provided with the distribution. 49 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 50 | and IDIAP Research Institute nor the names of its contributors may be 51 | used to endorse or promote products derived from this software without 52 | specific prior written permission. 53 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 54 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 55 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 56 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 57 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 58 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 59 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 60 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 61 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 62 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 63 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /configs/multi_step/createyaml.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import os 6 | import yaml 7 | 8 | 9 | def create_novel(baseyml, size, branch, outputdir): 10 | with open(baseyml, "r") as f: 11 | doc = yaml.full_load(f) 12 | 13 | for step in range(50, 550, 50): 14 | doc["seed"] = 1 15 | arch = doc["model"]["feature_extractor"]["arch"] 16 | if arch[-2:] != str(size): 17 | raise Exception("Incorrect base yml {baseyml} for size {size}".format(baseyml=baseyml, size=size)) 18 | 19 | if branch == "fc": 20 | doc["model"]["feature_extractor"]["n_freeze"] = 4 21 | else: 22 | doc["model"]["feature_extractor"]["n_freeze"] = 3 23 | 24 | doc["model"]["classifier"]["n_class"] = 50 25 | doc["data"]["loader"] = "NshotDataLoader" 26 | doc["data"]["train"] = "prepro/splits/imagenet/split500/novel_train_{step}.csv".format(step=step) 27 | doc["data"]["val"] = "prepro/splits/imagenet/split500/novel_val_{step}.csv".format(step=step) 28 | doc["data"]["n_shot"] = 0 29 | doc["data"]["lbl_offset"] = 0 30 | 31 | doc["training"]["epoch"] = 30 32 | doc["training"]["save_interval"] = 10 33 | doc["training"]["scheduler"]["step_size"] = 10 34 | doc["training"]["resume"]["model"] = "runs/imnet_base_{arch}_500/{arch}_500_norm4/ep-90_model.pkl".format( 35 | arch=arch 36 | ) 37 | doc["exp"] = "FT_{arch}_500_norm4_{branch}_all_s{step}".format(arch=arch, branch=branch, step=step) 38 | 39 | with open("{}/imnet_novel_{}_{}_{}.yml".format(outputdir, arch, branch, step), "w") as f: 40 | yaml.dump(doc, f) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description="config") 45 | parser.add_argument( 46 | "--baseyml", 47 | type=str, 48 | help="base yaml file", 49 | ) 50 | parser.add_argument( 51 | "--size", 52 | type=int, 53 | default=10, 54 | help="network size (10 or 18)", 55 | ) 56 | parser.add_argument( 57 | "--branch", 58 | type=str, 59 | help="branch layer (layer4 or fc)", 60 | ) 61 | parser.add_argument( 62 | "--outputdir", 63 | type=str, 64 | help="output dir for config files", 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | if args.branch != "fc" and args.branch != "layer4": 70 | raise Exception("Incorrect branch argument {}".format(args.branch)) 71 | if args.size != 10 and args.size != 18: 72 | raise Exception("Incorrect network size argument {}".format(args.size)) 73 | if not os.path.exists(args.baseyml): 74 | raise Exception("Base yaml file {} does not exist".format(args.baseyml)) 75 | 76 | create_novel(args.baseyml, args.size, args.branch, args.outputdir) 77 | -------------------------------------------------------------------------------- /configs/multi_step/createyaml_fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import yaml 6 | import string 7 | 8 | 9 | def get_yaml_list(strs, indents=0): 10 | return "\n".join(" " * indents + "- " + s for s in strs) 11 | 12 | 13 | templates = [ 14 | "imnet_delta_layer4_resnet10_500_tmpl.yml", 15 | "imnet_delta_layer4_resnet18_20shot_500_tmpl.yml", # Note that this is 20-shot to compare with prior work 16 | "imnet_delta_fc_resnet10_500_tmpl.yml", 17 | "imnet_delta_fc_resnet18_20shot_500_tmpl.yml", # Note that this is 20-shot to compare with prior work 18 | ] 19 | model_templates = [ 20 | ( 21 | "runs/imnet_base_resnet10_500/resnet10_500_norm4/ep-90_model.pkl", 22 | "runs/imnet_novel_resnet10_layer4_{n_novel_cls}/FT_resnet10_500_norm4_layer4_all_s{n_novel_cls}/ep-30_model.pkl", 23 | ), 24 | ( 25 | "runs/imnet_base_resnet18_500/resnet18_500_norm4/ep-90_model.pkl", 26 | "runs/imnet_novel_resnet18_layer4_{n_novel_cls}/FT_resnet18_500_norm4_layer4_all_s{n_novel_cls}/ep-30_model.pkl", 27 | ), 28 | ( 29 | "runs/imnet_base_resnet10_500/resnet10_500_norm4/ep-90_model.pkl", 30 | "runs/imnet_novel_resnet10_fc_{n_novel_cls}/FT_resnet10_500_norm4_fc_all_s{n_novel_cls}/ep-30_model.pkl", 31 | ), 32 | ( 33 | "runs/imnet_base_resnet18_500/resnet18_500_norm4/ep-90_model.pkl", 34 | "runs/imnet_novel_resnet18_fc_{n_novel_cls}/FT_resnet18_500_norm4_fc_all_s{n_novel_cls}/ep-30_model.pkl", 35 | ), 36 | ] 37 | 38 | for template, (base_model, model_template) in zip(templates, model_templates): 39 | with open(template, "rt") as fp: 40 | cfg = fp.read() 41 | 42 | cfg_tmpl = string.Template(cfg) 43 | for n in range(1, 11): 44 | with open(template.replace("tmpl", "split%d" % n), "wt") as fp: 45 | models_3_indents = """ 46 | - {base_model} 47 | """.format( 48 | base_model=base_model 49 | ) + get_yaml_list( 50 | [model_template.format(n_novel_cls=(x * 50)) for x in range(1, n + 1)], 51 | 12, 52 | ) 53 | fp.write( 54 | cfg_tmpl.substitute( 55 | n_novel=n, 56 | n_novel_plus_1=n + 1, 57 | n_novel_times_50=n * 50, 58 | n_base_cls_3_indents="\n" + get_yaml_list([str(x) for x in [500] + [50] * (n - 1)], 12), 59 | models_3_indents=models_3_indents, 60 | ) 61 | ) 62 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_base_resnet10_500.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1357 5 | model: 6 | feature_extractor: 7 | arch: resnet10 8 | n_freeze: 0 9 | pretrained: false 10 | classifier: 11 | arch: linearcls 12 | feat_size: 512 13 | n_class: 500 14 | norm: 4 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | train: prepro/splits/imagenet/split500/base_train.csv 19 | val: prepro/splits/imagenet/split500/base_val-test.csv 20 | n_workers: 12 21 | training: 22 | norm: 4 23 | epoch: 90 24 | batch_size: 256 25 | val_interval: 10 26 | save_interval: 10 27 | print_interval: 1 28 | optimizer_main: 29 | name: sgd 30 | lr: 0.1 31 | momentum: 0.9 32 | weight_decay: 0.0001 33 | scheduler: 34 | step_size: 30 35 | gamma: 0.1 36 | resume: 37 | model: 38 | param_only: true 39 | exp: resnet10_500_norm4 40 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_base_resnet18_500.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1357 5 | model: 6 | feature_extractor: 7 | arch: resnet18 8 | n_freeze: 0 9 | pretrained: false 10 | classifier: 11 | arch: linearcls 12 | feat_size: 512 13 | n_class: 500 14 | norm: 4 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | train: prepro/splits/imagenet/split500/base_train.csv 19 | val: prepro/splits/imagenet/split500/base_val-test.csv 20 | n_workers: 12 21 | training: 22 | norm: 4 23 | epoch: 90 24 | batch_size: 256 25 | val_interval: 10 26 | save_interval: 10 27 | print_interval: 1 28 | optimizer_main: 29 | name: sgd 30 | lr: 0.1 31 | momentum: 0.9 32 | weight_decay: 0.0001 33 | scheduler: 34 | step_size: 30 35 | gamma: 0.1 36 | resume: 37 | model: 38 | param_only: true 39 | exp: resnet18_500_norm4 40 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_delta_fc_resnet10_500_tmpl.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1 5 | model: 6 | feature_extractor: 7 | arch: resnet10_nbr # two branches for layer4, the rest is the same 8 | pretrained: false 9 | n_freeze: 5 # freeze all layers in resnet 10 | n_branch: ${n_novel_plus_1} 11 | classifier: 12 | arch: deltaclsn 13 | feat_size: 512 14 | n_classes: ${n_base_cls_3_indents} 15 | - 50 16 | norm: 4 17 | data: 18 | #loader: BalancedDataLoader # balanced sampling of base and novel 19 | loader: NshotDataLoader # random sampling of base and novel 20 | root_dir: prepro/data/imagenet 21 | train: prepro/splits/imagenet/split500/train_${n_novel_times_50}.csv 22 | val: prepro/splits/imagenet/split500/val_${n_novel_times_50}-dev.csv 23 | n_shot: 10 # 0 means using all data 24 | n_workers: 12 25 | training: 26 | norm: 4 27 | n_base_cls: ${n_base_cls_3_indents} 28 | n_novel_cls: 50 29 | epoch: 10 30 | batch_size: 256 31 | val_interval: 10 32 | save_interval: 10 33 | print_interval: 1 34 | dp_scaling: no_dp 35 | optimizer_main: 36 | name: sgd 37 | lr: 0.1 38 | momentum: 0.9 39 | weight_decay: 0.0001 40 | scheduler: 41 | step_size: 10 42 | gamma: 0.1 43 | resume: 44 | models: ${models_3_indents} 45 | param_only: true 46 | exp: split500/resnet10_nbr_500_fc_50_sp${n_novel}_a{:g}dp{:g}_10shot_s{:d}-dev 47 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_delta_fc_resnet18_20shot_500_tmpl.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1 5 | model: 6 | feature_extractor: 7 | arch: resnet18_nbr # two branches for layer4, the rest is the same 8 | pretrained: false 9 | n_freeze: 5 # freeze all layers in resnet 10 | n_branch: ${n_novel_plus_1} 11 | classifier: 12 | arch: deltaclsn 13 | feat_size: 512 14 | n_classes: ${n_base_cls_3_indents} 15 | - 50 16 | norm: 4 17 | data: 18 | #loader: BalancedDataLoader # balanced sampling of base and novel 19 | loader: NshotDataLoader # random sampling of base and novel 20 | root_dir: prepro/data/imagenet 21 | train: prepro/splits/imagenet/split500/train_${n_novel_times_50}.csv 22 | val: prepro/splits/imagenet/split500/val_${n_novel_times_50}-dev.csv 23 | n_shot: 20 # 0 means using all data 24 | n_workers: 12 25 | training: 26 | norm: 4 27 | n_base_cls: ${n_base_cls_3_indents} 28 | n_novel_cls: 50 29 | epoch: 5 30 | batch_size: 256 31 | val_interval: 5 32 | save_interval: 5 33 | print_interval: 1 34 | dp_scaling: no_dp 35 | optimizer_main: 36 | name: sgd 37 | lr: 0.1 38 | momentum: 0.9 39 | weight_decay: 0.0001 40 | scheduler: 41 | step_size: 5 42 | gamma: 0.1 43 | resume: 44 | models: ${models_3_indents} 45 | param_only: true 46 | exp: split500/resnet18_nbr_500_fc_50_sp${n_novel}_a{:g}dp{:g}_20shot_s{:d}-dev 47 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_delta_layer4_resnet10_500_tmpl.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1 5 | model: 6 | feature_extractor: 7 | arch: resnet10_nbr # two branches for layer4, the rest is the same 8 | pretrained: false 9 | n_freeze: 5 # freeze all layers in resnet 10 | n_branch: ${n_novel_plus_1} 11 | classifier: 12 | arch: deltaclsn 13 | feat_size: 512 14 | n_classes: ${n_base_cls_3_indents} 15 | - 50 16 | norm: 4 17 | data: 18 | #loader: BalancedDataLoader # balanced sampling of base and novel 19 | loader: NshotDataLoader # random sampling of base and novel 20 | root_dir: prepro/data/imagenet 21 | train: prepro/splits/imagenet/split500/train_${n_novel_times_50}.csv 22 | val: prepro/splits/imagenet/split500/val_${n_novel_times_50}-dev.csv 23 | n_shot: 10 # 0 means using all data 24 | n_workers: 12 25 | training: 26 | norm: 4 27 | n_base_cls: ${n_base_cls_3_indents} 28 | n_novel_cls: 50 29 | epoch: 10 30 | batch_size: 256 31 | val_interval: 10 32 | save_interval: 10 33 | print_interval: 1 34 | dp_scaling: no_dp 35 | optimizer_main: 36 | name: sgd 37 | lr: 0.1 38 | momentum: 0.9 39 | weight_decay: 0.0001 40 | scheduler: 41 | step_size: 10 42 | gamma: 0.1 43 | resume: 44 | models: ${models_3_indents} 45 | param_only: true 46 | exp: split500/resnet10_nbr_500_layer4_50_sp${n_novel}_a{:g}dp{:g}_10shot_s{:d}-dev 47 | -------------------------------------------------------------------------------- /configs/multi_step/imnet_delta_layer4_resnet18_20shot_500_tmpl.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1 5 | model: 6 | feature_extractor: 7 | arch: resnet18_nbr # two branches for layer4, the rest is the same 8 | pretrained: false 9 | n_freeze: 5 # freeze all layers in resnet 10 | n_branch: ${n_novel_plus_1} 11 | classifier: 12 | arch: deltaclsn 13 | feat_size: 512 14 | n_classes: ${n_base_cls_3_indents} 15 | - 50 16 | norm: 4 17 | data: 18 | #loader: BalancedDataLoader # balanced sampling of base and novel 19 | loader: NshotDataLoader # random sampling of base and novel 20 | root_dir: prepro/data/imagenet 21 | train: prepro/splits/imagenet/split500/train_${n_novel_times_50}.csv 22 | val: prepro/splits/imagenet/split500/val_${n_novel_times_50}-dev.csv 23 | n_shot: 20 # 0 means using all data 24 | n_workers: 12 25 | training: 26 | norm: 4 27 | n_base_cls: ${n_base_cls_3_indents} 28 | n_novel_cls: 50 29 | epoch: 5 30 | batch_size: 256 31 | val_interval: 5 32 | save_interval: 5 33 | print_interval: 1 34 | dp_scaling: no_dp 35 | optimizer_main: 36 | name: sgd 37 | lr: 0.1 38 | momentum: 0.9 39 | weight_decay: 0.0001 40 | scheduler: 41 | step_size: 5 42 | gamma: 0.1 43 | resume: 44 | models: ${models_3_indents} 45 | param_only: true 46 | exp: split500/resnet18_nbr_500_layer4_50_sp${n_novel}_a{:g}dp{:g}_20shot_s{:d}-dev 47 | -------------------------------------------------------------------------------- /configs/multi_step/test/resnet10/createyaml.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import yaml 6 | import string 7 | 8 | with open("deltacls_branches_tmpl.yaml", "rt") as fp: 9 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 10 | 11 | cfg_data_val_tmpl = string.Template(cfg["data"]["val"]) 12 | for x in range(1, 11): 13 | cfg["model"]["feature_extractor"]["n_branch"] = x + 1 14 | cfg["model"]["classifier"]["n_classes"] = [500] + [50] * x 15 | cfg["config"]["n_base_cls"] = [500] + [50] * (x - 1) 16 | cfg["data"]["val"] = cfg_data_val_tmpl.substitute(n_novel=50 * x) 17 | with open("deltacls_branches_split%d.yaml" % x, "wt") as fp: 18 | fp.write(yaml.dump(cfg, sort_keys=False)) 19 | -------------------------------------------------------------------------------- /configs/multi_step/test/resnet10/deltacls_branches_tmpl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # score fusion network (resnet10, split500) 5 | model: 6 | feature_extractor: 7 | arch: resnet10_nbr 8 | n_branch: TBD 9 | classifier: 10 | arch: deltaclsn 11 | feat_size: 512 12 | n_classes: TBD 13 | norm: 4 14 | 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | val: prepro/splits/imagenet/split500/val_${n_novel}-test.csv 19 | n_workers: 12 20 | 21 | config: 22 | n_base_cls: TBD 23 | n_novel_cls: 50 24 | norm: 4 25 | 26 | checkpoint: 27 | model1: null 28 | -------------------------------------------------------------------------------- /configs/multi_step/test/resnet18/createyaml.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import yaml 6 | import string 7 | 8 | with open("deltacls_branches_tmpl.yaml", "rt") as fp: 9 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 10 | 11 | cfg_data_val_tmpl = string.Template(cfg["data"]["val"]) 12 | for x in range(1, 11): 13 | cfg["model"]["feature_extractor"]["n_branch"] = x + 1 14 | cfg["model"]["classifier"]["n_classes"] = [500] + [50] * x 15 | cfg["config"]["n_base_cls"] = [500] + [50] * (x - 1) 16 | cfg["data"]["val"] = cfg_data_val_tmpl.substitute(n_novel=50 * x) 17 | with open("deltacls_branches_split%d.yaml" % x, "wt") as fp: 18 | fp.write(yaml.dump(cfg, sort_keys=False)) 19 | -------------------------------------------------------------------------------- /configs/multi_step/test/resnet18/deltacls_branches_tmpl.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # score fusion network (resnet18, split500) 5 | model: 6 | feature_extractor: 7 | arch: resnet18_nbr 8 | n_branch: TBD 9 | classifier: 10 | arch: deltaclsn 11 | feat_size: 512 12 | n_classes: TBD 13 | norm: 4 14 | 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | val: prepro/splits/imagenet/split500/val_${n_novel}-test.csv 19 | n_workers: 12 20 | 21 | config: 22 | n_base_cls: TBD 23 | n_novel_cls: 50 24 | norm: 4 25 | 26 | checkpoint: 27 | model1: null 28 | -------------------------------------------------------------------------------- /configs/table1/imnet_base_resnet10.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1357 5 | model: 6 | feature_extractor: 7 | arch: resnet10 8 | n_freeze: 0 9 | pretrained: false 10 | classifier: 11 | arch: linearcls 12 | feat_size: 512 13 | n_class: 800 14 | norm: 4 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | train: prepro/splits/imagenet/split80040/base_train.csv 19 | val: prepro/splits/imagenet/split80040/base_val-test.csv 20 | n_workers: 12 21 | training: 22 | norm: 4 23 | epoch: 90 24 | batch_size: 256 25 | val_interval: 10 26 | save_interval: 10 27 | print_interval: 10 28 | optimizer_main: 29 | name: sgd 30 | lr: 0.1 31 | momentum: 0.9 32 | weight_decay: 0.0001 33 | scheduler: 34 | step_size: 30 35 | gamma: 0.1 36 | resume: 37 | model: 38 | param_only: true 39 | exp: resnet10_800_norm4 40 | -------------------------------------------------------------------------------- /configs/table1/imnet_base_resnet18.yml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | seed: 1357 5 | model: 6 | feature_extractor: 7 | arch: resnet18 8 | n_freeze: 0 9 | pretrained: false 10 | classifier: 11 | arch: linearcls 12 | feat_size: 512 13 | n_class: 800 14 | norm: 4 15 | data: 16 | loader: BaseDataLoader 17 | root_dir: prepro/data/imagenet 18 | train: prepro/splits/imagenet/split80040/base_train.csv 19 | val: prepro/splits/imagenet/split80040/base_val-test.csv 20 | n_workers: 12 21 | training: 22 | norm: 4 23 | epoch: 90 24 | batch_size: 256 25 | val_interval: 10 26 | save_interval: 10 27 | print_interval: 1 28 | optimizer_main: 29 | name: sgd 30 | lr: 0.1 31 | momentum: 0.9 32 | weight_decay: 0.0001 33 | scheduler: 34 | step_size: 30 35 | gamma: 0.1 36 | resume: 37 | model: 38 | param_only: true 39 | exp: resnet18_800_norm4 40 | -------------------------------------------------------------------------------- /prepro/data/TODO-link-imagenet-here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/sp-cil/6243e4510eccd0672136b1c5e05a045255290a61/prepro/data/TODO-link-imagenet-here -------------------------------------------------------------------------------- /prepro/gen_split500.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | from collections import defaultdict 6 | import numpy as np 7 | 8 | from utils import * 9 | 10 | ### original splits of ImageNet 11 | data_dir = "data/imagenet" 12 | train = read_imagenet_data(data_dir, "imagenet_train.csv") 13 | val = read_imagenet_data(data_dir, "imagenet_val.csv") 14 | cls2id = get_imagenet_classdict(data_dir, "imagenet_classes.csv") 15 | 16 | ### base / novel class splits 17 | seed = 1357 18 | np.random.seed(seed) 19 | perm = np.random.permutation(1000) 20 | base_cls = perm[:500] 21 | novel_cls = perm[500:] 22 | n_novel = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500] 23 | 24 | # ImageNet id to base/novel class id 25 | id2bcid = {x: i for i, x in enumerate(base_cls)} 26 | id2ncid = {x: i for i, x in enumerate(novel_cls)} 27 | id2allcid = {x: i for i, x in enumerate(perm)} 28 | 29 | ### filter images for each split and relabel them with base/novel class Id 30 | tgt_dir = "splits/imagenet/split500" 31 | if not os.path.exists(tgt_dir): 32 | os.makedirs(tgt_dir) 33 | 34 | # train 35 | csv_file = os.path.join(tgt_dir, "base_train.csv") 36 | gen_split(csv_file, train, base_cls, cls2id, id2bcid, prefix="train") 37 | 38 | start = 0 39 | for n in n_novel: 40 | # only novel 41 | csv_file = os.path.join(tgt_dir, "novel_train_{}.csv".format(n)) 42 | gen_split( 43 | csv_file, 44 | train, 45 | novel_cls[start:n], 46 | cls2id, 47 | {x: i for i, x in enumerate(novel_cls[start:n])}, 48 | prefix="train", 49 | ) 50 | 51 | # base + novel 52 | csv_file = os.path.join(tgt_dir, "train_{}.csv".format(n)) 53 | cls_list = np.unique(base_cls.tolist() + novel_cls[:n].tolist()) 54 | gen_split(csv_file, train, cls_list, cls2id, id2allcid, prefix="train") 55 | 56 | start = n 57 | 58 | # val 59 | csv_file = os.path.join(tgt_dir, "base_val.csv") 60 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val") 61 | 62 | start = 0 63 | for n in n_novel: 64 | # only novel 65 | csv_file = os.path.join(tgt_dir, "novel_val_{}.csv".format(n)) 66 | gen_split( 67 | csv_file, 68 | val, 69 | novel_cls[start:n], 70 | cls2id, 71 | {x: i for i, x in enumerate(novel_cls[start:n])}, 72 | prefix="val", 73 | ) 74 | 75 | # base + novel 76 | csv_file = os.path.join(tgt_dir, "val_{}.csv".format(n)) 77 | cls_list = np.unique(base_cls.tolist() + novel_cls[:n].tolist()) 78 | gen_split(csv_file, val, cls_list, cls2id, id2allcid, prefix="val") 79 | 80 | start = n 81 | 82 | 83 | ### split validation set into dev and test 84 | np.random.seed(seed) 85 | 86 | # accumulate image id for each class 87 | cls2imgid_val = defaultdict(list) 88 | for idx, cls in enumerate(val["LabelName"]): 89 | cls2imgid_val[cls2id[cls]].append(idx) 90 | 91 | # split image id into two disjoint sets 92 | for cls in range(1000): 93 | idx_list = cls2imgid_val.get(cls) 94 | n_sample = len(idx_list) 95 | np.random.shuffle(idx_list) 96 | cls2imgid_val[cls] = [idx_list[: n_sample // 2], idx_list[n_sample // 2 :]] 97 | 98 | ### 500 base classes 99 | csv_file = os.path.join(tgt_dir, "base_val-dev.csv") 100 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in base_cls]).tolist() 101 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val", allowed_idx=idx_list1) 102 | 103 | csv_file = os.path.join(tgt_dir, "base_val-test.csv") 104 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in base_cls]).tolist() 105 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val", allowed_idx=idx_list2) 106 | 107 | # only novel 108 | start = 0 109 | for n in n_novel: 110 | csv_file = os.path.join(tgt_dir, "novel_val_{n}-dev.csv".format(n=n)) 111 | cls_list = novel_cls[start:n].tolist() 112 | id2ncid = {x: i for i, x in enumerate(cls_list)} 113 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 114 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list1) 115 | 116 | csv_file = os.path.join(tgt_dir, "novel_val_{n}-test.csv".format(n=n)) 117 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 118 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list2) 119 | 120 | # all base + novel 121 | csv_file = os.path.join(tgt_dir, "val_{n}-dev.csv".format(n=n)) 122 | cls_list = base_cls.tolist() + novel_cls[:n].tolist() 123 | id2cid = {x: i for i, x in enumerate(cls_list)} 124 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 125 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list1) 126 | 127 | csv_file = os.path.join(tgt_dir, "val_{n}-test.csv".format(n=n)) 128 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 129 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list2) 130 | start = n 131 | -------------------------------------------------------------------------------- /prepro/gen_split800_40_40.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | from collections import defaultdict 6 | import numpy as np 7 | 8 | from utils import * 9 | 10 | ### original splits of ImageNet 11 | data_dir = "data/imagenet" 12 | train = read_imagenet_data(data_dir, "imagenet_train.csv") 13 | val = read_imagenet_data(data_dir, "imagenet_val.csv") 14 | cls2id = get_imagenet_classdict(data_dir, "imagenet_classes.csv") 15 | 16 | ### base / novel class splits 17 | seed = 1357 18 | np.random.seed(seed) 19 | perm = np.random.permutation(1000) 20 | base_cls = perm[:800] 21 | novel_cls = perm[800:] 22 | n_novel = [40, 80] 23 | 24 | # ImageNet id to base/novel class id 25 | id2bcid = {x: i for i, x in enumerate(base_cls)} 26 | id2ncid = {x: i for i, x in enumerate(novel_cls)} 27 | id2allcid = {x: i for i, x in enumerate(perm)} 28 | 29 | ### filter images for each split and relabel them with base/novel class Id 30 | tgt_dir = "splits/imagenet/split80040" 31 | if not os.path.exists(tgt_dir): 32 | os.makedirs(tgt_dir) 33 | 34 | # train 35 | csv_file = os.path.join(tgt_dir, "base_train.csv") 36 | gen_split(csv_file, train, base_cls, cls2id, id2bcid, prefix="train") 37 | 38 | start = 0 39 | for n in n_novel: 40 | # only novel 41 | csv_file = os.path.join(tgt_dir, "novel_train_{}.csv".format(n)) 42 | gen_split(csv_file, train, novel_cls[start:n], cls2id, id2ncid, prefix="train") 43 | 44 | # base + novel 45 | csv_file = os.path.join(tgt_dir, "train_{}.csv".format(n)) 46 | cls_list = np.unique(base_cls.tolist() + novel_cls[:n].tolist()) 47 | gen_split(csv_file, train, cls_list, cls2id, id2allcid, prefix="train") 48 | 49 | start = n 50 | 51 | # val 52 | csv_file = os.path.join(tgt_dir, "base_val.csv") 53 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val") 54 | 55 | start = 0 56 | for n in n_novel: 57 | # only novel 58 | csv_file = os.path.join(tgt_dir, "novel_val_{}.csv".format(n)) 59 | gen_split(csv_file, val, novel_cls[start:n], cls2id, id2ncid, prefix="val") 60 | 61 | # base + novel 62 | csv_file = os.path.join(tgt_dir, "val_{}.csv".format(n)) 63 | cls_list = np.unique(base_cls.tolist() + novel_cls[:n].tolist()) 64 | gen_split(csv_file, val, cls_list, cls2id, id2allcid, prefix="val") 65 | 66 | start = 0 67 | 68 | ### split validation set into dev and test 69 | np.random.seed(seed) 70 | 71 | # accumulate image id for each class 72 | cls2imgid_val = defaultdict(list) 73 | for idx, cls in enumerate(val["LabelName"]): 74 | cls2imgid_val[cls2id[cls]].append(idx) 75 | 76 | # split image id into two disjoint sets 77 | for cls in range(1000): 78 | idx_list = cls2imgid_val.get(cls) 79 | n_sample = len(idx_list) 80 | np.random.shuffle(idx_list) 81 | cls2imgid_val[cls] = [idx_list[: n_sample // 2], idx_list[n_sample // 2 :]] 82 | 83 | ### 800 base classes 84 | csv_file = os.path.join(tgt_dir, "base_val-dev.csv") 85 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in base_cls]).tolist() 86 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val", allowed_idx=idx_list1) 87 | 88 | csv_file = os.path.join(tgt_dir, "base_val-test.csv") 89 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in base_cls]).tolist() 90 | gen_split(csv_file, val, base_cls, cls2id, id2bcid, prefix="val", allowed_idx=idx_list2) 91 | 92 | ### 40 novel classes 93 | # only novel 94 | csv_file = os.path.join(tgt_dir, "novel_val_40-dev.csv") 95 | cls_list = novel_cls[:40].tolist() 96 | id2ncid = {x: i for i, x in enumerate(cls_list)} 97 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 98 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list1) 99 | 100 | csv_file = os.path.join(tgt_dir, "novel_val_40-test.csv") 101 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 102 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list2) 103 | 104 | # all base + novel 105 | csv_file = os.path.join(tgt_dir, "val_40-dev.csv") 106 | cls_list = base_cls.tolist() + novel_cls[:40].tolist() 107 | id2cid = {x: i for i, x in enumerate(cls_list)} 108 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 109 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list1) 110 | 111 | csv_file = os.path.join(tgt_dir, "val_40-test.csv") 112 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 113 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list2) 114 | 115 | ### 200 novel classes 116 | # only novel 117 | csv_file = os.path.join(tgt_dir, "novel_val_200-dev.csv") 118 | cls_list = novel_cls[:200].tolist() 119 | id2ncid = {x: i for i, x in enumerate(cls_list)} 120 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 121 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list1) 122 | 123 | csv_file = os.path.join(tgt_dir, "novel_val_200-test.csv") 124 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 125 | gen_split(csv_file, val, cls_list, cls2id, id2ncid, prefix="val", allowed_idx=idx_list2) 126 | 127 | # all base + novel 128 | csv_file = os.path.join(tgt_dir, "val_200-dev.csv") 129 | cls_list = base_cls.tolist() + novel_cls[:200].tolist() 130 | id2cid = {x: i for i, x in enumerate(cls_list)} 131 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 132 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list1) 133 | 134 | csv_file = os.path.join(tgt_dir, "val_200-test.csv") 135 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 136 | gen_split(csv_file, val, cls_list, cls2id, id2cid, prefix="val", allowed_idx=idx_list2) 137 | 138 | ### balanced number of base and novel classes 139 | for n_cls in [40, 200]: 140 | cls_list = base_cls[:n_cls].tolist() + novel_cls[:n_cls].tolist() 141 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 142 | csv_file = os.path.join(tgt_dir, "base{}_novel{}_val-dev.csv".format(n_cls, n_cls)) 143 | gen_split(csv_file, val, cls_list, cls2id, id2allcid, prefix="val", allowed_idx=idx_list1) 144 | 145 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 146 | csv_file = os.path.join(tgt_dir, "base{}_novel{}_val-test.csv".format(n_cls, n_cls)) 147 | gen_split(csv_file, val, cls_list, cls2id, id2allcid, prefix="val", allowed_idx=idx_list2) 148 | 149 | ### subset of base classes 150 | n_base = [40, 100, 200, 400] 151 | for n_cls in n_base: 152 | cls_list = base_cls[:n_cls].tolist() 153 | idx_list1 = np.concatenate([cls2imgid_val[x][0] for x in cls_list]).tolist() 154 | csv_file = os.path.join(tgt_dir, "base{}_val-dev.csv".format(n_cls)) 155 | gen_split(csv_file, val, cls_list, cls2id, id2bcid, prefix="val", allowed_idx=idx_list1) 156 | 157 | idx_list2 = np.concatenate([cls2imgid_val[x][1] for x in cls_list]).tolist() 158 | csv_file = os.path.join(tgt_dir, "base{}_val-test.csv".format(n_cls)) 159 | gen_split(csv_file, val, cls_list, cls2id, id2bcid, prefix="val", allowed_idx=idx_list2) 160 | -------------------------------------------------------------------------------- /prepro/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | import os 6 | import pandas as pd 7 | 8 | 9 | def read_imagenet_data(data_dir, csv_file): 10 | """Read ImageNet data (image name and labels) from csv_file 11 | and return a dictonary with keys 'ImageID' and 'LabelName'. 12 | """ 13 | df = pd.read_csv(os.path.join(data_dir, csv_file), sep=",", header=0) 14 | return df.iloc[:, [0, 2]].to_dict("list") 15 | 16 | 17 | def get_imagenet_classdict(data_dir, csv_file): 18 | """Read ImageNet class mapping (id->cls_name) and return a dictionary.""" 19 | df = pd.read_csv(os.path.join(data_dir, csv_file), sep="\t", header=None) 20 | return {x[1]: x[0] for x in df.values} 21 | 22 | 23 | def gen_split( 24 | csv_file, 25 | data, 26 | cls_list, 27 | cls2id, 28 | id_dict, 29 | prefix="", 30 | allowed_idx=None, 31 | store_sp_lbl=False, 32 | ): 33 | """Generate data splits for base/novel classes and save to csv_file.""" 34 | split = [] 35 | if store_sp_lbl: 36 | 37 | if allowed_idx is None: 38 | allowed_idx = [list(range(len(data["ImageID"])))] 39 | 40 | # store split label (0: base / 1: novel) 41 | for sp_lbl, idx_list in enumerate(allowed_idx): 42 | for idx in idx_list: 43 | cls = data["LabelName"][idx] 44 | lbl = cls2id[cls] 45 | if lbl not in cls_list: 46 | continue 47 | 48 | if prefix == "val": 49 | img_name = os.path.join(prefix, cls, data["ImageID"][idx]) 50 | else: 51 | img_name = os.path.join(prefix, data["ImageID"][idx]) 52 | 53 | new_lbl = id_dict[lbl] 54 | split.append([img_name, cls, lbl, idx, new_lbl, sp_lbl]) 55 | 56 | df = pd.DataFrame(split, columns=["ImageID", "LabelName", "Label", "Idx", "ClsID", "SPLbl"]) 57 | 58 | else: 59 | 60 | if allowed_idx is None: 61 | allowed_idx = list(range(len(data["ImageID"]))) 62 | 63 | for idx in allowed_idx: 64 | cls = data["LabelName"][idx] 65 | lbl = cls2id[cls] 66 | if lbl not in cls_list: 67 | continue 68 | 69 | if prefix == "val": 70 | img_name = os.path.join(prefix, cls, data["ImageID"][idx]) 71 | else: 72 | img_name = os.path.join(prefix, data["ImageID"][idx]) 73 | 74 | new_lbl = id_dict[lbl] 75 | split.append([img_name, cls, lbl, idx, new_lbl]) 76 | df = pd.DataFrame(split, columns=["ImageID", "LabelName", "Label", "Idx", "ClsID"]) 77 | 78 | df.to_csv(csv_file, index=False) 79 | print("{} generated.".format(csv_file)) 80 | -------------------------------------------------------------------------------- /scripts/collectresults50050.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import sys 5 | import glob 6 | import numpy as np 7 | import pandas as pd 8 | import re 9 | 10 | argv = sys.argv[1:] 11 | assert len(argv) == 2, "Please call using two parameters, arch and layers. {}".format(argv) 12 | supported_input = { 13 | ("resnet10", "layer4"), 14 | ("resnet18", "layer4"), 15 | ("resnet10", "fc"), 16 | ("resnet18", "fc"), 17 | } 18 | assert tuple(argv) in supported_input, 'arch+layers combination "{}" not recognized. supported:\n{}'.format( 19 | argv, list(supported_input) 20 | ) 21 | 22 | network, branch = argv 23 | n_sp = 10 24 | shot = "_20shot" if network == "resnet18" else "" 25 | result_files_tmpl = "runs/imnet_delta_{branch}_{network}{shot}_500_split{{n}}/split500/test_result.csv".format( 26 | network=network, branch=branch, shot=shot 27 | ) 28 | base_result_file = sorted( 29 | glob.glob("runs/imnet_base_{network}_500/{network}_500_norm4/run_*.log".format(network=network)) 30 | )[-1] 31 | 32 | results = [pd.read_csv(result_files_tmpl.format(n=n)) for n in range(1, n_sp + 1)] 33 | for result in results: 34 | result["method"] = ["best-all", "best-bal", "best-avg"] 35 | 36 | # deal with base performance: extract from training logs 37 | with open(base_result_file, "rt") as f: 38 | base_result_line = f.readlines()[-2] 39 | base_accu = re.search(r"Prec@1\s([\d\.]+)\t\sPrec@5\s[\d\.]+$", base_result_line).group(1) 40 | base_result = results[0].drop("Acc (1)", axis=1) 41 | base_result.loc[:, [x for x in base_result.columns if "Acc" in x]] = float(base_accu) / 100 42 | results = [base_result] + results 43 | 44 | results_agg = pd.concat([result[["method", "Acc (all)", "Avg. Acc"]] for result in results]) 45 | 46 | results_mean = results_agg.groupby("method", sort=False).mean() 47 | print((results_mean * 100).applymap("{:.2f}".format).to_string()) 48 | -------------------------------------------------------------------------------- /scripts/createyaml.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import os 6 | import yaml 7 | 8 | 9 | def create_novel(baseyml, size, branch, outputdir): 10 | with open(baseyml, "r") as f: 11 | doc = yaml.full_load(f) 12 | 13 | doc["seed"] = 1 14 | arch = doc["model"]["feature_extractor"]["arch"] 15 | if arch[-2:] != str(size): 16 | raise Exception("Incorrect base yml {baseyml} for size {size}".format(baseyml=baseyml, size=size)) 17 | 18 | if branch == "fc": 19 | doc["model"]["feature_extractor"]["n_freeze"] = 4 20 | else: 21 | doc["model"]["feature_extractor"]["n_freeze"] = 3 22 | 23 | doc["model"]["classifier"]["n_class"] = 40 24 | doc["data"]["loader"] = "NshotDataLoader" 25 | doc["data"]["train"] = "prepro/splits/imagenet/split80040/novel_train_40.csv" 26 | doc["data"]["val"] = "prepro/splits/imagenet/split80040/novel_val_40.csv" 27 | doc["data"]["n_shot"] = 0 28 | doc["data"]["lbl_offset"] = 0 29 | 30 | doc["training"]["epoch"] = 30 31 | doc["training"]["save_interval"] = 10 32 | doc["training"]["scheduler"]["step_size"] = 10 33 | doc["training"]["resume"]["model"] = "runs/imnet_base_{arch}/{arch}_800_norm4/ep-90_model.pkl".format(arch=arch) 34 | doc["exp"] = "FT_{arch}_800_norm4_{branch}_all_s1".format(arch=arch, branch=branch) 35 | 36 | with open("{}/imnet_novel_{}_{}.yml".format(outputdir, arch, branch), "w") as f: 37 | yaml.dump(doc, f) 38 | 39 | 40 | def create_delta(baseyml, size, branch, outputdir): 41 | with open(baseyml, "r") as f: 42 | doc = yaml.full_load(f) 43 | 44 | doc["seed"] = 1 45 | arch = doc["model"]["feature_extractor"]["arch"] 46 | if arch[-2:] != str(size): 47 | raise Exception("Incorrect base yml {baseyml} for size {size}".format(baseyml=baseyml, size=size)) 48 | 49 | doc["model"]["feature_extractor"]["arch"] = arch + "_2br" 50 | doc["model"]["feature_extractor"]["n_freeze"] = 5 51 | doc["model"]["classifier"]["arch"] = "deltacls" 52 | doc["model"]["classifier"]["n_bclass"] = 800 53 | doc["model"]["classifier"]["n_nclass"] = 40 54 | del doc["model"]["classifier"]["n_class"] 55 | doc["data"]["loader"] = "NshotDataLoader" 56 | doc["data"]["train"] = "prepro/splits/imagenet/split80040/train_40.csv" 57 | doc["data"]["val"] = "prepro/splits/imagenet/split80040/val_40-dev.csv" 58 | doc["data"]["n_shot"] = 10 59 | doc["training"]["epoch"] = 10 60 | doc["training"]["save_interval"] = 10 61 | doc["training"]["n_base_cls"] = 800 62 | doc["training"]["scheduler"]["step_size"] = 10 63 | doc["training"]["resume"]["model1"] = "runs/imnet_base_{arch}/{arch}_800_norm4/ep-90_model.pkl".format(arch=arch) 64 | doc["training"]["resume"][ 65 | "model2" 66 | ] = "runs/imnet_novel_{arch}_{branch}/FT_{arch}_800_norm4_{branch}_all_s1/ep-30_model.pkl".format( 67 | arch=arch, branch=branch 68 | ) 69 | doc["training"]["dp_scaling"] = "no_dp" 70 | del doc["training"]["resume"]["model"] 71 | doc["exp"] = ( 72 | "split80040/{arch}_2br_800_{branch}_".format(arch=arch, branch=branch) + "40_a{:g}dp{:g}_10shot_s{:d}-dev" 73 | ) 74 | 75 | with open("{}/imnet_delta_{}_{}.yml".format(outputdir, arch, branch), "w") as f: 76 | yaml.dump(doc, f) 77 | 78 | 79 | def create_test(size, branch, outputdir): 80 | doc = dict() 81 | doc["seed"] = 1 82 | doc["model"] = dict() 83 | doc["model"] = dict() 84 | doc["model"]["feature_extractor"] = dict() 85 | doc["model"]["feature_extractor"]["arch"] = "resnet{size}_2br".format(size=size) 86 | doc["model"]["classifier"] = dict() 87 | doc["model"]["classifier"]["arch"] = "deltapoolcls" 88 | doc["model"]["classifier"]["feat_size"] = 512 89 | doc["model"]["classifier"]["n_bclass"] = 800 90 | doc["model"]["classifier"]["n_nclass"] = 40 91 | doc["model"]["classifier"]["norm"] = 4 92 | 93 | doc["data"] = dict() 94 | doc["data"]["loader"] = "BaseDataLoader" 95 | doc["data"]["root_dir"] = "prepro/data/imagenet" 96 | doc["data"]["val"] = "prepro/splits/imagenet/split80040/val_40-test.csv" 97 | doc["data"]["n_workers"] = 12 98 | 99 | doc["config"] = dict() 100 | doc["config"]["n_base_cls"] = 800 101 | doc["config"]["n_sel_cls"] = 800 102 | doc["config"]["norm"] = 4 103 | 104 | doc["checkpoint"] = dict() 105 | doc["checkpoint"][ 106 | "model1" 107 | ] = "runs/imnet_delta_resnet{size}_{branch}/split80040/resnet{size}_2br_800_{branch}_40_a0.4dp0.2_10shot_s3-dev/ep-10_model.pkl".format( 108 | branch=branch, size=size 109 | ) 110 | with open("{}/test_imnet_delta_resnet{}_{}.yml".format(outputdir, size, branch), "w") as f: 111 | yaml.dump(doc, f) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser(description="config") 116 | parser.add_argument( 117 | "--baseyml", 118 | type=str, 119 | help="base yaml file", 120 | ) 121 | parser.add_argument( 122 | "--size", 123 | type=int, 124 | default=10, 125 | help="network size (10 or 18)", 126 | ) 127 | parser.add_argument( 128 | "--branch", 129 | type=str, 130 | help="branch layer (layer4 or fc)", 131 | ) 132 | parser.add_argument( 133 | "--outputdir", 134 | type=str, 135 | help="output dir for config files", 136 | ) 137 | 138 | args = parser.parse_args() 139 | 140 | if args.branch != "fc" and args.branch != "layer4": 141 | raise Exception("Incorrect branch argument {}".format(args.branch)) 142 | if args.size != 10 and args.size != 18: 143 | raise Exception("Incorrect network size argument {}".format(args.size)) 144 | if not os.path.exists(args.baseyml): 145 | raise Exception("Base yaml file {} does not exist".format(args.baseyml)) 146 | 147 | create_delta(args.baseyml, args.size, args.branch, args.outputdir) 148 | create_novel(args.baseyml, args.size, args.branch, args.outputdir) 149 | create_test(args.size, args.branch, args.outputdir) 150 | -------------------------------------------------------------------------------- /scripts/getresults50050.sh: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | set -e 5 | branch="fc" 6 | size="18" 7 | branch="layer4" 8 | size="10" 9 | while getopts l:n: flag 10 | do 11 | case "${flag}" in 12 | l) branch=${OPTARG};; 13 | n) size=${OPTARG};; 14 | esac 15 | done 16 | 17 | if [ "$branch" != "fc" ] && [ "$branch" != "layer4" ]; then 18 | echo 'not a valid argument in layer layer existing' 19 | exit 20 | fi 21 | 22 | if [ "$size" == "10" ]; then 23 | shots="" 24 | ep=10 25 | elif [ "$size" == "18" ]; then 26 | shots="_20shot" # Note ResNet18 is run with 20-shot to compare to prior work 27 | ep=5 28 | else 29 | echo 'not a valid argument in network setting ' $config 30 | exit 31 | fi 32 | network=resnet$size 33 | 34 | echo 'Creating dataset splits ....' 35 | (cd prepro/ && python gen_split500.py) 36 | 37 | echo 'Creating config scripts for stage-I training ....' 38 | # base is already there and this will generate novel training files 39 | (cd configs/multi_step && python createyaml.py --baseyml imnet_base_${network}_500.yml --size $size --branch $branch --outputdir='.') 40 | echo 'Creating config scripts for stage-II fusion ....' 41 | (cd configs/multi_step && python createyaml_fusion.py) 42 | echo 'Creating config scripts for testing ....' 43 | (cd configs/multi_step/test/${network} && python createyaml.py) 44 | 45 | echo 'Starting base training for ' $network 46 | python3 src/train_base.py --config configs/multi_step/imnet_base_${network}_500.yml 47 | 48 | echo 'Starting novel training ....' 49 | echo 'Note: these can run parallel.' 50 | for step in 50 100 150 200 250 300 350 400 450 500; do 51 | python src/train_novel.py --config configs/multi_step/imnet_novel_${network}_${branch}_${step}.yml 52 | done 53 | 54 | echo 'Starting score fusion training ....' 55 | echo 'Note: these can run parallel as soon as the corresponding novel training is finished.' 56 | batch_size=512 57 | for step in 1 2 3 4 5 6 7 8 9 10; do 58 | train_config=configs/multi_step/imnet_delta_${branch}_${network}${shots}_500_split${step}.yml 59 | test_config=configs/multi_step/test/${network}/deltacls_branches_split${step}.yaml 60 | python src/hyper_search.py --train_config ${train_config} --test_config ${test_config} --b ${batch_size} --script _ms --ep ${ep} 61 | done 62 | 63 | echo 'Aggregating step results...' 64 | python scripts/collectresults50050.py ${network} ${branch} -------------------------------------------------------------------------------- /scripts/getresults80040.sh: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | set -e 5 | branch="layer4" 6 | size="10" 7 | while getopts l:n: flag 8 | do 9 | case "${flag}" in 10 | l) branch=${OPTARG};; 11 | n) size=${OPTARG};; 12 | esac 13 | done 14 | if [ "$branch" != "fc" ] && [ "$branch" != "layer4" ] 15 | then 16 | echo 'not a valid argument in layer layer existing' 17 | exit 18 | fi 19 | if [ "$size" != "10" ] && [ "$size" != "18" ] 20 | then 21 | echo 'not a valid argument in network setting ' $config 22 | exit 23 | fi 24 | network=resnet$size 25 | 26 | echo 'Creating dataset splits ....' 27 | (cd prepro/ && python gen_split800_40_40.py) 28 | 29 | echo 'Creating config scripts ....' 30 | python3 scripts/createyaml.py --baseyml configs/table1/imnet_base_${network}.yml --size $size --branch $branch --outputdir=configs/table1 31 | 32 | echo 'Starting base training for' $network 33 | python3 src/train_base.py --config configs/table1/imnet_base_${network}.yml 34 | 35 | echo 'Starting novel training ....' 36 | python src/train_novel.py --config configs/table1/imnet_novel_${network}_${branch}.yml 37 | 38 | echo 'Starting score fusion training ....' 39 | train_config=configs/table1/imnet_delta_${network}_${branch}.yml 40 | test_config=configs/table1/test_imnet_delta_${network}_${branch}.yml 41 | batch_size=1024 42 | python src/hyper_search.py --train_config ${train_config} --test_config ${test_config} --b ${batch_size} 43 | -------------------------------------------------------------------------------- /src/hyper_search.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | 6 | os.environ["MKL_THREADING_LAYER"] = "GNU" 7 | import argparse 8 | import yaml 9 | import pandas as pd 10 | from collections import defaultdict 11 | from termcolor import colored 12 | 13 | from metrics import averageMeter 14 | from parse_result import parse_result 15 | 16 | 17 | def hyper_search(config, exp, n_seed=1): 18 | alpha_range = [0, 0.2, 0.4, 0.6, 0.8, 1] 19 | dp_range = [0, 0.2, 0.4, 0.6, 0.8, 1] 20 | 21 | for seed in range(1, n_seed + 1): 22 | for dp in dp_range: 23 | for alpha in alpha_range: 24 | if not os.path.isdir(exp.format(alpha, dp, seed)): 25 | print(train.format(config, seed, dp, alpha)) 26 | os.system(train.format(config, seed, dp, alpha)) 27 | 28 | 29 | def main(): 30 | print(colored("[ Start hyper-parameter search ]", "cyan")) 31 | # run hyper-parameter search with development set 32 | exp = os.path.join("runs", os.path.basename(args.train_config)[:-4], cfg["exp"]) 33 | hyper_search(args.train_config, exp, args.n_seed) 34 | 35 | # parse log and get the results of the selected hyper-parameters 36 | mode = None 37 | if cfg["model"]["classifier"].get("n_classes", None): 38 | mode = "ms" 39 | elif cfg["model"]["classifier"].get("n_overlap", 0) > 0: 40 | mode = "overlap" 41 | sel_hyper = parse_result(exp, mode, args.n_seed) 42 | sel_all = False 43 | if sel_all: 44 | alpha_range = [0, 0.4, 1] 45 | dp_range = [0, 0.2, 0.4, 0.6, 0.8, 1] 46 | sel_hyper = [(dp, alpha) for alpha in alpha_range for dp in dp_range] 47 | 48 | for seed in range(args.n_seed + 1, 4): 49 | for (dp, alpha) in set(sel_hyper): 50 | if not os.path.isdir(exp.format(alpha, dp, seed)): 51 | print(train.format(args.train_config, seed, dp, alpha)) 52 | os.system(train.format(args.train_config, seed, dp, alpha)) 53 | 54 | # evaluate on testing set 55 | print(colored("[ Start testing ]", "cyan")) 56 | result = defaultdict(list) 57 | for (dp, alpha) in sel_hyper: 58 | Acc = averageMeter() 59 | Bacc = averageMeter() 60 | Nacc = averageMeter() 61 | Oacc = averageMeter() 62 | SPacc = defaultdict(averageMeter) 63 | Avg_acc = averageMeter() 64 | for seed in range(1, max(3, args.n_seed) + 1): 65 | ckpt = os.path.join(exp.format(alpha, dp, seed), "ep-%d_model.pkl" % args.ep) 66 | # run testing 67 | logfile = os.path.join(os.path.dirname(ckpt), "test_result.log") 68 | if not os.path.isfile(logfile): 69 | print(test.format(args.test_config, ckpt, logfile)) 70 | os.system(test.format(args.test_config, ckpt, logfile)) 71 | 72 | # read results from the logfile 73 | with open(logfile, "r") as f: 74 | res = f.readlines()[-1] 75 | res = [x for x in res.split(" ") if x] 76 | 77 | # get metrics 78 | acc = float(res[res.index("Acc") + 1]) 79 | if mode == "ms": 80 | spacc = [] 81 | idx_spacc = res.index("SPacc") + 1 82 | while True: 83 | spacc_item = res[idx_spacc] 84 | spacc.append(float(spacc_item.rstrip("/"))) 85 | if not spacc_item.endswith("/"): 86 | break 87 | idx_spacc += 1 88 | avg_acc = sum(spacc) / len(spacc) 89 | else: 90 | bacc = float(res[res.index("Bacc") + 1]) 91 | nacc = float(res[res.index("Nacc") + 1]) 92 | oacc = float(res[res.index("Oacc") + 1]) 93 | if mode == "overlap": 94 | avg_acc = (bacc + nacc + oacc) / 3 95 | else: 96 | avg_acc = (bacc + nacc) / 2 97 | 98 | # append results 99 | Acc.update(acc, 1) 100 | if mode == "ms": 101 | for i, spacc_item in enumerate(spacc): 102 | SPacc[i].update(spacc_item, 1) 103 | else: 104 | Bacc.update(bacc, 1) 105 | Nacc.update(nacc, 1) 106 | Oacc.update(oacc, 1) 107 | Avg_acc.update(avg_acc, 1) 108 | 109 | # append results 110 | result["alpha"].append(alpha) 111 | result["dp"].append(dp) 112 | result["Acc (all)"].append(Acc.avg) 113 | if mode == "ms": 114 | for i, SPacc_item in SPacc.items(): 115 | result["Acc ({i})".format(i=i)].append(SPacc_item.avg) 116 | else: 117 | result["Acc (base)"].append(Bacc.avg) 118 | result["Acc (novel)"].append(Nacc.avg) 119 | result["Acc (overlap)"].append(Oacc.avg) 120 | result["Avg. Acc"].append(Avg_acc.avg) 121 | 122 | pd.set_option("precision", 4) 123 | if mode == "ms": 124 | df = pd.DataFrame( 125 | result, 126 | columns=["alpha", "dp", "Acc (all)"] + ["Acc ({i})".format(i=i) for i in SPacc] + ["Avg. Acc"], 127 | ) 128 | elif mode == "overlap": 129 | df = pd.DataFrame( 130 | result, 131 | columns=[ 132 | "alpha", 133 | "dp", 134 | "Acc (all)", 135 | "Acc (base)", 136 | "Acc (novel)", 137 | "Acc (overlap)", 138 | "Avg. Acc", 139 | ], 140 | ) 141 | else: 142 | df = pd.DataFrame( 143 | result, 144 | columns=[ 145 | "alpha", 146 | "dp", 147 | "Acc (all)", 148 | "Acc (base)", 149 | "Acc (novel)", 150 | "Avg. Acc", 151 | ], 152 | ) 153 | print(colored("[ Result ]", "cyan")) 154 | print(df) 155 | 156 | # save searching results 157 | path = os.path.dirname(exp) 158 | csv_file = os.path.join(path, "test_result.csv") 159 | df.to_csv(csv_file, index=False, float_format="%.4f") 160 | print("result saved to {}.".format(csv_file)) 161 | 162 | 163 | if __name__ == "__main__": 164 | global args, cfg, train, test 165 | parser = argparse.ArgumentParser(description="config") 166 | parser.add_argument( 167 | "--script", 168 | type=str, 169 | default="", 170 | help="which script, train_fusion.py (default or train_fusion_ms.py", 171 | ) 172 | parser.add_argument( 173 | "--train_config", 174 | type=str, 175 | default="configs/imnet_delta.yml", 176 | help="config file for training", 177 | ) 178 | parser.add_argument( 179 | "--test_config", 180 | type=str, 181 | default="test_configs/split2_800_40/resnet10/deltapoolcls.yaml", 182 | help="config file for testing", 183 | ) 184 | parser.add_argument( 185 | "--n_seed", 186 | type=int, 187 | default=1, 188 | help="number of seeds for hyperparameter search", 189 | ) 190 | parser.add_argument( 191 | "--b", 192 | type=int, 193 | default=256, 194 | help="batch size for testing", 195 | ) 196 | parser.add_argument( 197 | "--ep", 198 | type=int, 199 | default=10, 200 | help="epoch of fusion training", 201 | ) 202 | 203 | args = parser.parse_args() 204 | 205 | train = "python src/train_fusion%s.py --config {} --seed {} --dp {} --alpha {}" % (args.script) 206 | test = "python src/test_fusion%s.py --config {} --checkpoint1 {} --log {} --b %d" % (args.script, args.b) 207 | 208 | if not os.path.isfile(args.train_config): 209 | raise BaseException("train_config: '{}' not found".format(args.train_config)) 210 | 211 | if not os.path.isfile(args.test_config): 212 | raise BaseException("test_config: '{}' not found".format(args.test_config)) 213 | 214 | with open(args.train_config) as fp: 215 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 216 | 217 | main() 218 | -------------------------------------------------------------------------------- /src/loader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from loader.nshotdatanovel import NshotNovelDataLoader 5 | from .basedata import BaseDataLoader 6 | from .nshotdata import NshotDataLoader 7 | from .baldata import BalancedDataLoader 8 | from .nshotdatanovel import NshotNovelDataLoader 9 | 10 | 11 | def get_dataloader(cfg, splits, batch_size, seed=None): 12 | loader = _get_loader_instance(cfg["loader"]) 13 | if loader is NshotDataLoader or loader is NshotNovelDataLoader: 14 | data_loader = loader(cfg, splits, batch_size, seed=seed) 15 | else: 16 | data_loader = loader(cfg, splits, batch_size) 17 | return data_loader 18 | 19 | 20 | def _get_loader_instance(name): 21 | try: 22 | return { 23 | "BaseDataLoader": BaseDataLoader, 24 | "NshotDataLoader": NshotDataLoader, 25 | "NshotNovelDataLoader": NshotNovelDataLoader, 26 | "BalancedDataLoader": BalancedDataLoader, 27 | }[name] 28 | except: 29 | raise ("Loader type {} not available".format(name)) 30 | -------------------------------------------------------------------------------- /src/loader/baldata.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | from collections import Counter 6 | import logging 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.sampler import Sampler 11 | 12 | from .img_flist import ImageFilelist 13 | 14 | logger = logging.getLogger("mylogger") 15 | 16 | 17 | class BalancedSampler(Sampler): 18 | def __init__(self, dataset, tgt_transform=None): 19 | 20 | self.idx = dataset.indices 21 | class_ids = np.asarray(dataset.dataset.data["ClsID"])[self.idx] 22 | if tgt_transform is not None: 23 | class_ids = list(map(tgt_transform, class_ids)) 24 | 25 | self.n_samples = len(class_ids) 26 | 27 | # compute class frequencies and set them as sampling weights 28 | counts = Counter(class_ids) 29 | get_freq = lambda x: 1.0 / counts[x] 30 | self.weights = torch.DoubleTensor(list(map(get_freq, class_ids))) 31 | 32 | def __iter__(self): 33 | sampled_idx = torch.multinomial(self.weights, self.n_samples, replacement=True) 34 | return (i for i in sampled_idx) 35 | 36 | def __len__(self): 37 | return self.n_samples 38 | 39 | 40 | def get_nshot_data(dataset, nshot=0): 41 | class_ids = np.array(dataset.data["ClsID"]) 42 | classes = np.unique(class_ids) 43 | if nshot > 0: 44 | sampled_idx = [] 45 | for i in classes: 46 | idx = np.where(class_ids == i)[0] 47 | selected = torch.randperm(len(idx))[:nshot] 48 | if nshot > 1: 49 | sampled_idx.extend(idx[selected].tolist()) 50 | else: 51 | sampled_idx.append(idx[selected]) 52 | 53 | else: 54 | sampled_idx = list(range(len(class_ids))) 55 | 56 | return sampled_idx 57 | 58 | 59 | def BalancedDataLoader(cfg, splits, batch_size): 60 | 61 | data_loader = dict() 62 | 63 | tgt_transform = lambda x: int(x >= 800) 64 | 65 | for split in splits: 66 | 67 | if split == "train": 68 | # data augmentation 69 | transform = transforms.Compose( 70 | [ 71 | transforms.RandomResizedCrop(224), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 74 | transforms.ToTensor(), 75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 76 | ] 77 | ) 78 | load_sp_lbl = cfg.get("load_sp_lbl", False) 79 | print("load_sp_lbl", load_sp_lbl) 80 | 81 | else: 82 | transform = transforms.Compose( 83 | [ 84 | transforms.Resize(256), 85 | transforms.CenterCrop(224), 86 | transforms.ToTensor(), 87 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 88 | ] 89 | ) 90 | load_sp_lbl = False 91 | 92 | dataset = ImageFilelist(cfg["root_dir"], cfg[split], transform, load_sp_lbl=load_sp_lbl) 93 | if split == "train": 94 | dataset = data.Subset(dataset, get_nshot_data(dataset, nshot=cfg.get("n_shot", 0))) 95 | sampler = BalancedSampler(dataset, tgt_transform) 96 | data_loader[split] = data.DataLoader( 97 | dataset, 98 | batch_size=batch_size, 99 | shuffle=False, 100 | drop_last=False, 101 | sampler=sampler, 102 | pin_memory=True, 103 | num_workers=cfg["n_workers"], 104 | ) 105 | else: 106 | data_loader[split] = data.DataLoader( 107 | dataset, 108 | batch_size=batch_size, 109 | shuffle=False, 110 | drop_last=False, 111 | pin_memory=True, 112 | num_workers=cfg["n_workers"], 113 | ) 114 | print("{split}: {size}".format(split=split, size=len(dataset))) 115 | logger.info("{split}: {size}".format(split=split, size=len(dataset))) 116 | 117 | print("Building data loader with {} workers.".format(cfg["n_workers"])) 118 | logger.info("Building data loader with {} workers.".format(cfg["n_workers"])) 119 | 120 | return data_loader 121 | -------------------------------------------------------------------------------- /src/loader/basedata.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | 8 | from .img_flist import ImageFilelist 9 | 10 | logger = logging.getLogger("mylogger") 11 | 12 | 13 | def BaseDataLoader(cfg, splits, batch_size): 14 | 15 | data_loader = dict() 16 | 17 | for split in splits: 18 | 19 | if split == "train": 20 | # data augmentation 21 | transform = transforms.Compose( 22 | [ 23 | transforms.RandomResizedCrop(224), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 28 | ] 29 | ) 30 | shuffle = True 31 | load_sp_lbl = cfg.get("load_sp_lbl", False) 32 | print("load_sp_lbl", load_sp_lbl) 33 | 34 | else: 35 | transform = transforms.Compose( 36 | [ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 41 | ] 42 | ) 43 | shuffle = False 44 | load_sp_lbl = False 45 | 46 | dataset = ImageFilelist(cfg["root_dir"], cfg[split], transform, load_sp_lbl=load_sp_lbl) 47 | data_loader[split] = data.DataLoader( 48 | dataset, 49 | batch_size=batch_size, 50 | shuffle=shuffle, 51 | drop_last=False, 52 | pin_memory=True, 53 | num_workers=cfg["n_workers"], 54 | ) 55 | print("{split}: {size}".format(split=split, size=len(dataset))) 56 | logger.info("{split}: {size}".format(split=split, size=len(dataset))) 57 | 58 | print("Building data loader with {} workers.".format(cfg["n_workers"])) 59 | logger.info("Building data loader with {} workers.".format(cfg["n_workers"])) 60 | 61 | return data_loader 62 | -------------------------------------------------------------------------------- /src/loader/img_flist.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pandas as pd 5 | import os 6 | from PIL import Image 7 | 8 | import torch.utils.data as data 9 | 10 | 11 | class ImageFilelist(data.Dataset): 12 | def __init__(self, root_dir, flist, transform=None, target_transform=None, load_sp_lbl=False): 13 | 14 | self.root_dir = root_dir # root dir of images 15 | 16 | if load_sp_lbl: 17 | self.data = pd.read_csv(flist).iloc[:, [0, 3, 4, 5]].to_dict("list") # only select the cols of interest 18 | else: 19 | self.data = pd.read_csv(flist).iloc[:, [0, 3, 4]].to_dict("list") # only select the cols of interest 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | self.load_sp_lbl = load_sp_lbl 23 | 24 | def __getitem__(self, index): 25 | img_path = os.path.join(self.root_dir, self.data["ImageID"][index]) 26 | img = Image.open(img_path).convert("RGB") 27 | target = self.data["ClsID"][index] 28 | idx = self.data["Idx"][index] 29 | if self.load_sp_lbl: 30 | sp_lbl = self.data["SPLbl"][index] 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | if self.target_transform is not None: 36 | target = self.target_transform(target) 37 | 38 | if self.load_sp_lbl: 39 | return img, target, idx, sp_lbl 40 | else: 41 | return img, target, idx 42 | 43 | def __len__(self): 44 | return len(self.data["ImageID"]) 45 | -------------------------------------------------------------------------------- /src/loader/nshotdata.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import logging 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from torch.utils.data.dataloader import default_collate 10 | 11 | from .img_flist import ImageFilelist 12 | 13 | logger = logging.getLogger("mylogger") 14 | 15 | 16 | def get_collate_func(rot): 17 | if rot: 18 | assert not rot, "Rotation no longer supported" 19 | return None 20 | else: 21 | return default_collate 22 | 23 | 24 | def get_nshot_data(dataset, nshot=0, seed=None): 25 | class_ids = np.array(dataset.data["ClsID"]) 26 | classes = np.unique(class_ids) 27 | seed = 2147483647 if seed is None else 65535 * seed 28 | if nshot > 0: 29 | sampled_idx = [] 30 | for i in classes: 31 | g_cpu = torch.Generator() 32 | g_cpu.manual_seed(seed + i) 33 | idx = np.where(class_ids == i)[0] 34 | selected = torch.randperm(len(idx), generator=g_cpu)[:nshot] 35 | if nshot > 1: 36 | sampled_idx.extend(idx[selected].tolist()) 37 | else: 38 | sampled_idx.append(idx[selected]) 39 | 40 | else: 41 | sampled_idx = list(range(len(class_ids))) 42 | 43 | return sampled_idx 44 | 45 | 46 | def NshotDataLoader(cfg, splits, batch_size, seed=None): 47 | 48 | data_loader = dict() 49 | 50 | tgt_transform = lambda x: x + cfg.get("lbl_offset", 0) 51 | collate_func = get_collate_func(cfg.get("rot", False)) 52 | 53 | for split in splits: 54 | 55 | if split == "train": 56 | # data augmentation 57 | transform = transforms.Compose( 58 | [ 59 | transforms.RandomResizedCrop(224), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 62 | transforms.ToTensor(), 63 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 64 | ] 65 | ) 66 | shuffle = True 67 | load_sp_lbl = cfg.get("load_sp_lbl", False) 68 | print("load_sp_lbl", load_sp_lbl) 69 | 70 | else: 71 | transform = transforms.Compose( 72 | [ 73 | transforms.Resize(256), 74 | transforms.CenterCrop(224), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 77 | ] 78 | ) 79 | shuffle = False 80 | load_sp_lbl = False 81 | 82 | dataset = ImageFilelist( 83 | cfg["root_dir"], 84 | cfg[split], 85 | transform, 86 | tgt_transform, 87 | load_sp_lbl=load_sp_lbl, 88 | ) 89 | if split == "train": 90 | dataset = data.Subset(dataset, get_nshot_data(dataset, nshot=cfg["n_shot"], seed=seed)) 91 | data_loader[split] = data.DataLoader( 92 | dataset, 93 | batch_size=batch_size, 94 | shuffle=shuffle, 95 | drop_last=False, 96 | collate_fn=collate_func, 97 | pin_memory=True, 98 | num_workers=cfg["n_workers"], 99 | ) 100 | print("{split}: {size}".format(split=split, size=len(dataset))) 101 | logger.info("{split}: {size}".format(split=split, size=len(dataset))) 102 | 103 | print("Building data loader with {} workers.".format(cfg["n_workers"])) 104 | logger.info("Building data loader with {} workers.".format(cfg["n_workers"])) 105 | 106 | return data_loader 107 | -------------------------------------------------------------------------------- /src/loader/nshotdatanovel.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import logging 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from torch.utils.data.dataloader import default_collate 10 | 11 | from .img_flist import ImageFilelist 12 | 13 | logger = logging.getLogger("mylogger") 14 | 15 | 16 | def get_collate_func(rot): 17 | if rot: 18 | assert not rot, "Rotation no longer supported" 19 | return None 20 | else: 21 | return default_collate 22 | 23 | 24 | def get_nshot_data(dataset, nshot=0, nshotclass=0, seed=None): 25 | class_ids = np.array(dataset.data["ClsID"]) 26 | classes = np.unique(class_ids) 27 | seed = 2147483647 if seed is None else 65535 * seed 28 | if nshot > 0: 29 | sampled_idx = [] 30 | for i in classes: 31 | g_cpu = torch.Generator() 32 | g_cpu.manual_seed(seed + i) 33 | idx = np.where(class_ids == i)[0] 34 | if i < nshotclass: 35 | selected = torch.randperm(len(idx), generator=g_cpu)[:nshot] 36 | else: 37 | selected = torch.randperm(len(idx), generator=g_cpu) 38 | if nshot > 1: 39 | sampled_idx.extend(idx[selected].tolist()) 40 | else: 41 | sampled_idx.append(idx[selected]) 42 | 43 | else: 44 | sampled_idx = list(range(len(class_ids))) 45 | 46 | return sampled_idx 47 | 48 | 49 | def NshotNovelDataLoader(cfg, splits, batch_size, seed=None): 50 | 51 | data_loader = dict() 52 | 53 | tgt_transform = lambda x: x + cfg.get("lbl_offset", 0) 54 | collate_func = get_collate_func(cfg.get("rot", False)) 55 | 56 | for split in splits: 57 | 58 | if split == "train": 59 | # data augmentation 60 | transform = transforms.Compose( 61 | [ 62 | transforms.RandomResizedCrop(224), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 67 | ] 68 | ) 69 | shuffle = True 70 | load_sp_lbl = cfg.get("load_sp_lbl", False) 71 | print("load_sp_lbl", load_sp_lbl) 72 | 73 | else: 74 | transform = transforms.Compose( 75 | [ 76 | transforms.Resize(256), 77 | transforms.CenterCrop(224), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 80 | ] 81 | ) 82 | shuffle = False 83 | load_sp_lbl = False 84 | 85 | dataset = ImageFilelist( 86 | cfg["root_dir"], 87 | cfg[split], 88 | transform, 89 | tgt_transform, 90 | load_sp_lbl=load_sp_lbl, 91 | ) 92 | if split == "train": 93 | dataset = data.Subset( 94 | dataset, 95 | get_nshot_data( 96 | dataset, 97 | nshot=cfg["n_shot"], 98 | nshotclass=cfg["nshotclass"], 99 | seed=seed, 100 | ), 101 | ) 102 | data_loader[split] = data.DataLoader( 103 | dataset, 104 | batch_size=batch_size, 105 | shuffle=shuffle, 106 | drop_last=False, 107 | collate_fn=collate_func, 108 | pin_memory=True, 109 | num_workers=cfg["n_workers"], 110 | ) 111 | print("{split}: {size}".format(split=split, size=len(dataset))) 112 | logger.info("{split}: {size}".format(split=split, size=len(dataset))) 113 | 114 | print("Building data loader with {} workers.".format(cfg["n_workers"])) 115 | logger.info("Building data loader with {} workers.".format(cfg["n_workers"])) 116 | 117 | return data_loader 118 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | maxk = max(topk) 7 | batch_size = target.size(0) 8 | 9 | _, pred = output.topk(maxk, 1, True, True) 10 | pred = pred.t() 11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 12 | 13 | res = [] 14 | for k in topk: 15 | correct_k = correct[:k].float().sum(0).sum(0, keepdim=True) 16 | res.append(correct_k.mul_(100.0 / batch_size)) 17 | return res 18 | 19 | 20 | class averageMeter(object): 21 | def __init__(self): 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = 0 26 | self.avg = 0 27 | self.sum = 0 28 | self.count = 0 29 | 30 | def update(self, val, n=1): 31 | self.val = val 32 | self.sum += val * n 33 | self.count += n 34 | self.avg = self.sum / self.count 35 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import copy 5 | import torch.nn as nn 6 | import logging 7 | 8 | from .resnet import resnet10, resnet18, resnet50 9 | from .resnet_2br import ( 10 | resnet10_2br, 11 | resnet10_nbr, 12 | resnet18_2br, 13 | resnet18_nbr, 14 | resnet50_2br, 15 | resnet10_2brl3, 16 | ) 17 | from .linearcls import linearcls 18 | from .deltacls import deltacls, deltaclsn 19 | from .deltapoolcls import deltapoolcls 20 | 21 | logger = logging.getLogger("mylogger") 22 | 23 | 24 | def get_model(model_dict, verbose=False): 25 | 26 | name = model_dict["arch"] 27 | model = _get_model_instance(name) 28 | param_dict = copy.deepcopy(model_dict) 29 | param_dict.pop("arch") 30 | 31 | if "resnet" in name: 32 | model = model(**param_dict) 33 | model.fc = nn.Identity() 34 | 35 | else: 36 | model = model(**param_dict) 37 | 38 | if verbose: 39 | logger.info(model) 40 | 41 | return model 42 | 43 | 44 | def _get_model_instance(name): 45 | try: 46 | return { 47 | "resnet10": resnet10, 48 | "resnet10_2br": resnet10_2br, 49 | "resnet10_nbr": resnet10_nbr, 50 | "resnet10_2brl3": resnet10_2brl3, 51 | "resnet18": resnet18, 52 | "resnet18_2br": resnet18_2br, 53 | "resnet18_nbr": resnet18_nbr, 54 | "resnet50": resnet50, 55 | "resnet50_2br": resnet50_2br, 56 | "linearcls": linearcls, 57 | "deltacls": deltacls, 58 | "deltaclsn": deltaclsn, 59 | "deltapoolcls": deltapoolcls, 60 | }[name] 61 | except: 62 | raise BaseException("Model {} not available".format(name)) 63 | -------------------------------------------------------------------------------- /src/models/deltacls.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class DeltaCls(nn.Module): 10 | def __init__(self, feat_size, n_bclass, n_nclass, norm=-1): 11 | super(DeltaCls, self).__init__() 12 | self.feat_size = feat_size 13 | self.n_bclass = n_bclass 14 | self.n_nclass = n_nclass 15 | self.norm = norm 16 | if self.norm > 0: 17 | self.linear_base = nn.Linear(feat_size, n_bclass, bias=False) 18 | self.linear_novel = nn.Linear(feat_size, n_nclass, bias=False) 19 | else: 20 | self.linear_base = nn.Linear(feat_size, n_bclass) 21 | self.linear_novel = nn.Linear(feat_size, n_nclass) 22 | 23 | self.delta_base = nn.Linear(feat_size, n_bclass, bias=False) 24 | self.delta_novel = nn.Linear(feat_size, n_nclass, bias=False) 25 | 26 | for name, param in self.named_parameters(): 27 | if "linear" in name: 28 | param.requires_grad = False 29 | 30 | def forward(self, x_b, x_n, route=False): 31 | if self.norm > 0: 32 | weight_b = F.normalize(self.linear_base.weight, p=2, dim=1) * self.norm 33 | weight_n = F.normalize(self.linear_novel.weight, p=2, dim=1) * self.norm 34 | z_b0 = torch.mm(x_b, weight_b.t()) 35 | z_n0 = torch.mm(x_n, weight_n.t()) 36 | else: 37 | z_b0 = self.linear_base(x_b) 38 | z_n0 = self.linear_novel(x_n) 39 | 40 | delta_z_b = self.delta_base(x_n) 41 | delta_z_n = self.delta_novel(x_b) 42 | 43 | z_b = z_b0 + delta_z_b 44 | z_n = z_n0 + delta_z_n 45 | 46 | if route: 47 | route_out = torch.cat( 48 | (z_b.max(dim=1)[0].unsqueeze(-1), z_n.max(dim=1)[0].unsqueeze(-1)), 49 | dim=1, 50 | ) 51 | return torch.cat([z_b, z_n], dim=1), route_out 52 | else: 53 | return torch.cat([z_b, z_n], dim=1) 54 | 55 | def load_state_dict2(self, state_dict1, state_dict2): 56 | own_state = self.state_dict() 57 | # load the base classifier 58 | for name, param in state_dict1.items(): 59 | name = name.replace("linear", "linear_base") 60 | if name in own_state: 61 | if "bias" in name: 62 | own_state[name].copy_(param[: self.n_bclass]) 63 | else: 64 | own_state[name].copy_(param[: self.n_bclass, :]) 65 | else: 66 | print(name) 67 | 68 | # load the novel classifier 69 | for name, param in state_dict2.items(): 70 | name = name.replace("linear", "linear_novel") 71 | if name in own_state: 72 | own_state[name].copy_(param) 73 | else: 74 | print(name) 75 | 76 | 77 | class DeltaClsN(nn.Module): 78 | def __init__(self, feat_size, n_classes, norm=-1): 79 | super(DeltaClsN, self).__init__() 80 | self.feat_size = feat_size 81 | self.n_classes = n_classes 82 | assert norm > 0, "Not implemented for norm = 0 and with bias" 83 | self.norm = norm 84 | self.linear = nn.ModuleList() 85 | for n_cls in n_classes: 86 | self.linear.append(nn.Linear(feat_size, n_cls, bias=False)) 87 | 88 | self.delta = nn.Linear(feat_size * len(n_classes), sum(n_classes), bias=False) 89 | 90 | for name, param in self.named_parameters(): 91 | if "linear" in name: 92 | param.requires_grad = False 93 | with torch.no_grad(): 94 | self.delta.weight *= 1 - self.block_diag_indicator([l.weight for l in self.linear]) 95 | 96 | @classmethod 97 | def block_diag_indicator(cls, blocks): 98 | mask0 = torch.tensor(sum([[i] * x.shape[0] for i, x in enumerate(blocks)], [])).to(blocks[0]) 99 | mask1 = torch.tensor(sum([[i] * x.shape[1] for i, x in enumerate(blocks)], [])).to(blocks[0]) 100 | mask = (mask0[:, None] == mask1[None, :]).to(blocks[0]) 101 | return mask 102 | 103 | def forward(self, *xs, route=False): 104 | assert len(xs) == len(self.linear) 105 | # block-diagonal entries 106 | weights = [F.normalize(l.weight, p=2, dim=1) * self.norm for l in self.linear] 107 | z_0s = [torch.mm(x, weight.t()) for x, weight in zip(xs, weights)] 108 | z_0 = torch.cat(z_0s, dim=1) 109 | 110 | # off-block-diagonal entries 111 | mask = self.block_diag_indicator(weights) 112 | z_delta = torch.mm(torch.cat(xs, dim=1), (self.delta.weight * (1 - mask)).t()) 113 | 114 | z = z_0 + z_delta 115 | if route: 116 | start = 0 117 | route_out = [] 118 | for n_cls in self.n_classes: 119 | end = start + n_cls 120 | route_out.append(z[:, start:end].max(dim=1)[0].unsqueeze(-1)) 121 | start = end 122 | route_out = torch.cat(route_out, dim=1) 123 | return z, route_out 124 | else: 125 | return z 126 | 127 | def load_state_dict2(self, *state_dicts): 128 | own_state = self.state_dict() 129 | for idx, state_dict in enumerate(state_dicts): 130 | # load the idx-th classifier 131 | for name, param in state_dict.items(): 132 | name = name.replace("linear", "linear.{}".format(idx)) 133 | if name in own_state: 134 | if "bias" in name: 135 | own_state[name].copy_(param[: self.n_classes[idx]]) 136 | else: 137 | own_state[name].copy_(param[: self.n_classes[idx], :]) 138 | else: 139 | print(name) 140 | 141 | 142 | def deltacls(**kwargs): 143 | model = DeltaCls(**kwargs) 144 | return model 145 | 146 | 147 | def deltaclsn(**kwargs): 148 | model = DeltaClsN(**kwargs) 149 | return model 150 | -------------------------------------------------------------------------------- /src/models/deltapoolcls.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class DeltaPoolCls(nn.Module): 10 | def __init__(self, feat_size, n_bclass, n_nclass, n_overlap=0, pool="max", norm=-1): 11 | super(DeltaPoolCls, self).__init__() 12 | assert n_overlap <= n_bclass and n_overlap <= n_nclass 13 | self.feat_size = feat_size 14 | self.n_bclass = n_bclass 15 | self.n_nclass = n_nclass 16 | self.n_overlap = n_overlap 17 | self.norm = norm 18 | if norm > 0: 19 | self.linear_base = nn.Linear(feat_size, n_bclass, bias=False) 20 | self.linear_novel = nn.Linear(feat_size, n_nclass, bias=False) 21 | else: 22 | self.linear_base = nn.Linear(feat_size, n_bclass) 23 | self.linear_novel = nn.Linear(feat_size, n_nclass) 24 | 25 | self.delta_base = nn.Linear(feat_size, n_bclass, bias=False) 26 | self.delta_novel = nn.Linear(feat_size, n_nclass, bias=False) 27 | 28 | if n_overlap > 0: 29 | if pool == "avg": 30 | self.pool = nn.AvgPool2d(kernel_size=(2, 1)) 31 | else: 32 | self.pool = nn.MaxPool2d(kernel_size=(2, 1)) 33 | 34 | for name, param in self.named_parameters(): 35 | if "linear" in name: 36 | param.requires_grad = False 37 | 38 | def forward(self, x_b, x_n, route=False): 39 | if self.norm > 0: 40 | weight_b = F.normalize(self.linear_base.weight, p=2, dim=1) * self.norm 41 | weight_n = F.normalize(self.linear_novel.weight, p=2, dim=1) * self.norm 42 | z_b0 = torch.mm(x_b, weight_b.t()) 43 | z_n0 = torch.mm(x_n, weight_n.t()) 44 | else: 45 | z_b0 = self.linear_base(x_b) 46 | z_n0 = self.linear_novel(x_n) 47 | 48 | delta_z_b = self.delta_base(x_n) 49 | delta_z_n = self.delta_novel(x_b) 50 | 51 | z_b = z_b0 + delta_z_b 52 | z_n = z_n0 + delta_z_n 53 | 54 | if self.n_overlap > 0: 55 | z_overlap = torch.cat( 56 | [ 57 | z_b[:, : self.n_overlap].unsqueeze(1), 58 | z_n[:, : self.n_overlap].unsqueeze(1), 59 | ], 60 | dim=1, 61 | ) 62 | z_overlap = self.pool(z_overlap).squeeze() 63 | z = torch.cat([z_overlap, z_b[:, self.n_overlap :], z_n[:, self.n_overlap :]], dim=1) 64 | else: 65 | z = torch.cat([z_b, z_n], dim=1) 66 | 67 | if route: 68 | route_out = torch.cat( 69 | (z_b.max(dim=1)[0].unsqueeze(-1), z_n.max(dim=1)[0].unsqueeze(-1)), 70 | dim=1, 71 | ) 72 | return z, route_out 73 | 74 | return z 75 | 76 | def load_state_dict2(self, state_dict1, state_dict2): 77 | own_state = self.state_dict() 78 | # load the base classifier 79 | for name, param in state_dict1.items(): 80 | name = name.replace("linear", "linear_base") 81 | if name in own_state: 82 | if "bias" in name: 83 | own_state[name].copy_(param[: self.n_bclass]) 84 | else: 85 | own_state[name].copy_(param[: self.n_bclass, :]) 86 | else: 87 | print(name) 88 | 89 | # load the novel classifier 90 | for name, param in state_dict2.items(): 91 | name = name.replace("linear", "linear_novel") 92 | if name in own_state: 93 | own_state[name].copy_(param) 94 | else: 95 | print(name) 96 | 97 | 98 | def deltapoolcls(**kwargs): 99 | model = DeltaPoolCls(**kwargs) 100 | return model 101 | -------------------------------------------------------------------------------- /src/models/linearcls.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class LinearCls(nn.Module): 10 | def __init__(self, feat_size, n_class, norm=0): 11 | super(LinearCls, self).__init__() 12 | self.feat_size = feat_size 13 | self.n_class = n_class 14 | self.norm = norm 15 | if self.norm > 0: 16 | self.linear = nn.Linear(feat_size, n_class, bias=False) 17 | else: 18 | self.linear = nn.Linear(feat_size, n_class) 19 | 20 | def forward(self, x): 21 | if self.norm > 0: 22 | weight = F.normalize(self.linear.weight, p=2, dim=1) * self.norm 23 | return torch.mm(x, weight.t()) 24 | 25 | return self.linear(x) 26 | 27 | 28 | def linearcls(**kwargs): 29 | model = LinearCls(**kwargs) 30 | return model 31 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | logger = logging.getLogger("mylogger") 9 | 10 | __all__ = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 11 | 12 | 13 | model_urls = { 14 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 15 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 16 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 17 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 18 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, num_classes=1000, n_freeze=0): 99 | self.n_classes = num_classes 100 | self.n_freeze = min(n_freeze, 4) 101 | self.inplanes = 64 102 | super(ResNet, self).__init__() 103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | self.avgpool = nn.AvgPool2d(7, stride=1) 112 | self.fc = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | if self.n_freeze > 0: 122 | self._freeze_layers() 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d( 129 | self.inplanes, 130 | planes * block.expansion, 131 | kernel_size=1, 132 | stride=stride, 133 | bias=False, 134 | ), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x, feat=False): 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) 150 | x = self.maxpool(x) 151 | 152 | x = self.layer1(x) 153 | x = self.layer2(x) 154 | x = self.layer3(x) 155 | x = self.layer4(x) 156 | 157 | x = self.avgpool(x) 158 | x_f = x.view(x.size(0), -1) 159 | x = self.fc(x_f) 160 | 161 | if not feat: 162 | return x 163 | 164 | return x, x_f 165 | 166 | def _freeze_layers(self): 167 | for module in list(self.children())[: self.n_freeze + 4]: 168 | for param in module.parameters(): 169 | param.requires_grad = False 170 | 171 | if self.n_freeze > 1: 172 | print("First {} layers of resnet are frozen.".format(self.n_freeze)) 173 | logger.info("First {} layers of resnet are frozen.".format(self.n_freeze)) 174 | else: 175 | print("First layer of resnet is frozen.".format(self.n_freeze)) 176 | logger.info("First layer of resnet is frozen.".format(self.n_freeze)) 177 | 178 | def train(self, mode=True): 179 | """Function overloaded from https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module 180 | The revision is for freezing batch norm parameters of designated layers. 181 | """ 182 | if not isinstance(mode, bool): 183 | raise ValueError("training mode is expected to be boolean") 184 | self.training = mode 185 | for i, module in enumerate(self.children()): 186 | if self.n_freeze > 0 and i < (self.n_freeze + 4): 187 | module.train(False) 188 | else: 189 | module.train(mode) 190 | 191 | return self 192 | 193 | 194 | def resnet10(pretrained=False, **kwargs): 195 | """Constructs a ResNet-10 model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 200 | return model 201 | 202 | 203 | def resnet18(pretrained=False, **kwargs): 204 | """Constructs a ResNet-18 model. 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) 211 | return model 212 | 213 | 214 | def resnet34(pretrained=False, **kwargs): 215 | """Constructs a ResNet-34 model. 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"])) 222 | return model 223 | 224 | 225 | def resnet50(pretrained=False, **kwargs): 226 | """Constructs a ResNet-50 model. 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 231 | if pretrained: 232 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 233 | return model 234 | 235 | 236 | def resnet101(pretrained=False, **kwargs): 237 | """Constructs a ResNet-101 model. 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"])) 244 | return model 245 | 246 | 247 | def resnet152(pretrained=False, **kwargs): 248 | """Constructs a ResNet-152 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"])) 255 | return model 256 | -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import copy 6 | from torch.optim import SGD, Adam 7 | 8 | from .scheduler import step_scheduler 9 | 10 | logger = logging.getLogger("mylogger") 11 | 12 | 13 | def get_optimizer(opt_dict): 14 | """Function to get the optimizer instance.""" 15 | name = opt_dict["name"] 16 | optimizer = _get_opt_instance(name) 17 | param_dict = copy.deepcopy(opt_dict) 18 | param_dict.pop("name") 19 | logger.info("Using {} optimizer".format(name)) 20 | 21 | return optimizer, param_dict 22 | 23 | 24 | def _get_opt_instance(name): 25 | try: 26 | return { 27 | "sgd": SGD, 28 | "adam": Adam, 29 | }[name] 30 | except: 31 | raise ("Optimizer {} not available".format(name)) 32 | -------------------------------------------------------------------------------- /src/optimizers/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | class step_scheduler(object): 6 | def __init__(self, optimizer, step_size, gamma): 7 | self.opt = optimizer 8 | self.step_size = step_size 9 | self.decay_rate = gamma 10 | self.epoch = 1 11 | 12 | def step(self): 13 | if self.epoch % self.step_size == 0: 14 | for param_group in self.opt.param_groups: 15 | param_group["lr"] = param_group["lr"] * self.decay_rate 16 | self.epoch += 1 17 | 18 | def get_lr(self): 19 | return self.opt.param_groups[0]["lr"] 20 | 21 | def state_dict(self): 22 | return { 23 | "lr": self.get_lr(), 24 | "epoch": self.epoch, 25 | "step_size": self.step_size, 26 | "decay_rate": self.decay_rate, 27 | } 28 | 29 | def load_state_dict(self, state_dict): 30 | self.epoch = state_dict["epoch"] 31 | self.step_size = state_dict["step_size"] 32 | self.decay_rate = state_dict["decay_rate"] 33 | self.step() 34 | -------------------------------------------------------------------------------- /src/parse_result.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import argparse 6 | import pandas as pd 7 | from collections import defaultdict 8 | 9 | from metrics import averageMeter 10 | import collections 11 | 12 | 13 | def parse_result(exp, mode=None, n_seed=1): 14 | assert mode in [None, "overlap", "ms"] 15 | path = os.path.dirname(exp) 16 | seed_range = range(1, n_seed + 1) 17 | alpha_range = [0, 0.2, 0.4, 0.6, 0.8, 1] 18 | dp_range = [0, 0.2, 0.4, 0.6, 0.8, 1] 19 | 20 | result = defaultdict(list) 21 | for alpha in alpha_range: 22 | for dp in dp_range: 23 | Acc = averageMeter() 24 | Bacc = averageMeter() 25 | Nacc = averageMeter() 26 | Oacc = averageMeter() 27 | SPacc = collections.defaultdict(averageMeter) 28 | Avg_acc = averageMeter() 29 | Area = averageMeter() 30 | for seed in seed_range: 31 | # read results from the logfile 32 | logdir = exp.format(alpha, dp, seed) 33 | log = sorted([x for x in os.listdir(logdir) if "run" in x])[-1] 34 | log = os.path.join(logdir, log) 35 | with open(log, "r") as f: 36 | res = f.readlines()[-2] 37 | res = [x for x in res.split(" ") if x] 38 | 39 | # get metrics 40 | acc = float(res[res.index("Acc") + 1]) 41 | if mode == "ms": 42 | spacc = [] 43 | idx_spacc = res.index("SPacc") + 1 44 | while True: 45 | spacc_item = res[idx_spacc] 46 | spacc.append(float(spacc_item.rstrip("/"))) 47 | if not spacc_item.endswith("/"): 48 | break 49 | idx_spacc += 1 50 | avg_acc = sum(spacc) / len(spacc) 51 | else: 52 | bacc = float(res[res.index("Bacc") + 1]) 53 | nacc = float(res[res.index("Nacc") + 1]) 54 | oacc = float(res[res.index("Oacc") + 1]) 55 | if mode == "overlap": 56 | avg_acc = (bacc + nacc + oacc) / 3 57 | else: 58 | avg_acc = (bacc + nacc) / 2 59 | area = acc * avg_acc 60 | 61 | # append results 62 | Acc.update(acc, 1) 63 | if mode == "ms": 64 | for i, spacc_item in enumerate(spacc): 65 | SPacc[i].update(spacc_item, 1) 66 | else: 67 | Bacc.update(bacc, 1) 68 | Nacc.update(nacc, 1) 69 | Oacc.update(oacc, 1) 70 | Avg_acc.update(avg_acc, 1) 71 | Area.update(area, 1) 72 | 73 | # append results 74 | result["alpha"].append(alpha) 75 | result["dp"].append(dp) 76 | result["Acc (all)"].append(Acc.avg) 77 | if mode == "ms": 78 | for i, SPacc_item in SPacc.items(): 79 | result["Acc ({i})".format(i=i)].append(SPacc_item.avg) 80 | else: 81 | result["Acc (base)"].append(Bacc.avg) 82 | result["Acc (novel)"].append(Nacc.avg) 83 | result["Acc (overlap)"].append(Oacc.avg) 84 | result["Avg. Acc"].append(Avg_acc.avg) 85 | result["Area"].append(Area.avg) 86 | 87 | pd.set_option("precision", 4) 88 | if mode == "ms": 89 | df = pd.DataFrame( 90 | result, 91 | columns=["alpha", "dp", "Acc (all)"] + ["Acc ({i})".format(i=i) for i in SPacc] + ["Avg. Acc", "Area"], 92 | ) 93 | elif mode == "overlap": 94 | df = pd.DataFrame( 95 | result, 96 | columns=[ 97 | "alpha", 98 | "dp", 99 | "Acc (all)", 100 | "Acc (base)", 101 | "Acc (novel)", 102 | "Acc (overlap)", 103 | "Avg. Acc", 104 | "Area", 105 | ], 106 | ) 107 | else: 108 | df = pd.DataFrame( 109 | result, 110 | columns=[ 111 | "alpha", 112 | "dp", 113 | "Acc (all)", 114 | "Acc (base)", 115 | "Acc (novel)", 116 | "Avg. Acc", 117 | "Area", 118 | ], 119 | ) 120 | 121 | # hyper-parameter selection 122 | best_all_idx = df.iloc[:, 2].idxmax() 123 | best_avg_idx = df.iloc[:, -2].idxmax() 124 | best_area_idx = df.iloc[:, -1].idxmax() 125 | sel_idx = [best_all_idx, best_area_idx, best_avg_idx] 126 | df_sel = df.iloc[sel_idx] 127 | print(df_sel) 128 | 129 | # save searching results 130 | csv_file = os.path.join(path, "sel_result.csv") 131 | df_sel.to_csv(csv_file, index=False, float_format="%.4f") 132 | 133 | csv_file = os.path.join(path, "all_result.csv") 134 | df.to_csv(csv_file, index=False, float_format="%.4f") 135 | 136 | sel_hyper = df_sel.iloc[:, [0, 1]].to_dict("list") 137 | return list(zip(sel_hyper["dp"], sel_hyper["alpha"])) 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description="config") 142 | parser.add_argument( 143 | "--exp", 144 | type=str, 145 | default="runs/imnet_delta/split2/resnet10_2br_800_layer4_40_a{:g}dp{:g}_10shot_s{:g}-dev", 146 | help="experiment", 147 | ) 148 | parser.add_argument( 149 | "--mode", 150 | default=None, 151 | choices=["overlap", "ms"], 152 | help="is overlapping case", 153 | ) 154 | parser.add_argument( 155 | "--n_seed", 156 | type=int, 157 | default=1, 158 | help="number of seeds for hyperparameter search", 159 | ) 160 | args = parser.parse_args() 161 | 162 | sel_hyper = parse_result(args.exp, args.mode, args.n_seed) 163 | -------------------------------------------------------------------------------- /src/test_fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import os 6 | import yaml 7 | import logging 8 | import torch 9 | from torch import nn 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | 13 | from loader import get_dataloader 14 | from models import get_model 15 | from metrics import averageMeter 16 | from utils import cvt2normal_state 17 | 18 | 19 | def main(): 20 | global norm, n_base_cls, n_overlap_cls, n_sel_cls 21 | 22 | if not torch.cuda.is_available(): 23 | raise SystemExit("GPU is needed.") 24 | 25 | # setup mode 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | # setup data loader 30 | splits = ["val"] 31 | data_loader = get_dataloader(cfg["data"], splits, args.b) 32 | 33 | # config 34 | n_overlap_cls = cfg["model"]["classifier"].get("n_overlap", 0) 35 | n_base_cls = cfg["config"].get("n_base_cls", 800) 36 | n_sel_cls = cfg["config"].get("n_sel_cls", n_base_cls) 37 | norm = cfg["config"].get("norm", -1) 38 | 39 | # setup model (feature extractor + classifier) 40 | n_gpu = torch.cuda.device_count() 41 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 42 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 43 | 44 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 45 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 46 | print("{} gpu(s) available.".format(n_gpu)) 47 | 48 | # load checkpoints 49 | ckpt1 = cfg["checkpoint"].get("model1", None) 50 | if ckpt1 is None: 51 | raise BaseException("Checkpoint needs to be specified.") 52 | 53 | if not os.path.isfile(ckpt1): 54 | raise BaseException("No checkpoint found at '{}'".format(ckpt1)) 55 | 56 | ckpt2 = cfg["checkpoint"].get("model2", None) 57 | if ckpt2: 58 | if not os.path.isfile(ckpt2): 59 | raise BaseException("No checkpoint found at '{}'".format(ckpt2)) 60 | 61 | if ckpt2 is None: 62 | checkpoint1 = torch.load(ckpt1) 63 | model_fe.module.load_state_dict(cvt2normal_state(checkpoint1["model_fe_state"])) 64 | model_cls.module.load_state_dict(cvt2normal_state(checkpoint1["model_cls_state"])) 65 | print("Loading model from checkpoint '{}'".format(ckpt1)) 66 | if logger: 67 | logger.info("Loading model from checkpoint '{}'".format(ckpt1)) 68 | else: 69 | checkpoint1 = torch.load(ckpt1) 70 | checkpoint2 = torch.load(ckpt2) 71 | model_fe.module.load_state_dict2( 72 | cvt2normal_state(checkpoint1["model_fe_state"]), 73 | cvt2normal_state(checkpoint2["model_fe_state"]), 74 | ) 75 | model_cls.module.load_state_dict2( 76 | cvt2normal_state(checkpoint1["model_cls_state"]), 77 | cvt2normal_state(checkpoint2["model_cls_state"]), 78 | ) 79 | print("Loading model from checkpoint '{}' and '{}'".format(ckpt1, ckpt2)) 80 | if logger: 81 | logger.info("Loading model from checkpoint '{}' and '{}'".format(ckpt1, ckpt2)) 82 | 83 | with torch.no_grad(): 84 | val(data_loader["val"], model_fe, model_cls) 85 | 86 | 87 | def val(data_loader, model_fe, model_cls): 88 | 89 | # setup average meters 90 | racc = averageMeter() 91 | acc = averageMeter() 92 | bacc = averageMeter() 93 | nacc = averageMeter() 94 | oacc = averageMeter() 95 | base2novel = averageMeter() 96 | novel2base = averageMeter() 97 | base_blogit = averageMeter() 98 | base_nlogit = averageMeter() 99 | novel_blogit = averageMeter() 100 | novel_nlogit = averageMeter() 101 | 102 | # setting evaluation mode 103 | model_fe.eval() 104 | model_cls.eval() 105 | 106 | one = torch.tensor([1]).cuda() 107 | for (step, value) in enumerate(data_loader): 108 | 109 | image = value[0].cuda() 110 | target = value[1].cuda(non_blocking=True) 111 | isnovel = target >= n_base_cls 112 | isoverlap = target < n_overlap_cls 113 | isbase = (~isnovel) * (~isoverlap) 114 | target[isnovel] -= n_base_cls - n_sel_cls 115 | 116 | # forward 117 | _, feat1, feat2 = model_fe(image, feat=True) 118 | if norm > 0: 119 | feat1 = F.normalize(feat1, p=2, dim=1) * norm 120 | feat2 = F.normalize(feat2, p=2, dim=1) * norm 121 | output = model_cls(feat1, feat2) 122 | 123 | # measure accuracy 124 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 125 | iscorrect = torch.eq(pred, target) 126 | all_acc = iscorrect.float().mean() 127 | acc.update(all_acc.item(), image.size(0)) 128 | 129 | # measure base and novel accuracy 130 | n_base = isbase.long().sum() 131 | n_novel = isnovel.long().sum() 132 | n_overlap = isoverlap.long().sum() 133 | assert (n_base + n_novel + n_overlap) == image.size(0) 134 | b_acc = iscorrect[isbase].float().mean() 135 | n_acc = iscorrect[isnovel].float().mean() 136 | o_acc = iscorrect[isoverlap].float().mean() 137 | if n_base > 0: 138 | bacc.update(b_acc.item(), n_base) 139 | if n_novel > 0: 140 | nacc.update(n_acc.item(), n_novel) 141 | if n_overlap > 0: 142 | oacc.update(o_acc.item(), n_overlap) 143 | 144 | # other analysis 145 | b2n = (pred[~isnovel] >= n_sel_cls).float().mean() if n_base > 0 else one 146 | n2b = (pred[isnovel] < n_sel_cls).float().mean() if n_novel > 0 else one 147 | if n_base > 0: 148 | base2novel.update(b2n.item(), n_base) 149 | if n_novel > 0: 150 | novel2base.update(n2b.item(), n_novel) 151 | r_acc = ((1 - b2n) * (n_base + n_overlap) + (1 - n2b) * n_novel) / image.size(0) 152 | racc.update(r_acc.item(), image.size(0)) 153 | 154 | flag_has_nlogit = n_sel_cls < output.shape[1] 155 | blogit = output[:, :n_sel_cls].max(dim=1)[0] / (norm**2) 156 | if flag_has_nlogit: 157 | nlogit = output[:, n_sel_cls:].max(dim=1)[0] / (norm**2) 158 | if n_base > 0: 159 | base_blogit.update(blogit[~isnovel].mean().item(), n_base) 160 | if flag_has_nlogit: 161 | base_nlogit.update(nlogit[~isnovel].mean().item(), n_base) 162 | if n_novel > 0: 163 | novel_blogit.update(blogit[isnovel].mean().item(), n_novel) 164 | if flag_has_nlogit: 165 | novel_nlogit.update(nlogit[isnovel].mean().item(), n_novel) 166 | 167 | print_str = ( 168 | "[Val] Acc {acc.avg:.4f} " 169 | "Racc {racc.avg: .3f} " 170 | "Bacc {bacc.avg: .4f} " 171 | "Nacc {nacc.avg: .4f} " 172 | "Oacc {oacc.avg: .4f} " 173 | "Base2novel {b2n.avg:.3f} " 174 | "Novel2base {n2b.avg:.3f} " 175 | "Blogit [B/N] [{bbl.avg:.3f}/{bnl.avg:.3f}] " 176 | "Nlogit [B/N] [{nbl.avg:.3f}/{nnl.avg:.3f}]".format( 177 | acc=acc, 178 | racc=racc, 179 | bacc=bacc, 180 | nacc=nacc, 181 | oacc=oacc, 182 | b2n=base2novel, 183 | n2b=novel2base, 184 | bbl=base_blogit, 185 | bnl=base_nlogit, 186 | nbl=novel_blogit, 187 | nnl=novel_nlogit, 188 | ) 189 | ) 190 | print(print_str) 191 | if logger: 192 | logger.info(print_str) 193 | 194 | 195 | if __name__ == "__main__": 196 | global cfg, args, logger 197 | 198 | parser = argparse.ArgumentParser(description="config") 199 | parser.add_argument( 200 | "--config", 201 | type=str, 202 | default="test_configs/split2_800_40/resnet10/poolcls.yaml", 203 | help="config file for testing", 204 | ) 205 | parser.add_argument( 206 | "--checkpoint1", 207 | type=str, 208 | default=None, 209 | help="checkpoint of base classifier / whole score fusion net", 210 | ) 211 | parser.add_argument( 212 | "--checkpoint2", 213 | type=str, 214 | default=None, 215 | help="checkpoint of novel classifier", 216 | ) 217 | parser.add_argument( 218 | "--pool", 219 | type=str, 220 | default=None, 221 | help="max / avg", 222 | ) 223 | parser.add_argument( 224 | "--log", 225 | type=str, 226 | default=None, 227 | help="path to the logfile (default: None)", 228 | ) 229 | parser.add_argument( 230 | "--b", 231 | type=int, 232 | default=256, 233 | help="batch size", 234 | ) 235 | 236 | args = parser.parse_args() 237 | print(args) 238 | 239 | if args.log: 240 | logger = logging.getLogger("mylogger") 241 | hdlr = logging.FileHandler(args.log) 242 | logger.addHandler(hdlr) 243 | logger.setLevel(logging.INFO) 244 | else: 245 | logger = None 246 | 247 | with open(args.config) as fp: 248 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 249 | 250 | if args.checkpoint1: 251 | cfg["checkpoint"]["model1"] = args.checkpoint1 252 | cfg["checkpoint"]["model2"] = args.checkpoint2 253 | 254 | if args.pool and cfg["model"]["classifier"].get("pool", None): 255 | cfg["model"]["classifier"]["pool"] = args.pool 256 | 257 | print(cfg) 258 | if logger: 259 | logger.info(cfg) 260 | 261 | main() 262 | -------------------------------------------------------------------------------- /src/test_fusion_ms.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import os 6 | import yaml 7 | import logging 8 | import torch 9 | from torch import nn 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | 13 | from loader import get_dataloader 14 | from models import get_model 15 | from metrics import averageMeter 16 | from utils import cvt2normal_state 17 | 18 | 19 | def main(): 20 | # global norm, n_base_cls, n_overlap_cls, n_sel_cls 21 | global norm, n_base_cls, n_all_cls 22 | 23 | if not torch.cuda.is_available(): 24 | raise SystemExit("GPU is needed.") 25 | 26 | # setup mode 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | # setup data loader 31 | splits = ["val"] 32 | data_loader = get_dataloader(cfg["data"], splits, args.b) 33 | 34 | # config 35 | n_novel_cls = cfg["config"].get("n_novel_cls", 40) 36 | n_base_cls = cfg["config"].get("n_base_cls", [800]) 37 | assert cfg["config"].get("n_sel_cls", n_base_cls) == n_base_cls 38 | assert cfg["config"].get("n_overlap", 0) == 0 39 | assert cfg["model"]["classifier"].get("n_overlap", 0) == 0 40 | n_all_cls = n_base_cls + [n_novel_cls] 41 | norm = cfg["config"].get("norm", -1) 42 | 43 | # setup model (feature extractor + classifier) 44 | n_gpu = torch.cuda.device_count() 45 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 46 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 47 | 48 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 49 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 50 | print("{} gpu(s) available.".format(n_gpu)) 51 | 52 | # load checkpoints 53 | ckpt1 = cfg["checkpoint"].get("model1", None) 54 | if ckpt1 is None: 55 | raise BaseException("Checkpoint needs to be specified.") 56 | 57 | if not os.path.isfile(ckpt1): 58 | raise BaseException("No checkpoint found at '{}'".format(ckpt1)) 59 | 60 | checkpoint1 = torch.load(ckpt1) 61 | model_fe.module.load_state_dict(cvt2normal_state(checkpoint1["model_fe_state"])) 62 | model_cls.module.load_state_dict(cvt2normal_state(checkpoint1["model_cls_state"])) 63 | print("Loading model from checkpoint '{}'".format(ckpt1)) 64 | if logger: 65 | logger.info("Loading model from checkpoint '{}'".format(ckpt1)) 66 | 67 | with torch.no_grad(): 68 | val(data_loader["val"], model_fe, model_cls) 69 | 70 | 71 | def val(data_loader, model_fe, model_cls): 72 | 73 | # setup average meters 74 | racc = averageMeter() 75 | acc = averageMeter() 76 | spacc = [averageMeter() for _ in n_all_cls] 77 | base2novel = averageMeter() 78 | novel2base = averageMeter() 79 | base_blogit = averageMeter() 80 | base_nlogit = averageMeter() 81 | novel_blogit = averageMeter() 82 | novel_nlogit = averageMeter() 83 | 84 | # setting evaluation mode 85 | model_fe.eval() 86 | model_cls.eval() 87 | split_lookup = sum([[idx] * n for idx, n in enumerate(n_all_cls)], []) 88 | split_lookup = torch.tensor(split_lookup).long().cuda() 89 | 90 | one = torch.tensor([1]).cuda() 91 | for (step, value) in enumerate(data_loader): 92 | 93 | image = value[0].cuda() 94 | target = value[1].cuda(non_blocking=True) 95 | split_ids = split_lookup[target] 96 | 97 | # forward 98 | _, *feats = model_fe(image, feat=True) 99 | if norm > 0: 100 | feats = [F.normalize(feat, p=2, dim=1) * norm for feat in feats] 101 | output = model_cls(*feats) 102 | 103 | # measure accuracy 104 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 105 | iscorrect = torch.eq(pred, target) 106 | all_acc = iscorrect.float().mean() 107 | acc.update(all_acc.item(), image.size(0)) 108 | 109 | # measure base and novel accuracy 110 | n_split_samples = [] 111 | for idx_split, n_split_cls in enumerate(n_all_cls): 112 | is_curr_split = split_ids == idx_split 113 | n_split_sample = is_curr_split.long().sum() 114 | split_acc = iscorrect[is_curr_split].float().mean() 115 | if n_split_sample: 116 | spacc[idx_split].update(split_acc.item(), n_split_sample) 117 | n_split_samples.append(n_split_sample) 118 | assert sum(n_split_samples) == image.size(0) 119 | 120 | # other analysis 121 | n_splits = len(n_all_cls) 122 | pred_split = split_lookup[pred] 123 | n_split_samples_prv = sum(n_split_samples[:-1]) 124 | 125 | b2n = (pred_split[split_ids < n_splits - 1] == n_splits - 1).float().mean() if n_split_samples_prv > 0 else one 126 | n2b = (pred_split[split_ids == n_splits - 1] < n_splits - 1).float().mean() if n_split_samples[-1] > 0 else one 127 | if n_split_samples_prv > 0: 128 | base2novel.update(b2n.item(), n_split_samples_prv) 129 | if n_split_samples[-1] > 0: 130 | novel2base.update(n2b.item(), n_split_samples[-1]) 131 | r_acc = (pred_split == split_ids).float().mean() 132 | racc.update(r_acc.item(), image.size(0)) 133 | 134 | n_last_novel_cls = n_all_cls[-1] 135 | flag_has_nlogit = bool(n_last_novel_cls) 136 | blogit = output[:, :-n_last_novel_cls].max(dim=1)[0] / (norm**2) 137 | if flag_has_nlogit: 138 | nlogit = output[:, -n_last_novel_cls:].max(dim=1)[0] / (norm**2) 139 | if n_split_samples_prv > 0: 140 | base_blogit.update(blogit[split_ids < n_splits - 1].mean().item(), n_split_samples_prv) 141 | if flag_has_nlogit: 142 | base_nlogit.update(nlogit[split_ids < n_splits - 1].mean().item(), n_split_samples_prv) 143 | if n_split_samples[-1] > 0: 144 | novel_blogit.update(blogit[split_ids == n_splits - 1].mean().item(), n_split_samples[-1]) 145 | if flag_has_nlogit: 146 | novel_nlogit.update(nlogit[split_ids == n_splits - 1].mean().item(), n_split_samples[-1]) 147 | 148 | print_str = ( 149 | "[Val] Acc {acc.avg:.4f} " 150 | "Racc {racc.avg: .3f} " 151 | "SPacc {spacc} " 152 | "Base2novel {b2n.avg:.3f} " 153 | "Novel2base {n2b.avg:.3f} " 154 | "Blogit [B/N] [{bbl.avg:.3f}/{bnl.avg:.3f}] " 155 | "Nlogit [B/N] [{nbl.avg:.3f}/{nnl.avg:.3f}]".format( 156 | acc=acc, 157 | racc=racc, 158 | b2n=base2novel, 159 | n2b=novel2base, 160 | spacc="/".join("{: .4f}".format(x.avg) for x in spacc), 161 | bbl=base_blogit, 162 | bnl=base_nlogit, 163 | nbl=novel_blogit, 164 | nnl=novel_nlogit, 165 | ) 166 | ) 167 | print(print_str) 168 | if logger: 169 | logger.info(print_str) 170 | 171 | 172 | if __name__ == "__main__": 173 | global cfg, args, logger 174 | 175 | parser = argparse.ArgumentParser(description="config") 176 | parser.add_argument( 177 | "--config", 178 | type=str, 179 | default="test_configs/split2_800_40/resnet10/poolcls.yaml", 180 | help="config file for testing", 181 | ) 182 | parser.add_argument( 183 | "--checkpoint1", 184 | type=str, 185 | default=None, 186 | help="checkpoint of the whole score fusion net", 187 | ) 188 | parser.add_argument( 189 | "--pool", 190 | type=str, 191 | default=None, 192 | help="max / avg", 193 | ) 194 | parser.add_argument( 195 | "--log", 196 | type=str, 197 | default=None, 198 | help="path to the logfile (default: None)", 199 | ) 200 | parser.add_argument( 201 | "--b", 202 | type=int, 203 | default=256, 204 | help="batch size", 205 | ) 206 | 207 | args = parser.parse_args() 208 | print(args) 209 | 210 | if args.log: 211 | logger = logging.getLogger("mylogger") 212 | hdlr = logging.FileHandler(args.log) 213 | logger.addHandler(hdlr) 214 | logger.setLevel(logging.INFO) 215 | else: 216 | logger = None 217 | 218 | with open(args.config) as fp: 219 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 220 | 221 | if args.checkpoint1: 222 | cfg["checkpoint"]["model1"] = args.checkpoint1 223 | 224 | if args.pool and cfg["model"]["classifier"].get("pool", None): 225 | cfg["model"]["classifier"]["pool"] = args.pool 226 | 227 | print(cfg) 228 | if logger: 229 | logger.info(cfg) 230 | 231 | main() 232 | -------------------------------------------------------------------------------- /src/train_base.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import argparse 6 | import os 7 | import yaml 8 | import shutil 9 | import torch 10 | from torch import nn 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.functional as F 13 | 14 | from loader import get_dataloader 15 | from models import get_model 16 | from optimizers import get_optimizer, step_scheduler 17 | from metrics import averageMeter, accuracy 18 | from utils import get_logger, cvt2normal_state 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | 23 | def main(): 24 | global norm 25 | 26 | if not torch.cuda.is_available(): 27 | raise SystemExit("GPU is needed.") 28 | 29 | # setup random seed 30 | torch.manual_seed(cfg.get("seed", 1)) 31 | torch.cuda.manual_seed(cfg.get("seed", 1)) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | # setup data loader 36 | splits = ["train", "val"] 37 | data_loader = get_dataloader(cfg["data"], splits, cfg["training"]["batch_size"]) 38 | 39 | # setup model (feature extractor + classifier) 40 | n_gpu = torch.cuda.device_count() 41 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 42 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 43 | 44 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 45 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 46 | print("{} gpu(s) available.".format(n_gpu)) 47 | 48 | # loss function 49 | criterion = nn.CrossEntropyLoss(reduction="none") 50 | norm = cfg["training"].get("norm", -1) 51 | 52 | # setup optimizer 53 | opt_main_cls, opt_main_params = get_optimizer(cfg["training"]["optimizer_main"]) 54 | cnn_params = list(model_fe.parameters()) + list(model_cls.parameters()) 55 | opt_main = opt_main_cls(cnn_params, **opt_main_params) 56 | logger.info("Using optimizer {}".format(opt_main)) 57 | 58 | # setup scheduler 59 | scheduler = step_scheduler(opt_main, **cfg["training"]["scheduler"]) 60 | 61 | # load checkpoint 62 | start_ep = 0 63 | if cfg["training"]["resume"].get("model", None): 64 | resume = cfg["training"]["resume"] 65 | if os.path.isfile(resume["model"]): 66 | print("Loading model from checkpoint '{}'".format(resume["model"])) 67 | logger.info("Loading model from checkpoint '{}'".format(resume["model"])) 68 | checkpoint = torch.load(resume["model"]) 69 | model_fe.module.load_state_dict(cvt2normal_state(checkpoint["model_fe_state"])) 70 | if resume.get("load_cls", False): 71 | model_cls.module.load_state_dict(cvt2normal_state(checkpoint["model_cls_state"])) 72 | logger.info("Loading classifier") 73 | if resume["param_only"] is False: 74 | start_ep = checkpoint["epoch"] 75 | opt_main.load_state_dict(checkpoint["opt_main_state"]) 76 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 77 | logger.info("Loaded checkpoint '{}' (iter {})".format(resume["model"], checkpoint["epoch"])) 78 | else: 79 | print("No checkpoint found at '{}'".format(resume["model"])) 80 | logger.info("No checkpoint found at '{}'".format(resume["model"])) 81 | 82 | print("Start training from epoch {}".format(start_ep)) 83 | logger.info("Start training from epoch {}".format(start_ep)) 84 | 85 | for ep in range(start_ep, cfg["training"]["epoch"]): 86 | 87 | train(data_loader["train"], model_fe, model_cls, opt_main, ep, criterion) 88 | 89 | if (ep + 1) % cfg["training"]["val_interval"] == 0: 90 | with torch.no_grad(): 91 | val(data_loader["val"], model_fe, model_cls, ep, criterion) 92 | 93 | if (ep + 1) % cfg["training"]["save_interval"] == 0: 94 | state = { 95 | "epoch": ep + 1, 96 | "model_fe_state": model_fe.state_dict(), 97 | "model_cls_state": model_cls.state_dict(), 98 | "opt_main_state": opt_main.state_dict(), 99 | "scheduler_state": scheduler.state_dict(), 100 | } 101 | ckpt_path = os.path.join(writer.file_writer.get_logdir(), "ep-{ep}_model.pkl") 102 | save_path = ckpt_path.format(ep=ep + 1) 103 | last_path = ckpt_path.format(ep=ep + 1 - cfg["training"]["save_interval"]) 104 | torch.save(state, save_path) 105 | if os.path.isfile(last_path): 106 | os.remove(last_path) 107 | print_str = "[Checkpoint]: {} saved".format(save_path) 108 | print(print_str) 109 | logger.info(print_str) 110 | 111 | scheduler.step() 112 | 113 | 114 | def train(data_loader, model_fe, model_cls, opt_main, epoch, criterion): 115 | 116 | # setup average meters 117 | batch_time = averageMeter() 118 | data_time = averageMeter() 119 | losses = averageMeter() 120 | top1 = averageMeter() 121 | top5 = averageMeter() 122 | 123 | # setting training mode 124 | model_fe.train() 125 | model_cls.train() 126 | 127 | n_step = int(len(data_loader.dataset) // float(data_loader.batch_size)) 128 | end = time.time() 129 | for (step, value) in enumerate(data_loader): 130 | 131 | # measure data loading time 132 | data_time.update(time.time() - end) 133 | 134 | image = value[0].cuda() 135 | target = value[1].cuda(non_blocking=True) 136 | 137 | # forward 138 | imfeat = model_fe(image) 139 | if norm > 0: 140 | imfeat = F.normalize(imfeat, p=2, dim=1) * norm 141 | output = model_cls(imfeat) 142 | 143 | loss = torch.mean(criterion(output, target).squeeze()) 144 | losses.update(loss.item(), image.size(0)) 145 | 146 | # measure accuracy 147 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 148 | top1.update(prec1[0], image.size(0)) 149 | top5.update(prec5[0], image.size(0)) 150 | 151 | # back propagation 152 | opt_main.zero_grad() 153 | loss.backward() 154 | opt_main.step() 155 | 156 | # measure elapsed time 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | 160 | if (step + 1) % 10 == 0: 161 | curr_lr_main = opt_main.param_groups[0]["lr"] 162 | print_str = ( 163 | "Epoch [{0}/{1}] " 164 | "Step: [{2}/{3}] " 165 | "LR: [{4}] " 166 | "Time {batch_time.avg:.3f} " 167 | "Data {data_time.avg:.3f} " 168 | "Loss {loss.avg:.4f} " 169 | "Top1 {top1.avg:.3f} " 170 | "Top5 {top5.avg:.3f}".format( 171 | epoch + 1, 172 | cfg["training"]["epoch"], 173 | step + 1, 174 | n_step, 175 | curr_lr_main, 176 | batch_time=batch_time, 177 | data_time=data_time, 178 | loss=losses, 179 | top1=top1, 180 | top5=top5, 181 | ) 182 | ) 183 | 184 | print(print_str) 185 | logger.info(print_str) 186 | 187 | if (epoch + 1) % cfg["training"]["print_interval"] == 0: 188 | curr_lr_main = opt_main.param_groups[0]["lr"] 189 | print_str = ( 190 | "Epoch: [{0}/{1}] " 191 | "LR: [{2}] " 192 | "Time {batch_time.avg:.3f} " 193 | "Data {data_time.avg:.3f} " 194 | "Loss {loss.avg:.4f} " 195 | "Top1 {top1.avg:.3f} " 196 | "Top5 {top5.avg:.3f}".format( 197 | epoch + 1, 198 | cfg["training"]["epoch"], 199 | curr_lr_main, 200 | batch_time=batch_time, 201 | data_time=data_time, 202 | loss=losses, 203 | top1=top1, 204 | top5=top5, 205 | ) 206 | ) 207 | 208 | print(print_str) 209 | logger.info(print_str) 210 | writer.add_scalar("train/lr", curr_lr_main, epoch + 1) 211 | writer.add_scalar("train/loss", losses.avg, epoch + 1) 212 | writer.add_scalar("train/top1", top1.avg, epoch + 1) 213 | writer.add_scalar("train/top5", top5.avg, epoch + 1) 214 | 215 | 216 | def val(data_loader, model_fe, model_cls, epoch, criterion): 217 | 218 | # setup average meters 219 | losses = averageMeter() 220 | top1 = averageMeter() 221 | top5 = averageMeter() 222 | 223 | # setting evaluation mode 224 | model_fe.eval() 225 | model_cls.eval() 226 | 227 | for (step, value) in enumerate(data_loader): 228 | 229 | image = value[0].cuda() 230 | target = value[1].cuda(non_blocking=True) 231 | 232 | # forward 233 | imfeat = model_fe(image) 234 | if norm > 0: 235 | imfeat = F.normalize(imfeat, p=2, dim=1) * norm 236 | output = model_cls(imfeat) 237 | 238 | loss = torch.mean(criterion(output, target).squeeze()) 239 | losses.update(loss.item(), image.size(0)) 240 | 241 | # measure accuracy 242 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 243 | top1.update(prec1[0], image.size(0)) 244 | top5.update(prec5[0], image.size(0)) 245 | 246 | print_str = "[Val] Prec@1 {top1.avg:.3f}\t Prec@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 247 | print(print_str) 248 | logger.info(print_str) 249 | 250 | writer.add_scalar("val/loss", losses.avg, epoch + 1) 251 | writer.add_scalar("val/top1", top1.avg, epoch + 1) 252 | writer.add_scalar("val/top5", top5.avg, epoch + 1) 253 | 254 | 255 | if __name__ == "__main__": 256 | global cfg, args, writer, logger 257 | 258 | parser = argparse.ArgumentParser(description="config") 259 | parser.add_argument( 260 | "--config", 261 | nargs="?", 262 | type=str, 263 | default="configs/imnet_base.yml", 264 | help="Configuration file to use", 265 | ) 266 | 267 | args = parser.parse_args() 268 | 269 | with open(args.config) as fp: 270 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 271 | 272 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], cfg["exp"]) 273 | writer = SummaryWriter(log_dir=logdir) 274 | 275 | print("RUNDIR: {}".format(logdir)) 276 | shutil.copy(args.config, logdir) 277 | 278 | logger = get_logger(logdir) 279 | logger.info("Start logging") 280 | 281 | print(args) 282 | logger.info(args) 283 | 284 | main() 285 | -------------------------------------------------------------------------------- /src/train_fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import argparse 6 | import os 7 | import yaml 8 | import shutil 9 | import torch 10 | from torch import nn 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.functional as F 13 | 14 | from loader import get_dataloader 15 | from models import get_model 16 | from optimizers import get_optimizer, step_scheduler 17 | from metrics import averageMeter, accuracy 18 | from utils import get_logger, cvt2normal_state 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | 23 | def main(): 24 | global norm, n_base_cls, n_novel_cls, n_overlap_cls, n_sel_cls 25 | 26 | if not torch.cuda.is_available(): 27 | raise SystemExit("GPU is needed.") 28 | 29 | # setup random seed 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | # setup data loader 36 | splits = ["train", "val"] 37 | data_loader = get_dataloader(cfg["data"], splits, cfg["training"]["batch_size"]) 38 | 39 | # config 40 | n_novel_cls = cfg["model"]["classifier"].get("n_nclass", 40) 41 | n_overlap_cls = cfg["model"]["classifier"].get("n_overlap", 0) 42 | n_base_cls = cfg["training"].get("n_base_cls", 800) 43 | n_sel_cls = cfg["training"].get("n_sel_cls", n_base_cls) 44 | 45 | # setup model (feature extractor + classifier) 46 | n_gpu = torch.cuda.device_count() 47 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 48 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 49 | 50 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 51 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 52 | print("{} gpu(s) available.".format(n_gpu)) 53 | 54 | # loss function 55 | criterion = nn.CrossEntropyLoss(reduction="none") 56 | norm = cfg["training"].get("norm", -1) 57 | 58 | # setup optimizer 59 | opt_main_cls, opt_main_params = get_optimizer(cfg["training"]["optimizer_main"]) 60 | cnn_params = list(model_fe.parameters()) + list(model_cls.parameters()) 61 | opt_main = opt_main_cls(cnn_params, **opt_main_params) 62 | logger.info("Using optimizer {}".format(opt_main)) 63 | 64 | # setup scheduler 65 | scheduler = step_scheduler(opt_main, **cfg["training"]["scheduler"]) 66 | 67 | # load checkpoint 68 | start_ep = 0 69 | if cfg["training"]["resume"].get("model1", False) and cfg["training"]["resume"].get("model2", False): 70 | model1 = cfg["training"]["resume"]["model1"] 71 | model2 = cfg["training"]["resume"]["model2"] 72 | if not os.path.isfile(model1): 73 | print("No checkpoint found at '{}'".format(model1)) 74 | logger.info("No checkpoint found at '{}'".format(model1)) 75 | elif not os.path.isfile(model2): 76 | print("No checkpoint found at '{}'".format(model2)) 77 | logger.info("No checkpoint found at '{}'".format(model2)) 78 | else: 79 | print("Loading model from checkpoint '{}' and '{}'".format(model1, model2)) 80 | logger.info("Loading model from checkpoint '{}' and '{}'".format(model1, model2)) 81 | checkpoint1 = torch.load(model1) 82 | checkpoint2 = torch.load(model2) 83 | model_fe.module.load_state_dict2( 84 | cvt2normal_state(checkpoint1["model_fe_state"]), 85 | cvt2normal_state(checkpoint2["model_fe_state"]), 86 | ) 87 | model_cls.module.load_state_dict2( 88 | cvt2normal_state(checkpoint1["model_cls_state"]), 89 | cvt2normal_state(checkpoint2["model_cls_state"]), 90 | ) 91 | logger.info("Loading classifier") 92 | 93 | print("Start training from epoch {}".format(start_ep)) 94 | logger.info("Start training from epoch {}".format(start_ep)) 95 | 96 | for ep in range(start_ep, cfg["training"]["epoch"]): 97 | 98 | train(data_loader["train"], model_fe, model_cls, opt_main, ep, criterion) 99 | 100 | if (ep + 1) % cfg["training"]["val_interval"] == 0: 101 | with torch.no_grad(): 102 | val(data_loader["val"], model_fe, model_cls, ep, criterion) 103 | 104 | if (ep + 1) % cfg["training"]["save_interval"] == 0: 105 | state = { 106 | "epoch": ep + 1, 107 | "model_fe_state": model_fe.state_dict(), 108 | "model_cls_state": model_cls.state_dict(), 109 | "opt_main_state": opt_main.state_dict(), 110 | "scheduler_state": scheduler.state_dict(), 111 | } 112 | ckpt_path = os.path.join(writer.file_writer.get_logdir(), "ep-{ep}_model.pkl") 113 | save_path = ckpt_path.format(ep=ep + 1) 114 | last_path = ckpt_path.format(ep=ep + 1 - cfg["training"]["save_interval"]) 115 | torch.save(state, save_path) 116 | if os.path.isfile(last_path): 117 | os.remove(last_path) 118 | print_str = "[Checkpoint]: {} saved".format(save_path) 119 | print(print_str) 120 | logger.info(print_str) 121 | 122 | scheduler.step() 123 | 124 | 125 | def train(data_loader, model_fe, model_cls, opt_main, epoch, criterion): 126 | 127 | # setup average meters 128 | batch_time = averageMeter() 129 | data_time = averageMeter() 130 | losses = averageMeter() 131 | acc = averageMeter() 132 | bacc = averageMeter() 133 | nacc = averageMeter() 134 | oacc = averageMeter() 135 | 136 | # setting training mode 137 | model_fe.train() 138 | model_cls.train() 139 | 140 | n_step = int(len(data_loader.dataset) // float(data_loader.batch_size)) 141 | end = time.time() 142 | for (step, value) in enumerate(data_loader): 143 | # measure data loading time 144 | data_time.update(time.time() - end) 145 | 146 | image = value[0].cuda() 147 | target = value[1].cuda(non_blocking=True) 148 | isnovel = target >= n_base_cls 149 | isoverlap = target < n_overlap_cls 150 | isbase = (~isnovel) * (~isoverlap) 151 | target[isnovel] -= n_base_cls - n_sel_cls # for partial selection (e.g. 40 / 800) 152 | 153 | # split label 154 | if len(value) > 3: 155 | sp_lbl = value[3].cuda() 156 | else: 157 | sp_lbl = isnovel.long() 158 | 159 | # forward 160 | _, feat1, feat2 = model_fe(image, feat=True) 161 | if norm > 0: 162 | feat1 = F.normalize(feat1, p=2, dim=1) * norm 163 | feat2 = F.normalize(feat2, p=2, dim=1) * norm 164 | 165 | # hn dropout 166 | dp_scaling = cfg["training"].get("dp_scaling", True) 167 | if dp_scaling is True: 168 | feat2[(sp_lbl == 0), :] = F.dropout(feat2[(sp_lbl == 0), :], p=dp) * (1 - dp) 169 | elif dp_scaling == "no_dp": 170 | feat2[(sp_lbl == 0), :] = feat2[(sp_lbl == 0), :] * (1 - dp) 171 | else: 172 | assert not dp_scaling 173 | feat2[(sp_lbl == 0), :] = F.dropout(feat2[(sp_lbl == 0), :], p=dp) 174 | 175 | output, route_out = model_cls(feat1, feat2, route=True) 176 | 177 | # compute loss 178 | r_ce = criterion(route_out, sp_lbl).squeeze() 179 | ce_b = r_ce[(sp_lbl == 0)].mean() 180 | ce_n = r_ce[(sp_lbl > 0)].mean() 181 | rloss = (ce_b + ce_n) / 2 182 | closs = torch.mean(criterion(output, target).squeeze()) 183 | 184 | loss = (1 - alpha) * closs + alpha * rloss 185 | losses.update(loss.item(), image.size(0)) 186 | 187 | # measure accuracy 188 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 189 | iscorrect = torch.eq(pred, target) 190 | all_acc = iscorrect.float().mean() 191 | acc.update(all_acc.item(), image.size(0)) 192 | 193 | # measure base and novel accuracy 194 | n_base = isbase.long().sum() 195 | n_novel = isnovel.long().sum() 196 | n_overlap = isoverlap.long().sum() 197 | assert (n_base + n_novel + n_overlap) == image.size(0) 198 | b_acc = iscorrect[isbase].float().mean() 199 | n_acc = iscorrect[isnovel].float().mean() 200 | o_acc = iscorrect[isoverlap].float().mean() 201 | if n_base > 0: 202 | bacc.update(b_acc.item(), n_base) 203 | if n_novel > 0: 204 | nacc.update(n_acc.item(), n_novel) 205 | if n_overlap > 0: 206 | oacc.update(o_acc.item(), n_overlap) 207 | 208 | # back propagation 209 | opt_main.zero_grad() 210 | loss.backward() 211 | opt_main.step() 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | if (step + 1) % 10 == 0: 218 | curr_lr_main = opt_main.param_groups[0]["lr"] 219 | print_str = ( 220 | "Epoch [{0}/{1}] " 221 | "Step: [{2}/{3}] " 222 | "LR: [{4}] " 223 | "Time {batch_time.avg:.3f} " 224 | "Data {data_time.avg:.3f} " 225 | "Loss {loss.avg:.4f} " 226 | "Acc {acc.avg:.3f} " 227 | "BaseAcc {bacc.avg:.3f} " 228 | "NovelAcc {nacc.avg:.3f} " 229 | "OverlapAcc {oacc.avg:.3f}".format( 230 | epoch + 1, 231 | cfg["training"]["epoch"], 232 | step + 1, 233 | n_step, 234 | curr_lr_main, 235 | batch_time=batch_time, 236 | data_time=data_time, 237 | loss=losses, 238 | acc=acc, 239 | bacc=bacc, 240 | nacc=nacc, 241 | oacc=oacc, 242 | ) 243 | ) 244 | 245 | print(print_str) 246 | logger.info(print_str) 247 | 248 | if (epoch + 1) % cfg["training"]["print_interval"] == 0: 249 | curr_lr_main = opt_main.param_groups[0]["lr"] 250 | print_str = ( 251 | "Epoch: [{0}/{1}] " 252 | "LR: [{2}] " 253 | "Time {batch_time.avg:.3f} " 254 | "Data {data_time.avg:.3f} " 255 | "Loss {loss.avg:.4f} " 256 | "Acc {acc.avg:.3f} " 257 | "BaseAcc {bacc.avg:.3f} " 258 | "NovelAcc {nacc.avg:.3f} " 259 | "OverlapAcc {oacc.avg:.3f}".format( 260 | epoch + 1, 261 | cfg["training"]["epoch"], 262 | curr_lr_main, 263 | batch_time=batch_time, 264 | data_time=data_time, 265 | loss=losses, 266 | acc=acc, 267 | bacc=bacc, 268 | nacc=nacc, 269 | oacc=oacc, 270 | ) 271 | ) 272 | 273 | print(print_str) 274 | logger.info(print_str) 275 | writer.add_scalar("train/lr", curr_lr_main, epoch + 1) 276 | writer.add_scalar("train/loss", losses.avg, epoch + 1) 277 | writer.add_scalar("train/acc", acc.avg, epoch + 1) 278 | writer.add_scalar("train/bacc", bacc.avg, epoch + 1) 279 | writer.add_scalar("train/nacc", nacc.avg, epoch + 1) 280 | writer.add_scalar("train/oacc", oacc.avg, epoch + 1) 281 | 282 | 283 | def val(data_loader, model_fe, model_cls, epoch, criterion): 284 | 285 | # setup average meters 286 | losses = averageMeter() 287 | racc = averageMeter() 288 | acc = averageMeter() 289 | bacc = averageMeter() 290 | nacc = averageMeter() 291 | oacc = averageMeter() 292 | base2novel = averageMeter() 293 | novel2base = averageMeter() 294 | 295 | # setting evaluation mode 296 | model_fe.eval() 297 | model_cls.eval() 298 | 299 | one = torch.tensor([1]).cuda() 300 | for (step, value) in enumerate(data_loader): 301 | 302 | image = value[0].cuda() 303 | target = value[1].cuda(non_blocking=True) 304 | isnovel = target >= n_base_cls 305 | isoverlap = target < n_overlap_cls 306 | isbase = (~isnovel) * (~isoverlap) 307 | target[isnovel] -= n_base_cls - n_sel_cls # for partial selection (e.g. 40 / 800) 308 | 309 | # forward 310 | _, feat1, feat2 = model_fe(image, feat=True) 311 | if norm > 0: 312 | feat1 = F.normalize(feat1, p=2, dim=1) * norm 313 | feat2 = F.normalize(feat2, p=2, dim=1) * norm 314 | output = model_cls(feat1, feat2) 315 | 316 | loss = torch.mean(criterion(output, target).squeeze()) 317 | losses.update(loss.item(), image.size(0)) 318 | 319 | # measure accuracy 320 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 321 | iscorrect = torch.eq(pred, target) 322 | all_acc = iscorrect.float().mean() 323 | acc.update(all_acc.item(), image.size(0)) 324 | 325 | # measure base and novel accuracy 326 | n_base = isbase.long().sum() 327 | n_novel = isnovel.long().sum() 328 | n_overlap = isoverlap.long().sum() 329 | assert (n_base + n_novel + n_overlap) == image.size(0) 330 | b_acc = iscorrect[isbase].float().mean() 331 | n_acc = iscorrect[isnovel].float().mean() 332 | o_acc = iscorrect[isoverlap].float().mean() 333 | if n_base > 0: 334 | bacc.update(b_acc.item(), n_base) 335 | if n_novel > 0: 336 | nacc.update(n_acc.item(), n_novel) 337 | if n_overlap > 0: 338 | oacc.update(o_acc.item(), n_overlap) 339 | 340 | # other analysis 341 | b2n = (pred[~isnovel] >= n_sel_cls).float().mean() if (n_base + n_overlap) > 0 else one 342 | n2b = (pred[isnovel] < n_sel_cls).float().mean() if n_novel > 0 else one 343 | if (n_base + n_overlap) > 0: 344 | base2novel.update(b2n.item(), n_base + n_overlap) 345 | if n_novel > 0: 346 | novel2base.update(n2b.item(), n_novel) 347 | r_acc = ((1 - b2n) * (n_base + n_overlap) + (1 - n2b) * n_novel) / image.size(0) 348 | racc.update(r_acc.item(), image.size(0)) 349 | 350 | print_str = ( 351 | "[Val] Acc {acc.avg:.4f} " 352 | "Racc {racc.avg: .3f} " 353 | "Bacc {bacc.avg: .4f} " 354 | "Nacc {nacc.avg: .4f} " 355 | "Oacc {oacc.avg: .4f} " 356 | "Base2novel {b2n.avg:.3f} " 357 | "Novel2base {n2b.avg:.3f}".format( 358 | acc=acc, 359 | racc=racc, 360 | bacc=bacc, 361 | nacc=nacc, 362 | oacc=oacc, 363 | b2n=base2novel, 364 | n2b=novel2base, 365 | ) 366 | ) 367 | print(print_str) 368 | logger.info(print_str) 369 | 370 | writer.add_scalar("val/loss", losses.avg, epoch + 1) 371 | writer.add_scalar("val/acc", acc.avg, epoch + 1) 372 | writer.add_scalar("val/racc", racc.avg, epoch + 1) 373 | writer.add_scalar("val/bacc", bacc.avg, epoch + 1) 374 | writer.add_scalar("val/nacc", nacc.avg, epoch + 1) 375 | writer.add_scalar("val/oacc", oacc.avg, epoch + 1) 376 | writer.add_scalar("val/base2novel", base2novel.avg, epoch + 1) 377 | writer.add_scalar("val/novel2base", novel2base.avg, epoch + 1) 378 | 379 | 380 | if __name__ == "__main__": 381 | global cfg, args, writer, logger 382 | global alpha, dp, seed 383 | 384 | parser = argparse.ArgumentParser(description="config") 385 | parser.add_argument( 386 | "--config", 387 | nargs="?", 388 | type=str, 389 | default="configs/imnet_delta.yml", 390 | help="Configuration file to use", 391 | ) 392 | parser.add_argument( 393 | "--seed", 394 | type=int, 395 | default=None, 396 | help="random seed", 397 | ) 398 | parser.add_argument( 399 | "--alpha", 400 | type=float, 401 | default=None, 402 | help="weight for routing loss", 403 | ) 404 | parser.add_argument( 405 | "--dp", 406 | type=float, 407 | default=None, 408 | help="hn dropout rate", 409 | ) 410 | 411 | args = parser.parse_args() 412 | 413 | with open(args.config) as fp: 414 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 415 | 416 | seed = args.seed if args.seed else cfg.get("seed", 1) 417 | alpha = args.alpha if args.alpha else cfg.get("alpha", 0) 418 | dp = args.dp if args.dp else cfg.get("dp", 0) 419 | 420 | exp = cfg["exp"].format(alpha, dp, seed) 421 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], exp) 422 | writer = SummaryWriter(log_dir=logdir) 423 | 424 | print("RUNDIR: {}".format(logdir)) 425 | shutil.copy(args.config, logdir) 426 | 427 | logger = get_logger(logdir) 428 | logger.info("Start logging") 429 | 430 | print(args) 431 | logger.info(args) 432 | 433 | main() 434 | -------------------------------------------------------------------------------- /src/train_fusion_ms.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | import yaml 9 | import shutil 10 | import torch 11 | from torch import nn 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.functional as F 14 | 15 | from loader import get_dataloader 16 | from models import get_model 17 | from optimizers import get_optimizer, step_scheduler 18 | from metrics import averageMeter, accuracy 19 | from utils import get_logger, cvt2normal_state 20 | 21 | from tensorboardX import SummaryWriter 22 | 23 | 24 | def main(): 25 | # global norm, n_base_cls, n_novel_cls, n_overlap_cls, n_sel_cls 26 | global norm, n_base_cls, n_novel_cls, n_all_cls 27 | 28 | if not torch.cuda.is_available(): 29 | raise SystemExit("GPU is needed.") 30 | 31 | # setup random seed 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | 37 | # setup data loader 38 | splits = ["train", "val"] 39 | data_loader = get_dataloader(cfg["data"], splits, cfg["training"]["batch_size"], seed=seed) 40 | 41 | # config 42 | n_novel_cls = cfg["training"].get("n_novel_cls", 40) 43 | assert cfg["model"]["classifier"].get("n_overlap", 0) == 0 44 | n_base_cls = cfg["training"].get("n_base_cls", [800]) 45 | assert cfg["training"].get("n_sel_cls", n_base_cls) == n_base_cls 46 | n_all_cls = n_base_cls + [n_novel_cls] 47 | 48 | # setup model (feature extractor + classifier) 49 | n_gpu = torch.cuda.device_count() 50 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 51 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 52 | 53 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 54 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 55 | print("{} gpu(s) available.".format(n_gpu)) 56 | 57 | # loss function 58 | criterion = nn.CrossEntropyLoss(reduction="none") 59 | norm = cfg["training"].get("norm", -1) 60 | 61 | # setup optimizer 62 | opt_main_cls, opt_main_params = get_optimizer(cfg["training"]["optimizer_main"]) 63 | cnn_params = list(model_fe.parameters()) + list(model_cls.parameters()) 64 | opt_main = opt_main_cls(cnn_params, **opt_main_params) 65 | logger.info("Using optimizer {}".format(opt_main)) 66 | 67 | # setup scheduler 68 | scheduler = step_scheduler(opt_main, **cfg["training"]["scheduler"]) 69 | 70 | # load checkpoint 71 | start_ep = 0 72 | if cfg["training"]["resume"].get("models", []): 73 | models = cfg["training"]["resume"]["models"] 74 | models_nexist = [model for model in models if not os.path.isfile(model)] 75 | if models_nexist: 76 | print_str = "\n".join(["No checkpoint found at:"] + models_nexist) 77 | print(print_str) 78 | logger.info(print_str) 79 | import pdb 80 | 81 | pdb.set_trace() 82 | print() 83 | else: 84 | print_str = "\n".join(["Loading model from checkpoints:"] + models) 85 | print(print_str) 86 | logger.info(print_str) 87 | checkpoints = [torch.load(model) for model in models] 88 | model_fe.module.load_state_dict2(*[cvt2normal_state(c["model_fe_state"]) for c in checkpoints]) 89 | model_cls.module.load_state_dict2(*[cvt2normal_state(c["model_cls_state"]) for c in checkpoints]) 90 | logger.info("Loading classifier") 91 | 92 | print("Start training from epoch {}".format(start_ep)) 93 | logger.info("Start training from epoch {}".format(start_ep)) 94 | 95 | for ep in range(start_ep, cfg["training"]["epoch"]): 96 | 97 | train(data_loader["train"], model_fe, model_cls, opt_main, ep, criterion) 98 | 99 | if (ep + 1) % cfg["training"]["val_interval"] == 0: 100 | with torch.no_grad(): 101 | val(data_loader["val"], model_fe, model_cls, ep, criterion) 102 | 103 | if (ep + 1) % cfg["training"]["save_interval"] == 0: 104 | state = { 105 | "epoch": ep + 1, 106 | "model_fe_state": model_fe.state_dict(), 107 | "model_cls_state": model_cls.state_dict(), 108 | "opt_main_state": opt_main.state_dict(), 109 | "scheduler_state": scheduler.state_dict(), 110 | } 111 | ckpt_path = os.path.join(writer.file_writer.get_logdir(), "ep-{ep}_model.pkl") 112 | save_path = ckpt_path.format(ep=ep + 1) 113 | last_path = ckpt_path.format(ep=ep + 1 - cfg["training"]["save_interval"]) 114 | torch.save(state, save_path) 115 | if os.path.isfile(last_path): 116 | os.remove(last_path) 117 | print_str = "[Checkpoint]: {} saved".format(save_path) 118 | print(print_str) 119 | logger.info(print_str) 120 | 121 | scheduler.step() 122 | 123 | 124 | def train(data_loader, model_fe, model_cls, opt_main, epoch, criterion): 125 | # setup average meters 126 | batch_time = averageMeter() 127 | data_time = averageMeter() 128 | losses = averageMeter() 129 | acc = averageMeter() 130 | spacc = [averageMeter() for _ in n_all_cls] 131 | 132 | # setting training mode 133 | model_fe.train() 134 | model_cls.train() 135 | split_lookup = sum([[idx] * n for idx, n in enumerate(n_all_cls)], []) 136 | split_lookup = torch.tensor(split_lookup).long().cuda() 137 | 138 | n_step = int(len(data_loader.dataset) // float(data_loader.batch_size)) 139 | end = time.time() 140 | for (step, value) in enumerate(data_loader): 141 | # measure data loading time 142 | data_time.update(time.time() - end) 143 | 144 | image = value[0].cuda() 145 | target = value[1].cuda(non_blocking=True) 146 | split_ids = split_lookup[target] 147 | 148 | # split label 149 | if len(value) > 3: 150 | sp_lbl = value[3].cuda() 151 | raise NotImplementedError("Provided split label usage not implemented") 152 | else: 153 | sp_lbl = split_ids 154 | 155 | # forward 156 | _, *feats = model_fe(image, feat=True) 157 | if norm > 0: 158 | feats = [F.normalize(feat, p=2, dim=1) * norm for feat in feats] 159 | 160 | # hn dropout 161 | dp_scaling = cfg["training"].get("dp_scaling", True) 162 | # Design decision: only for newest, or all new? 163 | for feat in feats[1:]: 164 | dp_scaling_applicable = sp_lbl == 0 165 | if dp_scaling is True: 166 | feat[dp_scaling_applicable, :] = F.dropout(feat[dp_scaling_applicable, :], p=dp) * (1 - dp) 167 | elif dp_scaling == "no_dp": 168 | feat[dp_scaling_applicable, :] = feat[dp_scaling_applicable, :] * (1 - dp) 169 | else: 170 | assert not dp_scaling 171 | feat[dp_scaling_applicable, :] = F.dropout(feat[dp_scaling_applicable, :], p=dp) 172 | 173 | output, route_out = model_cls(*feats, route=True) 174 | 175 | # compute loss 176 | r_ce = criterion(route_out, sp_lbl).squeeze() 177 | ce_sp = [] 178 | for idx_split, n_split_cls in enumerate(n_all_cls): 179 | is_curr_split = split_ids == idx_split 180 | ce_sp.append(r_ce[is_curr_split].mean() if is_curr_split.any() else 0) 181 | rloss = sum(ce_sp) / len(ce_sp) 182 | closs = torch.mean(criterion(output, target).squeeze()) 183 | 184 | loss = (1 - alpha) * closs + alpha * rloss 185 | losses.update(loss.item(), image.size(0)) 186 | 187 | # measure accuracy 188 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 189 | iscorrect = torch.eq(pred, target) 190 | all_acc = iscorrect.float().mean() 191 | acc.update(all_acc.item(), image.size(0)) 192 | 193 | # measure base and novel accuracy 194 | n_split_samples = [] 195 | for idx_split, n_split_cls in enumerate(n_all_cls): 196 | is_curr_split = split_ids == idx_split 197 | n_split_sample = is_curr_split.long().sum() 198 | split_acc = iscorrect[is_curr_split].float().mean() 199 | if n_split_sample: 200 | spacc[idx_split].update(split_acc.item(), n_split_sample) 201 | n_split_samples.append(n_split_sample) 202 | assert sum(n_split_samples) == image.size(0) 203 | 204 | # back propagation 205 | opt_main.zero_grad() 206 | loss.backward() 207 | opt_main.step() 208 | 209 | # measure elapsed time 210 | batch_time.update(time.time() - end) 211 | end = time.time() 212 | 213 | if (step + 1) % 10 == 0: 214 | curr_lr_main = opt_main.param_groups[0]["lr"] 215 | print_str = ( 216 | "Epoch [{0}/{1}] " 217 | "Step: [{2}/{3}] " 218 | "LR: [{4}] " 219 | "Time {batch_time.avg:.3f} " 220 | "Data {data_time.avg:.3f} " 221 | "Loss {loss.avg:.4f} " 222 | "Acc {acc.avg:.3f} " 223 | "SpAcc {spacc} ".format( 224 | epoch + 1, 225 | cfg["training"]["epoch"], 226 | step + 1, 227 | n_step, 228 | curr_lr_main, 229 | batch_time=batch_time, 230 | data_time=data_time, 231 | loss=losses, 232 | acc=acc, 233 | spacc="/".join("{: .3f}".format(x.avg) for x in spacc), 234 | ) 235 | ) 236 | 237 | print(print_str) 238 | logger.info(print_str) 239 | 240 | if (epoch + 1) % cfg["training"]["print_interval"] == 0: 241 | curr_lr_main = opt_main.param_groups[0]["lr"] 242 | print_str = ( 243 | "Epoch: [{0}/{1}] " 244 | "LR: [{2}] " 245 | "Time {batch_time.avg:.3f} " 246 | "Data {data_time.avg:.3f} " 247 | "Loss {loss.avg:.4f} " 248 | "Acc {acc.avg:.3f} " 249 | "SpAcc {spacc}".format( 250 | epoch + 1, 251 | cfg["training"]["epoch"], 252 | curr_lr_main, 253 | batch_time=batch_time, 254 | data_time=data_time, 255 | loss=losses, 256 | acc=acc, 257 | spacc="/".join("{: .3f}".format(x.avg) for x in spacc), 258 | ) 259 | ) 260 | 261 | print(print_str) 262 | logger.info(print_str) 263 | writer.add_scalar("train/lr", curr_lr_main, epoch + 1) 264 | writer.add_scalar("train/loss", losses.avg, epoch + 1) 265 | writer.add_scalar("train/acc", acc.avg, epoch + 1) 266 | for idx, x in enumerate(spacc): 267 | writer.add_scalar("val/spacc/{}".format(idx), x.avg, epoch + 1) 268 | 269 | 270 | def val(data_loader, model_fe, model_cls, epoch, criterion): 271 | # setup average meters 272 | losses = averageMeter() 273 | racc = averageMeter() 274 | acc = averageMeter() 275 | spacc = [averageMeter() for _ in n_all_cls] 276 | base2novel = averageMeter() 277 | novel2base = averageMeter() 278 | 279 | # setting evaluation mode 280 | model_fe.eval() 281 | model_cls.eval() 282 | split_lookup = sum([[idx] * n for idx, n in enumerate(n_all_cls)], []) 283 | split_lookup = torch.tensor(split_lookup).long().cuda() 284 | 285 | one = torch.tensor([1]).cuda() 286 | for (step, value) in enumerate(data_loader): 287 | 288 | image = value[0].cuda() 289 | target = value[1].cuda(non_blocking=True) 290 | split_ids = split_lookup[target] 291 | 292 | # forward 293 | _, *feats = model_fe(image, feat=True) 294 | if norm > 0: 295 | feats = [F.normalize(feat, p=2, dim=1) * norm for feat in feats] 296 | output = model_cls(*feats) 297 | 298 | loss = torch.mean(criterion(output, target).squeeze()) 299 | losses.update(loss.item(), image.size(0)) 300 | 301 | # measure accuracy 302 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 303 | iscorrect = torch.eq(pred, target) 304 | all_acc = iscorrect.float().mean() 305 | acc.update(all_acc.item(), image.size(0)) 306 | 307 | # measure base and novel accuracy 308 | n_split_samples = [] 309 | for idx_split, n_split_cls in enumerate(n_all_cls): 310 | is_curr_split = split_ids == idx_split 311 | n_split_sample = is_curr_split.long().sum() 312 | split_acc = iscorrect[is_curr_split].float().mean() 313 | if n_split_sample: 314 | spacc[idx_split].update(split_acc.item(), n_split_sample) 315 | n_split_samples.append(n_split_sample) 316 | assert sum(n_split_samples) == image.size(0) 317 | 318 | # other analysis 319 | n_splits = len(n_all_cls) 320 | pred_split = split_lookup[pred] 321 | n_split_samples_prv = sum(n_split_samples[:-1]) 322 | 323 | b2n = (pred_split[split_ids < n_splits - 1] == n_splits - 1).float().mean() if n_split_samples_prv > 0 else one 324 | n2b = (pred_split[split_ids == n_splits - 1] < n_splits - 1).float().mean() if n_split_samples[-1] > 0 else one 325 | if n_split_samples_prv > 0: 326 | base2novel.update(b2n.item(), n_split_samples_prv) 327 | if n_split_samples[-1] > 0: 328 | novel2base.update(n2b.item(), n_split_samples[-1]) 329 | r_acc = (pred_split == split_ids).float().mean() 330 | racc.update(r_acc.item(), image.size(0)) 331 | 332 | print_str = ( 333 | "[Val] Acc {acc.avg:.4f} " 334 | "Racc {racc.avg: .3f} " 335 | "SPacc {spacc} " 336 | "Base2novel {b2n.avg:.3f} " 337 | "Novel2base {n2b.avg:.3f}".format( 338 | acc=acc, 339 | racc=racc, 340 | b2n=base2novel, 341 | n2b=novel2base, 342 | spacc="/".join("{: .4f}".format(x.avg) for x in spacc), 343 | ) 344 | ) 345 | print(print_str) 346 | logger.info(print_str) 347 | 348 | writer.add_scalar("val/loss", losses.avg, epoch + 1) 349 | writer.add_scalar("val/acc", acc.avg, epoch + 1) 350 | writer.add_scalar("val/racc", racc.avg, epoch + 1) 351 | for idx, x in enumerate(spacc): 352 | writer.add_scalar("val/spacc/{}".format(idx), x.avg, epoch + 1) 353 | writer.add_scalar("val/base2novel", base2novel.avg, epoch + 1) 354 | writer.add_scalar("val/novel2base", novel2base.avg, epoch + 1) 355 | 356 | 357 | if __name__ == "__main__": 358 | global cfg, args, writer, logger 359 | global alpha, dp, seed 360 | 361 | parser = argparse.ArgumentParser(description="config") 362 | parser.add_argument( 363 | "--config", 364 | nargs="?", 365 | type=str, 366 | default="configs/imnet_delta.yml", 367 | help="Configuration file to use", 368 | ) 369 | parser.add_argument( 370 | "--seed", 371 | type=int, 372 | default=None, 373 | help="random seed", 374 | ) 375 | parser.add_argument( 376 | "--alpha", 377 | type=float, 378 | default=None, 379 | help="weight for routing loss", 380 | ) 381 | parser.add_argument( 382 | "--dp", 383 | type=float, 384 | default=None, 385 | help="hn dropout rate", 386 | ) 387 | 388 | args = parser.parse_args() 389 | 390 | with open(args.config) as fp: 391 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 392 | 393 | seed = args.seed if args.seed else cfg.get("seed", 1) 394 | alpha = args.alpha if args.alpha else cfg.get("alpha", 0) 395 | dp = args.dp if args.dp else cfg.get("dp", 0) 396 | 397 | exp = cfg["exp"].format(alpha, dp, seed) 398 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], exp) 399 | writer = SummaryWriter(log_dir=logdir) 400 | 401 | print("RUNDIR: {}".format(logdir)) 402 | shutil.copy(args.config, logdir) 403 | 404 | logger = get_logger(logdir) 405 | logger.info("Start logging") 406 | 407 | print(args) 408 | logger.info(args) 409 | 410 | main() 411 | -------------------------------------------------------------------------------- /src/train_novel.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import argparse 6 | import os 7 | import yaml 8 | import shutil 9 | import torch 10 | from torch import nn 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.functional as F 13 | 14 | from loader import get_dataloader 15 | from models import get_model 16 | from optimizers import get_optimizer, step_scheduler 17 | from metrics import averageMeter, accuracy 18 | from utils import get_logger, cvt2normal_state 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | 23 | def main(): 24 | global norm 25 | 26 | if not torch.cuda.is_available(): 27 | raise SystemExit("GPU is needed.") 28 | 29 | # setup random seed 30 | torch.manual_seed(cfg.get("seed", 1)) 31 | torch.cuda.manual_seed(cfg.get("seed", 1)) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | # setup data loader 36 | splits = ["train", "val"] 37 | data_loader = get_dataloader(cfg["data"], splits, cfg["training"]["batch_size"]) 38 | 39 | # setup model (feature extractor + classifier) 40 | n_gpu = torch.cuda.device_count() 41 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 42 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 43 | 44 | model_cls = get_model(cfg["model"]["classifier"], verbose=True).cuda() 45 | model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu)) 46 | print("{} gpu(s) available.".format(n_gpu)) 47 | 48 | # loss function 49 | criterion = nn.CrossEntropyLoss(reduction="none") 50 | norm = cfg["training"].get("norm", -1) 51 | 52 | # setup optimizer 53 | opt_main_cls, opt_main_params = get_optimizer(cfg["training"]["optimizer_main"]) 54 | params = [x for x in model_fe.parameters() if x.requires_grad] + list(model_cls.parameters()) 55 | print("Trainable parameters: {}".format(sum([x.flatten().shape[0] for x in params]))) 56 | logger.info("Trainable parameters: {}".format(sum([x.flatten().shape[0] for x in params]))) 57 | opt_main = opt_main_cls(params, **opt_main_params) 58 | logger.info("Using optimizer {}".format(opt_main)) 59 | 60 | # setup scheduler 61 | scheduler = step_scheduler(opt_main, **cfg["training"]["scheduler"]) 62 | 63 | # load checkpoint 64 | start_ep = 0 65 | if cfg["training"]["resume"].get("model", None): 66 | resume = cfg["training"]["resume"] 67 | if os.path.isfile(resume["model"]): 68 | print("Loading model from checkpoint '{}'".format(resume["model"])) 69 | logger.info("Loading model from checkpoint '{}'".format(resume["model"])) 70 | checkpoint = torch.load(resume["model"]) 71 | model_fe.module.load_state_dict(cvt2normal_state(checkpoint["model_fe_state"])) 72 | if resume.get("load_cls", False): 73 | model_cls.module.load_state_dict(cvt2normal_state(checkpoint["model_cls_state"])) 74 | logger.info("Loading classifier") 75 | if resume["param_only"] is False: 76 | start_ep = checkpoint["epoch"] 77 | opt_main.load_state_dict(checkpoint["opt_main_state"]) 78 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 79 | logger.info("Loaded checkpoint '{}' (iter {})".format(resume["model"], checkpoint["epoch"])) 80 | else: 81 | print("No checkpoint found at '{}'".format(resume["model"])) 82 | logger.info("No checkpoint found at '{}'".format(resume["model"])) 83 | 84 | print("Start training from epoch {}".format(start_ep)) 85 | logger.info("Start training from epoch {}".format(start_ep)) 86 | 87 | for ep in range(start_ep, cfg["training"]["epoch"]): 88 | 89 | train(data_loader["train"], model_fe, model_cls, opt_main, ep, criterion) 90 | 91 | if (ep + 1) % cfg["training"]["val_interval"] == 0: 92 | with torch.no_grad(): 93 | val(data_loader["val"], model_fe, model_cls, ep, criterion) 94 | 95 | if (ep + 1) % cfg["training"]["save_interval"] == 0: 96 | state = { 97 | "epoch": ep + 1, 98 | "model_fe_state": model_fe.state_dict(), 99 | "model_cls_state": model_cls.state_dict(), 100 | "opt_main_state": opt_main.state_dict(), 101 | "scheduler_state": scheduler.state_dict(), 102 | } 103 | ckpt_path = os.path.join(writer.file_writer.get_logdir(), "ep-{ep}_model.pkl") 104 | save_path = ckpt_path.format(ep=ep + 1) 105 | last_path = ckpt_path.format(ep=ep + 1 - cfg["training"]["save_interval"]) 106 | torch.save(state, save_path) 107 | if os.path.isfile(last_path): 108 | os.remove(last_path) 109 | print_str = "[Checkpoint]: {} saved".format(save_path) 110 | print(print_str) 111 | logger.info(print_str) 112 | 113 | scheduler.step() 114 | 115 | 116 | def train(data_loader, model_fe, model_cls, opt_main, epoch, criterion): 117 | 118 | # setup average meters 119 | batch_time = averageMeter() 120 | data_time = averageMeter() 121 | losses = averageMeter() 122 | top1 = averageMeter() 123 | top5 = averageMeter() 124 | 125 | # setting training mode 126 | model_fe.train() 127 | model_cls.train() 128 | 129 | n_step = int(len(data_loader.dataset) // float(data_loader.batch_size)) 130 | end = time.time() 131 | for (step, value) in enumerate(data_loader): 132 | 133 | # measure data loading time 134 | data_time.update(time.time() - end) 135 | 136 | image = value[0].cuda() 137 | target = value[1].cuda(non_blocking=True) 138 | 139 | # forward 140 | imfeat = model_fe(image) 141 | if norm > 0: 142 | imfeat = F.normalize(imfeat, p=2, dim=1) * norm 143 | output = model_cls(imfeat) 144 | 145 | loss = torch.mean(criterion(output, target).squeeze()) 146 | losses.update(loss.item(), image.size(0)) 147 | 148 | # measure accuracy 149 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 150 | top1.update(prec1[0], image.size(0)) 151 | top5.update(prec5[0], image.size(0)) 152 | 153 | # back propagation 154 | opt_main.zero_grad() 155 | loss.backward() 156 | opt_main.step() 157 | 158 | # measure elapsed time 159 | batch_time.update(time.time() - end) 160 | end = time.time() 161 | 162 | if (step + 1) % 10 == 0: 163 | curr_lr_main = opt_main.param_groups[0]["lr"] 164 | print_str = ( 165 | "Epoch [{0}/{1}] " 166 | "Step: [{2}/{3}] " 167 | "LR: [{4}] " 168 | "Time {batch_time.avg:.3f} " 169 | "Data {data_time.avg:.3f} " 170 | "Loss {loss.avg:.4f} " 171 | "Top1 {top1.avg:.3f} " 172 | "Top5 {top5.avg:.3f}".format( 173 | epoch + 1, 174 | cfg["training"]["epoch"], 175 | step + 1, 176 | n_step, 177 | curr_lr_main, 178 | batch_time=batch_time, 179 | data_time=data_time, 180 | loss=losses, 181 | top1=top1, 182 | top5=top5, 183 | ) 184 | ) 185 | 186 | print(print_str) 187 | logger.info(print_str) 188 | 189 | if (epoch + 1) % cfg["training"]["print_interval"] == 0: 190 | curr_lr_main = opt_main.param_groups[0]["lr"] 191 | print_str = ( 192 | "Epoch: [{0}/{1}] " 193 | "LR: [{2}] " 194 | "Time {batch_time.avg:.3f} " 195 | "Data {data_time.avg:.3f} " 196 | "Loss {loss.avg:.4f} " 197 | "Top1 {top1.avg:.3f} " 198 | "Top5 {top5.avg:.3f}".format( 199 | epoch + 1, 200 | cfg["training"]["epoch"], 201 | curr_lr_main, 202 | batch_time=batch_time, 203 | data_time=data_time, 204 | loss=losses, 205 | top1=top1, 206 | top5=top5, 207 | ) 208 | ) 209 | 210 | print(print_str) 211 | logger.info(print_str) 212 | writer.add_scalar("train/lr", curr_lr_main, epoch + 1) 213 | writer.add_scalar("train/loss", losses.avg, epoch + 1) 214 | writer.add_scalar("train/top1", top1.avg, epoch + 1) 215 | writer.add_scalar("train/top5", top5.avg, epoch + 1) 216 | 217 | 218 | def val(data_loader, model_fe, model_cls, epoch, criterion): 219 | 220 | # setup average meters 221 | losses = averageMeter() 222 | top1 = averageMeter() 223 | top5 = averageMeter() 224 | 225 | # setting evaluation mode 226 | model_fe.eval() 227 | model_cls.eval() 228 | 229 | for (step, value) in enumerate(data_loader): 230 | 231 | image = value[0].cuda() 232 | target = value[1].cuda(non_blocking=True) 233 | 234 | # forward 235 | imfeat = model_fe(image) 236 | if norm > 0: 237 | imfeat = F.normalize(imfeat, p=2, dim=1) * norm 238 | output = model_cls(imfeat) 239 | 240 | loss = torch.mean(criterion(output, target).squeeze()) 241 | losses.update(loss.item(), image.size(0)) 242 | 243 | # measure accuracy 244 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 245 | top1.update(prec1[0], image.size(0)) 246 | top5.update(prec5[0], image.size(0)) 247 | 248 | print_str = "[Val] Prec@1 {top1.avg:.3f}\t Prec@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 249 | print(print_str) 250 | logger.info(print_str) 251 | 252 | writer.add_scalar("val/loss", losses.avg, epoch + 1) 253 | writer.add_scalar("val/top1", top1.avg, epoch + 1) 254 | writer.add_scalar("val/top5", top5.avg, epoch + 1) 255 | 256 | 257 | if __name__ == "__main__": 258 | global cfg, args, writer, logger 259 | 260 | parser = argparse.ArgumentParser(description="config") 261 | parser.add_argument( 262 | "--config", 263 | nargs="?", 264 | type=str, 265 | default="configs/imnet_novel.yml", 266 | help="Configuration file to use", 267 | ) 268 | 269 | args = parser.parse_args() 270 | 271 | with open(args.config) as fp: 272 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 273 | 274 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], cfg["exp"]) 275 | writer = SummaryWriter(log_dir=logdir) 276 | 277 | print("RUNDIR: {}".format(logdir)) 278 | shutil.copy(args.config, logdir) 279 | 280 | logger = get_logger(logdir) 281 | logger.info("Start logging") 282 | 283 | print(args) 284 | logger.info(args) 285 | 286 | main() 287 | -------------------------------------------------------------------------------- /src/train_route.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import argparse 6 | import os 7 | import yaml 8 | import shutil 9 | import torch 10 | from torch import nn 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.functional as F 13 | 14 | from loader import get_dataloader 15 | from models import get_model 16 | from optimizers import get_optimizer, step_scheduler 17 | from metrics import averageMeter, accuracy 18 | from utils import get_logger, cvt2normal_state 19 | 20 | from tensorboardX import SummaryWriter 21 | 22 | 23 | def main(): 24 | global norm, n_base_cls, n_novel_cls, n_overlap_cls, n_sel_cls, loss_type 25 | 26 | if not torch.cuda.is_available(): 27 | raise SystemExit("GPU is needed.") 28 | 29 | # setup random seed 30 | torch.manual_seed(cfg.get("seed", 1)) 31 | torch.cuda.manual_seed(cfg.get("seed", 1)) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | # setup data loader 36 | splits = ["train", "val"] 37 | data_loader = get_dataloader(cfg["data"], splits, cfg["training"]["batch_size"]) 38 | 39 | # config 40 | n_base_cls = cfg["model"]["base_classifier"].get("n_class", 800) 41 | n_novel_cls = cfg["model"]["novel_classifier"].get("n_class", 40) 42 | n_overlap_cls = cfg["training"].get("n_overlap_cls", 0) 43 | n_sel_cls = cfg["training"].get("n_sel_cls", n_base_cls) 44 | 45 | # setup model (feature extractor + classifier) 46 | n_gpu = torch.cuda.device_count() 47 | model_fe = get_model(cfg["model"]["feature_extractor"], verbose=True).cuda() 48 | model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu)) 49 | 50 | model_rcls = get_model(cfg["model"]["route_classifier"], verbose=True).cuda() 51 | model_rcls = nn.DataParallel(model_rcls, device_ids=range(n_gpu)) 52 | 53 | model_bcls = get_model(cfg["model"]["base_classifier"], verbose=True).cuda() 54 | model_bcls = nn.DataParallel(model_bcls, device_ids=range(n_gpu)) 55 | 56 | model_ncls = get_model(cfg["model"]["novel_classifier"], verbose=True).cuda() 57 | model_ncls = nn.DataParallel(model_ncls, device_ids=range(n_gpu)) 58 | print("{} gpu(s) available.".format(n_gpu)) 59 | 60 | # loss function 61 | criterion = nn.CrossEntropyLoss(reduction="none") 62 | norm = cfg["training"].get("norm", -1) 63 | loss_type = cfg["training"].get("loss_type", "balanced") 64 | 65 | # setup optimizer 66 | opt_main_cls, opt_main_params = get_optimizer(cfg["training"]["optimizer_main"]) 67 | cnn_params = list(model_fe.parameters()) + list(model_rcls.parameters()) 68 | opt_main = opt_main_cls(cnn_params, **opt_main_params) 69 | logger.info("Using optimizer {}".format(opt_main)) 70 | 71 | # setup scheduler 72 | scheduler = step_scheduler(opt_main, **cfg["training"]["scheduler"]) 73 | 74 | # load checkpoint 75 | start_ep = 0 76 | if cfg["training"]["resume"].get("model1", False) and cfg["training"]["resume"].get("model2", False): 77 | model1 = cfg["training"]["resume"]["model1"] 78 | model2 = cfg["training"]["resume"]["model2"] 79 | if not os.path.isfile(model1): 80 | print("No checkpoint found at '{}'".format(model1)) 81 | logger.info("No checkpoint found at '{}'".format(model1)) 82 | elif not os.path.isfile(model2): 83 | print("No checkpoint found at '{}'".format(model2)) 84 | logger.info("No checkpoint found at '{}'".format(model2)) 85 | else: 86 | print("Loading model from checkpoint '{}' and '{}'".format(model1, model2)) 87 | logger.info("Loading model from checkpoint '{}' and '{}'".format(model1, model2)) 88 | checkpoint1 = torch.load(model1) 89 | checkpoint2 = torch.load(model2) 90 | model_fe.module.load_state_dict2( 91 | cvt2normal_state(checkpoint1["model_fe_state"]), 92 | cvt2normal_state(checkpoint2["model_fe_state"]), 93 | ) 94 | model_bcls.module.load_state_dict(cvt2normal_state(checkpoint1["model_cls_state"])) 95 | model_ncls.module.load_state_dict(cvt2normal_state(checkpoint2["model_cls_state"])) 96 | logger.info("Loading classifier") 97 | 98 | print("Start training from epoch {}".format(start_ep)) 99 | logger.info("Start training from epoch {}".format(start_ep)) 100 | 101 | for ep in range(start_ep, cfg["training"]["epoch"]): 102 | 103 | train(data_loader["train"], model_fe, model_rcls, opt_main, ep, criterion) 104 | 105 | if (ep + 1) % cfg["training"]["val_interval"] == 0: 106 | with torch.no_grad(): 107 | val( 108 | data_loader["val"], 109 | model_fe, 110 | model_rcls, 111 | model_bcls, 112 | model_ncls, 113 | ep, 114 | criterion, 115 | ) 116 | 117 | if (ep + 1) % cfg["training"]["save_interval"] == 0: 118 | state = { 119 | "epoch": ep + 1, 120 | "model_fe_state": model_fe.state_dict(), 121 | "model_rcls_state": model_rcls.state_dict(), 122 | "model_bcls_state": model_bcls.state_dict(), 123 | "model_ncls_state": model_ncls.state_dict(), 124 | "opt_main_state": opt_main.state_dict(), 125 | "scheduler_state": scheduler.state_dict(), 126 | } 127 | ckpt_path = os.path.join(writer.file_writer.get_logdir(), "ep-{ep}_model.pkl") 128 | save_path = ckpt_path.format(ep=ep + 1) 129 | last_path = ckpt_path.format(ep=ep + 1 - cfg["training"]["save_interval"]) 130 | torch.save(state, save_path) 131 | if os.path.isfile(last_path): 132 | os.remove(last_path) 133 | print_str = "[Checkpoint]: {} saved".format(save_path) 134 | print(print_str) 135 | logger.info(print_str) 136 | 137 | scheduler.step() 138 | 139 | 140 | def train(data_loader, model_fe, model_rcls, opt_main, epoch, criterion): 141 | 142 | # setup average meters 143 | batch_time = averageMeter() 144 | data_time = averageMeter() 145 | losses = averageMeter() 146 | racc = averageMeter() 147 | base2novel = averageMeter() 148 | novel2base = averageMeter() 149 | 150 | # setting training mode 151 | model_fe.train() 152 | model_rcls.train() 153 | 154 | one = torch.tensor([1]).cuda() 155 | n_step = int(len(data_loader.dataset) // float(data_loader.batch_size)) 156 | end = time.time() 157 | for (step, value) in enumerate(data_loader): 158 | if step == n_step: 159 | break 160 | # measure data loading time 161 | data_time.update(time.time() - end) 162 | 163 | image = value[0].cuda() 164 | target = value[1].cuda(non_blocking=True) 165 | isnovel = target >= n_base_cls 166 | 167 | # split label 168 | if len(value) > 3: 169 | sp_lbl = value[3].cuda() 170 | else: 171 | sp_lbl = isnovel.long() 172 | 173 | # forward 174 | _, feat1, feat2 = model_fe(image, feat=True) 175 | if norm > 0: 176 | feat1 = F.normalize(feat1, p=2, dim=1) * norm 177 | feat2 = F.normalize(feat2, p=2, dim=1) * norm 178 | imfeat = torch.cat([feat1, feat2], dim=1) 179 | output = model_rcls(imfeat) 180 | 181 | # compute loss 182 | if loss_type == "reweight": 183 | n_class = float(n_base_cls + n_novel_cls) 184 | weight = torch.tensor([n_class / n_base_cls, n_class / n_novel_cls]).cuda() 185 | alpha = weight.gather(0, sp_lbl.data) 186 | loss = torch.mean(criterion(output, sp_lbl).squeeze() * alpha) 187 | 188 | elif loss_type == "balanced": 189 | r_ce = criterion(output, sp_lbl).squeeze() 190 | ce_b = r_ce[(sp_lbl == 0)].mean() 191 | ce_n = r_ce[(sp_lbl > 0)].mean() 192 | loss = (ce_b + ce_n) / 2 193 | 194 | else: 195 | loss = torch.mean(criterion(output, sp_lbl).squeeze()) 196 | 197 | losses.update(loss.item(), image.size(0)) 198 | 199 | # measure accuracy 200 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 201 | iscorrect = torch.eq(pred, sp_lbl) 202 | r_acc = iscorrect.float().mean() 203 | racc.update(r_acc.item(), image.size(0)) 204 | 205 | # other analysis 206 | n_novel = isnovel.long().sum() 207 | n_base = image.size(0) - n_novel 208 | b2n = ((~isnovel) * (~iscorrect)).float().sum() / n_base if n_base > 0 else one 209 | n2b = (isnovel * (~iscorrect)).float().sum() / n_novel if n_novel > 0 else one 210 | if n_base > 0: 211 | base2novel.update(b2n.item(), n_base) 212 | if n_novel > 0: 213 | novel2base.update(n2b.item(), n_novel) 214 | 215 | # back propagation 216 | opt_main.zero_grad() 217 | loss.backward() 218 | opt_main.step() 219 | 220 | # measure elapsed time 221 | batch_time.update(time.time() - end) 222 | end = time.time() 223 | 224 | if (step + 1) % 10 == 0: 225 | curr_lr_main = opt_main.param_groups[0]["lr"] 226 | print_str = ( 227 | "Epoch [{0}/{1}] " 228 | "Step: [{2}/{3}] " 229 | "LR: [{4}] " 230 | "Time {batch_time.avg:.3f} " 231 | "Data {data_time.avg:.3f} " 232 | "Loss {loss.avg:.4f} " 233 | "Acc {acc.avg:.3f}".format( 234 | epoch + 1, 235 | cfg["training"]["epoch"], 236 | step + 1, 237 | n_step, 238 | curr_lr_main, 239 | batch_time=batch_time, 240 | data_time=data_time, 241 | loss=losses, 242 | acc=racc, 243 | ) 244 | ) 245 | 246 | print(print_str) 247 | logger.info(print_str) 248 | 249 | if (epoch + 1) % cfg["training"]["print_interval"] == 0: 250 | curr_lr_main = opt_main.param_groups[0]["lr"] 251 | print_str = ( 252 | "Epoch: [{0}/{1}] " 253 | "LR: [{2}] " 254 | "Time {batch_time.avg:.3f} " 255 | "Data {data_time.avg:.3f} " 256 | "Loss {loss.avg:.4f} " 257 | "Acc {acc.avg:.3f}".format( 258 | epoch + 1, 259 | cfg["training"]["epoch"], 260 | curr_lr_main, 261 | batch_time=batch_time, 262 | data_time=data_time, 263 | loss=losses, 264 | acc=racc, 265 | ) 266 | ) 267 | 268 | print(print_str) 269 | logger.info(print_str) 270 | writer.add_scalar("train/lr", curr_lr_main, epoch + 1) 271 | writer.add_scalar("train/loss", losses.avg, epoch + 1) 272 | writer.add_scalar("train/racc", racc.avg, epoch + 1) 273 | writer.add_scalar("train/base2novel", base2novel.avg, epoch + 1) 274 | writer.add_scalar("train/novel2base", novel2base.avg, epoch + 1) 275 | 276 | 277 | def val(data_loader, model_fe, model_rcls, model_bcls, model_ncls, epoch, criterion): 278 | 279 | # setup average meters 280 | losses = averageMeter() 281 | racc = averageMeter() 282 | acc = averageMeter() 283 | bacc = averageMeter() 284 | nacc = averageMeter() 285 | oacc = averageMeter() 286 | base2novel = averageMeter() 287 | novel2base = averageMeter() 288 | 289 | # setting evaluation mode 290 | model_fe.eval() 291 | model_rcls.eval() 292 | model_bcls.eval() 293 | model_ncls.eval() 294 | 295 | zero = torch.tensor([0]).cuda() 296 | one = torch.tensor([1]).cuda() 297 | for (step, value) in enumerate(data_loader): 298 | 299 | image = value[0].cuda() 300 | target = value[1].cuda(non_blocking=True) 301 | isnovel = target >= n_base_cls 302 | isoverlap = target < n_overlap_cls 303 | isbase = (~isnovel) * (~isoverlap) 304 | target[isnovel] -= n_base_cls # for partial selection (e.g. 40 / 800) 305 | 306 | # split label 307 | if len(value) > 3: 308 | sp_lbl = value[3].cuda() 309 | else: 310 | sp_lbl = isnovel.long() 311 | 312 | # forward 313 | _, feat1, feat2 = model_fe(image, feat=True) 314 | if norm > 0: 315 | feat1 = F.normalize(feat1, p=2, dim=1) * norm 316 | feat2 = F.normalize(feat2, p=2, dim=1) * norm 317 | imfeat = torch.cat([feat1, feat2], dim=1) 318 | output = model_rcls(imfeat) 319 | out1 = model_bcls(feat1) 320 | out2 = model_ncls(feat2) 321 | 322 | loss = torch.mean(criterion(output, sp_lbl).squeeze()) 323 | losses.update(loss.item(), image.size(0)) 324 | 325 | # measure routing accuracy 326 | conf, pred = torch.max(torch.softmax(output, dim=1), dim=1) 327 | iscorrect = torch.eq(pred, sp_lbl) 328 | r_acc = iscorrect.float().mean() 329 | racc.update(r_acc.item(), image.size(0)) 330 | 331 | # other analysis 332 | n_novel = isnovel.long().sum() 333 | n_base = image.size(0) - n_novel 334 | b2n = ((~isnovel) * (~iscorrect)).float().sum() / n_base if n_base > 0 else one 335 | n2b = (isnovel * (~iscorrect)).float().sum() / n_novel if n_novel > 0 else one 336 | if n_base > 0: 337 | base2novel.update(b2n.item(), n_base) 338 | if n_novel > 0: 339 | novel2base.update(n2b.item(), n_novel) 340 | 341 | # measure classification accuracy 342 | n_base = isbase.long().sum() 343 | b_idx = isbase * (pred == 0) 344 | _, pred1 = torch.max(torch.softmax(out1[:, :n_sel_cls], dim=1), dim=1) 345 | b_iscorrect = torch.eq(pred1[b_idx], target[b_idx]) 346 | b_acc = b_iscorrect.float().sum() / n_base if n_base > 0 else zero 347 | 348 | n_idx = isnovel * (pred == 1) 349 | _, pred2 = torch.max(torch.softmax(out2, dim=1), dim=1) 350 | n_iscorrect = torch.eq(pred2[n_idx], target[n_idx] + n_overlap_cls) 351 | n_acc = n_iscorrect.float().sum() / n_novel if n_novel > 0 else zero 352 | 353 | n_overlap = isoverlap.long().sum() 354 | o_idx1 = isoverlap * (pred == 0) 355 | o_iscorrect1 = torch.eq(pred1[o_idx1], target[o_idx1]) 356 | o_idx2 = isoverlap * (pred == 1) 357 | o_iscorrect2 = torch.eq(pred2[o_idx2], target[o_idx2]) 358 | o_acc = (o_iscorrect1.float().sum() + o_iscorrect2.float().sum()) / n_overlap if n_overlap > 0 else zero 359 | 360 | assert (n_base + n_novel + n_overlap) == image.size(0) 361 | all_acc = (b_acc * n_base + n_acc * n_novel + o_acc * n_overlap) / image.size(0) 362 | if n_base > 0: 363 | bacc.update(b_acc.item(), n_base) 364 | if n_novel > 0: 365 | nacc.update(n_acc.item(), n_novel) 366 | if n_overlap > 0: 367 | oacc.update(o_acc.item(), n_overlap) 368 | acc.update(all_acc.item(), image.size(0)) 369 | 370 | print_str = ( 371 | "[Val] Acc {acc.avg:.4f} " 372 | "Racc {racc.avg: .3f} " 373 | "Bacc {bacc.avg: .4f} " 374 | "Nacc {nacc.avg: .4f} " 375 | "Oacc {oacc.avg: .4f} " 376 | "Base2novel {b2n.avg:.3f} " 377 | "Novel2base {n2b.avg:.3f}".format( 378 | acc=acc, 379 | racc=racc, 380 | bacc=bacc, 381 | nacc=nacc, 382 | oacc=oacc, 383 | b2n=base2novel, 384 | n2b=novel2base, 385 | ) 386 | ) 387 | print(print_str) 388 | logger.info(print_str) 389 | 390 | writer.add_scalar("val/loss", losses.avg, epoch + 1) 391 | writer.add_scalar("val/acc", acc.avg, epoch + 1) 392 | writer.add_scalar("val/racc", racc.avg, epoch + 1) 393 | writer.add_scalar("val/bacc", bacc.avg, epoch + 1) 394 | writer.add_scalar("val/nacc", nacc.avg, epoch + 1) 395 | writer.add_scalar("val/oacc", oacc.avg, epoch + 1) 396 | writer.add_scalar("val/base2novel", base2novel.avg, epoch + 1) 397 | writer.add_scalar("val/novel2base", novel2base.avg, epoch + 1) 398 | 399 | 400 | if __name__ == "__main__": 401 | global cfg, args, writer, logger 402 | 403 | parser = argparse.ArgumentParser(description="config") 404 | parser.add_argument( 405 | "--config", 406 | nargs="?", 407 | type=str, 408 | default="configs/imnet_base.yml", 409 | help="Configuration file to use", 410 | ) 411 | 412 | args = parser.parse_args() 413 | 414 | with open(args.config) as fp: 415 | cfg = yaml.load(fp, Loader=yaml.SafeLoader) 416 | 417 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], cfg["exp"]) 418 | writer = SummaryWriter(log_dir=logdir) 419 | 420 | print("RUNDIR: {}".format(logdir)) 421 | shutil.copy(args.config, logdir) 422 | 423 | logger = get_logger(logdir) 424 | logger.info("Start logging") 425 | 426 | print(args) 427 | logger.info(args) 428 | 429 | main() 430 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import logging 6 | import datetime 7 | from collections import OrderedDict 8 | 9 | 10 | def adjust_learning_rate(optimizer, decay_rate=0.9): 11 | for param_group in optimizer.param_groups: 12 | param_group["lr"] = param_group["lr"] * decay_rate 13 | 14 | 15 | def assign_learning_rate(optimizer, lr=0.1): 16 | for param_group in optimizer.param_groups: 17 | param_group["lr"] = lr 18 | 19 | 20 | def add_weight_decay(params, l2_value, skip_list=()): 21 | decay, no_decay = [], [] 22 | for name, param in params: 23 | if not param.requires_grad: 24 | continue # frozen weights 25 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 26 | no_decay.append(param) 27 | else: 28 | decay.append(param) 29 | return [ 30 | {"params": no_decay, "weight_decay": 0.0}, 31 | {"params": decay, "weight_decay": l2_value}, 32 | ] 33 | 34 | 35 | def get_logger(logdir): 36 | """Function to build the logger.""" 37 | logger = logging.getLogger("mylogger") 38 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 39 | ts = ts.replace(":", "_").replace("-", "_") 40 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 41 | hdlr = logging.FileHandler(file_path) 42 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 43 | hdlr.setFormatter(formatter) 44 | logger.addHandler(hdlr) 45 | logger.setLevel(logging.INFO) 46 | return logger 47 | 48 | 49 | def cvt2normal_state(state_dict): 50 | """Converts a state dict saved from a dataParallel module to normal 51 | module state_dict inplace, i.e. removing "module" in the string. 52 | """ 53 | new_state_dict = OrderedDict() 54 | for name, param in state_dict.items(): 55 | name = name.replace("module.", "") 56 | new_state_dict[name] = param 57 | return new_state_dict 58 | --------------------------------------------------------------------------------