├── .gitignore ├── LADE_poster.pdf ├── LICENSE ├── README.md ├── config ├── CIFAR100_LT │ ├── balanced_softmax.yaml │ ├── causal.yaml │ ├── ce.yaml │ └── lade.yaml ├── ImageNet_LT │ ├── balanced_softmax.yaml │ ├── causal.yaml │ ├── ce.yaml │ └── lade.yaml ├── Places_LT │ ├── balanced_softmax.yaml │ ├── causal.yaml │ ├── ce.yaml │ └── lade.yaml └── iNaturalist18 │ ├── balanced_softmax.yaml │ ├── causal.yaml │ ├── ce.yaml │ └── lade.yaml ├── data ├── ClassAwareSampler.py ├── ClassPrioritySampler.py ├── ImageNet_LT │ ├── ImageNet_LT_test.txt │ ├── ImageNet_LT_test_org.txt │ ├── ImageNet_LT_train.txt │ └── ImageNet_LT_val.txt ├── ImbalanceCIFAR.py ├── MixedPrioritizedSampler.py ├── Places_LT_v2 │ ├── Places_LT_open.txt │ ├── Places_LT_test.txt │ ├── Places_LT_train.txt │ └── Places_LT_val.txt ├── dataloader.py └── iNaturalist18 │ ├── iNaturalist18_train.txt │ └── iNaturalist18_val.txt ├── logger.py ├── loss ├── FocalLoss.py ├── LADELoss.py ├── PriorCELoss.py ├── SoftmaxLoss.py └── WeightedSoftmaxLoss.py ├── main.py ├── models ├── CausalNormClassifier.py ├── DotProductClassifier.py ├── ResNet101Feature.py ├── ResNet152Feature.py ├── ResNet32Feature.py ├── ResNet50Feature.py ├── ResNetFeature.py ├── ResNext101Feature.py ├── ResNext152Feature.py ├── ResNext50Feature.py ├── ResNextFeature.py └── TauNormClassifier.py ├── notebooks ├── cifar100-shift-calib.ipynb └── imagenet-shift-calib.ipynb ├── parse_json.py ├── run_networks.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python,pycharm+all 2 | # Edit at https://www.gitignore.io/?templates=python,pycharm+all 3 | 4 | ### PyCharm+all ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # Generated files 16 | .idea/**/contentModel.xml 17 | 18 | # Sensitive or high-churn files 19 | .idea/**/dataSources/ 20 | .idea/**/dataSources.ids 21 | .idea/**/dataSources.local.xml 22 | .idea/**/sqlDataSources.xml 23 | .idea/**/dynamic.xml 24 | .idea/**/uiDesigner.xml 25 | .idea/**/dbnavigator.xml 26 | 27 | # Gradle 28 | .idea/**/gradle.xml 29 | .idea/**/libraries 30 | 31 | # Gradle and Maven with auto-import 32 | # When using Gradle or Maven with auto-import, you should exclude module files, 33 | # since they will be recreated, and may cause churn. Uncomment if using 34 | # auto-import. 35 | # .idea/modules.xml 36 | # .idea/*.iml 37 | # .idea/modules 38 | # *.iml 39 | # *.ipr 40 | 41 | # CMake 42 | cmake-build-*/ 43 | 44 | # Mongo Explorer plugin 45 | .idea/**/mongoSettings.xml 46 | 47 | # File-based project format 48 | *.iws 49 | 50 | # IntelliJ 51 | out/ 52 | 53 | # mpeltonen/sbt-idea plugin 54 | .idea_modules/ 55 | 56 | # JIRA plugin 57 | atlassian-ide-plugin.xml 58 | 59 | # Cursive Clojure plugin 60 | .idea/replstate.xml 61 | 62 | # Crashlytics plugin (for Android Studio and IntelliJ) 63 | com_crashlytics_export_strings.xml 64 | crashlytics.properties 65 | crashlytics-build.properties 66 | fabric.properties 67 | 68 | # Editor-based Rest Client 69 | .idea/httpRequests 70 | 71 | # Android studio 3.1+ serialized cache file 72 | .idea/caches/build_file_checksums.ser 73 | 74 | ### PyCharm+all Patch ### 75 | # Ignores the whole .idea folder and all .iml files 76 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 77 | 78 | .idea/ 79 | 80 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 81 | 82 | *.iml 83 | modules.xml 84 | .idea/misc.xml 85 | *.ipr 86 | 87 | # Sonarlint plugin 88 | .idea/sonarlint 89 | 90 | ### Python ### 91 | # Byte-compiled / optimized / DLL files 92 | __pycache__/ 93 | *.py[cod] 94 | *$py.class 95 | 96 | # C extensions 97 | *.so 98 | 99 | # Distribution / packaging 100 | .Python 101 | build/ 102 | develop-eggs/ 103 | dist/ 104 | downloads/ 105 | eggs/ 106 | .eggs/ 107 | lib/ 108 | lib64/ 109 | parts/ 110 | sdist/ 111 | var/ 112 | wheels/ 113 | pip-wheel-metadata/ 114 | share/python-wheels/ 115 | *.egg-info/ 116 | .installed.cfg 117 | *.egg 118 | MANIFEST 119 | 120 | # PyInstaller 121 | # Usually these files are written by a python script from a template 122 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 123 | *.manifest 124 | *.spec 125 | 126 | # Installer logs 127 | pip-log.txt 128 | pip-delete-this-directory.txt 129 | 130 | # Unit test / coverage reports 131 | htmlcov/ 132 | .tox/ 133 | .nox/ 134 | .coverage 135 | .coverage.* 136 | .cache 137 | nosetests.xml 138 | coverage.xml 139 | *.cover 140 | .hypothesis/ 141 | .pytest_cache/ 142 | 143 | # Translations 144 | *.mo 145 | *.pot 146 | 147 | # Scrapy stuff: 148 | .scrapy 149 | 150 | # Sphinx documentation 151 | docs/_build/ 152 | 153 | # PyBuilder 154 | target/ 155 | 156 | # pyenv 157 | .python-version 158 | 159 | # pipenv 160 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 161 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 162 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 163 | # install all needed dependencies. 164 | #Pipfile.lock 165 | 166 | # celery beat schedule file 167 | celerybeat-schedule 168 | 169 | # SageMath parsed files 170 | *.sage.py 171 | 172 | # Spyder project settings 173 | .spyderproject 174 | .spyproject 175 | 176 | # Rope project settings 177 | .ropeproject 178 | 179 | # Mr Developer 180 | .mr.developer.cfg 181 | .project 182 | .pydevproject 183 | 184 | # mkdocs documentation 185 | /site 186 | 187 | # mypy 188 | .mypy_cache/ 189 | .dmypy.json 190 | dmypy.json 191 | 192 | # Pyre type checker 193 | .pyre/ 194 | 195 | # End of https://www.gitignore.io/api/python,pycharm+all 196 | 197 | .vscode -------------------------------------------------------------------------------- /LADE_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/LADE/5db5d134720872620706025bc4dec8347ddc349b/LADE_poster.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For LADE software 4 | 5 | Copyright (c) Hyperconnect, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Hyperconnect nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | Portions of the source code are from the Classifier-Balancing project and are licensed under 33 | the following terms: 34 | 35 | BSD License 36 | 37 | For Classifier-Balacing software 38 | 39 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 40 | 41 | Redistribution and use in source and binary forms, with or without modification, 42 | are permitted provided that the following conditions are met: 43 | 44 | * Redistributions of source code must retain the above copyright notice, this 45 | list of conditions and the following disclaimer. 46 | 47 | * Redistributions in binary form must reproduce the above copyright notice, 48 | this list of conditions and the following disclaimer in the documentation 49 | and/or other materials provided with the distribution. 50 | 51 | * Neither the name Facebook nor the names of its contributors may be used to 52 | endorse or promote products derived from this software without specific 53 | prior written permission. 54 | 55 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 56 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 57 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 58 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 59 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 60 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 61 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 62 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 63 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 64 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 65 | 66 | Portions of the source code are from the OLTR project and are licensed under 67 | the following terms: 68 | 69 | BSD 3-Clause License 70 | 71 | Copyright (c) 2019, Zhongqi Miao. All rights reserved. 72 | 73 | Redistribution and use in source and binary forms, with or without 74 | modification, are permitted provided that the following conditions are met: 75 | 76 | 1. Redistributions of source code must retain the above copyright notice, this 77 | list of conditions and the following disclaimer. 78 | 79 | 2. Redistributions in binary form must reproduce the above copyright notice, 80 | this list of conditions and the following disclaimer in the documentation 81 | and/or other materials provided with the distribution. 82 | 83 | 3. Neither the name of the copyright holder nor the names of its 84 | contributors may be used to endorse or promote products derived from 85 | this software without specific prior written permission. 86 | 87 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 88 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 89 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 90 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 91 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 92 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 93 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 94 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 95 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 96 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Disentangling Label Distribution for Long-tailed Visual Recognition (CVPR 2021) 2 | - [Arxiv link](https://arxiv.org/abs/2012.00321) 3 | - [Blog post](https://hyperconnect.github.io/2021/04/05/ailab-lade.html) 4 | - [Poster](./LADE_poster.pdf) 5 | - This codebase is built on [Causal Norm](https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch). 6 | 7 | ## Install 8 | 9 | ``` 10 | conda create -n longtail pip python=3.7 -y 11 | source activate longtail 12 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 13 | pip install pyyaml tqdm matplotlib sklearn h5py tensorboard 14 | ``` 15 | 16 | ## Training 17 | 18 | ### Preliminaries 19 | 20 | - Download pretrained caffe resnet152 model for Places-LT: please refer to [link](https://github.com/zhmiao/OpenLongTailRecognition-OLTR#download-caffe-pre-trained-models-for-places_lt-stage_1-training). 21 | 22 | - Prepare dataset: CIFAR-100, Places-LT, ImageNet-LT, iNaturalist 2018 23 | - Please download those datasets following [Decoupling](https://github.com/facebookresearch/classifier-balancing#dataset). 24 | 25 | ### CIFAR-100 training 26 | 27 | For CIFAR-100 with imbalance ratio 0.01, using LADE: 28 | 29 | ``` 30 | python main.py --seed 1 --cfg config/CIFAR100_LT/lade.yaml --exp_name lade2021/cifar100_imb0.01_lade --cifar_imb_ratio 0.01 --remine_lambda 0.01 --alpha 0.1 --gpu 0 31 | ``` 32 | 33 | ### Places-LT training 34 | 35 | For PC Softmax: 36 | 37 | ``` 38 | python main.py --seed 1 --cfg config/Places_LT/ce.yaml --exp_name lade2021/places_pc_softmax --lr 0.05 --gpu 0,1,2,3 39 | ``` 40 | 41 | For LADE: 42 | 43 | ``` 44 | python main.py --seed 1 --cfg config/Places_LT/lade.yaml --exp_name lade2021/places_lade --lr 0.05 --remine_lambda 0.1 --alpha 0.005 --gpu 0,1,2,3 45 | ``` 46 | 47 | ### ImageNet-LT training 48 | 49 | For LADE: 50 | 51 | ``` 52 | python main.py --seed 1 --cfg config/ImageNet_LT/lade.yaml --exp_name lade2021/imagenet_lade --lr 0.05 --remine_lambda 0.5 --alpha 0.05 --gpu 0,1,2,3 53 | ``` 54 | 55 | ### iNaturalist18 training 56 | 57 | For LADE: 58 | 59 | ``` 60 | python main.py --seed 1 --cfg ./config/iNaturalist18/lade.yaml --exp_name lade2021/inat_lade --lr 0.1 --alpha 0.05 --gpu 0,1,2,3 61 | ``` 62 | 63 | ## Evaluate on shifted test set & Confidence calibration 64 | For Imagenet (Section 4.3, 4.4): 65 | ``` 66 | ./notebooks/imagenet-shift-calib.ipynb 67 | ``` 68 | 69 | For CIFAR-100 (Supplementary material): 70 | ``` 71 | ./notebooks/cifar100-shift-calib.ipynb 72 | ``` 73 | 74 | ### License 75 | The use of this software is released under BSD-3. 76 | 77 | ### Citation 78 | If you find our paper or this project helps your research, please kindly consider citing our paper in your publications. 79 | ``` 80 | @inproceedings{hong2021disentangling, 81 | title={Disentangling label distribution for long-tailed visual recognition}, 82 | author={Hong, Youngkyu and Han, Seungju and Choi, Kwanghee and Seo, Seokjun and Kim, Beomsu and Chang, Buru}, 83 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 84 | pages={6626--6636}, 85 | year={2021} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /config/CIFAR100_LT/balanced_softmax.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 100, prior: 0.1, img_max: 500} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/DotProductClassifier.py 16 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005, threshold_momentum: 0.0} 17 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 18 | params: {dataset: CIFAR100_LT, feat_dim: 128, num_classes: 100} 19 | feat_model: 20 | def_file: ./models/ResNet32Feature.py 21 | fix: false 22 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 24 | params: {dataset: CIFAR100_LT, dropout: null, stage1_weights: false, use_fc: false, pretrain: false} 25 | shuffle: false 26 | training_opt: 27 | backbone: resnet32 28 | batch_size: 256 29 | dataset: CIFAR100_LT 30 | display_step: 10 31 | display_grad: False 32 | display_grad_step: 10 33 | feature_dim: 128 34 | log_dir: ./exp_results/test 35 | work_dir: ./exp_results 36 | exp_name: test 37 | num_classes: 100 38 | cifar_imb_ratio: 0.1 # 0.01, 0.02, 0.1 for 100, 50, 10 39 | num_epochs: 200 40 | num_workers: 12 41 | open_threshold: 0.1 42 | sampler: null 43 | sub_dir: models 44 | -------------------------------------------------------------------------------- /config/CIFAR100_LT/causal.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/CausalNormClassifier.py 16 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 17 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 18 | params: {dataset: CIFAR100_LT, feat_dim: 128, num_classes: 100, stage1_weights: false, use_effect: true, num_head: 2, tau: 16.0, alpha: 1.5, gamma: 0.03125} 19 | feat_model: 20 | def_file: ./models/ResNet32Feature.py 21 | fix: false 22 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 24 | params: {dataset: CIFAR100_LT, dropout: null, stage1_weights: false, use_fc: false, pretrain: false} 25 | shuffle: false 26 | training_opt: 27 | backbone: resnet32 28 | batch_size: 256 29 | dataset: CIFAR100_LT 30 | display_step: 10 31 | display_grad: False 32 | display_grad_step: 10 33 | feature_dim: 128 34 | log_dir: ./exp_results/test 35 | work_dir: ./exp_results 36 | exp_name: test 37 | num_classes: 100 38 | cifar_imb_ratio: 0.01 # 0.01, 0.02, 0.1 for 100, 50, 10 39 | num_epochs: 200 40 | num_workers: 12 41 | open_threshold: 0.1 42 | sampler: null 43 | sub_dir: models 44 | -------------------------------------------------------------------------------- /config/CIFAR100_LT/ce.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/DotProductClassifier.py 16 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005, threshold_momentum: 0.0} 17 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 18 | params: {dataset: CIFAR100_LT, feat_dim: 128, num_classes: 100} 19 | feat_model: 20 | def_file: ./models/ResNet32Feature.py 21 | fix: false 22 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 24 | params: {dataset: CIFAR100_LT, dropout: null, stage1_weights: false, use_fc: false, pretrain: false} 25 | shuffle: false 26 | training_opt: 27 | backbone: resnet32 28 | batch_size: 256 29 | dataset: CIFAR100_LT 30 | display_step: 10 31 | display_grad: False 32 | display_grad_step: 10 33 | feature_dim: 128 34 | log_dir: ./exp_results/test 35 | work_dir: ./exp_results 36 | exp_name: test 37 | num_classes: 100 38 | cifar_imb_ratio: 0.1 # 0.01, 0.02, 0.1 for 100, 50, 10 39 | num_epochs: 200 40 | num_workers: 12 41 | open_threshold: 0.1 42 | sampler: null 43 | sub_dir: models 44 | -------------------------------------------------------------------------------- /config/CIFAR100_LT/lade.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 100, prior: 0.1, img_max: 500} 5 | optim_params: null 6 | weight: 1.0 7 | RouteWeightLoss: 8 | def_file: ./loss/LADELoss.py 9 | loss_params: {num_classes: 100, prior: 0.1, img_max: 500, remine_lambda: 0.1} 10 | optim_params: null 11 | weight: 0.1 12 | last: false 13 | # apply incremental pca to remove main components 14 | apply_ipca: false 15 | num_components: 512 16 | model_dir: null 17 | tuning_memory: false 18 | networks: 19 | classifier: 20 | def_file: ./models/DotProductClassifier.py 21 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005, threshold_momentum: 0.0} 22 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 23 | params: {dataset: CIFAR100_LT, feat_dim: 128, num_classes: 100, use_route: True} 24 | feat_model: 25 | def_file: ./models/ResNet32Feature.py 26 | fix: false 27 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 28 | scheduler_params: {coslr: false, endlr: 0.0, gamma: 0.1, step_size: 30, warmup: true, lr_step: [120, 160], lr_factor: 0.01, warm_epoch: 5} 29 | params: {dataset: CIFAR100_LT, dropout: null, stage1_weights: false, use_fc: false, pretrain: false} 30 | shuffle: false 31 | training_opt: 32 | backbone: resnet32 33 | batch_size: 256 34 | dataset: CIFAR100_LT 35 | display_step: 10 36 | display_grad: False 37 | display_grad_step: 10 38 | feature_dim: 128 39 | log_dir: ./exp_results/test 40 | work_dir: ./exp_results 41 | exp_name: test 42 | num_classes: 100 43 | cifar_imb_ratio: 0.1 # 0.01, 0.02, 0.1 for 100, 50, 10 44 | num_epochs: 200 45 | num_workers: 12 46 | open_threshold: 0.1 47 | sampler: null 48 | sub_dir: models 49 | -------------------------------------------------------------------------------- /config/ImageNet_LT/balanced_softmax.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 1000, prior_txt: './data/ImageNet_LT/ImageNet_LT_train.txt'} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | model_dir: null 9 | networks: 10 | classifier: 11 | def_file: ./models/DotProductClassifier.py 12 | optim_params: {lr: 0.05, momentum: 0.9, weight_decay: 0.0005, threshold_momentum: 0.0} 13 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 14 | params: {dataset: ImageNet_LT, feat_dim: 2048, num_classes: 1000} 15 | feat_model: 16 | def_file: ./models/ResNext50Feature.py 17 | fix: false 18 | optim_params: {lr: 0.05, momentum: 0.9, weight_decay: 0.0005} 19 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 20 | params: {dataset: ImageNet_LT, dropout: null, stage1_weights: false, use_fc: false} 21 | shuffle: false 22 | training_opt: 23 | backbone: resnext50 24 | batch_size: 256 25 | dataset: ImageNet_LT 26 | display_step: 10 27 | display_grad: False 28 | display_grad_step: 10 29 | feature_dim: 2048 30 | log_dir: ./exp_results/test 31 | work_dir: ./exp_results 32 | exp_name: test 33 | num_classes: 1000 34 | num_epochs: 90 35 | num_workers: 12 36 | open_threshold: 0.1 37 | sampler: null 38 | sub_dir: models -------------------------------------------------------------------------------- /config/ImageNet_LT/causal.yaml: -------------------------------------------------------------------------------- 1 | # default num_head = 2 2 | criterions: 3 | PerformanceLoss: 4 | def_file: ./loss/SoftmaxLoss.py 5 | loss_params: {} 6 | optim_params: null 7 | weight: 1.0 8 | last: false 9 | # apply incremental pca to remove main components 10 | apply_ipca: false 11 | num_components: 512 12 | model_dir: null 13 | tuning_memory: false 14 | networks: 15 | classifier: 16 | def_file: ./models/CausalNormClassifier.py 17 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 18 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 19 | params: {dataset: ImageNet_LT, feat_dim: 2048, num_classes: 1000, stage1_weights: false, use_effect: true, num_head: 2, tau: 16.0, alpha: 3.0, gamma: 0.03125} 20 | feat_model: 21 | def_file: ./models/ResNext50Feature.py 22 | fix: false 23 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 24 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 25 | params: {dataset: ImageNet_LT, dropout: null, stage1_weights: false, use_fc: false} 26 | shuffle: false 27 | training_opt: 28 | backbone: resnext50 29 | batch_size: 256 30 | dataset: ImageNet_LT 31 | display_step: 10 32 | display_grad: False 33 | display_grad_step: 10 34 | feature_dim: 2048 35 | log_dir: ./logs/ImageNet_LT/models/resnext50_e90_causal_norm_classifier_head2 36 | log_root: /logs/ImageNet_LT 37 | num_classes: 1000 38 | num_epochs: 90 39 | num_workers: 12 40 | open_threshold: 0.1 41 | sampler: null 42 | sub_dir: models 43 | -------------------------------------------------------------------------------- /config/ImageNet_LT/ce.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | model_dir: null 9 | networks: 10 | classifier: 11 | def_file: ./models/DotProductClassifier.py 12 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 13 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 14 | params: {dataset: ImageNet_LT, feat_dim: 2048, num_classes: 1000, stage1_weights: false} 15 | feat_model: 16 | def_file: ./models/ResNext50Feature.py 17 | fix: false 18 | optim_params: {lr: 0.2, momentum: 0.9, weight_decay: 0.0005} 19 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 20 | params: {dataset: ImageNet_LT, dropout: null, stage1_weights: false, use_fc: false} 21 | shuffle: false 22 | training_opt: 23 | backbone: resnext50 24 | batch_size: 256 25 | dataset: ImageNet_LT 26 | display_step: 10 27 | display_grad: False 28 | display_grad_step: 10 29 | feature_dim: 2048 30 | log_dir: ./logs/ImageNet_LT/models/resnext50_cross_entropy_e90 31 | log_root: /logs/ImageNet_LT 32 | num_classes: 1000 33 | num_epochs: 90 34 | num_workers: 12 35 | open_threshold: 0.1 36 | sampler: null 37 | sub_dir: models 38 | -------------------------------------------------------------------------------- /config/ImageNet_LT/lade.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 1000, prior_txt: './data/ImageNet_LT/ImageNet_LT_train.txt'} 5 | optim_params: null 6 | weight: 1.0 7 | RouteWeightLoss: 8 | def_file: ./loss/LADELoss.py 9 | loss_params: {num_classes: 1000, prior_txt: './data/ImageNet_LT/ImageNet_LT_train.txt', remine_lambda: 0.1} 10 | optim_params: null 11 | weight: 0.1 12 | last: false 13 | model_dir: null 14 | networks: 15 | classifier: 16 | def_file: ./models/DotProductClassifier.py 17 | optim_params: {lr: 0.05, momentum: 0.9, weight_decay: 0.0005, threshold_momentum: 0.0} 18 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 19 | params: {dataset: ImageNet_LT, feat_dim: 2048, num_classes: 1000, use_route: true} 20 | feat_model: 21 | def_file: ./models/ResNext50Feature.py 22 | fix: false 23 | optim_params: {lr: 0.05, momentum: 0.9, weight_decay: 0.0005} 24 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 25 | params: {dataset: ImageNet_LT, dropout: null, stage1_weights: false, use_fc: false} 26 | shuffle: false 27 | training_opt: 28 | backbone: resnext50 29 | batch_size: 256 30 | dataset: ImageNet_LT 31 | display_step: 10 32 | display_grad: False 33 | display_grad_step: 10 34 | feature_dim: 2048 35 | log_dir: ./exp_results/test 36 | work_dir: ./exp_results 37 | exp_name: test 38 | num_classes: 1000 39 | num_epochs: 180 40 | num_workers: 12 41 | open_threshold: 0.1 42 | sampler: null 43 | sub_dir: models 44 | -------------------------------------------------------------------------------- /config/Places_LT/balanced_softmax.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 365, prior_txt: './data/Places_LT/Places_LT_train.txt'} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/DotProductClassifier.py 16 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 17 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 18 | params: {dataset: Places_LT, feat_dim: 2048, num_classes: 365} 19 | feat_model: 20 | def_file: ./models/ResNet152Feature.py 21 | fix: false 22 | optim_params: {lr: 0.001, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 24 | params: {dataset: Places_LT, dropout: null, stage1_weights: false, use_fc: false, caffe: True} 25 | shuffle: false 26 | training_opt: 27 | batch_size: 128 28 | dataset: Places_LT 29 | display_step: 10 30 | display_grad: False 31 | display_grad_step: 10 32 | feature_dim: 2048 33 | log_dir: ./logs/Places_LT/models/resnet152_ce_prior 34 | log_root: /logs/Places_LT 35 | num_classes: 365 36 | num_epochs: 30 37 | num_workers: 4 38 | open_threshold: 0.1 39 | sampler: null 40 | sub_dir: models 41 | -------------------------------------------------------------------------------- /config/Places_LT/causal.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/CausalNormClassifier.py 16 | optim_params: {lr: 0.05, momentum: 0.9, weight_decay: 0.0005} 17 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 18 | params: {dataset: Places_LT, feat_dim: 2048, num_classes: 365, stage1_weights: false, use_effect: true, num_head: 2, tau: 16.0, alpha: 3.0, gamma: 0.03125} 19 | feat_model: 20 | def_file: ./models/ResNet152Feature.py 21 | fix: false 22 | optim_params: {lr: 0.001, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 24 | params: {dataset: Places_LT, dropout: null, stage1_weights: false, use_fc: false, caffe: True} 25 | shuffle: false 26 | training_opt: 27 | batch_size: 128 28 | dataset: Places_LT 29 | display_step: 10 30 | display_grad: False 31 | display_grad_step: 10 32 | feature_dim: 2048 33 | log_dir: ./logs/Places_LT/models/resnet152_e90_causal_norm_classifier_head2 34 | log_root: /logs/Places_LT 35 | num_classes: 365 36 | num_epochs: 30 37 | num_workers: 4 38 | open_threshold: 0.1 39 | sampler: null 40 | sub_dir: models 41 | -------------------------------------------------------------------------------- /config/Places_LT/ce.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | # apply incremental pca to remove main components 9 | apply_ipca: false 10 | num_components: 512 11 | model_dir: null 12 | tuning_memory: false 13 | networks: 14 | classifier: 15 | def_file: ./models/DotProductClassifier.py 16 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 17 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 18 | params: {dataset: Places_LT, feat_dim: 2048, num_classes: 365} 19 | feat_model: 20 | def_file: ./models/ResNet152Feature.py 21 | fix: false 22 | optim_params: {lr: 0.001, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 24 | params: {dataset: Places_LT, dropout: null, stage1_weights: false, use_fc: false, caffe: True} 25 | shuffle: false 26 | training_opt: 27 | batch_size: 128 28 | dataset: Places_LT 29 | display_step: 10 30 | display_grad: False 31 | display_grad_step: 10 32 | feature_dim: 2048 33 | log_dir: ./logs/Places_LT/models/resnet152_ce_prior 34 | log_root: /logs/Places_LT 35 | num_classes: 365 36 | num_epochs: 30 37 | num_workers: 4 38 | open_threshold: 0.1 39 | sampler: null 40 | sub_dir: models 41 | -------------------------------------------------------------------------------- /config/Places_LT/lade.yaml: -------------------------------------------------------------------------------- 1 | # default num_head = 2 2 | criterions: 3 | PerformanceLoss: 4 | def_file: ./loss/PriorCELoss.py 5 | loss_params: {num_classes: 365, prior_txt: 'data/Places_LT_v2/Places_LT_train.txt'} 6 | optim_params: null 7 | weight: 1.0 8 | RouteWeightLoss: 9 | def_file: ./loss/LADELoss.py 10 | loss_params: {num_classes: 365, remine_lambda: 0.1, prior_txt: 'data/Places_LT_v2/Places_LT_train.txt'} 11 | optim_params: null 12 | weight: 0.1 13 | last: false 14 | # apply incremental pca to remove main components 15 | apply_ipca: false 16 | num_components: 512 17 | model_dir: null 18 | tuning_memory: false 19 | networks: 20 | classifier: 21 | def_file: ./models/DotProductClassifier.py 22 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005} 23 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 24 | params: {dataset: Places_LT, feat_dim: 2048, num_classes: 365, use_route: True} 25 | feat_model: 26 | def_file: ./models/ResNet152Feature.py 27 | fix: false 28 | optim_params: {lr: 0.001, momentum: 0.9, weight_decay: 0.0005} 29 | scheduler_params: {coslr: false, endlr: 0.0, step_size: 10, gamma: 0.1, warmup: false} 30 | params: {dataset: Places_LT, dropout: null, stage1_weights: false, use_fc: false, caffe: True} 31 | shuffle: false 32 | training_opt: 33 | batch_size: 128 34 | dataset: Places_LT 35 | display_step: 10 36 | display_grad: False 37 | display_grad_step: 10 38 | feature_dim: 2048 39 | log_dir: ./logs/Places_LT/models/resnet152_ce_prior 40 | log_root: /logs/Places_LT 41 | num_classes: 365 42 | num_epochs: 30 43 | num_workers: 4 44 | open_threshold: 0.1 45 | sampler: null 46 | sub_dir: models 47 | -------------------------------------------------------------------------------- /config/iNaturalist18/balanced_softmax.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 8142, prior_txt: './data/iNaturalist18/iNaturalist18_train.txt'} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | model_dir: null 9 | networks: 10 | classifier: 11 | def_file: ./models/DotProductClassifier.py 12 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 13 | scheduler_params: {coslr: false, warmup: false, step_size: 30, gamma: 0.1} 14 | params: {dataset: iNaturalist18, feat_dim: 2048, num_classes: 8142, stage1_weights: false} 15 | feat_model: 16 | def_file: ./models/ResNet50Feature.py 17 | fix: false 18 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 19 | scheduler_params: {coslr: false, warmup: false, step_size: 30, gamma: 0.1} 20 | params: {dataset: iNaturalist18, dropout: null, stage1_weights: false, use_fc: false} 21 | shuffle: false 22 | training_opt: 23 | backbone: resnet50 24 | batch_size: 256 25 | dataset: iNaturalist18 26 | display_step: 10 27 | display_grad: False 28 | display_grad_step: 10 29 | feature_dim: 2048 30 | log_dir: ./exp_results/test 31 | work_dir: ./exp_results 32 | exp_name: test 33 | num_classes: 8142 34 | num_epochs: 200 35 | num_workers: 12 36 | open_threshold: 0.1 37 | sampler: null 38 | sub_dir: models 39 | -------------------------------------------------------------------------------- /config/iNaturalist18/causal.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | apply_ipca: false 9 | num_components: 512 10 | model_dir: null 11 | tuning_memory: false 12 | networks: 13 | classifier: 14 | def_file: ./models/CausalNormClassifier.py 15 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 16 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 17 | params: {dataset: iNaturalist18, feat_dim: 2048, num_classes: 8142, stage1_weights: false, use_effect: true, num_head: 2, tau: 16.0, alpha: 3.0, gamma: 0.03125} 18 | feat_model: 19 | def_file: ./models/ResNet50Feature.py 20 | fix: false 21 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 22 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 23 | params: {dataset: iNaturalist18, dropout: null, stage1_weights: false, use_fc: false} 24 | shuffle: false 25 | training_opt: 26 | backbone: resnet50 27 | batch_size: 256 28 | dataset: iNaturalist18 29 | display_step: 10 30 | display_grad: False 31 | display_grad_step: 10 32 | feature_dim: 2048 33 | log_dir: ./exp_results/test 34 | work_dir: ./exp_results 35 | exp_name: test 36 | num_classes: 8142 37 | num_epochs: 200 38 | num_workers: 12 39 | open_threshold: 0.1 40 | sampler: null 41 | sub_dir: models 42 | -------------------------------------------------------------------------------- /config/iNaturalist18/ce.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/SoftmaxLoss.py 4 | loss_params: {} 5 | optim_params: null 6 | weight: 1.0 7 | last: false 8 | model_dir: null 9 | networks: 10 | classifier: 11 | def_file: ./models/DotProductClassifier.py 12 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 13 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 14 | params: {dataset: iNaturalist18, feat_dim: 2048, num_classes: 8142, stage1_weights: false} 15 | feat_model: 16 | def_file: ./models/ResNet50Feature.py 17 | fix: false 18 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 19 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 20 | params: {dataset: iNaturalist18, dropout: null, stage1_weights: false, use_fc: false} 21 | shuffle: false 22 | training_opt: 23 | backbone: resnet50 24 | batch_size: 256 25 | dataset: iNaturalist18 26 | display_step: 10 27 | display_grad: False 28 | display_grad_step: 10 29 | feature_dim: 2048 30 | log_dir: ./exp_results/test 31 | work_dir: ./exp_results 32 | exp_name: test 33 | num_classes: 8142 34 | num_epochs: 200 35 | num_workers: 12 36 | open_threshold: 0.1 37 | sampler: null 38 | sub_dir: models 39 | -------------------------------------------------------------------------------- /config/iNaturalist18/lade.yaml: -------------------------------------------------------------------------------- 1 | criterions: 2 | PerformanceLoss: 3 | def_file: ./loss/PriorCELoss.py 4 | loss_params: {num_classes: 8142, prior_txt: './data/iNaturalist18/iNaturalist18_train.txt'} 5 | optim_params: null 6 | weight: 1.0 7 | RouteWeightLoss: 8 | def_file: ./loss/LADELoss.py 9 | loss_params: {num_classes: 8142, prior_txt: './data/iNaturalist18/iNaturalist18_train.txt', remine_lambda: 0.5} 10 | optim_params: null 11 | weight: 0.05 12 | last: false 13 | model_dir: null 14 | networks: 15 | classifier: 16 | def_file: ./models/DotProductClassifier.py 17 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001, threshold_momentum: 0.0} 18 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 19 | params: {dataset: iNaturalist18, feat_dim: 2048, num_classes: 8142, use_route: true} 20 | feat_model: 21 | def_file: ./models/ResNet50Feature.py 22 | fix: false 23 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0001} 24 | scheduler_params: {coslr: true, endlr: 0.0, gamma: 0.1, step_size: 30} 25 | params: {dataset: iNaturalist18, dropout: null, stage1_weights: false, use_fc: false} 26 | shuffle: false 27 | training_opt: 28 | backbone: resnet50 29 | batch_size: 256 30 | dataset: iNaturalist18 31 | display_step: 10 32 | display_grad: False 33 | display_grad_step: 10 34 | feature_dim: 2048 35 | log_dir: ./exp_results/test 36 | work_dir: ./exp_results 37 | exp_name: test 38 | num_classes: 8142 39 | num_epochs: 200 40 | num_workers: 12 41 | open_threshold: 0.1 42 | sampler: null 43 | sub_dir: models -------------------------------------------------------------------------------- /data/ClassAwareSampler.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | import random 16 | import numpy as np 17 | from torch.utils.data.sampler import Sampler 18 | import pdb 19 | 20 | ################################## 21 | ## Class-aware sampling, partly implemented by frombeijingwithlove 22 | ################################## 23 | 24 | class RandomCycleIter: 25 | 26 | def __init__ (self, data, test_mode=False): 27 | self.data_list = list(data) 28 | self.length = len(self.data_list) 29 | self.i = self.length - 1 30 | self.test_mode = test_mode 31 | 32 | def __iter__ (self): 33 | return self 34 | 35 | def __next__ (self): 36 | self.i += 1 37 | 38 | if self.i == self.length: 39 | self.i = 0 40 | if not self.test_mode: 41 | random.shuffle(self.data_list) 42 | 43 | return self.data_list[self.i] 44 | 45 | def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1): 46 | 47 | i = 0 48 | j = 0 49 | while i < n: 50 | 51 | # yield next(data_iter_list[next(cls_iter)]) 52 | 53 | if j >= num_samples_cls: 54 | j = 0 55 | 56 | if j == 0: 57 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 58 | yield temp_tuple[j] 59 | else: 60 | yield temp_tuple[j] 61 | 62 | i += 1 63 | j += 1 64 | 65 | class ClassAwareSampler (Sampler): 66 | 67 | def __init__(self, data_source, num_samples_cls=1,): 68 | num_classes = len(np.unique(data_source.labels)) 69 | self.class_iter = RandomCycleIter(range(num_classes)) 70 | cls_data_list = [list() for _ in range(num_classes)] 71 | for i, label in enumerate(data_source.labels): 72 | cls_data_list[label].append(i) 73 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 74 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 75 | self.num_samples_cls = num_samples_cls 76 | 77 | def __iter__ (self): 78 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 79 | self.num_samples, self.num_samples_cls) 80 | 81 | def __len__ (self): 82 | return self.num_samples 83 | 84 | def get_sampler(): 85 | return ClassAwareSampler 86 | 87 | ################################## -------------------------------------------------------------------------------- /data/ClassPrioritySampler.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | import random 10 | import numpy as np 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | class RandomCycleIter: 15 | 16 | def __init__ (self, data, test_mode=False): 17 | self.data_list = list(data) 18 | self.length = len(self.data_list) 19 | self.i = self.length - 1 20 | self.test_mode = test_mode 21 | 22 | def __iter__ (self): 23 | return self 24 | 25 | def __next__ (self): 26 | self.i += 1 27 | 28 | if self.i == self.length: 29 | self.i = 0 30 | if not self.test_mode: 31 | random.shuffle(self.data_list) 32 | 33 | return self.data_list[self.i] 34 | 35 | 36 | class PriorityTree(object): 37 | def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0, 38 | alpha=1.0): 39 | """ 40 | fixed_weights: weights that wont be updated by self.update() 41 | """ 42 | assert fixed_weights is None or len(fixed_weights) == capacity 43 | assert len(init_weights) == capacity 44 | self.alpha = alpha 45 | self._capacity = capacity 46 | self._tree_size = 2 * capacity - 1 47 | self.fixed_scale = fixed_scale 48 | self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \ 49 | else fixed_weights 50 | self.tree = np.zeros(self._tree_size) 51 | self._initialized = False 52 | self.initialize(init_weights) 53 | 54 | def initialize(self, init_weights): 55 | """Initialize the tree.""" 56 | 57 | # Rescale the fixed_weights if it is not zero 58 | self.fixed_scale_init = self.fixed_scale 59 | if self.fixed_weights.sum() > 0 and init_weights.sum() > 0: 60 | self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum() 61 | self.fixed_weights *= self.fixed_scale * init_weights.sum() \ 62 | / self.fixed_weights.sum() 63 | print('FixedWeights: {}'.format(self.fixed_weights.sum())) 64 | 65 | self.update_whole(init_weights + self.fixed_weights) 66 | self._initialized = True 67 | 68 | def reset_adaptive_weights(self, adaptive_weights): 69 | self.update_whole(self.fixed_weights + adaptive_weights) 70 | 71 | def reset_fixed_weights(self, fixed_weights, rescale=False): 72 | """ Reset the manually designed weights and 73 | update the whole tree accordingly. 74 | 75 | @rescale: rescale the fixed_weights such that 76 | fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum() 77 | """ 78 | 79 | adaptive_weights = self.get_adaptive_weights() 80 | fixed_sum = fixed_weights.sum() 81 | if rescale and fixed_sum > 0: 82 | # Rescale fixedweight based on adaptive weights 83 | scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum 84 | else: 85 | # Rescale fixedweight based on previous fixedweight 86 | scale = self.fixed_weights.sum() / fixed_sum 87 | self.fixed_weights = fixed_weights * scale 88 | self.update_whole(self.fixed_weights + adaptive_weights) 89 | 90 | def update_whole(self, total_weights): 91 | """ Update the whole tree based on per-example sampling weights """ 92 | if self.alpha != 1: 93 | total_weights = np.power(total_weights, self.alpha) 94 | lefti = self.pointer_to_treeidx(0) 95 | righti = self.pointer_to_treeidx(self.capacity-1) 96 | self.tree[lefti:righti+1] = total_weights 97 | 98 | # Iteratively find a parent layer 99 | while lefti != 0 and righti != 0: 100 | lefti = (lefti - 1) // 2 if lefti != 0 else 0 101 | righti = (righti - 1) // 2 if righti != 0 else 0 102 | 103 | # Assign paraent weights from right to left 104 | for i in range(righti, lefti-1, -1): 105 | self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2] 106 | 107 | def get_adaptive_weights(self): 108 | """ Get the instance-aware weights, that are not mannually designed""" 109 | if self.alpha == 1: 110 | return self.get_total_weights() - self.fixed_weights 111 | else: 112 | return self.get_raw_total_weights() - self.fixed_weights 113 | 114 | def get_total_weights(self): 115 | """ Get the per-example sampling weights 116 | return shape: [capacity] 117 | """ 118 | lefti = self.pointer_to_treeidx(0) 119 | righti = self.pointer_to_treeidx(self.capacity-1) 120 | return self.tree[lefti:righti+1] 121 | 122 | def get_raw_total_weights(self): 123 | """ Get the per-example sampling weights 124 | return shape: [capacity] 125 | """ 126 | lefti = self.pointer_to_treeidx(0) 127 | righti = self.pointer_to_treeidx(self.capacity-1) 128 | return np.power(self.tree[lefti:righti+1], 1/self.alpha) 129 | 130 | @property 131 | def size(self): 132 | return self._tree_size 133 | 134 | @property 135 | def capacity(self): 136 | return self._capacity 137 | 138 | def __len__(self): 139 | return self.capacity 140 | 141 | def pointer_to_treeidx(self, pointer): 142 | assert pointer < self.capacity 143 | return int(pointer + self.capacity - 1) 144 | 145 | def update(self, pointer, priority): 146 | assert pointer < self.capacity 147 | tree_idx = self.pointer_to_treeidx(pointer) 148 | priority += self.fixed_weights[pointer] 149 | if self.alpha != 1: 150 | priority = np.power(priority, self.alpha) 151 | delta = priority - self.tree[tree_idx] 152 | self.tree[tree_idx] = priority 153 | while tree_idx != 0: 154 | tree_idx = (tree_idx - 1) // 2 155 | self.tree[tree_idx] += delta 156 | 157 | def update_delta(self, pointer, delta): 158 | assert pointer < self.capacity 159 | tree_idx = self.pointer_to_treeidx(pointer) 160 | ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx] 161 | # delta *= ratio 162 | if self.alpha != 1: 163 | # Update delta 164 | if self.tree[tree_idx] < 0 or \ 165 | np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0: 166 | import pdb; pdb.set_trace() 167 | delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta, 168 | self.alpha) \ 169 | - self.tree[tree_idx] 170 | self.tree[tree_idx] += delta 171 | while tree_idx != 0: 172 | tree_idx = (tree_idx - 1) // 2 173 | self.tree[tree_idx] += delta 174 | 175 | def get_leaf(self, value): 176 | assert self._initialized, 'PriorityTree not initialized!!!!' 177 | assert self.total > 0, 'No priority weights setted!!' 178 | parent = 0 179 | while True: 180 | left_child = 2 * parent + 1 181 | right_child = 2 * parent + 2 182 | if left_child >= len(self.tree): 183 | tgt_leaf = parent 184 | break 185 | if value < self.tree[left_child]: 186 | parent = left_child 187 | else: 188 | value -= self.tree[left_child] 189 | parent = right_child 190 | data_idx = tgt_leaf - self.capacity + 1 191 | return data_idx, self.tree[tgt_leaf] # data idx, priority 192 | 193 | @property 194 | def total(self): 195 | assert self._initialized, 'PriorityTree not initialized!!!!' 196 | return self.tree[0] 197 | 198 | @property 199 | def max(self): 200 | return np.max(self.tree[-self.capacity:]) 201 | 202 | @property 203 | def min(self): 204 | assert self._initialized, 'PriorityTree not initialized!!!!' 205 | return np.min(self.tree[-self.capacity:]) 206 | 207 | def get_weights(self): 208 | wdict = {'fixed_weights': self.fixed_weights, 209 | 'total_weights': self.get_total_weights()} 210 | if self.alpha != 1: 211 | wdict.update({'raw_total_weights': self.get_raw_total_weights(), 212 | 'alpha': self.alpha}) 213 | 214 | return wdict 215 | 216 | class ClassPrioritySampler(Sampler): 217 | """ 218 | A sampler combining manually designed sampling strategy and prioritized 219 | sampling strategy. 220 | 221 | Manually disigned strategy contains two parts: 222 | 223 | $$ manual_weights = lam * balanced_weights + (1-lam) uniform_weights 224 | 225 | Here we use a generalized version of balanced weights as follows, 226 | when n limits to infinity, balanced_weights = real_balanced_weights 227 | 228 | $$ balanced_weights = uniform_weights ^ (1/n) 229 | 230 | Then the balanced weights are scaled such that 231 | 232 | $$ balanced_weights.sum() = balance_scale * uniform_weights.sum() 233 | 234 | Note: above weights are per-class weights 235 | 236 | Overall sampling weights are given as 237 | $$ sampling_weights = manual_weights * fixed_scale + priority_weights 238 | 239 | Arguments: 240 | @dataset: A dataset 241 | @balance_scale: The scale of balanced_weights 242 | @lam: A weight to combine balanced weights and uniform weights 243 | - None for shifting sampling 244 | - 0 for uniform sampling 245 | - 1 for balanced sampling 246 | @fixed_scale: The scale of manually designed weights 247 | - fixed_scale < 0 means, the manually designed distribution will 248 | be used as the backend distribution of priorities. 249 | @cycle: shifting strategy 250 | - 0 for linear shifting: 3 -> 2 - > 1 251 | - 1 for periodic shifting: 252 | 3 -> 2 - > 1 -> 3 -> 2 - > 1 -> 3 -> 2 - > 1 253 | - 2 for cosine-like periodic shifting: 254 | 3 -> 2 - > 1 -> 1 -> 2 - > 3 -> 3 -> 2 - > 1 255 | @nroot: 256 | - None for truly balanced weights 257 | - >= 2 for pseudo-balanced weights 258 | @rescale: whether to rebalance the manual weights and priority weights 259 | every epoch 260 | @root_decay: 261 | - 'exp': for exponential decay 262 | - 'linear': for linear decay 263 | """ 264 | def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0, 265 | lam=None, epochs=90, cycle=0, nroot=None, manual_only=False, 266 | rescale=False, root_decay=None, decay_gap=30, ptype='score', 267 | pri_mode='train', momentum=0., alpha=1.0): 268 | """ 269 | """ 270 | self.dataset = dataset 271 | self.balance_scale = balance_scale 272 | self.fixed_scale = fixed_scale 273 | self.epochs = epochs 274 | self.lam = lam 275 | self.cycle = cycle 276 | self.nroot = nroot 277 | self.rescale = rescale 278 | self.manual_only = manual_only 279 | self.root_decay = root_decay 280 | self.decay_gap = decay_gap 281 | self.ptype = ptype 282 | self.pri_mode = pri_mode 283 | self.num_samples = len(dataset) 284 | self.manual_as_backend = False 285 | self.momentum = momentum 286 | self.alpha = alpha 287 | 288 | assert 0. <= self.momentum <= 1.0 289 | assert 0. <= self.alpha 290 | 291 | # Change the backend distribution of priority if needed 292 | if self.fixed_scale < 0: 293 | self.fixed_scale = 0 294 | self.manual_as_backend = True 295 | 296 | # If using root_decay, reset relevent parameters 297 | if self.root_decay in ['exp', 'linear', 'autoexp']: 298 | self.lam = 1 299 | self.manual_only = True 300 | self.nroot = 1 301 | if self.root_decay == 'autoexp': 302 | self.decay_gap = 1 303 | self.decay_factor = np.power(nroot, 1/(self.epochs-1)) 304 | else: 305 | assert self.root_decay is None 306 | assert self.nroot is None or self.nroot > 1 307 | print("====> Decay GAP: {}".format(self.decay_gap)) 308 | 309 | # Take care of lambdas 310 | self.freeze = True 311 | if self.lam is None: 312 | self.freeze = False 313 | if cycle == 0: 314 | self.lams = np.linspace(0, 1, epochs) 315 | elif cycle == 1: 316 | self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3) 317 | elif cycle == 2: 318 | self.lams = np.concatenate([np.linspace(0,1,epochs//3), 319 | np.linspace(0,1,epochs//3)[::-1], 320 | np.linspace(0,1,epochs//3)]) 321 | else: 322 | raise NotImplementedError( 323 | 'cycle = {} not implemented'.format(cycle)) 324 | else: 325 | self.lams = [self.lam] 326 | 327 | # Get num of samples per class 328 | self.cls_cnts = [] 329 | self.labels = labels = np.array(self.dataset.labels) 330 | for l in np.unique(labels): 331 | self.cls_cnts.append(np.sum(labels==l)) 332 | self.num_classes = len(self.cls_cnts) 333 | self.cnts = np.array(self.cls_cnts).astype(float) 334 | 335 | # Get per-class image indexes 336 | self.cls_idxs = [[] for _ in range(self.num_classes)] 337 | for i, label in enumerate(self.dataset.labels): 338 | self.cls_idxs[label].append(i) 339 | self.data_iter_list = [RandomCycleIter(x) for x in self.cls_idxs] 340 | for ci in range(self.num_classes): 341 | self.cls_idxs[ci] = np.array(self.cls_idxs[ci]) 342 | 343 | # Build balanced weights based on class counts 344 | self.balanced_weights = self.get_balanced_weights(self.nroot) 345 | self.uniform_weights = self.get_uniform_weights() 346 | self.manual_weights = self.get_manual_weights(self.lams[0]) 347 | 348 | # back_weights = self.get_balanced_weights(1.5) 349 | back_weights = self.uniform_weights 350 | 351 | # Calculate priority ratios that reshape priority into target distribution 352 | self.per_cls_ratios = self.get_cls_ratios( 353 | self.manual_weights if self.manual_as_backend else back_weights) 354 | self.per_example_ratios = self.broadcast(self.per_cls_ratios) 355 | 356 | # Setup priority tree 357 | if self.ptype == 'score': 358 | self.init_weight = 1. 359 | elif self.ptype in ['CE', 'entropy']: 360 | self.init_weight = 6.9 361 | else: 362 | raise NotImplementedError('ptype {} not implemented'.format(self.ptype)) 363 | if self.manual_only: 364 | self.init_weight = 0. 365 | self.per_example_uni_weights = np.ones(self.num_samples) * self.init_weight 366 | self.per_example_velocities = np.zeros(self.num_samples) 367 | # init_priorities = np.power(self.init_weight, self.alpha) \ 368 | # * self.uniform_weights * self.per_cls_ratios 369 | init_priorities = self.init_weight * self.uniform_weights * self.per_cls_ratios 370 | self.ptree = PriorityTree(self.num_classes, init_priorities, 371 | self.manual_weights.copy(), fixed_scale=self.fixed_scale, 372 | alpha=self.alpha) 373 | 374 | def get_cls_ratios(self, tgt_weights): 375 | if tgt_weights is self.uniform_weights: 376 | return np.ones_like(self.uniform_weights) 377 | per_cls_ratios = tgt_weights / self.uniform_weights 378 | per_cls_ratios *= self.uniform_weights.sum() / tgt_weights.sum() 379 | return per_cls_ratios 380 | 381 | def get_cls_weights(self): 382 | ratioed_ws = self.per_example_uni_weights * self.per_example_ratios 383 | return self.debroadcast_sum(ratioed_ws) 384 | 385 | def broadcast(self, per_cls_info): 386 | per_exmaple_info = np.zeros(self.num_samples) 387 | # Braodcast per-cls info to each example 388 | for ci in range(self.num_classes): 389 | per_exmaple_info[self.cls_idxs[ci]] = per_cls_info[ci] 390 | return per_exmaple_info 391 | 392 | def debroadcast_sum(self, per_example_info): 393 | per_cls_info = np.zeros(self.num_classes) 394 | # DeBraodcast per-example info to each cls by summation 395 | for ci in range(self.num_classes): 396 | per_cls_info[ci] = per_example_info[self.cls_idxs[ci]].sum() 397 | return per_cls_info 398 | 399 | def get_manual_weights(self, lam): 400 | # Merge balanced weights and uniform weights 401 | if lam == 1: 402 | manual_weights = self.balanced_weights.copy() 403 | elif lam == 0: 404 | manual_weights = self.uniform_weights.copy() 405 | else: 406 | manual_weights = self.balanced_weights * lam + (1-lam) * self.uniform_weights 407 | return manual_weights 408 | 409 | def get_uniform_weights(self): 410 | return self.cnts.copy() 411 | 412 | def get_balanced_weights(self, nroot): 413 | """ Calculate normalized generalized balanced weights """ 414 | 415 | cnts = self.cnts 416 | if nroot is None: 417 | # Real balanced sampling weights, each class has the same weights 418 | # Un-normalized !!! 419 | cls_ws = np.ones(len(cnts)) 420 | elif nroot >= 1: 421 | # Generalized balanced weights 422 | # Un-normalized !!! 423 | cls_ws = cnts / cnts.sum() 424 | cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum() 425 | cls_ws = cls_ws 426 | else: 427 | raise NotImplementedError('root:{} not implemented'.format(nroot)) 428 | 429 | # Get un-normalized weights 430 | balanced_weights = cls_ws 431 | 432 | # Normalization and rescale 433 | balanced_weights *= self.num_samples / balanced_weights.sum() * \ 434 | self.balance_scale 435 | return balanced_weights 436 | 437 | def __iter__(self): 438 | for _ in range(self.num_samples): 439 | w = random.random() * self.ptree.total 440 | ci, pri = self.ptree.get_leaf(w) 441 | yield next(self.data_iter_list[ci]) 442 | 443 | def __len__(self): 444 | return self.num_samples 445 | 446 | def reset_weights(self, epoch): 447 | # If it is linear shifting 448 | if not self.freeze: 449 | e = np.clip(epoch, 0, self.epochs-1) 450 | self.manual_weights = self.get_manual_weights(self.lams[e]) 451 | # make sure 'self.fixed_scale > 0' and 'self.manual_as_backend = True' are 452 | # mutually exclusive 453 | if self.fixed_scale > 0: 454 | self.ptree.reset_fixed_weights(self.manual_weights, self.rescale) 455 | if self.manual_as_backend: 456 | self.update_backend_distribution(self.manual_weights) 457 | 458 | # If it is root decay 459 | if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0: 460 | if self.root_decay == 'exp': 461 | self.nroot *= 2 462 | elif self.root_decay == 'linear': 463 | self.nroot += 1 464 | elif self.root_decay == 'autoexp': 465 | # self.nroot *= self.decay_factor 466 | self.nroot = np.power(self.decay_factor, epoch) 467 | 468 | bw = self.get_balanced_weights(self.nroot) 469 | if self.manual_as_backend: 470 | self.update_backend_distribution(bw) 471 | else: 472 | self.ptree.reset_fixed_weights(bw) 473 | 474 | def update_backend_distribution(self, tgt_weights): 475 | # Recalculate the cls ratios based on the given target distribution 476 | self.per_cls_ratios = self.get_cls_ratios(tgt_weights) 477 | self.per_example_ratios = self.broadcast(self.per_cls_ratios) 478 | 479 | # Recalculate the new per-class weights based on the new ratios 480 | # new_backend_weights = self.init_weight * self.uniform_weights * self.per_cls_ratios 481 | new_cls_weights = self.get_cls_weights() 482 | self.ptree.reset_adaptive_weights(new_cls_weights) 483 | 484 | def update_weights(self, inds, weights, labels): 485 | """ Update priority weights """ 486 | if not self.manual_only and self.pri_mode == 'train': 487 | weights = np.clip(weights, 0, self.init_weight) 488 | 489 | # Iterate over all classes in the batch 490 | for l in np.unique(labels): 491 | # Calculate per-class delta weights 492 | example_inds = inds[labels==l] 493 | last_weights = self.per_example_uni_weights[example_inds] 494 | # delta = np.power(weights[labels==l], self.alpha) - \ 495 | # np.power(last_weights, self.alpha) 496 | delta = weights[labels==l] - last_weights 497 | delta = self.momentum * self.per_example_velocities[example_inds] + \ 498 | (1-self.momentum) * delta 499 | 500 | # Update velocities 501 | self.per_example_velocities[example_inds] = delta 502 | # Update per-example weights 503 | # self.per_example_uni_weights[example_inds] = weights[labels==l] 504 | self.per_example_uni_weights[example_inds] += delta 505 | 506 | # Sacle the delta 507 | # (ie, the per-example weights both before and after update) 508 | delta *= self.per_example_ratios[example_inds] 509 | 510 | # Update tree 511 | if self.alpha == 1: 512 | self.ptree.update_delta(l, delta.sum()) 513 | else: 514 | self.ptree.update(l, self.per_example_uni_weights[self.cls_idxs[l]].sum()) 515 | 516 | 517 | def reset_priority(self, weights, labels): 518 | if self.pri_mode == 'valid': 519 | assert len(np.unique(labels)) == self.num_classes 520 | weights = np.clip(weights, 0, self.init_weight) 521 | cls_weights = np.zeros(self.num_classes) 522 | for c in np.unique(labels): 523 | cls_weights[c] = weights[labels==c].mean() 524 | cls_weights *= self.cnts 525 | cls_weights *= self.per_cls_ratios 526 | self.ptree.reset_adaptive_weights(cls_weights) 527 | 528 | def get_weights(self): 529 | return self.ptree.get_weights() 530 | 531 | 532 | def get_sampler(): 533 | return ClassPrioritySampler 534 | -------------------------------------------------------------------------------- /data/ImbalanceCIFAR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from https://github.com/Megvii-Nanjing/BBN 3 | Customized by Kaihua Tang 4 | """ 5 | 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | from PIL import Image 10 | import random 11 | 12 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 13 | cls_num = 10 14 | 15 | def __init__(self, phase, imbalance_ratio, root = '/gruntdata5/kaihua/datasets', imb_type='exp', 16 | test_imb_ratio=None, reverse=False): 17 | train = True if phase == "train" else False 18 | super(IMBALANCECIFAR10, self).__init__(root, train, transform=None, target_transform=None, download=True) 19 | self.train = train 20 | if self.train: 21 | self.img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio, 22 | reverse=reverse) 23 | self.gen_imbalanced_data(self.img_num_list) 24 | self.transform = transforms.Compose([ 25 | transforms.RandomCrop(32, padding=4), 26 | transforms.RandomHorizontalFlip(), 27 | #transforms.Resize(224), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 30 | ]) 31 | else: 32 | if test_imb_ratio: 33 | # if test imbalance ratio is explicitly given, test dataset should be imbalanced. 34 | self.img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, test_imb_ratio, 35 | reverse=reverse) 36 | self.gen_imbalanced_data(self.img_num_list) 37 | else: 38 | self.img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor=1.0, 39 | reverse=reverse) 40 | 41 | self.transform = transforms.Compose([ 42 | #transforms.Resize(224), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 | ]) 46 | 47 | self.labels = self.targets 48 | 49 | print("{} Mode: Contain {} images".format(phase, len(self.data))) 50 | 51 | def _get_class_dict(self): 52 | class_dict = dict() 53 | for i, anno in enumerate(self.get_annotations()): 54 | cat_id = anno["category_id"] 55 | if not cat_id in class_dict: 56 | class_dict[cat_id] = [] 57 | class_dict[cat_id].append(i) 58 | return class_dict 59 | 60 | 61 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor, reverse=False): 62 | img_max = len(self.data) / cls_num 63 | img_num_per_cls = [] 64 | if imb_type == 'exp': 65 | for cls_idx in range(cls_num): 66 | if reverse: 67 | num = img_max * (imb_factor**((cls_num - 1 - cls_idx) / (cls_num - 1.0))) 68 | else: 69 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 70 | img_num_per_cls.append(int(num)) 71 | elif imb_type == 'step': 72 | for cls_idx in range(cls_num // 2): 73 | img_num_per_cls.append(int(img_max)) 74 | for cls_idx in range(cls_num // 2): 75 | img_num_per_cls.append(int(img_max * imb_factor)) 76 | else: 77 | img_num_per_cls.extend([int(img_max)] * cls_num) 78 | return img_num_per_cls 79 | 80 | def gen_imbalanced_data(self, img_num_per_cls): 81 | new_data = [] 82 | new_targets = [] 83 | targets_np = np.array(self.targets, dtype=np.int64) 84 | classes = np.unique(targets_np) 85 | 86 | self.num_per_cls_dict = dict() 87 | for the_class, the_img_num in zip(classes, img_num_per_cls): 88 | self.num_per_cls_dict[the_class] = the_img_num 89 | idx = np.where(targets_np == the_class)[0] 90 | np.random.shuffle(idx) 91 | selec_idx = idx[:the_img_num] 92 | new_data.append(self.data[selec_idx, ...]) 93 | new_targets.extend([the_class, ] * the_img_num) 94 | new_data = np.vstack(new_data) 95 | self.data = new_data 96 | self.targets = new_targets 97 | 98 | def __getitem__(self, index): 99 | img, label = self.data[index], self.labels[index] 100 | 101 | # doing this so that it is consistent with all other datasets 102 | # to return a PIL Image 103 | img = Image.fromarray(img) 104 | 105 | if self.transform is not None: 106 | img = self.transform(img) 107 | 108 | if self.target_transform is not None: 109 | label = self.target_transform(label) 110 | 111 | return img, label, index 112 | 113 | def __len__(self): 114 | return len(self.labels) 115 | 116 | def get_num_classes(self): 117 | return self.cls_num 118 | 119 | def get_annotations(self): 120 | annos = [] 121 | for label in self.labels: 122 | annos.append({'category_id': int(label)}) 123 | return annos 124 | 125 | def get_cls_num_list(self): 126 | cls_num_list = [] 127 | for i in range(self.cls_num): 128 | cls_num_list.append(self.num_per_cls_dict[i]) 129 | return cls_num_list 130 | 131 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 132 | """`CIFAR100 `_ Dataset. 133 | This is a subclass of the `CIFAR10` Dataset. 134 | """ 135 | cls_num = 100 136 | base_folder = 'cifar-100-python' 137 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 138 | filename = "cifar-100-python.tar.gz" 139 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 140 | train_list = [ 141 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 142 | ] 143 | 144 | test_list = [ 145 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 146 | ] 147 | meta = { 148 | 'filename': 'meta', 149 | 'key': 'fine_label_names', 150 | 'md5': '7973b15100ade9c7d40fb424638fde48', 151 | } 152 | -------------------------------------------------------------------------------- /data/MixedPrioritizedSampler.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | import random 10 | import numpy as np 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | class PriorityTree(object): 15 | def __init__(self, capacity, fixed_weights=None, fixed_scale=1.0, 16 | init_weight=1.0): 17 | """ 18 | fixed_weights: weights that wont be updated by self.update() 19 | """ 20 | assert fixed_weights is None or len(fixed_weights) == capacity 21 | self._capacity = capacity 22 | self._tree_size = 2 * capacity - 1 23 | self.fixed_scale = fixed_scale 24 | self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \ 25 | else fixed_weights 26 | self.tree = np.zeros(self._tree_size) 27 | self._initialized = False 28 | self.initialize(init_weight) 29 | 30 | def initialize(self, init_weight): 31 | """Initialize the tree.""" 32 | 33 | # Rescale the fixed_weights if it is not zero 34 | if self.fixed_weights.sum() > 0 and init_weight > 0: 35 | self.fixed_weights *= self.fixed_scale * init_weight * self.capacity \ 36 | / self.fixed_weights.sum() 37 | print('FixedWeights: {}'.format(self.fixed_weights.sum())) 38 | 39 | self.update_whole(init_weight + self.fixed_weights) 40 | self._initialized = True 41 | 42 | def reset_fixed_weights(self, fixed_weights, rescale=False): 43 | """ Reset the manually designed weights and 44 | update the whole tree accordingly. 45 | 46 | @rescale: rescale the fixed_weights such that 47 | fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum() 48 | """ 49 | 50 | adaptive_weights = self.get_adaptive_weights() 51 | fixed_sum = fixed_weights.sum() 52 | if rescale and fixed_sum > 0: 53 | scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum 54 | self.fixed_weights = fixed_weights * scale 55 | else: 56 | self.fixed_weights = fixed_weights 57 | self.update_whole(self.fixed_weights + adaptive_weights) 58 | 59 | def update_whole(self, total_weights): 60 | """ Update the whole tree based on per-example sampling weights """ 61 | lefti = self.pointer_to_treeidx(0) 62 | righti = self.pointer_to_treeidx(self.capacity-1) 63 | self.tree[lefti:righti+1] = total_weights 64 | 65 | # Iteratively find a parent layer 66 | while lefti != 0 and righti != 0: 67 | lefti = (lefti - 1) // 2 if lefti != 0 else 0 68 | righti = (righti - 1) // 2 if righti != 0 else 0 69 | 70 | # Assign paraent weights from right to left 71 | for i in range(righti, lefti-1, -1): 72 | self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2] 73 | 74 | def get_adaptive_weights(self): 75 | """ Get the instance-aware weights, that are not mannually designed""" 76 | return self.get_total_weights() - self.fixed_weights 77 | 78 | def get_total_weights(self): 79 | """ Get the per-example sampling weights 80 | return shape: [capacity] 81 | """ 82 | lefti = self.pointer_to_treeidx(0) 83 | righti = self.pointer_to_treeidx(self.capacity-1) 84 | return self.tree[lefti:righti+1] 85 | 86 | @property 87 | def size(self): 88 | return self._tree_size 89 | 90 | @property 91 | def capacity(self): 92 | return self._capacity 93 | 94 | def __len__(self): 95 | return self.capacity 96 | 97 | def pointer_to_treeidx(self, pointer): 98 | assert pointer < self.capacity 99 | return int(pointer + self.capacity - 1) 100 | 101 | def update(self, pointer, priority): 102 | assert pointer < self.capacity 103 | tree_idx = self.pointer_to_treeidx(pointer) 104 | priority += self.fixed_weights[pointer] 105 | delta = priority - self.tree[tree_idx] 106 | self.tree[tree_idx] = priority 107 | while tree_idx != 0: 108 | tree_idx = (tree_idx - 1) // 2 109 | self.tree[tree_idx] += delta 110 | 111 | def get_leaf(self, value): 112 | assert self._initialized, 'PriorityTree not initialized!!!!' 113 | assert self.total > 0, 'No priority weights setted!!' 114 | parent = 0 115 | while True: 116 | left_child = 2 * parent + 1 117 | right_child = 2 * parent + 2 118 | if left_child >= len(self.tree): 119 | tgt_leaf = parent 120 | break 121 | if value < self.tree[left_child]: 122 | parent = left_child 123 | else: 124 | value -= self.tree[left_child] 125 | parent = right_child 126 | data_idx = tgt_leaf - self.capacity + 1 127 | return data_idx, self.tree[tgt_leaf] # data idx, priority 128 | 129 | @property 130 | def total(self): 131 | assert self._initialized, 'PriorityTree not initialized!!!!' 132 | return self.tree[0] 133 | 134 | @property 135 | def max(self): 136 | return np.max(self.tree[-self.capacity:]) 137 | 138 | @property 139 | def min(self): 140 | assert self._initialized, 'PriorityTree not initialized!!!!' 141 | return np.min(self.tree[-self.capacity:]) 142 | 143 | def get_weights(self): 144 | return {'fixed_weights': self.fixed_weights, 145 | 'total_weights': self.get_total_weights()} 146 | 147 | 148 | class MixedPrioritizedSampler(Sampler): 149 | """ 150 | A sampler combining manually designed sampling strategy and prioritized 151 | sampling strategy. 152 | 153 | Manually disigned strategy contains two parts: 154 | 155 | $$ manual_weights = lam * balanced_weights + (1-lam) uniform_weights 156 | 157 | Here we use a generalized version of balanced weights as follows, 158 | when n limits to infinity, balanced_weights = real_balanced_weights 159 | 160 | $$ balanced_weights = uniform_weights ^ (1/n) 161 | 162 | Then the balanced weights are scaled such that 163 | 164 | $$ balanced_weights.sum() = balance_scale * uniform_weights.sum() 165 | 166 | Note: above weights are per-class weights 167 | 168 | Overall sampling weights are given as 169 | $$ sampling_weights = manual_weights * fixed_scale + priority_weights 170 | 171 | Arguments: 172 | @dataset: A dataset 173 | @balance_scale: The scale of balanced_weights 174 | @lam: A weight to combine balanced weights and uniform weights 175 | - None for shifting sampling 176 | - 0 for uniform sampling 177 | - 1 for balanced sampling 178 | @fixed_scale: The scale of manually designed weights 179 | @cycle: shifting strategy 180 | - 0 for linear shifting: 3 -> 2 - > 1 181 | - 1 for periodic shifting: 182 | 3 -> 2 - > 1 -> 3 -> 2 - > 1 -> 3 -> 2 - > 1 183 | - 2 for cosine-like periodic shifting: 184 | 3 -> 2 - > 1 -> 1 -> 2 - > 3 -> 3 -> 2 - > 1 185 | @nroot: 186 | - None for truly balanced weights 187 | - >= 2 for pseudo-balanced weights 188 | @rescale: whether to rebalance the manual weights and priority weights 189 | every epoch 190 | @root_decay: 191 | - 'exp': for exponential decay 192 | - 'linear': for linear decay 193 | """ 194 | def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0, 195 | lam=None, epochs=90, cycle=0, nroot=None, manual_only=False, 196 | rescale=False, root_decay=None, decay_gap=30, ptype='score', 197 | alpha=1.0): 198 | """ 199 | """ 200 | self.dataset = dataset 201 | self.balance_scale = balance_scale 202 | self.fixed_scale = fixed_scale 203 | self.epochs = epochs 204 | self.lam = lam 205 | self.cycle = cycle 206 | self.nroot = nroot 207 | self.rescale = rescale 208 | self.manual_only = manual_only 209 | self.root_decay = root_decay 210 | self.decay_gap = decay_gap 211 | self.ptype = ptype 212 | self.num_samples = len(dataset) 213 | self.alpha = alpha 214 | 215 | # If using root_decay, reset relevent parameters 216 | if self.root_decay in ['exp', 'linear', 'autoexp']: 217 | self.lam = 1 218 | self.manual_only = True 219 | self.nroot = 1 220 | if self.root_decay == 'autoexp': 221 | self.decay_gap = 1 222 | self.decay_factor = np.power(nroot, 1/(self.epochs-1)) 223 | else: 224 | assert self.root_decay is None 225 | assert self.nroot is None or self.nroot >= 2 226 | print("====> Decay GAP: {}".format(self.decay_gap)) 227 | 228 | # Take care of lambdas 229 | if self.lam is None: 230 | self.freeze = False 231 | if cycle == 0: 232 | self.lams = np.linspace(0, 1, epochs) 233 | elif cycle == 1: 234 | self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3) 235 | elif cycle == 2: 236 | self.lams = np.concatenate([np.linspace(0,1,epochs//3), 237 | np.linspace(0,1,epochs//3)[::-1], 238 | np.linspace(0,1,epochs//3)]) 239 | else: 240 | raise NotImplementedError( 241 | 'cycle = {} not implemented'.format(cycle)) 242 | else: 243 | self.lams = [self.lam] 244 | self.freeze = True 245 | 246 | # Get num of samples per class 247 | self.cls_cnts = [] 248 | self.labels = labels = np.array(self.dataset.labels) 249 | for l in np.unique(labels): 250 | self.cls_cnts.append(np.sum(labels==l)) 251 | self.num_classes = len(self.cls_cnts) 252 | self.cnts = np.array(self.cls_cnts).astype(float) 253 | 254 | # Get per-class image indexes 255 | self.cls_idxs = [[] for _ in range(self.num_classes)] 256 | for i, label in enumerate(self.dataset.labels): 257 | self.cls_idxs[label].append(i) 258 | for ci in range(self.num_classes): 259 | self.cls_idxs[ci] = np.array(self.cls_idxs[ci]) 260 | 261 | # Build balanced weights based on class counts 262 | self.balanced_weights = self.get_balanced_weights(self.nroot) 263 | self.manual_weights = self.get_manual_weights(self.lams[0]) 264 | 265 | # Setup priority tree 266 | if self.ptype == 'score': 267 | self.init_weight = 1. 268 | elif self.ptype in ['CE', 'entropy']: 269 | self.init_weight = 6.9 270 | else: 271 | raise NotImplementedError('ptype {} not implemented'.format(self.ptype)) 272 | if self.manual_only: 273 | self.init_weight = 0. 274 | self.init_weight = np.power(self.init_weight, self.alpha) 275 | self.ptree = PriorityTree(self.num_samples, self.manual_weights, 276 | fixed_scale=self.fixed_scale, 277 | init_weight=self.init_weight) 278 | 279 | def get_manual_weights(self, lam): 280 | # Merge balanced weights and uniform weights 281 | if lam == 1: 282 | manual_weights = self.balanced_weights 283 | elif lam == 0: 284 | manual_weights = np.ones(len(self.balanced_weights)) 285 | else: 286 | manual_weights = self.balanced_weights * lam + (1-lam) 287 | return manual_weights 288 | 289 | def get_balanced_weights(self, nroot): 290 | """ Calculate normalized generalized balanced weights """ 291 | 292 | cnts = self.cnts 293 | if nroot is None: 294 | # Real balanced sampling weights 295 | cls_ws = cnts.min() / cnts 296 | elif nroot >= 1: 297 | # Generalized balanced weights 298 | cls_ws = cnts / cnts.sum() 299 | cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum() 300 | cls_ws = cls_ws / cnts 301 | else: 302 | raise NotImplementedError('root:{} not implemented'.format(nroot)) 303 | 304 | # Get un-normalized weights 305 | balanced_weights = np.zeros(self.num_samples) 306 | for ci in range(self.num_classes): 307 | balanced_weights[self.cls_idxs[ci]] = cls_ws[ci] 308 | 309 | # Normalization and rescale 310 | balanced_weights *= self.num_samples / balanced_weights.sum() * \ 311 | self.balance_scale 312 | return balanced_weights 313 | 314 | def __iter__(self): 315 | for _ in range(self.num_samples): 316 | w = random.random() * self.ptree.total 317 | i, pri = self.ptree.get_leaf(w) 318 | yield i 319 | 320 | def __len__(self): 321 | return self.num_samples 322 | 323 | def reset_weights(self, epoch): 324 | if not self.freeze and self.fixed_scale > 0: 325 | if epoch >= self.epochs: 326 | e = self.epochs - 1 327 | elif epoch < 1: 328 | e = 0 329 | else: 330 | e = epoch 331 | self.manual_weights = self.get_manual_weights(self.lams[e]) 332 | self.ptree.reset_fixed_weights(self.manual_weights, self.rescale) 333 | 334 | if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0: 335 | if self.root_decay == 'exp': 336 | self.nroot *= 2 337 | elif self.root_decay == 'linear': 338 | self.nroot += 1 339 | elif self.root_decay == 'autoexp': 340 | # self.nroot *= self.decay_factor 341 | self.nroot = np.power(self.decay_factor, epoch) 342 | 343 | bw = self.get_balanced_weights(self.nroot) 344 | self.ptree.reset_fixed_weights(bw) 345 | 346 | def update_weights(self, inds, weights): 347 | """ Update priority weights """ 348 | if not self.manual_only: 349 | weights = np.clip(weights, 0, self.init_weight) 350 | weights = np.power(weights, self.alpha) 351 | for i, w in zip(inds, weights): 352 | self.ptree.update(i, w) 353 | 354 | def get_weights(self): 355 | return self.ptree.get_weights() 356 | 357 | 358 | def get_sampler(): 359 | return MixedPrioritizedSampler 360 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | All rights reserved. 6 | 7 | This source code is licensed under the license found in the 8 | LICENSE file in the root directory of this source tree. 9 | 10 | Portions of the source code are from the OLTR project which 11 | notice below and in LICENSE in the root directory of 12 | this source tree. 13 | 14 | Copyright (c) 2019, Zhongqi Miao 15 | All rights reserved. 16 | """ 17 | 18 | from collections import Counter 19 | import torch 20 | import numpy as np 21 | import torchvision 22 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 23 | from torchvision import transforms 24 | import os 25 | from PIL import Image 26 | from data.ImbalanceCIFAR import IMBALANCECIFAR10, IMBALANCECIFAR100 27 | 28 | # Image statistics 29 | RGB_statistics = { 30 | 'iNaturalist18': { 31 | 'mean': [0.466, 0.471, 0.380], 32 | 'std': [0.195, 0.194, 0.192] 33 | }, 34 | 'default': { 35 | 'mean': [0.485, 0.456, 0.406], 36 | 'std':[0.229, 0.224, 0.225] 37 | } 38 | } 39 | 40 | # Data transformation with augmentation 41 | def get_data_transform(split, rgb_mean, rbg_std, key='default'): 42 | data_transforms = { 43 | 'train': transforms.Compose([ 44 | transforms.RandomResizedCrop(224), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize(rgb_mean, rbg_std) 48 | ]) if key == 'iNaturalist18' else transforms.Compose([ 49 | transforms.RandomResizedCrop(224), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 52 | transforms.ToTensor(), 53 | transforms.Normalize(rgb_mean, rbg_std) 54 | ]), 55 | 'val': transforms.Compose([ 56 | transforms.Resize(256), 57 | transforms.CenterCrop(224), 58 | transforms.ToTensor(), 59 | transforms.Normalize(rgb_mean, rbg_std) 60 | ]), 61 | 'test': transforms.Compose([ 62 | transforms.Resize(256), 63 | transforms.CenterCrop(224), 64 | transforms.ToTensor(), 65 | transforms.Normalize(rgb_mean, rbg_std) 66 | ]) 67 | } 68 | return data_transforms[split] 69 | 70 | # Dataset 71 | class LT_Dataset(Dataset): 72 | 73 | def __init__(self, root, txt, transform=None, template=None, top_k=None): 74 | self.img_path = [] 75 | self.labels = [] 76 | self.transform = transform 77 | with open(txt) as f: 78 | for line in f: 79 | self.img_path.append(os.path.join(root, line.split()[0])) 80 | self.labels.append(int(line.split()[1])) 81 | 82 | # get image number list 83 | occur_dict = dict(Counter(self.labels)) 84 | self.img_num_list = [occur_dict[i] for i in sorted(occur_dict.keys())] 85 | 86 | # select top k class 87 | if top_k: 88 | # only select top k in training, in case train/val/test not matching. 89 | if 'train' in txt: 90 | max_len = max(self.labels) + 1 91 | dist = [[i, 0] for i in range(max_len)] 92 | for i in self.labels: 93 | dist[i][-1] += 1 94 | dist.sort(key = lambda x:x[1], reverse=True) 95 | # saving 96 | torch.save(dist, template + '_top_{}_mapping'.format(top_k)) 97 | else: 98 | # loading 99 | dist = torch.load(template + '_top_{}_mapping'.format(top_k)) 100 | selected_labels = {item[0]:i for i, item in enumerate(dist[:top_k])} 101 | # replace original path and labels 102 | self.new_img_path = [] 103 | self.new_labels = [] 104 | for path, label in zip(self.img_path, self.labels): 105 | if label in selected_labels: 106 | self.new_img_path.append(path) 107 | self.new_labels.append(selected_labels[label]) 108 | self.img_path = self.new_img_path 109 | self.labels = self.new_labels 110 | 111 | def __len__(self): 112 | return len(self.labels) 113 | 114 | def __getitem__(self, index): 115 | 116 | path = self.img_path[index] 117 | label = self.labels[index] 118 | 119 | with open(path, 'rb') as f: 120 | sample = Image.open(f).convert('RGB') 121 | 122 | if self.transform is not None: 123 | sample = self.transform(sample) 124 | 125 | return sample, label, index 126 | 127 | # Load datasets 128 | def load_data(data_root, dataset, phase, batch_size, top_k_class=None, 129 | sampler_dic=None, num_workers=4, shuffle=True, cifar_imb_ratio=None, 130 | test_imb_ratio=None, reverse=False): 131 | 132 | txt_split = phase 133 | if dataset == "Places_LT": 134 | txt = f"./data/Places_LT_v2/Places_LT_{phase}.txt" 135 | template = None 136 | else: 137 | txt = './data/%s/%s_%s.txt'%(dataset, dataset, txt_split) 138 | template = './data/%s/%s'%(dataset, dataset) 139 | 140 | print('Loading data from %s' % (txt)) 141 | 142 | if dataset == 'iNaturalist18': 143 | print('===> Loading iNaturalist18 statistics') 144 | key = 'iNaturalist18' 145 | else: 146 | key = 'default' 147 | 148 | if dataset == 'CIFAR10_LT': 149 | print('====> CIFAR10 Imbalance Ratio: ', cifar_imb_ratio) 150 | set_ = IMBALANCECIFAR10(phase, imbalance_ratio=cifar_imb_ratio, root=data_root, 151 | test_imb_ratio=test_imb_ratio, reverse=reverse) 152 | elif dataset == 'CIFAR100_LT': 153 | print('====> CIFAR100 Imbalance Ratio: ', cifar_imb_ratio) 154 | set_ = IMBALANCECIFAR100(phase, imbalance_ratio=cifar_imb_ratio, root=data_root, 155 | test_imb_ratio=test_imb_ratio, reverse=reverse) 156 | else: 157 | rgb_mean, rgb_std = RGB_statistics[key]['mean'], RGB_statistics[key]['std'] 158 | if phase not in ['train', 'val']: 159 | transform = get_data_transform('test', rgb_mean, rgb_std, key) 160 | else: 161 | transform = get_data_transform(phase, rgb_mean, rgb_std, key) 162 | print('Use data transformation:', transform) 163 | 164 | set_ = LT_Dataset(data_root, txt, transform, template=template, top_k=top_k_class) 165 | 166 | 167 | print(len(set_)) 168 | 169 | if sampler_dic and phase == 'train': 170 | print('=====> Using sampler: ', sampler_dic['sampler']) 171 | # print('Sample %s samples per-class.' % sampler_dic['num_samples_cls']) 172 | print('=====> Sampler parameters: ', sampler_dic['params']) 173 | return torch.FloatTensor(set_.img_num_list) / torch.FloatTensor(set_.img_num_list).sum(), \ 174 | DataLoader(dataset=set_, batch_size=batch_size, shuffle=False, 175 | sampler=sampler_dic['sampler'](set_, **sampler_dic['params']), 176 | num_workers=num_workers) 177 | else: 178 | print('=====> No sampler.') 179 | print('=====> Shuffle is %s.' % (shuffle)) 180 | return torch.FloatTensor(set_.img_num_list) / torch.FloatTensor(set_.img_num_list).sum(), \ 181 | DataLoader(dataset=set_, batch_size=batch_size, 182 | shuffle=shuffle, num_workers=num_workers) 183 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | import yaml 10 | import csv 11 | import h5py 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, logdir): 16 | self.logdir = logdir 17 | if not os.path.isdir(logdir): 18 | os.makedirs(logdir) 19 | self.cfg_file = os.path.join(self.logdir, 'cfg.yaml') 20 | self.acc_file = os.path.join(self.logdir, 'acc.csv') 21 | self.loss_file = os.path.join(self.logdir, 'loss.csv') 22 | self.ws_file = os.path.join(self.logdir, 'ws.h5') 23 | self.acc_keys = None 24 | self.loss_keys = None 25 | self.logging_ws = False 26 | 27 | def log_cfg(self, cfg): 28 | print('===> Saving cfg parameters to: ', self.cfg_file) 29 | with open(self.cfg_file, 'w') as f: 30 | yaml.dump(cfg, f) 31 | 32 | def log_acc(self, accs): 33 | if self.acc_keys is None: 34 | self.acc_keys = [k for k in accs.keys()] 35 | with open(self.acc_file, 'w') as f: 36 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 37 | writer.writeheader() 38 | writer.writerow(accs) 39 | else: 40 | with open(self.acc_file, 'a') as f: 41 | writer = csv.DictWriter(f, fieldnames=self.acc_keys) 42 | writer.writerow(accs) 43 | 44 | def log_loss(self, losses): 45 | # valid_losses = {k: v for k, v in losses.items() if v is not None} 46 | valid_losses = losses 47 | if self.loss_keys is None: 48 | self.loss_keys = [k for k in valid_losses.keys()] 49 | with open(self.loss_file, 'w') as f: 50 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 51 | writer.writeheader() 52 | writer.writerow(valid_losses) 53 | else: 54 | with open(self.loss_file, 'a') as f: 55 | writer = csv.DictWriter(f, fieldnames=self.loss_keys) 56 | writer.writerow(valid_losses) 57 | 58 | def log_ws(self, e, ws): 59 | mode = 'a' if self.logging_ws else 'w' 60 | self.logging_ws = True 61 | 62 | key = 'Epoch{:02d}'.format(e) 63 | with h5py.File(self.ws_file, mode) as f: 64 | g = f.create_group(key) 65 | for k, v in ws.items(): 66 | g.create_dataset(k, data=v) 67 | -------------------------------------------------------------------------------- /loss/FocalLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): 6 | def __init__(self, gamma=0, alpha=None, size_average=True): 7 | super(FocalLoss, self).__init__() 8 | self.gamma = gamma 9 | self.alpha = alpha 10 | self.size_average = size_average 11 | 12 | def forward(self, input, target): 13 | if input.dim()>2: 14 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 15 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 16 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 17 | target = target.view(-1,1) 18 | 19 | logpt = F.log_softmax(input, dim=-1) 20 | logpt = logpt.gather(1,target) 21 | logpt = logpt.view(-1) 22 | pt = logpt.detach().exp() 23 | 24 | if self.alpha is not None: 25 | assert False 26 | 27 | loss = -1 * (1-pt)**self.gamma * logpt 28 | if self.size_average: return loss.mean() 29 | else: return loss.sum() 30 | 31 | def create_loss(): 32 | return FocalLoss(gamma = 2.0) -------------------------------------------------------------------------------- /loss/LADELoss.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | """ 4 | 5 | import functools 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from utils import * 12 | 13 | 14 | class LADELoss(nn.Module): 15 | def __init__(self, num_classes=10, img_max=None, prior=None, prior_txt=None, remine_lambda=0.1): 16 | super().__init__() 17 | if img_max is not None or prior_txt is not None: 18 | self.img_num_per_cls = calculate_prior(num_classes, img_max, prior, prior_txt, return_num=True).float().cuda() 19 | self.prior = self.img_num_per_cls / self.img_num_per_cls.sum() 20 | else: 21 | self.prior = None 22 | 23 | self.balanced_prior = torch.tensor(1. / num_classes).float().cuda() 24 | self.remine_lambda = remine_lambda 25 | 26 | self.num_classes = num_classes 27 | self.cls_weight = (self.img_num_per_cls.float() / torch.sum(self.img_num_per_cls.float())).cuda() 28 | 29 | def mine_lower_bound(self, x_p, x_q, num_samples_per_cls): 30 | N = x_p.size(-1) 31 | first_term = torch.sum(x_p, -1) / (num_samples_per_cls + 1e-8) 32 | second_term = torch.logsumexp(x_q, -1) - np.log(N) 33 | 34 | return first_term - second_term, first_term, second_term 35 | 36 | def remine_lower_bound(self, x_p, x_q, num_samples_per_cls): 37 | loss, first_term, second_term = self.mine_lower_bound(x_p, x_q, num_samples_per_cls) 38 | reg = (second_term ** 2) * self.remine_lambda 39 | return loss - reg, first_term, second_term 40 | 41 | def forward(self, y_pred, target, q_pred=None): 42 | """ 43 | y_pred: N x C 44 | target: N 45 | """ 46 | per_cls_pred_spread = y_pred.T * (target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target)) # C x N 47 | pred_spread = (y_pred - torch.log(self.prior + 1e-9) + torch.log(self.balanced_prior + 1e-9)).T # C x N 48 | 49 | num_samples_per_cls = torch.sum(target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target), -1).float() # C 50 | estim_loss, first_term, second_term = self.remine_lower_bound(per_cls_pred_spread, pred_spread, num_samples_per_cls) 51 | 52 | loss = -torch.sum(estim_loss * self.cls_weight) 53 | return loss 54 | 55 | def create_loss(num_classes, img_max=None, prior=None, prior_txt=None, remine_lambda=0.1): 56 | print("Loading LADELoss.") 57 | return LADELoss( 58 | num_classes=num_classes, 59 | img_max=img_max, 60 | prior=prior, 61 | prior_txt=prior_txt, 62 | remine_lambda=remine_lambda, 63 | ) 64 | -------------------------------------------------------------------------------- /loss/PriorCELoss.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | from utils import * 11 | 12 | 13 | class PriorCELoss(nn.Module): 14 | # Also named as LADE-CE Loss 15 | def __init__(self, num_classes, img_max=None, prior=None, prior_txt=None): 16 | super().__init__() 17 | self.img_num_per_cls = calculate_prior(num_classes, img_max, prior, prior_txt, return_num=True).float().cuda() 18 | self.prior = self.img_num_per_cls / self.img_num_per_cls.sum() 19 | self.criterion = nn.CrossEntropyLoss() 20 | self.num_classes = num_classes 21 | 22 | def forward(self, x, y): 23 | logits = x + torch.log(self.prior + 1e-9) 24 | loss = self.criterion(logits, y) 25 | return loss 26 | 27 | 28 | def create_loss(num_classes, img_max=None, prior=None, prior_txt=None): 29 | print('Loading PriorCELoss Loss.') 30 | return PriorCELoss( 31 | num_classes=num_classes, 32 | img_max=img_max, 33 | prior=prior, 34 | prior_txt=prior_txt, 35 | ) 36 | -------------------------------------------------------------------------------- /loss/SoftmaxLoss.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | 16 | import torch.nn as nn 17 | 18 | def create_loss (): 19 | print('Loading Softmax Loss.') 20 | return nn.CrossEntropyLoss() 21 | 22 | -------------------------------------------------------------------------------- /loss/WeightedSoftmaxLoss.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | def create_loss (): 19 | print('Loading Weighted Softmax Loss.') 20 | # Imagenet_LT class distribution 21 | dist = [0 for _ in range(1000)] 22 | with open('./data/ImageNet_LT/ImageNet_LT_train.txt') as f: 23 | for line in f: 24 | dist[int(line.split()[1])] += 1 25 | num = sum(dist) 26 | prob = [i/num for i in dist] 27 | prob = torch.FloatTensor(prob) 28 | # normalization 29 | max_prob = prob.max().item() 30 | prob = prob / max_prob 31 | # class reweight 32 | weight = - prob.log() + 1 33 | 34 | return nn.CrossEntropyLoss(weight=weight) 35 | 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | """ 4 | 5 | import os 6 | import argparse 7 | import pprint 8 | from data import dataloader 9 | from run_networks import model 10 | import warnings 11 | import yaml 12 | import numpy as np 13 | from utils import source_import, update 14 | from pathlib import Path 15 | import torch.backends.cudnn as cudnn 16 | 17 | data_root_dict = {'ImageNet': '/nas/dataset/others/imagenet/raw', 18 | 'iNaturalist18': '/nas/dataset/others/iNaturalist18', 19 | 'Places': '/nas/dataset/others/places365/', 20 | 'CIFAR100': '/nas/dataset/others/cifar100',} 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--cfg', default=None, type=str) 24 | parser.add_argument('--seed', default=None, type=int) 25 | parser.add_argument('--test', default=False, action='store_true') 26 | parser.add_argument('--save_feature', default=False, action='store_true') 27 | parser.add_argument('--batch_size', type=int, default=None) 28 | parser.add_argument('--model_dir', type=str, default=None) 29 | parser.add_argument('--alpha', type=float, default=0.1) 30 | parser.add_argument('--lr', type=float, default=0.2) 31 | parser.add_argument('--cifar_imb_ratio', type=float, default=0.1, choices=[0.01, 0.02, 0.1]) 32 | parser.add_argument("--remine_lambda", default=None, type=float) 33 | parser.add_argument("--work_dir", default="./exp_results", type=str, help="output dir") 34 | parser.add_argument("--exp_name", default="test", type=str, help="exp name") 35 | parser.add_argument('--gpu', default='0', type=str, 36 | help='id(s) for CUDA_VISIBLE_DEVICES') 37 | parser.add_argument("--no-use-dv", action="store_true") 38 | parser.add_argument("--test_imb_ratio", type=float, default=None, 39 | help="Give explicit imbalance ratio for test dataset.") 40 | parser.add_argument("--exist_only", type=int, default=0) 41 | parser.add_argument("--test-reverse", type=int, default=0) 42 | parser.add_argument("--train-reverse", action="store_true") 43 | parser.add_argument('--root', default=None, type=str) 44 | 45 | args = parser.parse_args() 46 | args.test_reverse = bool(args.test_reverse) 47 | 48 | print(f'args: {args}') 49 | 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | output_dir = f'{args.work_dir}/{args.exp_name}' 52 | Path(output_dir).mkdir(parents=True, exist_ok=True) 53 | # ============================================================================ 54 | # Random Seed 55 | import torch 56 | import random 57 | if args.seed is not None: 58 | print('=======> Using Fixed Random Seed <========') 59 | random.seed(args.seed) 60 | np.random.seed(args.seed) 61 | torch.manual_seed(args.seed) 62 | torch.cuda.manual_seed(args.seed) 63 | cudnn.deterministic = True 64 | 65 | # ============================================================================ 66 | # LOAD CONFIGURATIONS 67 | with open(args.cfg) as f: 68 | config = yaml.load(f) 69 | config = update(config, args, output_dir) 70 | 71 | test_mode = args.test 72 | save_mode = args.save_feature # only in eval 73 | training_opt = config['training_opt'] 74 | dataset = training_opt['dataset'] 75 | 76 | if not os.path.isdir(training_opt['log_dir']): 77 | os.makedirs(training_opt['log_dir']) 78 | 79 | if args.root is not None: 80 | data_root = args.root 81 | else: 82 | data_root = data_root_dict[dataset.rstrip('_LT')] 83 | 84 | print('Loading dataset from: %s' % data_root) 85 | pprint.pprint(config) 86 | 87 | 88 | # ============================================================================ 89 | # TRAINING 90 | if not test_mode: 91 | # during training, different sampler may be applied 92 | sampler_defs = training_opt['sampler'] 93 | if sampler_defs: 94 | if sampler_defs['type'] == 'ClassAwareSampler': 95 | sampler_dic = { 96 | 'sampler': source_import(sampler_defs['def_file']).get_sampler(), 97 | 'params': {'num_samples_cls': sampler_defs['num_samples_cls']} 98 | } 99 | elif sampler_defs['type'] in ['MixedPrioritizedSampler', 100 | 'ClassPrioritySampler']: 101 | sampler_dic = { 102 | 'sampler': source_import(sampler_defs['def_file']).get_sampler(), 103 | 'params': {k: v for k, v in sampler_defs.items() \ 104 | if k not in ['type', 'def_file']} 105 | } 106 | else: 107 | sampler_dic = None 108 | 109 | # generated sub-datasets all have test split 110 | splits = ['train', 'val'] 111 | if dataset not in ['iNaturalist18', 'ImageNet']: 112 | splits.append('test') 113 | data = {x: dataloader.load_data(data_root=data_root, 114 | dataset=dataset, phase=x, 115 | batch_size=training_opt['batch_size'], 116 | sampler_dic=sampler_dic, 117 | num_workers=training_opt['num_workers'], 118 | top_k_class=training_opt['top_k'] if 'top_k' in training_opt else None, 119 | cifar_imb_ratio=training_opt['cifar_imb_ratio'] if 'cifar_imb_ratio' in training_opt else None, 120 | reverse=args.train_reverse) 121 | for x in splits} 122 | 123 | training_model = model(config, data, test=False) 124 | training_model.train() 125 | 126 | # ============================================================================ 127 | # TESTING 128 | else: 129 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", 130 | UserWarning) 131 | print('Under testing phase, we load training data simply to calculate training data number for each class.') 132 | 133 | if 'iNaturalist' in dataset.rstrip('_LT'): 134 | splits = ['train', 'val'] 135 | test_split = 'val' 136 | else: 137 | splits = ['train', 'val', 'test'] 138 | test_split = 'test' 139 | 140 | data = {x: dataloader.load_data(data_root=data_root, 141 | dataset=dataset, phase=x, 142 | batch_size=training_opt['batch_size'], 143 | sampler_dic=None, 144 | num_workers=training_opt['num_workers'], 145 | top_k_class=training_opt['top_k'] if 'top_k' in training_opt else None, 146 | shuffle=False, 147 | cifar_imb_ratio=training_opt['cifar_imb_ratio'] if 'cifar_imb_ratio' in training_opt else None, 148 | test_imb_ratio=args.test_imb_ratio, 149 | reverse=args.train_reverse if x == "train" else args.test_reverse) 150 | for x in splits} 151 | 152 | training_model = model(config, data, test=True, 153 | test_imb_ratio=args.test_imb_ratio, 154 | test_reverse=args.test_reverse) 155 | # load checkpoints 156 | training_model.load_model(args.model_dir) 157 | 158 | training_model.eval(phase=test_split, save_feat=save_mode) 159 | 160 | print('='*25, ' ALL COMPLETED ', '='*25) 161 | -------------------------------------------------------------------------------- /models/CausalNormClassifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | from os import path 5 | import math 6 | 7 | class Causal_Norm_Classifier(nn.Module): 8 | 9 | def __init__(self, num_classes=1000, feat_dim=2048, use_effect=True, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125, *args): 10 | super(Causal_Norm_Classifier, self).__init__() 11 | self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True) 12 | self.scale = tau / num_head # 16.0 / num_head 13 | self.norm_scale = gamma # 1.0 / 32.0 14 | self.alpha = alpha # 3.0 15 | self.num_head = num_head 16 | self.head_dim = feat_dim // num_head 17 | self.use_effect = use_effect 18 | self.reset_parameters(self.weight) 19 | # self.relu = nn.ReLU(inplace=True) 20 | 21 | def reset_parameters(self, weight): 22 | stdv = 1. / math.sqrt(weight.size(1)) 23 | weight.data.uniform_(-stdv, stdv) 24 | 25 | def forward(self, x, label, embed): 26 | # calculate capsule normalized feature vector and predict 27 | normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale) 28 | normed_x = self.multi_head_call(self.l2_norm, x) 29 | y = torch.mm(normed_x * self.scale, normed_w.t()) 30 | 31 | # remove the effect of confounder c during test 32 | if (not self.training) and self.use_effect: 33 | self.embed = torch.from_numpy(embed).view(1, -1).to(x.device) 34 | normed_c = self.multi_head_call(self.l2_norm, self.embed) 35 | head_dim = x.shape[1] // self.num_head 36 | x_list = torch.split(normed_x, head_dim, dim=1) 37 | c_list = torch.split(normed_c, head_dim, dim=1) 38 | w_list = torch.split(normed_w, head_dim, dim=1) 39 | output = [] 40 | 41 | for nx, nc, nw in zip(x_list, c_list, w_list): 42 | cos_val, sin_val = self.get_cos_sin(nx, nc) 43 | y0 = torch.mm((nx - cos_val * self.alpha * nc) * self.scale, nw.t()) 44 | output.append(y0) 45 | y = sum(output) 46 | 47 | return y, None 48 | 49 | def get_cos_sin(self, x, y): 50 | cos_val = (x * y).sum(-1, keepdim=True) / torch.norm(x, 2, 1, keepdim=True) / torch.norm(y, 2, 1, keepdim=True) 51 | sin_val = (1 - cos_val * cos_val).sqrt() 52 | return cos_val, sin_val 53 | 54 | def multi_head_call(self, func, x, weight=None): 55 | assert len(x.shape) == 2 56 | x_list = torch.split(x, self.head_dim, dim=1) 57 | if weight: 58 | y_list = [func(item, weight) for item in x_list] 59 | else: 60 | y_list = [func(item) for item in x_list] 61 | assert len(x_list) == self.num_head 62 | assert len(y_list) == self.num_head 63 | return torch.cat(y_list, dim=1) 64 | 65 | def l2_norm(self, x): 66 | normed_x = x / torch.norm(x, 2, 1, keepdim=True) 67 | return normed_x 68 | 69 | def capsule_norm(self, x): 70 | norm= torch.norm(x.clone(), 2, 1, keepdim=True) 71 | normed_x = (norm / (1 + norm)) * (x / norm) 72 | return normed_x 73 | 74 | def causal_norm(self, x, weight): 75 | norm= torch.norm(x, 2, 1, keepdim=True) 76 | normed_x = x / (norm + weight) 77 | return normed_x 78 | 79 | def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, use_effect=True, num_head=None, tau=None, alpha=None, gamma=None, *args): 80 | print('Loading Causal Norm Classifier with use_effect: {}, num_head: {}, tau: {}, alpha: {}, gamma: {}.'.format(str(use_effect), num_head, tau, alpha, gamma)) 81 | clf = Causal_Norm_Classifier(num_classes, feat_dim, use_effect=use_effect, num_head=num_head, tau=tau, alpha=alpha, gamma=gamma) 82 | 83 | return clf 84 | -------------------------------------------------------------------------------- /models/DotProductClassifier.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | 16 | import torch.nn as nn 17 | from utils import * 18 | from os import path 19 | 20 | class DotProduct_Classifier(nn.Module): 21 | 22 | def __init__(self, num_classes=1000, feat_dim=2048, use_route=False, *args): 23 | super(DotProduct_Classifier, self).__init__() 24 | self.fc = nn.Linear(feat_dim, num_classes) 25 | self.use_route = use_route 26 | 27 | def forward(self, x, *args): 28 | x = self.fc(x) 29 | if self.use_route: 30 | return x, x 31 | else: 32 | return x, None 33 | 34 | def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, use_route=False, *args): 35 | print('Loading Dot Product Classifier.') 36 | clf = DotProduct_Classifier(num_classes, feat_dim, use_route) 37 | 38 | if not test: 39 | if stage1_weights: 40 | assert(dataset) 41 | print('Loading %s Stage 1 Classifier Weights.' % dataset) 42 | if log_dir is not None: 43 | subdir = log_dir.strip('/').split('/')[-1] 44 | subdir = subdir.replace('stage2', 'stage1') 45 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir) 46 | # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 47 | else: 48 | weight_dir = './logs/%s/stage1' % dataset 49 | print('==> Loading classifier weights from %s' % weight_dir) 50 | clf.fc = init_weights(model=clf.fc, 51 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'), 52 | classifier=True) 53 | else: 54 | print('Random initialized classifier weights.') 55 | 56 | return clf 57 | -------------------------------------------------------------------------------- /models/ResNet101Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | from models.ResNetFeature import * 10 | from utils import * 11 | from os import path 12 | 13 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 14 | 15 | print('Loading Scratch ResNet 101 Feature Model.') 16 | resnet101 = ResNet(Bottleneck, [3, 4, 23, 3], use_fc=use_fc, dropout=None) 17 | 18 | if not test: 19 | if stage1_weights: 20 | assert(dataset) 21 | print('Loading %s Stage 1 ResNet 101 Weights.' % dataset) 22 | if log_dir is not None: 23 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 24 | else: 25 | weight_dir = './logs/%s/stage1' % dataset 26 | print('==> Loading weights from %s' % weight_dir) 27 | resnet101 = init_weights(model=resnet101, 28 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth')) 29 | else: 30 | print('No Pretrained Weights For Feature Model.') 31 | 32 | return resnet101 33 | -------------------------------------------------------------------------------- /models/ResNet152Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from models.ResNetFeature import * 9 | from utils import * 10 | 11 | 12 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, caffe=False, test=False): 13 | 14 | print('Loading Scratch ResNet 152 Feature Model.') 15 | resnet152 = ResNet(Bottleneck, [3, 8, 36, 3], use_fc=use_fc, dropout=None) 16 | 17 | if not test: 18 | 19 | assert(caffe != stage1_weights) 20 | 21 | if caffe: 22 | print('Loading Caffe Pretrained ResNet 152 Weights.') 23 | resnet152 = init_weights(model=resnet152, 24 | weights_path='./pretrained/caffe_resnet152.pth', 25 | caffe=True) 26 | elif stage1_weights: 27 | assert(dataset) 28 | print('Loading %s Stage 1 ResNet 152 Weights.' % dataset) 29 | resnet152 = init_weights(model=resnet152, 30 | weights_path='./logs/%s/stage1/final_model_checkpoint.pth' % dataset) 31 | else: 32 | print('No Pretrained Weights For Feature Model.') 33 | 34 | return resnet152 35 | -------------------------------------------------------------------------------- /models/ResNet32Feature.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import * 3 | from os import path 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.init as init 10 | 11 | 12 | def _weights_init(m): 13 | classname = m.__class__.__name__ 14 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 15 | init.kaiming_normal_(m.weight) 16 | 17 | 18 | class LambdaLayer(nn.Module): 19 | def __init__(self, lambd): 20 | super(LambdaLayer, self).__init__() 21 | self.lambd = lambd 22 | 23 | def forward(self, x): 24 | return self.lambd(x) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, in_planes, planes, stride=1, option="A"): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = nn.Conv2d( 33 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 34 | ) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d( 37 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 38 | ) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | 41 | self.shortcut = nn.Sequential() 42 | if stride != 1 or in_planes != planes: 43 | if option == "A": 44 | """ 45 | For CIFAR10 ResNet paper uses option A. 46 | """ 47 | self.shortcut = LambdaLayer( 48 | lambda x: F.pad( 49 | x[:, :, ::2, ::2], 50 | (0, 0, 0, 0, planes // 4, planes // 4), 51 | "constant", 52 | 0, 53 | ) 54 | ) 55 | elif option == "B": 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d( 58 | in_planes, 59 | self.expansion * planes, 60 | kernel_size=1, 61 | stride=stride, 62 | bias=False, 63 | ), 64 | nn.BatchNorm2d(self.expansion * planes), 65 | ) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = self.bn2(self.conv2(out)) 70 | out += self.shortcut(x) 71 | out = F.relu(out) 72 | return out 73 | 74 | 75 | class BBN_ResNet_Cifar(nn.Module): 76 | def __init__(self, block, num_blocks): 77 | super(BBN_ResNet_Cifar, self).__init__() 78 | self.in_planes = 16 79 | 80 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(16) 82 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 64, num_blocks[2] - 1, stride=2) 85 | self.cb_block = block(self.in_planes, self.in_planes, stride=1) 86 | self.rb_block = block(self.in_planes, self.in_planes, stride=1) 87 | 88 | self.apply(_weights_init) 89 | 90 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 91 | 92 | def load_model(self, pretrain): 93 | print("Loading Backbone pretrain model from {}......".format(pretrain)) 94 | model_dict = self.state_dict() 95 | pretrain_dict = torch.load(pretrain)["state_dict_best"]['feat_model'] 96 | 97 | new_dict = OrderedDict() 98 | 99 | for k, v in pretrain_dict.items(): 100 | if k.startswith("module"): 101 | k = k[7:] 102 | if "fc" not in k and "classifier" not in k: 103 | new_dict[k] = v 104 | 105 | model_dict.update(new_dict) 106 | self.load_state_dict(model_dict) 107 | print("Backbone model has been loaded......") 108 | 109 | def _make_layer(self, block, planes, num_blocks, stride, add_flag=True): 110 | strides = [stride] + [1] * (num_blocks - 1) 111 | layers = [] 112 | for stride in strides: 113 | layers.append(block(self.in_planes, planes, stride)) 114 | self.in_planes = planes * block.expansion 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def forward(self, x, **kwargs): 119 | out = F.relu(self.bn1(self.conv1(x))) 120 | out = self.layer1(out) 121 | out = self.layer2(out) 122 | out = self.layer3(out) 123 | if "feature_cb" in kwargs: 124 | out = self.cb_block(out) 125 | return out 126 | elif "feature_rb" in kwargs: 127 | out = self.rb_block(out) 128 | return out 129 | 130 | out1 = self.cb_block(out) 131 | out2 = self.rb_block(out) 132 | out = torch.cat((out1, out2), dim=1) 133 | 134 | out = self.avgpool(out) 135 | out = out.view(out.shape[0], -1) 136 | 137 | 138 | return out 139 | 140 | 141 | 142 | def create_model(use_fc=False, pretrain=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 143 | 144 | print('Loading ResNet 32 Feature Model.') 145 | resnet32 = BBN_ResNet_Cifar(BasicBlock, [5, 5, 5]) 146 | 147 | pretrained_model="./data/checkpoints/final_model_checkpoint.pth" 148 | if path.exists(pretrained_model) and pretrain: 149 | print('===> Load Initialization for ResNet32') 150 | resnet32.load_model(pretrain=pretrained_model) 151 | else: 152 | print('===> Train backbone from the scratch') 153 | 154 | return resnet32 155 | -------------------------------------------------------------------------------- /models/ResNet50Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | from models.ResNetFeature import * 10 | from utils import * 11 | from os import path 12 | 13 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 14 | 15 | print('Loading Scratch ResNet 50 Feature Model.') 16 | resnet50 = ResNet(Bottleneck, [3, 4, 6, 3], use_fc=use_fc, dropout=None) 17 | 18 | if not test: 19 | if stage1_weights: 20 | assert(dataset) 21 | print('Loading %s Stage 1 ResNet 10 Weights.' % dataset) 22 | if log_dir is not None: 23 | # subdir = log_dir.strip('/').split('/')[-1] 24 | # subdir = subdir.replace('stage2', 'stage1') 25 | # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir) 26 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 27 | else: 28 | weight_dir = './logs/%s/stage1' % dataset 29 | print('==> Loading weights from %s' % weight_dir) 30 | resnet50 = init_weights(model=resnet50, 31 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth')) 32 | else: 33 | print('No Pretrained Weights For Feature Model.') 34 | 35 | return resnet50 36 | -------------------------------------------------------------------------------- /models/ResNetFeature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | 16 | import math 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, use_fc=False, dropout=None): 97 | self.inplanes = 64 98 | super(ResNet, self).__init__() 99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = nn.BatchNorm2d(64) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 108 | self.avgpool = nn.AvgPool2d(7, stride=1) 109 | 110 | self.use_fc = use_fc 111 | self.use_dropout = True if dropout else False 112 | 113 | if self.use_fc: 114 | print('Using fc.') 115 | self.fc_add = nn.Linear(512*block.expansion, 512) 116 | 117 | if self.use_dropout: 118 | print('Using dropout.') 119 | self.dropout = nn.Dropout(p=dropout) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x, *args): 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) 150 | x = self.maxpool(x) 151 | 152 | x = self.layer1(x) 153 | x = self.layer2(x) 154 | x = self.layer3(x) 155 | x = self.layer4(x) 156 | 157 | x = self.avgpool(x) 158 | 159 | x = x.view(x.size(0), -1) 160 | 161 | if self.use_fc: 162 | x = F.relu(self.fc_add(x)) 163 | 164 | if self.use_dropout: 165 | x = self.dropout(x) 166 | 167 | return x 168 | -------------------------------------------------------------------------------- /models/ResNext101Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from models.ResNextFeature import * 9 | from utils import * 10 | from os import path 11 | 12 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, last_relu=True, *args): 13 | 14 | print('Loading Scratch ResNext 101 Feature Model.') 15 | print('======> Last ReLU: ', last_relu) 16 | resnext = ResNext(Bottleneck, [3, 4, 23, 3], use_fc=use_fc, dropout=None, 17 | groups=32, width_per_group=4, last_relu=last_relu) 18 | 19 | if not test: 20 | if stage1_weights: 21 | assert(dataset) 22 | print('Loading %s Stage 1 ResNext 101 Weights.' % dataset) 23 | if log_dir is not None: 24 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 25 | else: 26 | weight_dir = './logs/%s/stage1' % dataset 27 | print('==> Loading weights from %s' % weight_dir) 28 | resnext = init_weights(model=resnext, 29 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth')) 30 | else: 31 | print('No Pretrained Weights For Feature Model.') 32 | 33 | return resnext 34 | -------------------------------------------------------------------------------- /models/ResNext152Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from models.ResNextFeature import * 9 | from utils import * 10 | from os import path 11 | 12 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, last_relu=True, *args): 13 | 14 | print('Loading Scratch ResNext 152 Feature Model.') 15 | print('======> Last ReLU: ', last_relu) 16 | resnext = ResNext(Bottleneck, [3, 8, 36, 3], use_fc=use_fc, dropout=None, 17 | groups=32, width_per_group=4, last_relu=last_relu) 18 | 19 | if not test: 20 | if stage1_weights: 21 | assert(dataset) 22 | print('Loading %s Stage 1 ResNext 152 Weights.' % dataset) 23 | if log_dir is not None: 24 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 25 | else: 26 | weight_dir = './logs/%s/stage1' % dataset 27 | print('==> Loading weights from %s' % weight_dir) 28 | resnext = init_weights(model=resnext, 29 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth')) 30 | else: 31 | print('No Pretrained Weights For Feature Model.') 32 | 33 | return resnext 34 | -------------------------------------------------------------------------------- /models/ResNext50Feature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from models.ResNextFeature import * 9 | from utils import * 10 | from os import path 11 | 12 | def create_model(use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, last_relu=True, *args): 13 | 14 | print('Loading Scratch ResNext 50 Feature Model.') 15 | print('======> Last ReLU: ', last_relu) 16 | resnext = ResNext(Bottleneck, [3, 4, 6, 3], use_fc=use_fc, dropout=None, 17 | groups=32, width_per_group=4, last_relu=last_relu) 18 | 19 | if not test: 20 | if stage1_weights: 21 | assert(dataset) 22 | print('Loading %s Stage 1 ResNext 50 Weights.' % dataset) 23 | if log_dir is not None: 24 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 25 | else: 26 | weight_dir = './logs/%s/stage1' % dataset 27 | print('==> Loading weights from %s' % weight_dir) 28 | resnext = init_weights(model=resnext, 29 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth')) 30 | else: 31 | print('No Pretrained Weights For Feature Model.') 32 | 33 | return resnext -------------------------------------------------------------------------------- /models/ResNextFeature.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from utils import * 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None, 54 | groups=1, base_width=64, is_last=False, last_relu=True): 55 | super(Bottleneck, self).__init__() 56 | width = int(planes * (base_width / 64.)) * groups 57 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) 58 | self.bn1 = nn.BatchNorm2d(width) 59 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 60 | groups=groups, padding=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(width) 62 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * 4) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.downsample = downsample 66 | self.stride = stride 67 | self.is_last = is_last 68 | self.last_relu = last_relu 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | if self.is_last and (not self.last_relu): 89 | pass 90 | else: 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | class ResNext(nn.Module): 96 | 97 | def __init__(self, block, layers, groups=1, width_per_group=64, use_fc=False, dropout=None, 98 | use_glore=False, use_gem=False, last_relu=True): 99 | self.inplanes = 64 100 | super(ResNext, self).__init__() 101 | 102 | self.groups = groups 103 | self.base_width = width_per_group 104 | 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, is_last=True, last_relu=last_relu) 114 | self.avgpool = nn.AvgPool2d(7, stride=1) 115 | 116 | self.use_fc = use_fc 117 | self.use_dropout = True if dropout else False 118 | 119 | if self.use_fc: 120 | print('Using fc.') 121 | self.fc_add = nn.Linear(512*block.expansion, 512) 122 | 123 | if self.use_dropout: 124 | print('Using dropout.') 125 | self.dropout = nn.Dropout(p=dropout) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | 135 | 136 | def _make_layer(self, block, planes, blocks, stride=1, is_last=False, last_relu=True): 137 | downsample = None 138 | if stride != 1 or self.inplanes != planes * block.expansion: 139 | downsample = nn.Sequential( 140 | nn.Conv2d(self.inplanes, planes * block.expansion, 141 | kernel_size=1, stride=stride, bias=False), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, downsample, 147 | groups=self.groups, base_width=self.base_width)) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, 151 | groups=self.groups, base_width=self.base_width, 152 | is_last=(is_last and i == blocks-1), last_relu=last_relu)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | def forward(self, x, *args): 157 | x = self.conv1(x) 158 | x = self.bn1(x) 159 | x = self.relu(x) 160 | x = self.maxpool(x) 161 | 162 | x = self.layer1(x) 163 | x = self.layer2(x) 164 | x = self.layer3(x) 165 | x = self.layer4(x) 166 | 167 | x = self.avgpool(x) 168 | 169 | x = x.view(x.size(0), -1) 170 | 171 | if self.use_fc: 172 | x = F.relu(self.fc_add(x)) 173 | 174 | if self.use_dropout: 175 | x = self.dropout(x) 176 | 177 | return x -------------------------------------------------------------------------------- /models/TauNormClassifier.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Facebook, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Portions of the source code are from the OLTR project which 8 | notice below and in LICENSE in the root directory of 9 | this source tree. 10 | 11 | Copyright (c) 2019, Zhongqi Miao 12 | All rights reserved. 13 | """ 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn.parameter import Parameter 18 | from utils import * 19 | from os import path 20 | import math 21 | 22 | class DotProduct_Classifier(nn.Module): 23 | 24 | def __init__(self, num_classes=1000, feat_dim=2048, use_effect=False, num_head=2, apply_gate=False, *args): 25 | super(DotProduct_Classifier, self).__init__() 26 | self.fc = nn.Linear(feat_dim, num_classes) 27 | self.scales = Parameter(torch.ones(num_classes)) 28 | for param_name, param in self.fc.named_parameters(): 29 | param.requires_grad = False 30 | 31 | 32 | def forward(self, x, label, embed): 33 | x = self.fc(x) 34 | x *= self.scales 35 | return x, None 36 | 37 | def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 38 | print('Loading Tau Norm Classifier.') 39 | clf = DotProduct_Classifier(num_classes, feat_dim) 40 | 41 | if not test: 42 | if stage1_weights: 43 | assert(dataset) 44 | print('Loading %s Stage 1 Classifier Weights.' % dataset) 45 | if log_dir is not None: 46 | subdir = log_dir.strip('/').split('/')[-1] 47 | subdir = subdir.replace('stage2', 'stage1') 48 | weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir) 49 | # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 50 | else: 51 | weight_dir = './logs/%s/stage1' % dataset 52 | print('==> Loading classifier weights from %s' % weight_dir) 53 | clf.fc = init_weights(model=clf.fc, 54 | weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'), 55 | classifier=True) 56 | else: 57 | print('Random initialized classifier weights.') 58 | 59 | return clf -------------------------------------------------------------------------------- /parse_json.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | import argparse 3 | from tqdm import tqdm 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description="Train FGVC Network") 7 | 8 | parser.add_argument( 9 | "--file", 10 | help="json file to be converted", 11 | required=True, 12 | type=str, 13 | ) 14 | parser.add_argument( 15 | "--root", 16 | help="root path to save image", 17 | type=str, 18 | required=True, 19 | ) 20 | parser.add_argument( 21 | "--sp", 22 | help="save path for converted file ", 23 | type=str, 24 | required=False, 25 | default="." 26 | ) 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | def convert(json_file, image_root): 32 | all_annos = json.load(open(json_file, 'r')) 33 | import ipdb; ipdb.set_trace() 34 | annos = all_annos['annotations'] 35 | images = all_annos['images'] 36 | new_annos = [] 37 | 38 | print("Converting file {} ...".format(json_file)) 39 | for anno, image in tqdm(zip(annos, images)): 40 | assert image["id"] == anno["id"] 41 | 42 | new_annos.append({"image_id": image["id"], 43 | "im_height": image["height"], 44 | "im_width": image["width"], 45 | "category_id": anno["category_id"], 46 | "fpath": os.path.join(image_root, image["file_name"])}) 47 | num_classes = len(all_annos["categories"]) 48 | return {"annotations": new_annos, 49 | "num_classes": num_classes} 50 | 51 | if __name__ == "__main__": 52 | args = parse_args() 53 | converted_annos = convert(args.file, args.root) 54 | save_path = os.path.join(args.sp, "converted_" + os.path.split(args.file)[-1]) 55 | print("Converted, Saveing converted file to {}".format(save_path)) 56 | with open(save_path, "w") as f: 57 | json.dump(converted_annos, f) -------------------------------------------------------------------------------- /run_networks.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | All rights reserved. 6 | 7 | This source code is licensed under the license found in the 8 | LICENSE file in the root directory of this source tree. 9 | 10 | Portions of the source code are from the OLTR project which 11 | notice below and in LICENSE in the root directory of 12 | this source tree. 13 | 14 | Copyright (c) 2019, Zhongqi Miao 15 | All rights reserved. 16 | """ 17 | 18 | import os 19 | import copy 20 | import pickle 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import torch.nn.functional as F 25 | from tqdm import tqdm 26 | from utils import * 27 | from logger import Logger 28 | import time 29 | import numpy as np 30 | import warnings 31 | import pdb 32 | from torch.utils.tensorboard import SummaryWriter 33 | 34 | from sklearn.decomposition import IncrementalPCA 35 | 36 | class model (): 37 | def __init__(self, config, data, test=False, test_imb_ratio=1.0, test_reverse=False): 38 | self.config = config 39 | self.training_opt = self.config['training_opt'] 40 | self.data = {key: item[1] for key, item in data.items()} 41 | self.test_mode = test 42 | self.num_gpus = torch.cuda.device_count() 43 | self.do_shuffle = config['shuffle'] if 'shuffle' in config else False 44 | 45 | # Setup prior distribution 46 | self.prior_distribution = {key: item[0] for key, item in data.items()} 47 | 48 | # Setup logger 49 | self.logger = Logger(self.training_opt['log_dir']) 50 | 51 | # init moving average 52 | self.embed_mean = torch.zeros(int(self.training_opt['feature_dim'])).numpy() 53 | self.mu = 0.9 54 | self.sumexp_logits = torch.zeros(1) 55 | 56 | if not test: 57 | self.tensorboard = SummaryWriter(log_dir=f"{self.training_opt['log_dir']}/tensorboard") 58 | self.current_step = 0 59 | self.current_epoch = 0 60 | # Initialize model 61 | self.init_models() 62 | 63 | # apply incremental pca 64 | self.apply_pca = ('apply_ipca' in self.config) and self.config['apply_ipca'] 65 | if self.apply_pca: 66 | print('==========> Apply Incremental PCA <=======') 67 | self.pca = IncrementalPCA(n_components=self.config['num_components'], batch_size=self.training_opt['batch_size']) 68 | 69 | # Load pre-trained model parameters 70 | if 'model_dir' in self.config and self.config['model_dir'] is not None: 71 | self.load_model(self.config['model_dir']) 72 | 73 | # Under training mode, initialize training steps, optimizers, schedulers, criterions 74 | if not self.test_mode: 75 | 76 | # If using steps for training, we need to calculate training steps 77 | # for each epoch based on actual number of training data instead of 78 | # oversampled data number 79 | print('Using steps for training.') 80 | self.training_data_num = len(self.data['train'].dataset) 81 | self.epoch_steps = int(self.training_data_num / self.training_opt['batch_size']) 82 | 83 | # Initialize model optimizer and scheduler 84 | print('Initializing model optimizer.') 85 | self.init_optimizers(self.model_optim_params_dict) 86 | self.init_criterions() 87 | 88 | # Set up log file 89 | self.log_file = os.path.join(self.training_opt['log_dir'], 'log.txt') 90 | self.logger.log_cfg(self.config) 91 | else: 92 | if test_reverse: 93 | self.log_file = os.path.join(self.training_opt['log_dir'], f'test-reverse_imb{test_imb_ratio}.txt') 94 | else: 95 | self.log_file = os.path.join(self.training_opt['log_dir'], f'test_imb{test_imb_ratio}.txt') 96 | self.logger.log_cfg(self.config) 97 | 98 | def write_summary(self, split, step, **kargs): 99 | for n, v in kargs.items(): 100 | if hasattr(self, 'tensorboard'): 101 | self.tensorboard.add_scalar( 102 | tag=f"{split}/{n}", 103 | scalar_value=v, 104 | global_step=step 105 | ) 106 | 107 | def init_models(self, optimizer=True): 108 | networks_defs = self.config['networks'] 109 | self.networks = {} 110 | self.model_optim_params_dict = {} 111 | self.model_optim_named_params = {} 112 | 113 | print("Using", torch.cuda.device_count(), "GPUs.") 114 | 115 | for key, val in networks_defs.items(): 116 | # Networks 117 | def_file = val['def_file'] 118 | model_args = val['params'] 119 | model_args.update({'test': self.test_mode}) 120 | if "Prior" in def_file: 121 | model_args["prior"] = self.prior_distribution["train"] 122 | 123 | self.networks[key] = source_import(def_file).create_model(**model_args) 124 | self.networks[key] = nn.DataParallel(self.networks[key]).cuda() 125 | 126 | if 'fix' in val and val['fix']: 127 | print('Freezing weights of module {}'.format(key)) 128 | for param_name, param in self.networks[key].named_parameters(): 129 | # Freeze all parameters except final fc layer 130 | if 'fc' not in param_name: 131 | param.requires_grad = False 132 | print('=====> Freezing: {} | False'.format(key)) 133 | 134 | if 'fix_set' in val: 135 | for fix_layer in val['fix_set']: 136 | for param_name, param in self.networks[key].named_parameters(): 137 | if fix_layer == param_name: 138 | param.requires_grad = False 139 | print('=====> Freezing: {} | {}'.format(param_name, param.requires_grad)) 140 | continue 141 | 142 | 143 | # Optimizer list 144 | optim_params = val['optim_params'] 145 | self.model_optim_named_params.update(dict(self.networks[key].named_parameters())) 146 | self.model_optim_params_dict[key] = {'params': self.networks[key].parameters(), 147 | 'lr': optim_params['lr'], 148 | 'momentum': optim_params['momentum'], 149 | 'weight_decay': optim_params['weight_decay']} 150 | 151 | def init_criterions(self): 152 | criterion_defs = self.config['criterions'] 153 | self.criterions = {} 154 | self.criterion_weights = {} 155 | 156 | for key, val in criterion_defs.items(): 157 | def_file = val['def_file'] 158 | loss_args = val['loss_params'] 159 | 160 | self.criterions[key] = source_import(def_file).create_loss(**loss_args).cuda() 161 | self.criterion_weights[key] = val['weight'] 162 | 163 | if val['optim_params']: 164 | print('Initializing criterion optimizer.') 165 | optim_params = val['optim_params'] 166 | optim_params = [{'params': self.criterions[key].parameters(), 167 | 'lr': optim_params['lr'], 168 | 'momentum': optim_params['momentum'], 169 | 'weight_decay': optim_params['weight_decay']}] 170 | # Initialize criterion optimizer and scheduler 171 | self.criterion_optimizer, \ 172 | self.criterion_optimizer_scheduler = self.init_optimizers(optim_params) 173 | else: 174 | self.criterion_optimizer = None 175 | 176 | def init_optimizers(self, optim_params_dict): 177 | ''' 178 | seperate backbone optimizer and classifier optimizer 179 | by Kaihua 180 | ''' 181 | networks_defs = self.config['networks'] 182 | self.model_optimizer_dict = {} 183 | self.model_scheduler_dict = {} 184 | 185 | for key, val in networks_defs.items(): 186 | # optimizer 187 | if 'optimizer' in self.training_opt and self.training_opt['optimizer'] == 'adam': 188 | print('=====> Using Adam optimizer') 189 | optimizer = optim.Adam([optim_params_dict[key],]) 190 | else: 191 | print('=====> Using SGD optimizer') 192 | optimizer = optim.SGD([optim_params_dict[key],]) 193 | self.model_optimizer_dict[key] = optimizer 194 | # scheduler 195 | scheduler_params = val['scheduler_params'] 196 | if scheduler_params['coslr']: 197 | print("===> Module {} : Using coslr eta_min={}".format(key, scheduler_params['endlr'])) 198 | self.model_scheduler_dict[key] = torch.optim.lr_scheduler.CosineAnnealingLR( 199 | optimizer, self.training_opt['num_epochs'], eta_min=scheduler_params['endlr']) 200 | elif scheduler_params['warmup']: 201 | print("===> Module {} : Using warmup".format(key)) 202 | self.model_scheduler_dict[key] = WarmupMultiStepLR(optimizer, scheduler_params['lr_step'], 203 | gamma=scheduler_params['lr_factor'], warmup_epochs=scheduler_params['warm_epoch']) 204 | else: 205 | self.model_scheduler_dict[key] = optim.lr_scheduler.StepLR(optimizer, 206 | step_size=scheduler_params['step_size'], 207 | gamma=scheduler_params['gamma']) 208 | 209 | return 210 | 211 | def show_current_lr(self): 212 | max_lr = 0.0 213 | for key, val in self.model_optimizer_dict.items(): 214 | lr_set = list(set([para['lr'] for para in val.param_groups])) 215 | if max(lr_set) > max_lr: 216 | max_lr = max(lr_set) 217 | lr_set = ','.join([str(i) for i in lr_set]) 218 | print_str = ['=====> Current Learning Rate of model {} : {}'.format(key, str(lr_set))] 219 | print_write(print_str, self.log_file) 220 | return max_lr 221 | 222 | 223 | def batch_forward(self, inputs, labels=None, feature_ext=False, phase='train'): 224 | ''' 225 | This is a general single batch running function. 226 | ''' 227 | 228 | # Calculate Features 229 | self.features = self.networks['feat_model'](inputs) 230 | 231 | if self.apply_pca: 232 | if phase=='train' and self.features.shape[0] > 0: 233 | self.pca.partial_fit(self.features.cpu().numpy()) 234 | else: 235 | pca_feat = self.pca.transform(self.features.cpu().numpy()) 236 | pca_feat[:, 0] = 0.0 237 | new_feat = self.pca.inverse_transform(pca_feat) 238 | self.features = torch.from_numpy(new_feat).float().to(self.features.device) 239 | 240 | # update moving average 241 | if phase == 'train': 242 | self.embed_mean = self.mu * self.embed_mean + self.features.detach().mean(0).view(-1).cpu().numpy() 243 | 244 | # If not just extracting features, calculate logits 245 | if not feature_ext: 246 | # cont_eval = 'continue_eval' in self.training_opt and self.training_opt['continue_eval'] and phase != 'train' 247 | self.logits, self.route_logits = self.networks['classifier'](self.features, labels, self.embed_mean) 248 | 249 | def batch_backward(self, print_grad=False): 250 | # Zero out optimizer gradients 251 | for key, optimizer in self.model_optimizer_dict.items(): 252 | optimizer.zero_grad() 253 | if self.criterion_optimizer: 254 | self.criterion_optimizer.zero_grad() 255 | # Back-propagation from loss outputs 256 | self.loss.backward() 257 | # display gradient 258 | if self.training_opt['display_grad']: 259 | print_grad_norm(self.model_optim_named_params, print_write, self.log_file, verbose=print_grad) 260 | # Step optimizers 261 | for key, optimizer in self.model_optimizer_dict.items(): 262 | optimizer.step() 263 | if self.criterion_optimizer: 264 | self.criterion_optimizer.step() 265 | 266 | def batch_loss(self, labels): 267 | self.loss = 0 268 | 269 | # First, apply performance loss 270 | if 'PerformanceLoss' in self.criterions.keys(): 271 | self.loss_perf = self.criterions['PerformanceLoss'](self.logits, labels) 272 | self.loss_perf *= self.criterion_weights['PerformanceLoss'] 273 | self.loss += self.loss_perf 274 | 275 | # Apply loss on Route Weights if set up 276 | if 'RouteWeightLoss' in self.criterions.keys(): 277 | self.loss_route = self.criterions['RouteWeightLoss'](self.route_logits, labels) 278 | self.loss_route = self.loss_route * self.criterion_weights['RouteWeightLoss'] 279 | # Add Route Weights loss to total loss 280 | self.loss += self.loss_route 281 | 282 | # hard-coded 283 | self.sumexp_logits = torch.sum(torch.exp(self.logits), dim=-1) 284 | 285 | def shuffle_batch(self, x, y): 286 | index = torch.randperm(x.size(0)) 287 | x = x[index] 288 | y = y[index] 289 | return x, y 290 | 291 | def train(self): 292 | # When training the network 293 | print_str = ['Phase: train'] 294 | print_write(print_str, self.log_file) 295 | time.sleep(0.25) 296 | 297 | print_write(['Force shuffle in training??? --- ', self.do_shuffle], self.log_file) 298 | 299 | # Initialize best model 300 | best_model_weights = {} 301 | best_model_weights['feat_model'] = copy.deepcopy(self.networks['feat_model'].state_dict()) 302 | best_model_weights['classifier'] = copy.deepcopy(self.networks['classifier'].state_dict()) 303 | best_acc = 0.0 304 | best_epoch = 0 305 | 306 | end_epoch = self.training_opt['num_epochs'] 307 | 308 | # Loop over epochs 309 | for epoch in range(1, end_epoch + 1): 310 | self.current_epoch = epoch 311 | for key, model in self.networks.items(): 312 | # only train the module with lr > 0 313 | if self.config['networks'][key]['optim_params']['lr'] == 0.0: 314 | print_write(['=====> module {} is set to eval due to 0.0 learning rate.'.format(key)], self.log_file) 315 | model.eval() 316 | else: 317 | model.train() 318 | 319 | torch.cuda.empty_cache() 320 | 321 | # Set model modes and set scheduler 322 | # In training, step optimizer scheduler and set model to train() 323 | for key, scheduler in self.model_scheduler_dict.items(): 324 | scheduler.step() 325 | if self.criterion_optimizer: 326 | self.criterion_optimizer_scheduler.step() 327 | 328 | # Iterate over dataset 329 | total_preds = [] 330 | total_labels = [] 331 | 332 | # indicate current path 333 | print_write([self.training_opt['log_dir']], self.log_file) 334 | # print learning rate 335 | current_lr = self.show_current_lr() 336 | current_lr = min(current_lr * 50, 1.0) 337 | # scale the original mu according to the lr 338 | if 'CIFAR' not in self.training_opt['dataset']: 339 | self.mu = 1.0 - (1 - 0.9) * current_lr 340 | 341 | for step, (inputs, labels, indexes) in enumerate(self.data['train']): 342 | # Break when step equal to epoch step 343 | if step == self.epoch_steps: 344 | break 345 | self.current_step += 1 346 | if self.do_shuffle: 347 | inputs, labels = self.shuffle_batch(inputs, labels) 348 | inputs, labels = inputs.cuda(), labels.cuda() 349 | 350 | # If on training phase, enable gradients 351 | with torch.set_grad_enabled(True): 352 | 353 | # If training, forward with loss, and no top 5 accuracy calculation 354 | self.batch_forward(inputs, labels, phase='train') 355 | self.batch_loss(labels) 356 | self.batch_backward(print_grad=(step % self.training_opt['display_grad_step'] == 0)) 357 | 358 | # Tracking predictions 359 | _, preds = torch.max(self.logits, 1) 360 | total_preds.append(torch2numpy(preds)) 361 | total_labels.append(torch2numpy(labels)) 362 | 363 | # Output minibatch training results 364 | if step % self.training_opt['display_step'] == 0: 365 | 366 | records = dict() 367 | if 'RouteWeightLoss' in self.criterions: 368 | records['loss_route'] = self.loss_route.item() 369 | else: 370 | records["loss_route"] = 0. 371 | 372 | if 'PerformanceLoss' in self.criterions: 373 | records['loss_perf'] = self.loss_perf.item() 374 | 375 | records['loss'] = self.loss.item() 376 | 377 | minibatch_acc = mic_acc_cal(preds, labels) 378 | 379 | records['acc'] = minibatch_acc 380 | records['sum_exp_logits'] = torch.mean(self.sumexp_logits).item() 381 | 382 | print_str = ['Epoch: [%d/%d]' 383 | % (epoch, self.training_opt['num_epochs']), 384 | 'Step: %5d' 385 | % (step)] 386 | print_str.extend([ 387 | f'{key}: {val:.3f}' for key, val in records.items() 388 | ]) 389 | 390 | self.write_summary( 391 | split='train', 392 | step=self.current_step, 393 | **records 394 | ) 395 | print_write(print_str, self.log_file) 396 | 397 | loss_info = { 398 | 'Epoch': epoch, 399 | 'Step': step, 400 | 'Total': records['loss'], 401 | 'CE': records['loss_perf'], 402 | 'route': records['loss_route'], 403 | } 404 | 405 | self.logger.log_loss(loss_info) 406 | 407 | # batch-level: sampler update 408 | if hasattr(self.data['train'].sampler, 'update_weights'): 409 | if hasattr(self.data['train'].sampler, 'ptype'): 410 | ptype = self.data['train'].sampler.ptype 411 | else: 412 | ptype = 'score' 413 | ws = get_priority(ptype, self.logits.detach(), labels) 414 | 415 | inlist = [indexes.cpu().numpy(), ws] 416 | if self.training_opt['sampler']['type'] == 'ClassPrioritySampler': 417 | inlist.append(labels.cpu().numpy()) 418 | self.data['train'].sampler.update_weights(*inlist) 419 | 420 | # epoch-level: reset sampler weight 421 | if hasattr(self.data['train'].sampler, 'get_weights'): 422 | self.logger.log_ws(epoch, self.data['train'].sampler.get_weights()) 423 | if hasattr(self.data['train'].sampler, 'reset_weights'): 424 | self.data['train'].sampler.reset_weights(epoch) 425 | 426 | # After every epoch, validation 427 | rsls = {'epoch': epoch} 428 | rsls_train = self.eval_with_preds(total_preds, total_labels) 429 | rsls_eval = self.eval(phase='val') 430 | rsls.update(rsls_train) 431 | rsls.update(rsls_eval) 432 | 433 | # Reset class weights for sampling if pri_mode is valid 434 | if hasattr(self.data['train'].sampler, 'reset_priority'): 435 | ws = get_priority(self.data['train'].sampler.ptype, 436 | self.total_logits.detach(), 437 | self.total_labels) 438 | self.data['train'].sampler.reset_priority(ws, self.total_labels.cpu().numpy()) 439 | 440 | # Log results 441 | self.logger.log_acc(rsls) 442 | 443 | # Under validation, the best model need to be updated 444 | if self.eval_acc_mic_top1 > best_acc: 445 | best_epoch = epoch 446 | best_acc = self.eval_acc_mic_top1 447 | best_model_weights['feat_model'] = copy.deepcopy(self.networks['feat_model'].state_dict()) 448 | best_model_weights['classifier'] = copy.deepcopy(self.networks['classifier'].state_dict()) 449 | 450 | print('===> Saving checkpoint') 451 | self.save_latest(epoch) 452 | 453 | print() 454 | print('Training Complete.') 455 | 456 | print_str = ['Best validation accuracy is %.3f at epoch %d' % (best_acc, best_epoch)] 457 | print_write(print_str, self.log_file) 458 | # Save the best model 459 | self.save_model(epoch, best_epoch, best_model_weights, best_acc) 460 | 461 | # Test on the test set 462 | if 'CIFAR' not in self.training_opt["dataset"]: 463 | self.reset_model(best_model_weights) 464 | 465 | self.eval('test' if 'test' in self.data else 'val') 466 | print('Done') 467 | 468 | def eval_with_preds(self, preds, labels): 469 | # Count the number of examples 470 | n_total = sum([len(p) for p in preds]) 471 | 472 | # Split the examples into normal and mixup 473 | normal_preds, normal_labels = [], [] 474 | mixup_preds, mixup_labels1, mixup_labels2, mixup_ws = [], [], [], [] 475 | for p, l in zip(preds, labels): 476 | if isinstance(l, tuple): 477 | mixup_preds.append(p) 478 | mixup_labels1.append(l[0]) 479 | mixup_labels2.append(l[1]) 480 | mixup_ws.append(l[2] * np.ones_like(l[0])) 481 | else: 482 | normal_preds.append(p) 483 | normal_labels.append(l) 484 | 485 | # Calculate normal prediction accuracy 486 | rsl = {'train_all':0., 'train_many':0., 'train_median':0., 'train_low': 0.} 487 | if len(normal_preds) > 0: 488 | normal_preds, normal_labels = list(map(np.concatenate, [normal_preds, normal_labels])) 489 | n_top1 = mic_acc_cal(normal_preds, normal_labels) 490 | n_top1_many, \ 491 | n_top1_median, \ 492 | n_top1_low, = shot_acc(normal_preds, normal_labels, self.data['train']) 493 | rsl['train_all'] += len(normal_preds) / n_total * n_top1 494 | rsl['train_many'] += len(normal_preds) / n_total * n_top1_many 495 | rsl['train_median'] += len(normal_preds) / n_total * n_top1_median 496 | rsl['train_low'] += len(normal_preds) / n_total * n_top1_low 497 | 498 | # Calculate mixup prediction accuracy 499 | if len(mixup_preds) > 0: 500 | mixup_preds, mixup_labels, mixup_ws = \ 501 | list(map(np.concatenate, [mixup_preds*2, mixup_labels1+mixup_labels2, mixup_ws])) 502 | mixup_ws = np.concatenate([mixup_ws, 1-mixup_ws]) 503 | n_top1 = weighted_mic_acc_cal(mixup_preds, mixup_labels, mixup_ws) 504 | n_top1_many, \ 505 | n_top1_median, \ 506 | n_top1_low, = weighted_shot_acc(mixup_preds, mixup_labels, mixup_ws, self.data['train']) 507 | rsl['train_all'] += len(mixup_preds) / 2 / n_total * n_top1 508 | rsl['train_many'] += len(mixup_preds) / 2 / n_total * n_top1_many 509 | rsl['train_median'] += len(mixup_preds) / 2 / n_total * n_top1_median 510 | rsl['train_low'] += len(mixup_preds) / 2 / n_total * n_top1_low 511 | 512 | # Top-1 accuracy and additional string 513 | print_str = ['\n Training acc Top1: %.3f \n' % (rsl['train_all']), 514 | 'Many_top1: %.3f' % (rsl['train_many']), 515 | 'Median_top1: %.3f' % (rsl['train_median']), 516 | 'Low_top1: %.3f' % (rsl['train_low']), 517 | '\n'] 518 | print_write(print_str, self.log_file) 519 | 520 | return rsl 521 | 522 | def store_logits(self, phase): 523 | if phase not in self.data: 524 | print(f'No phase {phase}. Not storing logits.') 525 | return 526 | 527 | self.total_logits = torch.empty((0, self.training_opt['num_classes'])).cuda() 528 | self.total_labels = torch.empty(0, dtype=torch.long).cuda() 529 | # Iterate over dataset 530 | for model in self.networks.values(): 531 | model.eval() 532 | for inputs, labels, paths in tqdm(self.data[phase]): 533 | inputs, labels = inputs.cuda(), labels.cuda() 534 | 535 | # If on training phase, enable gradients 536 | with torch.set_grad_enabled(False): 537 | 538 | # In validation or testing 539 | self.batch_forward(inputs, labels, phase="val") 540 | 541 | if hasattr(self.networks["classifier"].module, "thresholds"): 542 | self.logits = self.route_logits - self.networks["classifier"].module.thresholds 543 | 544 | self.total_logits = torch.cat((self.total_logits, self.logits)) 545 | self.total_labels = torch.cat((self.total_labels, labels)) 546 | 547 | np.save(os.path.join(self.training_opt['log_dir'], 548 | f"{phase}_total_logits"), self.total_logits.cpu().data.numpy()) 549 | np.save(os.path.join(self.training_opt['log_dir'], 550 | f"{phase}_total_labels"), self.total_labels.cpu().data.numpy()) 551 | 552 | 553 | def eval(self, phase='val', save_feat=False): 554 | print_str = ['Phase: %s' % (phase)] 555 | print_write(print_str, self.log_file) 556 | 557 | if phase == "test": 558 | self.store_logits(phase="train") 559 | self.store_logits(phase="val") 560 | self.store_logits(phase="test") 561 | 562 | time.sleep(0.25) 563 | 564 | torch.cuda.empty_cache() 565 | 566 | # In validation or testing mode, set model to eval() and initialize running loss/correct 567 | for model in self.networks.values(): 568 | model.eval() 569 | 570 | self.total_logits = torch.empty((0, self.training_opt['num_classes'])).cuda() 571 | self.total_labels = torch.empty(0, dtype=torch.long).cuda() 572 | self.total_paths = np.empty(0) 573 | 574 | feats_all, labels_all, idxs_all, logits_all = [], [], [], [] 575 | featmaps_all = [] 576 | 577 | # feature saving initialization 578 | if save_feat: 579 | self.saving_feature_with_label_init() 580 | 581 | # Iterate over dataset 582 | for inputs, labels, paths in tqdm(self.data[phase]): 583 | inputs, labels = inputs.cuda(), labels.cuda() 584 | 585 | # If on training phase, enable gradients 586 | with torch.set_grad_enabled(False): 587 | 588 | # In validation or testing 589 | self.batch_forward(inputs, labels, phase=phase) 590 | # feature saving update 591 | if save_feat: 592 | self.saving_feature_with_label_update(self.features, self.logits, labels) 593 | 594 | if "Softmax" in self.config["criterions"]["PerformanceLoss"]["def_file"] and \ 595 | "RouteWeightLoss" not in self.config["criterions"] and \ 596 | "DotProductClassifier" in self.config["networks"]["classifier"]["def_file"]: 597 | self.logits -= torch.log(self.prior_distribution["train"]).cuda() 598 | else: 599 | self.logits += torch.log(self.prior_distribution[phase]).cuda() 600 | 601 | self.total_logits = torch.cat((self.total_logits, self.logits)) 602 | self.total_labels = torch.cat((self.total_labels, labels)) 603 | self.total_paths = np.concatenate((self.total_paths, paths)) 604 | 605 | # feature saving export 606 | if save_feat: 607 | self.saving_feature_with_label_export() 608 | 609 | probs, preds = F.softmax(self.total_logits.detach(), dim=1).max(dim=1) 610 | 611 | # Calculate the overall accuracy and F measurement 612 | self.eval_acc_mic_top1= mic_acc_cal(preds[self.total_labels != -1], 613 | self.total_labels[self.total_labels != -1]) 614 | self.eval_f_measure = F_measure(preds, self.total_labels, theta=self.training_opt['open_threshold']) 615 | 616 | self.many_acc_top1, \ 617 | self.median_acc_top1, \ 618 | self.low_acc_top1, \ 619 | self.cls_accs = shot_acc(preds[self.total_labels != -1], 620 | self.total_labels[self.total_labels != -1], 621 | self.data['train'], 622 | acc_per_cls=True) 623 | 624 | # Top-1 accuracy and additional string 625 | print_str = ['\n\n', 626 | 'Phase: %s' 627 | % (phase), 628 | '\n\n', 629 | 'Evaluation_accuracy_micro_top1: %.3f' 630 | % (self.eval_acc_mic_top1), 631 | '\n', 632 | 'Averaged F-measure: %.3f' 633 | % (self.eval_f_measure), 634 | '\n', 635 | 'Many_shot_accuracy_top1: %.3f' 636 | % (self.many_acc_top1), 637 | 'Median_shot_accuracy_top1: %.3f' 638 | % (self.median_acc_top1), 639 | 'Low_shot_accuracy_top1: %.3f' 640 | % (self.low_acc_top1), 641 | '\n'] 642 | 643 | rsl = {phase + '_all': self.eval_acc_mic_top1, 644 | phase + '_many': self.many_acc_top1, 645 | phase + '_median': self.median_acc_top1, 646 | phase + '_low': self.low_acc_top1, 647 | phase + '_fscore': self.eval_f_measure} 648 | 649 | 650 | if phase == 'val': 651 | print_write(print_str, self.log_file) 652 | self.write_summary( 653 | split='val', 654 | step=self.current_epoch, 655 | eval_acc_mic_top1=self.eval_acc_mic_top1, 656 | many_acc_top1=self.many_acc_top1, 657 | median_acc_top1=self.median_acc_top1, 658 | low_acc_top1=self.low_acc_top1, 659 | eval_f_measure=self.eval_f_measure, 660 | ) 661 | else: 662 | acc_str = ["{:.1f} \t {:.1f} \t {:.1f} \t {:.1f}".format( 663 | self.many_acc_top1 * 100, 664 | self.median_acc_top1 * 100, 665 | self.low_acc_top1 * 100, 666 | self.eval_acc_mic_top1 * 100)] 667 | if self.log_file is not None and os.path.exists(self.log_file): 668 | print_write(print_str, self.log_file) 669 | print_write(acc_str, self.log_file) 670 | else: 671 | print(*print_str) 672 | print(*acc_str) 673 | 674 | if phase == 'test': 675 | with open(os.path.join(self.training_opt['log_dir'], 'cls_accs.pkl'), 'wb') as f: 676 | pickle.dump(self.cls_accs, f) 677 | return rsl 678 | 679 | def reset_model(self, model_state): 680 | for key, model in self.networks.items(): 681 | weights = model_state[key] 682 | weights = {k: weights[k] for k in weights if k in model.state_dict()} 683 | model.load_state_dict(weights) 684 | 685 | def load_model(self, model_dir=None): 686 | model_dir = self.training_opt['log_dir'] if model_dir is None else model_dir 687 | 688 | if 'CIFAR' in self.training_opt['dataset']: 689 | # CIFARs don't have val set, so use the latest model 690 | print('Validation on the latest model.') 691 | if not model_dir.endswith('.pth'): 692 | model_dir = os.path.join(model_dir, 'latest_model_checkpoint.pth') 693 | else: 694 | print('Validation on the best model.') 695 | if not model_dir.endswith('.pth'): 696 | model_dir = os.path.join(model_dir, 'final_model_checkpoint.pth') 697 | 698 | print('Loading model from %s' % (model_dir)) 699 | checkpoint = torch.load(model_dir) 700 | if 'latest' in model_dir: 701 | model_state = checkpoint['state_dict'] 702 | else: 703 | model_state = checkpoint['state_dict_best'] 704 | 705 | for key, model in self.networks.items(): 706 | ########################################## 707 | # if loading classifier in training: 708 | # 1. only tuning memory embedding 709 | # 2. retrain the entire classifier 710 | ########################################## 711 | if 'embed' in checkpoint: 712 | print('============> Load Moving Average <===========') 713 | self.embed_mean = checkpoint['embed'] 714 | if not self.test_mode and 'Classifier' in self.config['networks'][key]['def_file']: 715 | if 'tuning_memory' in self.config and self.config['tuning_memory']: 716 | print('=============== WARNING! WARNING! ===============') 717 | print('========> Only Tuning Memory Embedding <========') 718 | for param_name, param in self.networks[key].named_parameters(): 719 | # frezing all params only tuning memory_embeding 720 | if 'embed' in param_name: 721 | param.requires_grad = True 722 | print('=====> Abandon Weight {} in {} from the checkpoints.'.format(param_name, key)) 723 | if param_name in model_state[key]: 724 | del model_state[key][param_name] 725 | else: 726 | param.requires_grad = False 727 | print('=====> Tuning: {} | {}'.format(str(param.requires_grad).ljust(5, ' '), param_name)) 728 | print('=================================================') 729 | else: 730 | # Skip classifier initialization 731 | #print('================ WARNING! WARNING! ================') 732 | print('=======> Load classifier from checkpoint <=======') 733 | #print('===================================================') 734 | #continue 735 | weights = model_state[key] 736 | weights = {k: weights[k] for k in weights if k in model.state_dict()} 737 | x = model.state_dict() 738 | x.update(weights) 739 | if all([weights[k].sum().item() == x[k].sum().item() for k in weights if k in x]): 740 | print('=====> All keys in weights have been loaded to the module {}'.format(key)) 741 | else: 742 | print('=====> Error! Error! Error! Error! Loading failure in module {}'.format(key)) 743 | model.load_state_dict(x) 744 | 745 | def save_latest(self, epoch): 746 | model_weights = {} 747 | model_weights['feat_model'] = copy.deepcopy(self.networks['feat_model'].state_dict()) 748 | model_weights['classifier'] = copy.deepcopy(self.networks['classifier'].state_dict()) 749 | 750 | model_states = { 751 | 'epoch': epoch, 752 | 'state_dict': model_weights, 753 | 'embed': self.embed_mean, 754 | } 755 | 756 | model_dir = os.path.join(self.training_opt['log_dir'], 757 | 'latest_model_checkpoint.pth') 758 | torch.save(model_states, model_dir) 759 | 760 | def save_model(self, epoch, best_epoch, best_model_weights, best_acc): 761 | 762 | model_states = {'epoch': epoch, 763 | 'best_epoch': best_epoch, 764 | 'state_dict_best': best_model_weights, 765 | 'best_acc': best_acc, 766 | 'embed': self.embed_mean,} 767 | 768 | model_dir = os.path.join(self.training_opt['log_dir'], 769 | 'final_model_checkpoint.pth') 770 | 771 | torch.save(model_states, model_dir) 772 | 773 | def output_logits(self): 774 | filename = os.path.join(self.training_opt['log_dir'], 'logits') 775 | print("Saving total logits to: %s.npz" % filename) 776 | np.savez(filename, 777 | logits=self.total_logits.detach().cpu().numpy(), 778 | labels=self.total_labels.detach().cpu().numpy(), 779 | paths=self.total_paths) 780 | 781 | def saving_feature_with_label_init(self): 782 | self.saving_feature_container = [] 783 | self.saving_logit_container = [] 784 | self.saving_label_container = [] 785 | 786 | 787 | def saving_feature_with_label_update(self, features, logits, labels): 788 | self.saving_feature_container.append(features.detach().cpu()) 789 | self.saving_logit_container.append(logits.detach().cpu()) 790 | self.saving_label_container.append(labels.detach().cpu()) 791 | 792 | 793 | def saving_feature_with_label_export(self): 794 | eval_features = {'features': torch.cat(self.saving_feature_container, dim=0).numpy(), 795 | 'labels': torch.cat(self.saving_label_container, dim=0).numpy(), 796 | 'logits': torch.cat(self.saving_logit_container, dim=0).numpy(), 797 | } 798 | 799 | eval_features_dir = os.path.join(self.training_opt['log_dir'], 800 | 'eval_features_with_labels.pth') 801 | torch.save(eval_features, eval_features_dir) 802 | print_write(['=====> Features with labels are saved as {}'.format(eval_features_dir)], self.log_file) 803 | 804 | 805 | def calculate_thresholds(self, phase): 806 | for model in self.networks.values(): 807 | model.eval() 808 | 809 | store_logits = [] 810 | 811 | if phase in self.data: 812 | 813 | for inputs, labels, paths in tqdm(self.data[phase]): 814 | inputs, labels = inputs.cuda(), labels.cuda() 815 | 816 | # If on training phase, enable gradients 817 | with torch.set_grad_enabled(False): 818 | 819 | # In validation or testing 820 | self.batch_forward(inputs, labels, phase="test") 821 | store_logits.append(self.route_logits) # route_logits: (B, C) 822 | 823 | store_logits = torch.cat(store_logits, dim=0) # (number_of_samples, C) 824 | thresholds = torch.logsumexp(store_logits, dim=0) - np.log(store_logits.size(0)) 825 | 826 | return thresholds.data.cpu().numpy() 827 | else: 828 | return None 829 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) Hyperconnect, Inc. and its affiliates. 2 | All rights reserved. 3 | 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | All rights reserved. 6 | 7 | This source code is licensed under the license found in the 8 | LICENSE file in the root directory of this source tree. 9 | 10 | Portions of the source code are from the OLTR project which 11 | notice below and in LICENSE in the root directory of 12 | this source tree. 13 | 14 | Copyright (c) 2019, Zhongqi Miao 15 | All rights reserved. 16 | """ 17 | 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | import torch 21 | from sklearn.metrics import f1_score 22 | import torch.nn.functional as F 23 | import importlib 24 | import pdb 25 | import math 26 | from bisect import bisect_right 27 | from collections import Counter 28 | 29 | def update(config, args, output_dir): 30 | # Change parameters 31 | config['model_dir'] = get_value(config['model_dir'], args.model_dir) 32 | config['training_opt']['batch_size'] = \ 33 | get_value(config['training_opt']['batch_size'], args.batch_size) 34 | config['networks']['classifier']['optim_params']['lr'] = \ 35 | get_value(config['networks']['classifier']['optim_params']['lr'], args.lr) 36 | 37 | # if using pretrained model, don't change lr 38 | if "caffe" in config["networks"]["feat_model"]["params"] and \ 39 | config["networks"]["feat_model"]["params"]["caffe"]: 40 | pass 41 | else: 42 | config['networks']['feat_model']['optim_params']['lr'] = \ 43 | get_value(config['networks']['feat_model']['optim_params']['lr'], args.lr) 44 | 45 | if 'cifar_imb_ratio' in config['training_opt']: 46 | config['training_opt']['cifar_imb_ratio'] = \ 47 | get_value(config['training_opt']['cifar_imb_ratio'], args.cifar_imb_ratio) 48 | config['training_opt']['log_dir'] = output_dir 49 | 50 | if "RouteWeightLoss" in config["criterions"].keys(): 51 | if "prior" in config["criterions"]["PerformanceLoss"]["loss_params"]: 52 | config['criterions']['PerformanceLoss']['loss_params']['prior'] = \ 53 | get_value(config['criterions']['PerformanceLoss']['loss_params']['prior'], args.cifar_imb_ratio) 54 | if "prior" in config["criterions"]["RouteWeightLoss"]["loss_params"]: 55 | config['criterions']['RouteWeightLoss']['loss_params']['prior'] = \ 56 | get_value(config['criterions']['RouteWeightLoss']['loss_params']['prior'], args.cifar_imb_ratio) 57 | if "exist_only" in config["criterions"]["RouteWeightLoss"]["loss_params"]: 58 | config['criterions']['RouteWeightLoss']['loss_params']['exist_only'] = \ 59 | get_value(config['criterions']['RouteWeightLoss']['loss_params']['exist_only'], bool(args.exist_only)) 60 | 61 | config['criterions']['RouteWeightLoss']['weight'] = \ 62 | get_value(config['criterions']['RouteWeightLoss']['weight'], args.alpha) 63 | config['criterions']['RouteWeightLoss']['loss_params']['remine_lambda'] = \ 64 | get_value(config['criterions']['RouteWeightLoss']['loss_params']['remine_lambda'], 65 | args.remine_lambda) 66 | 67 | if args.no_use_dv: 68 | config["criterions"].pop("RouteWeightLoss", None) 69 | 70 | return config 71 | 72 | 73 | def source_import(file_path): 74 | """This function imports python module directly from source code using importlib""" 75 | spec = importlib.util.spec_from_file_location('', file_path) 76 | module = importlib.util.module_from_spec(spec) 77 | spec.loader.exec_module(module) 78 | return module 79 | 80 | def batch_show(inp, title=None): 81 | """Imshow for Tensor.""" 82 | inp = inp.numpy().transpose((1, 2, 0)) 83 | mean = np.array([0.485, 0.456, 0.406]) 84 | std = np.array([0.229, 0.224, 0.225]) 85 | inp = std * inp + mean 86 | inp = np.clip(inp, 0, 1) 87 | plt.figure(figsize=(20,20)) 88 | plt.imshow(inp) 89 | if title is not None: 90 | plt.title(title) 91 | 92 | def print_write(print_str, log_file): 93 | print(*print_str) 94 | if log_file is None: 95 | return 96 | with open(log_file, 'a') as f: 97 | print(*print_str, file=f) 98 | 99 | def init_weights(model, weights_path, caffe=False, classifier=False): 100 | """Initialize weights""" 101 | print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model', 102 | weights_path)) 103 | weights = torch.load(weights_path) 104 | if not classifier: 105 | if caffe: 106 | weights = {k: weights[k] if k in weights else model.state_dict()[k] 107 | for k in model.state_dict()} 108 | else: 109 | weights = weights['state_dict_best']['feat_model'] 110 | weights = {k: weights['module.' + k] if 'module.' + k in weights else model.state_dict()[k] 111 | for k in model.state_dict()} 112 | else: 113 | weights = weights['state_dict_best']['classifier'] 114 | weights = {k: weights['module.fc.' + k] if 'module.fc.' + k in weights else model.state_dict()[k] 115 | for k in model.state_dict()} 116 | model.load_state_dict(weights) 117 | return model 118 | 119 | def shot_acc(preds, labels, train_data, many_shot_thr=100, low_shot_thr=20, acc_per_cls=False): 120 | 121 | if isinstance(train_data, np.ndarray): 122 | training_labels = np.array(train_data).astype(int) 123 | else: 124 | training_labels = np.array(train_data.dataset.labels).astype(int) 125 | 126 | if isinstance(preds, torch.Tensor): 127 | preds = preds.detach().cpu().numpy() 128 | labels = labels.detach().cpu().numpy() 129 | elif isinstance(preds, np.ndarray): 130 | pass 131 | else: 132 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 133 | train_class_count = [] 134 | test_class_count = [] 135 | class_correct = [] 136 | for l in np.unique(labels): 137 | train_class_count.append(len(training_labels[training_labels == l])) 138 | test_class_count.append(len(labels[labels == l])) 139 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 140 | 141 | many_shot = [] 142 | median_shot = [] 143 | low_shot = [] 144 | for i in range(len(train_class_count)): 145 | if train_class_count[i] > many_shot_thr: 146 | many_shot.append((class_correct[i] / test_class_count[i])) 147 | elif train_class_count[i] < low_shot_thr: 148 | low_shot.append((class_correct[i] / test_class_count[i])) 149 | else: 150 | median_shot.append((class_correct[i] / test_class_count[i])) 151 | 152 | if len(many_shot) == 0: 153 | many_shot.append(0) 154 | if len(median_shot) == 0: 155 | median_shot.append(0) 156 | if len(low_shot) == 0: 157 | low_shot.append(0) 158 | 159 | if acc_per_cls: 160 | class_accs = [c / cnt for c, cnt in zip(class_correct, test_class_count)] 161 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot), class_accs 162 | else: 163 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 164 | 165 | def weighted_shot_acc (preds, labels, ws, train_data, many_shot_thr=100, low_shot_thr=20): 166 | 167 | training_labels = np.array(train_data.dataset.labels).astype(int) 168 | 169 | if isinstance(preds, torch.Tensor): 170 | preds = preds.detach().cpu().numpy() 171 | labels = labels.detach().cpu().numpy() 172 | elif isinstance(preds, np.ndarray): 173 | pass 174 | else: 175 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 176 | train_class_count = [] 177 | test_class_count = [] 178 | class_correct = [] 179 | for l in np.unique(labels): 180 | train_class_count.append(len(training_labels[training_labels == l])) 181 | test_class_count.append(ws[labels==l].sum()) 182 | class_correct.append(((preds[labels==l] == labels[labels==l]) * ws[labels==l]).sum()) 183 | 184 | many_shot = [] 185 | median_shot = [] 186 | low_shot = [] 187 | for i in range(len(train_class_count)): 188 | if train_class_count[i] > many_shot_thr: 189 | many_shot.append((class_correct[i] / test_class_count[i])) 190 | elif train_class_count[i] < low_shot_thr: 191 | low_shot.append((class_correct[i] / test_class_count[i])) 192 | else: 193 | median_shot.append((class_correct[i] / test_class_count[i])) 194 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 195 | 196 | def F_measure(preds, labels, theta=None): 197 | # Regular f1 score 198 | return f1_score(labels.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='macro') 199 | 200 | def mic_acc_cal(preds, labels): 201 | if isinstance(labels, tuple): 202 | assert len(labels) == 3 203 | targets_a, targets_b, lam = labels 204 | acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ 205 | + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) 206 | else: 207 | acc_mic_top1 = (preds == labels).sum().item() / len(labels) 208 | return acc_mic_top1 209 | 210 | 211 | def weighted_mic_acc_cal(preds, labels, ws): 212 | acc_mic_top1 = ws[preds == labels].sum() / ws.sum() 213 | return acc_mic_top1 214 | 215 | def class_count (data): 216 | labels = np.array(data.dataset.labels) 217 | class_data_num = [] 218 | for l in np.unique(labels): 219 | class_data_num.append(len(labels[labels == l])) 220 | return class_data_num 221 | 222 | 223 | # New Added 224 | def torch2numpy(x): 225 | if isinstance(x, torch.Tensor): 226 | return x.detach().cpu().numpy() 227 | elif isinstance(x, (list, tuple)): 228 | return tuple([torch2numpy(xi) for xi in x]) 229 | else: 230 | return x 231 | 232 | def logits2score(logits, labels): 233 | scores = F.softmax(logits, dim=1) 234 | score = scores.gather(1, labels.view(-1, 1)) 235 | score = score.squeeze().cpu().numpy() 236 | return score 237 | 238 | 239 | def logits2entropy(logits): 240 | scores = F.softmax(logits, dim=1) 241 | scores = scores.cpu().numpy() + 1e-30 242 | ent = -scores * np.log(scores) 243 | ent = np.sum(ent, 1) 244 | return ent 245 | 246 | 247 | def logits2CE(logits, labels): 248 | scores = F.softmax(logits, dim=1) 249 | score = scores.gather(1, labels.view(-1, 1)) 250 | score = score.squeeze().cpu().numpy() + 1e-30 251 | ce = -np.log(score) 252 | return ce 253 | 254 | 255 | def get_priority(ptype, logits, labels): 256 | if ptype == 'score': 257 | ws = 1 - logits2score(logits, labels) 258 | elif ptype == 'entropy': 259 | ws = logits2entropy(logits) 260 | elif ptype == 'CE': 261 | ws = logits2CE(logits, labels) 262 | 263 | return ws 264 | 265 | def get_value(oldv, newv): 266 | if newv is not None: 267 | return newv 268 | else: 269 | return oldv 270 | 271 | 272 | # Tang Kaihua New Add 273 | def print_grad_norm(named_parameters, logger_func, log_file, verbose=False): 274 | if not verbose: 275 | return None 276 | 277 | total_norm = 0.0 278 | param_to_norm = {} 279 | param_to_shape = {} 280 | for n, p in named_parameters.items(): 281 | if p.grad is not None: 282 | param_norm = p.grad.norm(2) 283 | total_norm += param_norm ** 2 284 | param_to_norm[n] = param_norm 285 | param_to_shape[n] = p.size() 286 | 287 | total_norm = total_norm ** (1. / 2) 288 | 289 | logger_func(['----------Total norm {:.5f}-----------------'.format(total_norm)], log_file) 290 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 291 | logger_func(["{:<50s}: {:.5f}, ({})".format(name, norm, param_to_shape[name])], log_file) 292 | logger_func(['-------------------------------'], log_file) 293 | 294 | return total_norm 295 | 296 | def smooth_l1_loss(input, target, beta=1. / 9, reduction='mean'): 297 | n = torch.abs(input - target) 298 | cond = n < beta 299 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) 300 | if reduction == 'mean': 301 | return loss.mean() 302 | elif reduction == 'sum': 303 | return loss.sum() 304 | else: 305 | print('XXXXXX Error Reduction Type for smooth_l1_loss, use default mean') 306 | return loss.mean() 307 | 308 | 309 | def l2_loss(input, target, reduction='mean'): 310 | return F.mse_loss(input, target, reduction=reduction) 311 | 312 | 313 | def regression_loss(input, target, l2=False, pre_mean=True, l1=False, moving_average=False, moving_ratio=0.01): 314 | assert (l2 + l1 + moving_average) == 1 315 | if l2: 316 | if (input.shape[0] == target.shape[0]): 317 | assert not pre_mean 318 | loss = l2_loss(input, target.clone().detach()) 319 | else: 320 | assert pre_mean 321 | loss = l2_loss(input, target.clone().detach().mean(0, keepdim=True)) 322 | elif l1: 323 | loss = smooth_l1_loss(input, target.clone().detach()) 324 | elif moving_average: 325 | # input should be register_buffer rather than nn.Parameter 326 | with torch.no_grad(): 327 | input = (1 - moving_ratio) * input + moving_ratio * target.clone().detach().mean(0, keepdim=True) 328 | loss = None 329 | return loss 330 | 331 | def gumbel_softmax(logits, tau=1, hard=False, gumbel=True, dim=-1): 332 | if gumbel: 333 | gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1) 334 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 335 | y_soft = gumbels.softmax(dim) 336 | else: 337 | y_soft = logits.softmax(dim) 338 | 339 | if hard: 340 | # Straight through. 341 | index = y_soft.max(dim, keepdim=True)[1] 342 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) 343 | ret = y_hard - y_soft.detach() + y_soft 344 | else: 345 | # Reparametrization trick. 346 | ret = y_soft 347 | return ret 348 | 349 | 350 | def gumbel_sigmoid(logits, tau=1, hard=False, gumbel=True): 351 | if gumbel: 352 | gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1) 353 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 354 | y_soft = torch.sigmoid(gumbels) 355 | else: 356 | y_soft = torch.sigmoid(logits) 357 | 358 | if hard: 359 | # Straight through. 360 | y_hard = (y_soft > 0.5).float() 361 | ret = y_hard - y_soft.detach() + y_soft 362 | else: 363 | # Reparametrization trick. 364 | ret = y_soft 365 | return ret 366 | 367 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 368 | def __init__( 369 | self, 370 | optimizer, 371 | milestones, 372 | gamma=0.1, 373 | warmup_factor=1.0 / 3, 374 | warmup_epochs=5, 375 | warmup_method="linear", 376 | last_epoch=-1, 377 | ): 378 | if not list(milestones) == sorted(milestones): 379 | raise ValueError( 380 | "Milestones should be a list of" " increasing integers. Got {}", 381 | milestones, 382 | ) 383 | 384 | if warmup_method not in ("constant", "linear"): 385 | raise ValueError( 386 | "Only 'constant' or 'linear' warmup_method accepted" 387 | "got {}".format(warmup_method) 388 | ) 389 | self.milestones = milestones 390 | self.gamma = gamma 391 | self.warmup_factor = warmup_factor 392 | self.warmup_epochs = warmup_epochs 393 | self.warmup_method = warmup_method 394 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 395 | 396 | def get_lr(self): 397 | warmup_factor = 1 398 | if self.last_epoch < self.warmup_epochs: 399 | if self.warmup_method == "constant": 400 | warmup_factor = self.warmup_factor 401 | elif self.warmup_method == "linear": 402 | alpha = float(self.last_epoch) / self.warmup_epochs 403 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 404 | return [ 405 | base_lr 406 | * warmup_factor 407 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 408 | for base_lr in self.base_lrs 409 | ] 410 | 411 | 412 | def calculate_prior(num_classes, img_max=None, prior=None, prior_txt=None, reverse=False, return_num=False): 413 | if prior_txt: 414 | labels = [] 415 | with open(prior_txt) as f: 416 | for line in f: 417 | labels.append(int(line.split()[1])) 418 | occur_dict = dict(Counter(labels)) 419 | img_num_per_cls = [occur_dict[i] for i in range(num_classes)] 420 | else: 421 | img_num_per_cls = [] 422 | for cls_idx in range(num_classes): 423 | if reverse: 424 | num = img_max * (prior ** ((num_classes - 1 - cls_idx) / (num_classes - 1.0))) 425 | else: 426 | num = img_max * (prior ** (cls_idx / (num_classes - 1.0))) 427 | img_num_per_cls.append(int(num)) 428 | img_num_per_cls = torch.Tensor(img_num_per_cls) 429 | 430 | if return_num: 431 | return img_num_per_cls 432 | else: 433 | return img_num_per_cls / img_num_per_cls.sum() 434 | --------------------------------------------------------------------------------