├── .gitignore ├── README.md ├── cce.py ├── license.txt └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm,jupyternotebooks 132 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,jupyternotebooks 133 | 134 | ### JupyterNotebooks ### 135 | # gitignore template for Jupyter Notebooks 136 | # website: http://jupyter.org/ 137 | 138 | .ipynb_checkpoints 139 | */.ipynb_checkpoints/* 140 | 141 | # IPython 142 | profile_default/ 143 | ipython_config.py 144 | 145 | # Remove previous ipynb_checkpoints 146 | # git rm -r .ipynb_checkpoints/ 147 | 148 | ### PyCharm ### 149 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 150 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 151 | 152 | # User-specific stuff 153 | .idea/**/workspace.xml 154 | .idea/**/tasks.xml 155 | .idea/**/usage.statistics.xml 156 | .idea/**/dictionaries 157 | .idea/**/shelf 158 | 159 | # AWS User-specific 160 | .idea/**/aws.xml 161 | 162 | # Generated files 163 | .idea/**/contentModel.xml 164 | 165 | # Sensitive or high-churn files 166 | .idea/**/dataSources/ 167 | .idea/**/dataSources.ids 168 | .idea/**/dataSources.local.xml 169 | .idea/**/sqlDataSources.xml 170 | .idea/**/dynamic.xml 171 | .idea/**/uiDesigner.xml 172 | .idea/**/dbnavigator.xml 173 | 174 | # Gradle 175 | .idea/**/gradle.xml 176 | .idea/**/libraries 177 | 178 | # Gradle and Maven with auto-import 179 | # When using Gradle or Maven with auto-import, you should exclude module files, 180 | # since they will be recreated, and may cause churn. Uncomment if using 181 | # auto-import. 182 | # .idea/artifacts 183 | # .idea/compiler.xml 184 | # .idea/jarRepositories.xml 185 | # .idea/modules.xml 186 | # .idea/*.iml 187 | # .idea/modules 188 | # *.iml 189 | # *.ipr 190 | 191 | # CMake 192 | cmake-build-*/ 193 | 194 | # Mongo Explorer plugin 195 | .idea/**/mongoSettings.xml 196 | 197 | # File-based project format 198 | *.iws 199 | 200 | # IntelliJ 201 | out/ 202 | 203 | # mpeltonen/sbt-idea plugin 204 | .idea_modules/ 205 | 206 | # JIRA plugin 207 | atlassian-ide-plugin.xml 208 | 209 | # Cursive Clojure plugin 210 | .idea/replstate.xml 211 | 212 | # Crashlytics plugin (for Android Studio and IntelliJ) 213 | com_crashlytics_export_strings.xml 214 | crashlytics.properties 215 | crashlytics-build.properties 216 | fabric.properties 217 | 218 | # Editor-based Rest Client 219 | .idea/httpRequests 220 | 221 | # Android studio 3.1+ serialized cache file 222 | .idea/caches/build_file_checksums.ser 223 | 224 | ### PyCharm Patch ### 225 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 226 | 227 | # *.iml 228 | # modules.xml 229 | # .idea/misc.xml 230 | # *.ipr 231 | 232 | # Sonarlint plugin 233 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 234 | .idea/**/sonarlint/ 235 | 236 | # SonarQube Plugin 237 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 238 | .idea/**/sonarIssues.xml 239 | 240 | # Markdown Navigator plugin 241 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 242 | .idea/**/markdown-navigator.xml 243 | .idea/**/markdown-navigator-enh.xml 244 | .idea/**/markdown-navigator/ 245 | 246 | # Cache file creation bug 247 | # See https://youtrack.jetbrains.com/issue/JBR-2257 248 | .idea/$CACHE_FILE$ 249 | 250 | # CodeStream plugin 251 | # https://plugins.jetbrains.com/plugin/12206-codestream 252 | .idea/codestream.xml 253 | 254 | # End of https://www.toptal.com/developers/gitignore/api/pycharm,jupyternotebooks 255 | 256 | 257 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode 258 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode 259 | 260 | ### VisualStudioCode ### 261 | .vscode/* 262 | !.vscode/settings.json 263 | !.vscode/tasks.json 264 | !.vscode/launch.json 265 | !.vscode/extensions.json 266 | *.code-workspace 267 | 268 | # Local History for Visual Studio Code 269 | .history/ 270 | 271 | ### VisualStudioCode Patch ### 272 | # Ignore all local history of files 273 | .history 274 | .ionide 275 | 276 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode 277 | 278 | 279 | # 280 | *.pt 281 | *.pth 282 | *.h5 283 | 284 | # Created by https://www.toptal.com/developers/gitignore/api/images 285 | # Edit at https://www.toptal.com/developers/gitignore?templates=images 286 | 287 | ### Images ### 288 | # JPEG 289 | *.jpg 290 | *.jpeg 291 | *.jpe 292 | *.jif 293 | *.jfif 294 | *.jfi 295 | 296 | # JPEG 2000 297 | *.jp2 298 | *.j2k 299 | *.jpf 300 | *.jpx 301 | *.jpm 302 | *.mj2 303 | 304 | # JPEG XR 305 | *.jxr 306 | *.hdp 307 | *.wdp 308 | 309 | # Graphics Interchange Format 310 | *.gif 311 | 312 | # RAW 313 | *.raw 314 | 315 | # Web P 316 | *.webp 317 | 318 | # Portable Network Graphics 319 | *.png 320 | 321 | # Animated Portable Network Graphics 322 | *.apng 323 | 324 | # Multiple-image Network Graphics 325 | *.mng 326 | 327 | # Tagged Image File Format 328 | *.tiff 329 | *.tif 330 | 331 | # Scalable Vector Graphics 332 | *.svg 333 | *.svgz 334 | 335 | # Portable Document Format 336 | *.pdf 337 | 338 | # X BitMap 339 | *.xbm 340 | 341 | # BMP 342 | *.bmp 343 | *.dib 344 | 345 | # ICO 346 | *.ico 347 | 348 | # 3D Images 349 | *.3dm 350 | *.max 351 | 352 | # End of https://www.toptal.com/developers/gitignore/api/images 353 | 354 | *.csv 355 | *.mat 356 | 357 | 358 | *.pkl 359 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Imbalanced Image Classification with Complement Cross Entropy (Pytorch) 2 | **[Yechan Kim](https://github.com/unique-chan), [Younkwan Lee](https://github.com/brightyoun), and [Moongu Jeon](https://scholar.google.co.kr/citations?user=zfngGSkAAAAJ&hl=ko&oi=ao)** 3 | 4 | [Cite this paper](https://doi.org/10.1016/j.patrec.2021.07.017) 5 | 6 | ## News: 7 | - (06/2022) Now, you can easily try our loss function with **[Holocron](https://github.com/frgfm/Holocron)**. Holocron includes implementations of recent Deep Learning tricks in computer vision, easily paired up with your favorite framework and model zoo. 8 | - (08/2021) Our paper is accepted to ***Pattern Recognition Letters*** 🎉. 9 | 10 | 11 | ## This repository contains: 12 | - Complement Cross Entropy (code) 13 | - For simplicity, classification code is provided separately in this [GitHub repo 🖱️](https://github.com/unique-chan/Simple-Image-Classification): you can easily use `Complement Cross Entropy` by passing `--loss_function='CCE'` for executing `train.py`. For details, please visit the above repository. 14 | 15 | ## Prerequisites 16 | * See requirements.txt 17 | ``` 18 | torch 19 | torchvision 20 | ``` 21 | 22 | ## Code 23 | ```python 24 | class CCE(nn.Module): 25 | def __init__(self, device, balancing_factor=1): 26 | super(CCE, self).__init__() 27 | self.nll_loss = nn.NLLLoss() 28 | self.device = device # {'cpu', 'cuda:0', 'cuda:1', ...} 29 | self.balancing_factor = balancing_factor 30 | 31 | def forward(self, yHat, y): 32 | # Note: yHat.shape[1] <=> number of classes 33 | batch_size = len(y) 34 | # cross entropy 35 | cross_entropy = self.nll_loss(F.log_softmax(yHat, dim=1), y) 36 | # complement entropy 37 | yHat = F.softmax(yHat, dim=1) 38 | Yg = yHat.gather(dim=1, index=torch.unsqueeze(y, 1)) 39 | Px = yHat / (1 - Yg) + 1e-7 40 | Px_log = torch.log(Px + 1e-10) 41 | y_zerohot = torch.ones(batch_size, yHat.shape[1]).scatter_( 42 | 1, y.view(batch_size, 1).data.cpu(), 0) 43 | output = Px * Px_log * y_zerohot.to(device=self.device) 44 | complement_entropy = torch.sum(output) / (float(batch_size) * float(yHat.shape[1])) 45 | 46 | return cross_entropy - self.balancing_factor * complement_entropy 47 | ``` 48 | 49 | ## Citation 50 | If you use this code for your research, please cite the following paper: 51 | ~~~ME 52 | @article{kim2021imbalanced, 53 | title={Imbalanced image classification with complement cross entropy}, 54 | author={Kim, Yechan and Lee, Younkwan and Jeon, Moongu}, 55 | journal={Pattern Recognition Letters}, 56 | volume={151}, 57 | pages={33--40}, 58 | year={2021}, 59 | publisher={Elsevier} 60 | } 61 | ~~~ 62 | 63 | ## Contribution 64 | If you find any bugs or have opinions for further improvements, please feel free to create a pull request or contact me (yechankim@gm.gist.ac.kr). All contributions are welcome. 65 | 66 | ## Reference 67 | 1. Hao-Yun Chen, Pei-Hsin Wang, Chun-Hao Liu, Shih-Chieh Chang, Jia-Yu Pan, Yu-Ting Chen, Wei Wei, and Da-Cheng Juan. Complement objective training. arXiv preprint arXiv:1903.01182, 2019. 68 | 2. Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Doll ́ar.Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision, pages 2980–2988, 2017. 69 | 3. Tong He, Zhi Zhang, Hang Zhang, Zhongyue Zhang, Junyuan Xie, andMu Li. Bag of tricks for image classification with convolutional neuralnetworks. InProceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 558–567, 2019. 70 | 4. https://github.com/calmisential/Basic_CNNs_TensorFlow2 71 | 5. https://github.com/Hsuxu/Loss_ToolBox-PyTorch 72 | 6. https://github.com/weiaicunzai/pytorch-cifar100 73 | -------------------------------------------------------------------------------- /cce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CCE(nn.Module): 7 | def __init__(self, device, balancing_factor=1): 8 | super(CCE, self).__init__() 9 | self.nll_loss = nn.NLLLoss() 10 | self.device = device # {'cpu', 'cuda:0', 'cuda:1', ...} 11 | self.balancing_factor = balancing_factor 12 | 13 | def forward(self, yHat, y): 14 | # Note: yHat.shape[1] <=> number of classes 15 | batch_size = len(y) 16 | # cross entropy 17 | cross_entropy = self.nll_loss(F.log_softmax(yHat, dim=1), y) 18 | # complement entropy 19 | yHat = F.softmax(yHat, dim=1) 20 | Yg = yHat.gather(dim=1, index=torch.unsqueeze(y, 1)) 21 | Px = yHat / (1 - Yg) + 1e-7 22 | Px_log = torch.log(Px + 1e-10) 23 | y_zerohot = torch.ones(batch_size, yHat.shape[1]).scatter_( 24 | 1, y.view(batch_size, 1).data.cpu(), 0) 25 | output = Px * Px_log * y_zerohot.to(device=self.device) 26 | complement_entropy = torch.sum(output) / (float(batch_size) * float(yHat.shape[1])) 27 | 28 | return cross_entropy - self.balancing_factor * complement_entropy 29 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yechan Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | 4 | --------------------------------------------------------------------------------