├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── code-style.yml │ ├── flake8-lint.yml │ └── python-publish.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── _config.yml ├── examples ├── gctf_horses_v_humans.ipynb └── gctf_mnist.ipynb ├── gctf ├── __init__.py ├── centralized_gradients.py ├── optimizers.py └── version.py ├── images ├── gctf.png ├── illutstration.png └── projected_grad.png └── setup.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## :camera: Screenshots 2 | 3 | 4 | ## :page_facing_up: Context 5 | 6 | 7 | ## :pencil: Changes 8 | 9 | 10 | ## :paperclip: Related PR 11 | 12 | 13 | ## :no_entry_sign: Breaking 14 | 15 | 16 | ## :stopwatch: Next steps 17 | 18 | -------------------------------------------------------------------------------- /.github/workflows/code-style.yml: -------------------------------------------------------------------------------- 1 | name: Format python code 2 | on: push 3 | jobs: 4 | autopep8: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - name: autopep8 9 | uses: peter-evans/autopep8@v1 10 | with: 11 | args: --recursive --in-place --aggressive --aggressive . 12 | - name: Create Pull Request 13 | uses: peter-evans/create-pull-request@v3 14 | with: 15 | commit-message: autopep8 action fixes 16 | title: Fixes by autopep8 action 17 | body: This is an auto-generated PR with fixes by autopep8. 18 | labels: autopep8 19 | reviewers: Rishit-dagli 20 | branch: autopep8-patches 21 | -------------------------------------------------------------------------------- /.github/workflows/flake8-lint.yml: -------------------------------------------------------------------------------- 1 | name: Flake8 Lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.7, 3.8, 3.9] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install flake8 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | workflow_dispatch: 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - name: Set up Python 🐍 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: '3.x' 23 | 24 | - name: Cache pip 25 | uses: actions/cache@v2 26 | with: 27 | path: ~/.cache/pip 28 | # Look to see if there is a cache hit for the corresponding requirements file 29 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 30 | restore-keys: | 31 | ${{ runner.os }}-pip- 32 | ${{ runner.os }}- 33 | 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install setuptools wheel twine 38 | 39 | - name: Build and publish 🚀 40 | env: 41 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 42 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 43 | run: | 44 | python setup.py sdist bdist_wheel 45 | twine upload dist/* 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at rishit.dagli@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Feeling Awesome! Thanks for thinking about this. 2 | 3 | You can contribute us by filing issues, bugs and PRs. You can also take a look at active issues and fix them. We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 4 | 5 | - Reporting a bug 6 | - Discussing the current state of the code 7 | - Submitting a fix 8 | - Proposing new features 9 | - Becoming a maintainer 10 | 11 | If you want to discuss on something then feel free to present your opinions, views or any other relevant comment on [discussions](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/discussions). 12 | Please note we have a [code of conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project. 13 | 14 | ### Code contribution 15 | 16 | - Open issue regarding proposed change. 17 | - If your proposed change is approved, Fork this repo and do changes. 18 | - Open PR against latest *development* branch. Add nice description in PR. 19 | - You're done! 20 | 21 | ### Code contribution checklist 22 | 23 | - New code addition/deletion should not break existing flow of a system. 24 | - Ensure any install or build dependencies are removed before the end of the layer when doing a build. 25 | - Update the [README.md](README.md) if the change requires with details of changes to the interface. 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | recursive-include gctf *.py 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient Centralization TensorFlow [![Twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Fgithub.com%2FRishit-dagli%2FGradient-Centralization-TensorFlow)](https://twitter.com/intent/tweet?text=Wow:&url=https%3A%2F%2Fgithub.com%2FRishit-dagli%2FGradient-Centralization-TensorFlow) 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/gradient-centralization-tf)](https://pypi.org/project/gradient-centralization-tf/) 4 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4570279.svg)](https://doi.org/10.5281/zenodo.4570279) 5 | [![Upload Python Package](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/actions/workflows/python-publish.yml/badge.svg)](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/actions/workflows/python-publish.yml) 6 | [![Flake8 Lint](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/actions/workflows/flake8-lint.yml/badge.svg)](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/actions/workflows/flake8-lint.yml) 7 | ![Python Version](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue) 8 | 9 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Rishit-dagli/Gradient-Centralization-TensorFlow/HEAD) 10 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow) 11 | 12 | [![GitHub license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) 13 | [![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/) 14 | [![GitHub stars](https://img.shields.io/github/stars/Rishit-dagli/Gradient-Centralization-TensorFlow?style=social)](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/stargazers) 15 | [![GitHub forks](https://img.shields.io/github/forks/Rishit-dagli/Gradient-Centralization-TensorFlow?style=social)](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/network) 16 | [![GitHub watchers](https://img.shields.io/github/watchers/Rishit-dagli/Gradient-Centralization-TensorFlow?style=social)](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/watchers) 17 | 18 | This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique for 19 | Deep Neural Networks as suggested by Yong et al. in the paper 20 | [Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/abs/2004.01461). It can both speedup training 21 | process and improve the final generalization performance of DNNs. 22 | 23 | ![](images/gctf.png) 24 | 25 | ## Installation 26 | 27 | Run the following to install: 28 | 29 | ```bash 30 | pip install gradient-centralization-tf 31 | ``` 32 | 33 | ## About the Examples 34 | 35 | ### [`gctf_mnist.ipynb`](examples/gctf_mnist.ipynb) 36 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_mnist.ipynb) 37 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Rishit-dagli/Gradient-Centralization-TensorFlow/c4c1b0f947b0ae6de0a2964b2fcb5c37faa6c72b?filepath=examples%2Fgctf_mnist.ipynb) 38 | 39 | This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) 40 | Python package to train on the [Fashion MNIST](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist) 41 | dataset availaible from [`tf.keras.datasets`](https://www.tensorflow.org/api_docs/python/tf/keras/datasets). It further 42 | also compares using `gctf` and performance without using `gctf`. 43 | 44 | ### [`gctf_horses_v_humans.ipynb`](examples/gctf_horses_v_humans.ipynb) 45 | 46 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_horses_v_humans.ipynb) 47 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Rishit-dagli/Gradient-Centralization-TensorFlow/c4c1b0f947b0ae6de0a2964b2fcb5c37faa6c72b?filepath=examples%2Fgctf_horses_v_humans.ipynb) 48 | 49 | This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) 50 | Python package to train on the [Horses vs Humans](http://www.laurencemoroney.com/horses-or-humans-dataset/) dataset by 51 | [Laurence Moroney](https://twitter.com/lmoroney). It further also compares using `gctf` and performance without using 52 | `gctf`. 53 | 54 | ## Usage 55 | 56 | ### [`gctf.centralized_gradients_for_optimizer`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/gctf/centralized_gradients.py#L45-L55) 57 | 58 | Create a centralized gradients functions for a specified optimizer. 59 | 60 | #### Arguments: 61 | - `optimizer`: a `tf.keras.optimizers.Optimizer` object. The optimizer you are using. 62 | 63 | #### Example: 64 | 65 | ```py 66 | >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) 67 | >>> opt.get_gradients = gctf.centralized_gradients_for_optimizer(opt) 68 | >>> model.compile(optimizer = opt, ...) 69 | ``` 70 | 71 | #### Returns: 72 | A `tf.keras.optimizers.Optimizer` object. 73 | 74 | ### [`gctf.get_centralized_gradients`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/a7c5226dad86ca42341061e3fafc8c8d1ec3f51f/gctf/centralized_gradients.py#L5-L42) 75 | 76 | Computes the centralized gradients. 77 | 78 | This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you 79 | could point `get_gradients` to this function. This is a modified version of 80 | `tf.keras.optimizers.Optimizer.get_gradients`. 81 | 82 | #### Arguments: 83 | - `optimizer`: a `tf.keras.optimizers.Optimizer` object. The optimizer you are using. 84 | - `loss`: Scalar tensor to minimize. 85 | - `params`: List of variables. 86 | 87 | #### Returns: 88 | A gradients tensor. 89 | 90 | ### [`gctf.optimizers`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/gctf/optimizers.py) 91 | 92 | Pre built updated optimizers implementing GC. 93 | 94 | This module is speciially built for testing out GC and in most cases you would be using [`gctf.centralized_gradients_for_optimizer`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow#gctfcentralized_gradients_for_optimizer) though this module implements `gctf.centralized_gradients_for_optimizer`. You can directly use all optimizers with [`tf.keras.optimizers`](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) updated for GC. 95 | 96 | #### Example: 97 | 98 | ```py 99 | >>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...) 100 | >>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...) 101 | >>> model.compile(optimizer = gctf.optimizers.sgd(), ...) 102 | ``` 103 | 104 | #### Returns: 105 | A `tf.keras.optimizers.Optimizer` object. 106 | 107 | ## Developing `gctf` 108 | 109 | To install `gradient-centralization-tf`, along with tools you need to develop and test, run the following in your 110 | virtualenv: 111 | 112 | ```bash 113 | git clone https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow 114 | # or clone your own fork 115 | 116 | pip install -e .[dev] 117 | ``` 118 | 119 | ## Want to Contribute 🙋‍♂️? 120 | 121 | Awesome! If you want to contribute to this project, you're always welcome! See [Contributing Guidelines](CONTRIBUTING.md). You can also take a look at [open issues](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/issues) for getting more information about current or upcoming tasks. 122 | 123 | ## Want to discuss? 💬 124 | 125 | Have any questions, doubts or want to present your opinions, views? You're always welcome. You can [start discussions](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/discussions). 126 | 127 | ## License 128 | 129 | ``` 130 | Copyright 2020 Rishit Dagli 131 | 132 | Licensed under the Apache License, Version 2.0 (the "License"); 133 | you may not use this file except in compliance with the License. 134 | You may obtain a copy of the License at 135 | 136 | http://www.apache.org/licenses/LICENSE-2.0 137 | 138 | Unless required by applicable law or agreed to in writing, software 139 | distributed under the License is distributed on an "AS IS" BASIS, 140 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 141 | See the License for the specific language governing permissions and 142 | limitations under the License. 143 | ``` 144 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /examples/gctf_horses_v_humans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "gctf-horses-v-humans.ipynb", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "IieqPex4O3eJ" 31 | }, 32 | "source": [ 33 | "# GCTF Horses vs Humans\r\n", 34 | "\r\n", 35 | "This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) Python package to train on the [Horses vs Humans](http://www.laurencemoroney.com/horses-or-humans-dataset/) dataset by [Laurence Moroney](https://twitter.com/lmoroney). Gradient Centralization is a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper \r\n", 36 | "[Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/abs/2004.01461). It can both speedup training \r\n", 37 | " process and improve the final generalization performance of DNNs.\r\n", 38 | "\r\n", 39 | "If you find this useful please consider giving a ⭐ to [the repo](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/).\r\n", 40 | "\r\n", 41 | "## A bit about GC\r\n", 42 | "\r\n", 43 | "Gradient Centralization operates directly on gradients by centralizing the gradient vectors to have zero mean. It can both speedup training process and improve the final generalization performance of DNNs. Here is an Illustration of the GC operation on gradient matrix/tensor of weights in the fully-connected layer (left) and convolutional layer (right). GC computes the column/slice mean of gradient matrix/tensor and centralizes each column/slice to have zero mean.\r\n", 44 | "\r\n", 45 | "![](https://i.imgur.com/KitoO8J.png)\r\n", 46 | "\r\n", 47 | "GC can be viewed as a projected gradient descent method with a constrained loss function. The geometrical interpretation of GC. The gradient is projected on a hyperplane $e^T(w-w^t)=0$, where the projected gradient is used to update the weight.\r\n", 48 | "\r\n", 49 | "![](https://i.imgur.com/ekHhQv0.png)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "WcpkfjkwSjmv" 56 | }, 57 | "source": [ 58 | "## Setup" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "XjTQqCEaTSSG" 65 | }, 66 | "source": [ 67 | "import tensorflow as tf\r\n", 68 | "from time import time" 69 | ], 70 | "execution_count": 1, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "32NuQrDdSpP3" 77 | }, 78 | "source": [ 79 | "### Install the package\r\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "colab": { 86 | "base_uri": "https://localhost:8080/" 87 | }, 88 | "id": "Qce8MP-42vkt", 89 | "outputId": "571100f0-5735-42eb-d9c5-1868db53e755" 90 | }, 91 | "source": [ 92 | "!pip install gradient-centralization-tf" 93 | ], 94 | "execution_count": 2, 95 | "outputs": [ 96 | { 97 | "output_type": "stream", 98 | "text": [ 99 | "Collecting gradient-centralization-tf\n", 100 | " Downloading https://files.pythonhosted.org/packages/58/4c/6253587b8f6ccdf03fd4830de2574cbda48a1a84bc660d5dd8978d0f94fb/gradient_centralization_tf-0.0.2-py3-none-any.whl\n", 101 | "Requirement already satisfied: tensorflow~=2.4.0 in /usr/local/lib/python3.6/dist-packages (from gradient-centralization-tf) (2.4.1)\n", 102 | "Requirement already satisfied: keras~=2.4.0 in /usr/local/lib/python3.6/dist-packages (from gradient-centralization-tf) (2.4.3)\n", 103 | "Requirement already satisfied: h5py~=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.10.0)\n", 104 | "Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.1.2)\n", 105 | "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.3.3)\n", 106 | "Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.2.0)\n", 107 | "Requirement already satisfied: grpcio~=1.32.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.32.0)\n", 108 | "Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.3.0)\n", 109 | "Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.7.4.3)\n", 110 | "Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.10.0)\n", 111 | "Requirement already satisfied: numpy~=1.19.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.19.5)\n", 112 | "Requirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.15.0)\n", 113 | "Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.1.0)\n", 114 | "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.12.4)\n", 115 | "Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.36.2)\n", 116 | "Requirement already satisfied: tensorflow-estimator<2.5.0,>=2.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.4.0)\n", 117 | "Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.12.1)\n", 118 | "Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.12)\n", 119 | "Requirement already satisfied: tensorboard~=2.4 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.4.1)\n", 120 | "Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.6.3)\n", 121 | "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from keras~=2.4.0->gradient-centralization-tf) (1.4.1)\n", 122 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras~=2.4.0->gradient-centralization-tf) (3.13)\n", 123 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.9.2->tensorflow~=2.4.0->gradient-centralization-tf) (53.0.0)\n", 124 | "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.8.0)\n", 125 | "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.4.2)\n", 126 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.3.3)\n", 127 | "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2.23.0)\n", 128 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.0.1)\n", 129 | "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.25.0)\n", 130 | "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.3.0)\n", 131 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.4.0)\n", 132 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.24.3)\n", 133 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2020.12.5)\n", 134 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.0.4)\n", 135 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2.10)\n", 136 | "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (4.2.1)\n", 137 | "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.2.8)\n", 138 | "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (4.7)\n", 139 | "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.1.0)\n", 140 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.4.0)\n", 141 | "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.4.8)\n", 142 | "Installing collected packages: gradient-centralization-tf\n", 143 | "Successfully installed gradient-centralization-tf-0.0.2\n" 144 | ], 145 | "name": "stdout" 146 | } 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": { 152 | "id": "-lVvc6uKHDYI" 153 | }, 154 | "source": [ 155 | "## Get the data" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "metadata": { 161 | "id": "5lEoHGCV61Bv", 162 | "colab": { 163 | "base_uri": "https://localhost:8080/" 164 | }, 165 | "outputId": "87e66434-eaba-4892-d2b1-99972e63305a" 166 | }, 167 | "source": [ 168 | "!wget --no-check-certificate \\\r\n", 169 | " https://storage.googleapis.com/laurencemoroney-blog.appspot.com/horse-or-human.zip \\\r\n", 170 | " -O /tmp/horse-or-human.zip\r\n", 171 | "\r\n", 172 | "!wget --no-check-certificate \\\r\n", 173 | " https://storage.googleapis.com/laurencemoroney-blog.appspot.com/validation-horse-or-human.zip \\\r\n", 174 | " -O /tmp/validation-horse-or-human.zip\r\n", 175 | " \r\n", 176 | "import os\r\n", 177 | "import zipfile\r\n", 178 | "\r\n", 179 | "local_zip = '/tmp/horse-or-human.zip'\r\n", 180 | "zip_ref = zipfile.ZipFile(local_zip, 'r')\r\n", 181 | "zip_ref.extractall('/tmp/horse-or-human')\r\n", 182 | "local_zip = '/tmp/validation-horse-or-human.zip'\r\n", 183 | "zip_ref = zipfile.ZipFile(local_zip, 'r')\r\n", 184 | "zip_ref.extractall('/tmp/validation-horse-or-human')\r\n", 185 | "zip_ref.close()\r\n", 186 | "# Directory with our training horse pictures\r\n", 187 | "train_horse_dir = os.path.join('/tmp/horse-or-human/horses')\r\n", 188 | "\r\n", 189 | "# Directory with our training human pictures\r\n", 190 | "train_human_dir = os.path.join('/tmp/horse-or-human/humans')\r\n", 191 | "\r\n", 192 | "# Directory with our training horse pictures\r\n", 193 | "validation_horse_dir = os.path.join('/tmp/validation-horse-or-human/horses')\r\n", 194 | "\r\n", 195 | "# Directory with our training human pictures\r\n", 196 | "validation_human_dir = os.path.join('/tmp/validation-horse-or-human/humans')" 197 | ], 198 | "execution_count": 3, 199 | "outputs": [ 200 | { 201 | "output_type": "stream", 202 | "text": [ 203 | "--2021-02-21 12:08:31-- https://storage.googleapis.com/laurencemoroney-blog.appspot.com/horse-or-human.zip\n", 204 | "Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.115.128, 172.253.122.128, 142.250.31.128, ...\n", 205 | "Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.115.128|:443... connected.\n", 206 | "HTTP request sent, awaiting response... 200 OK\n", 207 | "Length: 149574867 (143M) [application/zip]\n", 208 | "Saving to: ‘/tmp/horse-or-human.zip’\n", 209 | "\n", 210 | "/tmp/horse-or-human 100%[===================>] 142.65M 262MB/s in 0.5s \n", 211 | "\n", 212 | "2021-02-21 12:08:31 (262 MB/s) - ‘/tmp/horse-or-human.zip’ saved [149574867/149574867]\n", 213 | "\n", 214 | "--2021-02-21 12:08:32-- https://storage.googleapis.com/laurencemoroney-blog.appspot.com/validation-horse-or-human.zip\n", 215 | "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.164.144, 142.250.73.208, 172.253.62.128, ...\n", 216 | "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.164.144|:443... connected.\n", 217 | "HTTP request sent, awaiting response... 200 OK\n", 218 | "Length: 11480187 (11M) [application/zip]\n", 219 | "Saving to: ‘/tmp/validation-horse-or-human.zip’\n", 220 | "\n", 221 | "/tmp/validation-hor 100%[===================>] 10.95M --.-KB/s in 0.1s \n", 222 | "\n", 223 | "2021-02-21 12:08:32 (93.0 MB/s) - ‘/tmp/validation-horse-or-human.zip’ saved [11480187/11480187]\n", 224 | "\n" 225 | ], 226 | "name": "stdout" 227 | } 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "id": "UsLyqydwHPOZ" 234 | }, 235 | "source": [ 236 | "## Image Augmentation\r\n", 237 | "\r\n", 238 | "We will perform a couple of augmentations on the image" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "metadata": { 244 | "colab": { 245 | "base_uri": "https://localhost:8080/" 246 | }, 247 | "id": "JUagrOo6wHJP", 248 | "outputId": "873ec40e-58a0-4aa6-91d0-4d80cd810947" 249 | }, 250 | "source": [ 251 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\r\n", 252 | "\r\n", 253 | "# All images will be rescaled by 1./255\r\n", 254 | "train_datagen = ImageDataGenerator(\r\n", 255 | " rescale=1./255,\r\n", 256 | " rotation_range=40,\r\n", 257 | " width_shift_range=0.2,\r\n", 258 | " height_shift_range=0.2,\r\n", 259 | " shear_range=0.2,\r\n", 260 | " zoom_range=0.2,\r\n", 261 | " horizontal_flip=True,\r\n", 262 | " fill_mode='nearest')\r\n", 263 | "\r\n", 264 | "validation_datagen = ImageDataGenerator(rescale=1/255)\r\n", 265 | "\r\n", 266 | "# Flow training images in batches of 128 using train_datagen generator\r\n", 267 | "train_generator = train_datagen.flow_from_directory(\r\n", 268 | " '/tmp/horse-or-human/', # This is the source directory for training images\r\n", 269 | " target_size=(300, 300), # All images will be resized to 150x150\r\n", 270 | " batch_size=128,\r\n", 271 | " # Since we use binary_crossentropy loss, we need binary labels\r\n", 272 | " class_mode='binary')\r\n", 273 | "\r\n", 274 | "# Flow training images in batches of 128 using train_datagen generator\r\n", 275 | "validation_generator = validation_datagen.flow_from_directory(\r\n", 276 | " '/tmp/validation-horse-or-human/', # This is the source directory for training images\r\n", 277 | " target_size=(300, 300), # All images will be resized to 150x150\r\n", 278 | " batch_size=32,\r\n", 279 | " # Since we use binary_crossentropy loss, we need binary labels\r\n", 280 | " class_mode='binary')" 281 | ], 282 | "execution_count": 4, 283 | "outputs": [ 284 | { 285 | "output_type": "stream", 286 | "text": [ 287 | "Found 1027 images belonging to 2 classes.\n", 288 | "Found 256 images belonging to 2 classes.\n" 289 | ], 290 | "name": "stdout" 291 | } 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": { 297 | "id": "JCcBZ1UTHVc3" 298 | }, 299 | "source": [ 300 | "## Training the model\r\n", 301 | "\r\n", 302 | "Here we have built a very simple model with 5 Convolutional for this example. " 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "id": "LuW4o9yyvlg_" 309 | }, 310 | "source": [ 311 | "model = tf.keras.models.Sequential([\r\n", 312 | " # Note the input shape is the desired size of the image 300x300 with 3 bytes color\r\n", 313 | " # This is the first convolution\r\n", 314 | " tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),\r\n", 315 | " tf.keras.layers.MaxPooling2D(2, 2),\r\n", 316 | " # The second convolution\r\n", 317 | " tf.keras.layers.Conv2D(32, (3,3), activation='relu'),\r\n", 318 | " tf.keras.layers.Dropout(0.5),\r\n", 319 | " tf.keras.layers.MaxPooling2D(2,2),\r\n", 320 | " # The third convolution\r\n", 321 | " tf.keras.layers.Conv2D(64, (3,3), activation='relu'),\r\n", 322 | " tf.keras.layers.Dropout(0.5),\r\n", 323 | " tf.keras.layers.MaxPooling2D(2,2),\r\n", 324 | " # The fourth convolution\r\n", 325 | " tf.keras.layers.Conv2D(64, (3,3), activation='relu'),\r\n", 326 | " tf.keras.layers.MaxPooling2D(2,2),\r\n", 327 | " # The fifth convolution\r\n", 328 | " tf.keras.layers.Conv2D(64, (3,3), activation='relu'),\r\n", 329 | " tf.keras.layers.MaxPooling2D(2,2),\r\n", 330 | " # Flatten the results to feed into a DNN\r\n", 331 | " \r\n", 332 | " tf.keras.layers.Flatten(),\r\n", 333 | " tf.keras.layers.Dropout(0.5),\r\n", 334 | " # 512 neuron hidden layer\r\n", 335 | " tf.keras.layers.Dense(512, activation='relu'),\r\n", 336 | " # Only 1 output neuron. It will contain a value from 0-1 where 0 for 1 class ('horses') and 1 for the other ('humans')\r\n", 337 | " tf.keras.layers.Dense(1, activation='sigmoid')\r\n", 338 | "])" 339 | ], 340 | "execution_count": 5, 341 | "outputs": [] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "id": "pMBGO9JbHgWG" 347 | }, 348 | "source": [ 349 | "On the same note since we are interested in comparing results we will create a callback which allows us to compute the training time." 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "metadata": { 355 | "id": "Pfk5BFgewUfq" 356 | }, 357 | "source": [ 358 | "class TimeHistory(tf.keras.callbacks.Callback):\r\n", 359 | " def on_train_begin(self, logs={}):\r\n", 360 | " self.times = []\r\n", 361 | "\r\n", 362 | " def on_epoch_begin(self, batch, logs={}):\r\n", 363 | " self.epoch_time_start = time()\r\n", 364 | "\r\n", 365 | " def on_epoch_end(self, batch, logs={}):\r\n", 366 | " self.times.append(time() - self.epoch_time_start)" 367 | ], 368 | "execution_count": 6, 369 | "outputs": [] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": { 374 | "id": "Xf1NPKe3HqVf" 375 | }, 376 | "source": [ 377 | "### Train a model without [`gctf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "m-_AtY2fwCDZ", 384 | "outputId": "38f19bc7-c7b1-4d9b-e543-4d31cbc51105", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/" 387 | } 388 | }, 389 | "source": [ 390 | "from tensorflow.keras.optimizers import RMSprop\r\n", 391 | "\r\n", 392 | "time_callback_no_gctf = TimeHistory()\r\n", 393 | "model.compile(loss='binary_crossentropy',\r\n", 394 | " optimizer=RMSprop(lr=1e-4),\r\n", 395 | " metrics=['accuracy'])\r\n", 396 | "\r\n", 397 | "history_no_gctf = model.fit(\r\n", 398 | " train_generator,\r\n", 399 | " steps_per_epoch=8, \r\n", 400 | " epochs=10,\r\n", 401 | " verbose=1,\r\n", 402 | " validation_data = validation_generator,\r\n", 403 | " validation_steps=8,\r\n", 404 | " callbacks = [time_callback_no_gctf])" 405 | ], 406 | "execution_count": 7, 407 | "outputs": [ 408 | { 409 | "output_type": "stream", 410 | "text": [ 411 | "Epoch 1/10\n", 412 | "8/8 [==============================] - 27s 2s/step - loss: 0.7229 - accuracy: 0.4852 - val_loss: 0.6898 - val_accuracy: 0.5000\n", 413 | "Epoch 2/10\n", 414 | "8/8 [==============================] - 22s 3s/step - loss: 0.6871 - accuracy: 0.5560 - val_loss: 0.6858 - val_accuracy: 0.5234\n", 415 | "Epoch 3/10\n", 416 | "8/8 [==============================] - 22s 3s/step - loss: 0.6732 - accuracy: 0.6040 - val_loss: 0.6801 - val_accuracy: 0.5508\n", 417 | "Epoch 4/10\n", 418 | "8/8 [==============================] - 22s 3s/step - loss: 0.6343 - accuracy: 0.6694 - val_loss: 0.6916 - val_accuracy: 0.5000\n", 419 | "Epoch 5/10\n", 420 | "8/8 [==============================] - 21s 3s/step - loss: 0.6548 - accuracy: 0.6131 - val_loss: 0.6718 - val_accuracy: 0.8281\n", 421 | "Epoch 6/10\n", 422 | "8/8 [==============================] - 21s 3s/step - loss: 0.5896 - accuracy: 0.6966 - val_loss: 0.6733 - val_accuracy: 0.5000\n", 423 | "Epoch 7/10\n", 424 | "8/8 [==============================] - 21s 3s/step - loss: 0.5870 - accuracy: 0.7058 - val_loss: 0.6604 - val_accuracy: 0.6094\n", 425 | "Epoch 8/10\n", 426 | "8/8 [==============================] - 21s 3s/step - loss: 0.5534 - accuracy: 0.7235 - val_loss: 0.6887 - val_accuracy: 0.5000\n", 427 | "Epoch 9/10\n", 428 | "8/8 [==============================] - 23s 3s/step - loss: 0.5626 - accuracy: 0.7112 - val_loss: 0.6570 - val_accuracy: 0.5586\n", 429 | "Epoch 10/10\n", 430 | "8/8 [==============================] - 21s 3s/step - loss: 0.5607 - accuracy: 0.7258 - val_loss: 0.6463 - val_accuracy: 0.6016\n" 431 | ], 432 | "name": "stdout" 433 | } 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": { 439 | "id": "96YKj6_GH5Yf" 440 | }, 441 | "source": [ 442 | "### Train a model with [`gctf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow/)" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "metadata": { 448 | "id": "HRNC9FT0wjY0", 449 | "outputId": "e74de2a8-4ceb-4713-f35c-a7f88b84704b", 450 | "colab": { 451 | "base_uri": "https://localhost:8080/" 452 | } 453 | }, 454 | "source": [ 455 | "import gctf #import gctf\r\n", 456 | "\r\n", 457 | "time_callback_gctf = TimeHistory()\r\n", 458 | "model.compile(loss='binary_crossentropy',\r\n", 459 | " optimizer=gctf.optimizers.rmsprop(learning_rate = 1e-4),\r\n", 460 | " metrics=['accuracy'])\r\n", 461 | "\r\n", 462 | "history_gctf = model.fit(\r\n", 463 | " train_generator,\r\n", 464 | " steps_per_epoch=8, \r\n", 465 | " epochs=10,\r\n", 466 | " verbose=1,\r\n", 467 | " validation_data = validation_generator,\r\n", 468 | " validation_steps=8,\r\n", 469 | " callbacks = [time_callback_gctf])" 470 | ], 471 | "execution_count": 8, 472 | "outputs": [ 473 | { 474 | "output_type": "stream", 475 | "text": [ 476 | "Epoch 1/10\n", 477 | "8/8 [==============================] - 24s 3s/step - loss: 0.6394 - accuracy: 0.6779 - val_loss: 0.6885 - val_accuracy: 0.5000\n", 478 | "Epoch 2/10\n", 479 | "8/8 [==============================] - 21s 3s/step - loss: 0.5504 - accuracy: 0.7124 - val_loss: 0.6450 - val_accuracy: 0.5625\n", 480 | "Epoch 3/10\n", 481 | "8/8 [==============================] - 22s 3s/step - loss: 0.5050 - accuracy: 0.7673 - val_loss: 0.6163 - val_accuracy: 0.6094\n", 482 | "Epoch 4/10\n", 483 | "8/8 [==============================] - 21s 3s/step - loss: 0.5206 - accuracy: 0.7589 - val_loss: 0.5969 - val_accuracy: 0.6797\n", 484 | "Epoch 5/10\n", 485 | "8/8 [==============================] - 21s 3s/step - loss: 0.5175 - accuracy: 0.7506 - val_loss: 0.7745 - val_accuracy: 0.5000\n", 486 | "Epoch 6/10\n", 487 | "8/8 [==============================] - 23s 3s/step - loss: 0.6449 - accuracy: 0.6996 - val_loss: 0.6114 - val_accuracy: 0.5820\n", 488 | "Epoch 7/10\n", 489 | "8/8 [==============================] - 21s 3s/step - loss: 0.5059 - accuracy: 0.7551 - val_loss: 0.5494 - val_accuracy: 0.7461\n", 490 | "Epoch 8/10\n", 491 | "8/8 [==============================] - 21s 3s/step - loss: 0.4751 - accuracy: 0.7774 - val_loss: 0.5426 - val_accuracy: 0.7461\n", 492 | "Epoch 9/10\n", 493 | "8/8 [==============================] - 21s 3s/step - loss: 0.4755 - accuracy: 0.7816 - val_loss: 0.5948 - val_accuracy: 0.6172\n", 494 | "Epoch 10/10\n", 495 | "8/8 [==============================] - 21s 3s/step - loss: 0.4431 - accuracy: 0.7922 - val_loss: 0.7306 - val_accuracy: 0.5273\n" 496 | ], 497 | "name": "stdout" 498 | } 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": { 504 | "id": "P6_aG1L_H-Ko" 505 | }, 506 | "source": [ 507 | "## Compare results\r\n", 508 | "\r\n", 509 | "In this example we are further interested in also comparing the results" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "metadata": { 515 | "id": "zOg1wLwrfpqd", 516 | "outputId": "2a00e0f6-c9b0-44bb-ce8d-e4f373777b38", 517 | "colab": { 518 | "base_uri": "https://localhost:8080/" 519 | } 520 | }, 521 | "source": [ 522 | "from tabulate import tabulate\r\n", 523 | "\r\n", 524 | "data = [[\"Model without gctf:\",sum(time_callback_no_gctf.times),history_no_gctf.history['accuracy'][-1],history_no_gctf.history['loss'][-1]],\r\n", 525 | " [\"Model with gctf\",sum(time_callback_gctf.times),history_gctf.history['accuracy'][-1],history_gctf.history['loss'][-1]]] \r\n", 526 | "\r\n", 527 | "print(tabulate(data, headers=[\"Type\",\"Execution time\", \"Accuracy\", \"Loss\"]))" 528 | ], 529 | "execution_count": 9, 530 | "outputs": [ 531 | { 532 | "output_type": "stream", 533 | "text": [ 534 | "Type Execution time Accuracy Loss\n", 535 | "------------------- ---------------- ---------- --------\n", 536 | "Model without gctf: 221.626 0.690768 0.625912\n", 537 | "Model with gctf 216.744 0.805339 0.426568\n" 538 | ], 539 | "name": "stdout" 540 | } 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "metadata": { 546 | "id": "2Y5WxRlBf45e" 547 | }, 548 | "source": [ 549 | "" 550 | ], 551 | "execution_count": null, 552 | "outputs": [] 553 | } 554 | ] 555 | } 556 | -------------------------------------------------------------------------------- /examples/gctf_mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "gctf-mnist.ipynb", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "IieqPex4O3eJ" 31 | }, 32 | "source": [ 33 | "# GCTF MNIST\r\n", 34 | "\r\n", 35 | "This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) Python package to train on the [Fashion MNIST](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist) dataset availaible from [`tf.keras.datasets`](https://www.tensorflow.org/api_docs/python/tf/keras/datasets).Gradient Centralization is a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper \r\n", 36 | "[Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/abs/2004.01461). It can both speedup training \r\n", 37 | " process and improve the final generalization performance of DNNs.\r\n", 38 | "\r\n", 39 | "## A bit about GC\r\n", 40 | "\r\n", 41 | "Gradient Centralization operates directly on gradients by centralizing the gradient vectors to have zero mean. It can both speedup training process and improve the final generalization performance of DNNs. Here is an Illustration of the GC operation on gradient matrix/tensor of weights in the fully-connected layer (left) and convolutional layer (right). GC computes the column/slice mean of gradient matrix/tensor and centralizes each column/slice to have zero mean.\r\n", 42 | "\r\n", 43 | "![](https://i.imgur.com/KitoO8J.png)\r\n", 44 | "\r\n", 45 | "GC can be viewed as a projected gradient descent method with a constrained loss function. The geometrical interpretation of GC. The gradient is projected on a hyperplane $e^T(w-w^t)=0$, where the projected gradient is used to update the weight.\r\n", 46 | "\r\n", 47 | "![](https://i.imgur.com/ekHhQv0.png)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "id": "WcpkfjkwSjmv" 54 | }, 55 | "source": [ 56 | "## Setup" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "XjTQqCEaTSSG" 63 | }, 64 | "source": [ 65 | "import tensorflow as tf\r\n", 66 | "from time import time" 67 | ], 68 | "execution_count": 1, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "id": "32NuQrDdSpP3" 75 | }, 76 | "source": [ 77 | "### Install the package" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "colab": { 84 | "base_uri": "https://localhost:8080/" 85 | }, 86 | "id": "aFvPquvcOq1B", 87 | "outputId": "36ab7fcf-8561-4403-b216-76be2519c52b" 88 | }, 89 | "source": [ 90 | "!pip install gradient-centralization-tf" 91 | ], 92 | "execution_count": 2, 93 | "outputs": [ 94 | { 95 | "output_type": "stream", 96 | "text": [ 97 | "Collecting gradient-centralization-tf\n", 98 | " Downloading https://files.pythonhosted.org/packages/58/4c/6253587b8f6ccdf03fd4830de2574cbda48a1a84bc660d5dd8978d0f94fb/gradient_centralization_tf-0.0.2-py3-none-any.whl\n", 99 | "Requirement already satisfied: tensorflow~=2.4.0 in /usr/local/lib/python3.6/dist-packages (from gradient-centralization-tf) (2.4.1)\n", 100 | "Requirement already satisfied: keras~=2.4.0 in /usr/local/lib/python3.6/dist-packages (from gradient-centralization-tf) (2.4.3)\n", 101 | "Requirement already satisfied: numpy~=1.19.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.19.5)\n", 102 | "Requirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.12.1)\n", 103 | "Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.10.0)\n", 104 | "Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.3.0)\n", 105 | "Requirement already satisfied: tensorflow-estimator<2.5.0,>=2.4.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.4.0)\n", 106 | "Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.2.0)\n", 107 | "Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.6.3)\n", 108 | "Requirement already satisfied: grpcio~=1.32.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.32.0)\n", 109 | "Requirement already satisfied: h5py~=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.10.0)\n", 110 | "Requirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.12)\n", 111 | "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.12.4)\n", 112 | "Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.1.2)\n", 113 | "Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.1.0)\n", 114 | "Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.36.2)\n", 115 | "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (0.3.3)\n", 116 | "Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (3.7.4.3)\n", 117 | "Requirement already satisfied: tensorboard~=2.4 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (2.4.1)\n", 118 | "Requirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow~=2.4.0->gradient-centralization-tf) (1.15.0)\n", 119 | "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from keras~=2.4.0->gradient-centralization-tf) (1.4.1)\n", 120 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras~=2.4.0->gradient-centralization-tf) (3.13)\n", 121 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.9.2->tensorflow~=2.4.0->gradient-centralization-tf) (53.0.0)\n", 122 | "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.25.0)\n", 123 | "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.8.0)\n", 124 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.3.3)\n", 125 | "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.4.2)\n", 126 | "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2.23.0)\n", 127 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.0.1)\n", 128 | "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.2.8)\n", 129 | "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (4.7)\n", 130 | "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (4.2.1)\n", 131 | "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.4.0)\n", 132 | "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.3.0)\n", 133 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2.10)\n", 134 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.0.4)\n", 135 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (2020.12.5)\n", 136 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (1.24.3)\n", 137 | "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.6/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (0.4.8)\n", 138 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.4.0)\n", 139 | "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.4->tensorflow~=2.4.0->gradient-centralization-tf) (3.1.0)\n", 140 | "Installing collected packages: gradient-centralization-tf\n", 141 | "Successfully installed gradient-centralization-tf-0.0.2\n" 142 | ], 143 | "name": "stdout" 144 | } 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "Okru2B3uTKRx" 151 | }, 152 | "source": [ 153 | "## Get the data and create model structure" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "ywUm1ZYiOY8x", 160 | "colab": { 161 | "base_uri": "https://localhost:8080/" 162 | }, 163 | "outputId": "630d7bff-15a4-487c-84c9-3c96bea42912" 164 | }, 165 | "source": [ 166 | "mnist = tf.keras.datasets.fashion_mnist\r\n", 167 | "(training_images, training_labels), (test_images, test_labels) = mnist.load_data()\r\n", 168 | "training_images = training_images / 255.0\r\n", 169 | "test_images = test_images / 255.0\r\n", 170 | "\r\n", 171 | "# Model architecture\r\n", 172 | "model = tf.keras.models.Sequential([\r\n", 173 | " tf.keras.layers.Flatten(), \r\n", 174 | " tf.keras.layers.Dense(512, activation=tf.nn.relu),\r\n", 175 | " tf.keras.layers.Dense(256, activation=tf.nn.relu),\r\n", 176 | " tf.keras.layers.Dense(64, activation=tf.nn.relu),\r\n", 177 | " tf.keras.layers.Dense(512, activation=tf.nn.relu),\r\n", 178 | " tf.keras.layers.Dense(256, activation=tf.nn.relu),\r\n", 179 | " tf.keras.layers.Dense(64, activation=tf.nn.relu), \r\n", 180 | " tf.keras.layers.Dense(10, activation=tf.nn.softmax)])" 181 | ], 182 | "execution_count": 3, 183 | "outputs": [ 184 | { 185 | "output_type": "stream", 186 | "text": [ 187 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n", 188 | "32768/29515 [=================================] - 0s 0us/step\n", 189 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n", 190 | "26427392/26421880 [==============================] - 0s 0us/step\n", 191 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n", 192 | "8192/5148 [===============================================] - 0s 0us/step\n", 193 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n", 194 | "4423680/4422102 [==============================] - 0s 0us/step\n" 195 | ], 196 | "name": "stdout" 197 | } 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "id": "_p_jknUHUILI" 204 | }, 205 | "source": [ 206 | "## Train a model without `gctf`" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": { 212 | "id": "4Us9_ZCBpD2a" 213 | }, 214 | "source": [ 215 | "Make a Callback to compute computation time\r\n" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "ADdFS7OapCXZ" 222 | }, 223 | "source": [ 224 | "class TimeHistory(tf.keras.callbacks.Callback):\r\n", 225 | " def on_train_begin(self, logs={}):\r\n", 226 | " self.times = []\r\n", 227 | "\r\n", 228 | " def on_epoch_begin(self, batch, logs={}):\r\n", 229 | " self.epoch_time_start = time()\r\n", 230 | "\r\n", 231 | " def on_epoch_end(self, batch, logs={}):\r\n", 232 | " self.times.append(time() - self.epoch_time_start)" 233 | ], 234 | "execution_count": 4, 235 | "outputs": [] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "metadata": { 240 | "colab": { 241 | "base_uri": "https://localhost:8080/" 242 | }, 243 | "id": "BQXmGp4_Oj7l", 244 | "outputId": "407f4d3f-c422-439d-e26b-74eb90edf373" 245 | }, 246 | "source": [ 247 | "time_callback_no_gctf = TimeHistory()\r\n", 248 | "\r\n", 249 | "model.compile(optimizer = tf.keras.optimizers.Adam(),\r\n", 250 | " loss = 'sparse_categorical_crossentropy',\r\n", 251 | " metrics = ['accuracy'])\r\n", 252 | "\r\n", 253 | "history_no_gctf = model.fit(training_images, training_labels, epochs=5, callbacks = [time_callback_no_gctf])" 254 | ], 255 | "execution_count": 5, 256 | "outputs": [ 257 | { 258 | "output_type": "stream", 259 | "text": [ 260 | "Epoch 1/5\n", 261 | "1875/1875 [==============================] - 6s 2ms/step - loss: 0.6533 - accuracy: 0.7619\n", 262 | "Epoch 2/5\n", 263 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.3986 - accuracy: 0.8559\n", 264 | "Epoch 3/5\n", 265 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.3527 - accuracy: 0.8731\n", 266 | "Epoch 4/5\n", 267 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.3271 - accuracy: 0.8810\n", 268 | "Epoch 5/5\n", 269 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.3075 - accuracy: 0.8884\n" 270 | ], 271 | "name": "stdout" 272 | } 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "X_qW351shXgM" 279 | }, 280 | "source": [ 281 | "## Train a model with `gctf`" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "metadata": { 287 | "colab": { 288 | "base_uri": "https://localhost:8080/" 289 | }, 290 | "id": "slHbE9hKWBWV", 291 | "outputId": "cfed8cb2-6260-4303-e749-ac1a71143949" 292 | }, 293 | "source": [ 294 | "import gctf #import gctf\r\n", 295 | "\r\n", 296 | "time_callback_gctf = TimeHistory()\r\n", 297 | "\r\n", 298 | "model.compile(optimizer = gctf.optimizers.adam(),\r\n", 299 | " loss = 'sparse_categorical_crossentropy',\r\n", 300 | " metrics=['accuracy'])\r\n", 301 | "\r\n", 302 | "history_gctf = model.fit(training_images, training_labels, epochs=5, callbacks=[time_callback_gctf])" 303 | ], 304 | "execution_count": 8, 305 | "outputs": [ 306 | { 307 | "output_type": "stream", 308 | "text": [ 309 | "Epoch 1/5\n", 310 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2572 - accuracy: 0.9063\n", 311 | "Epoch 2/5\n", 312 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2551 - accuracy: 0.9103\n", 313 | "Epoch 3/5\n", 314 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2330 - accuracy: 0.9133\n", 315 | "Epoch 4/5\n", 316 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2288 - accuracy: 0.9168\n", 317 | "Epoch 5/5\n", 318 | "1875/1875 [==============================] - 4s 2ms/step - loss: 0.2237 - accuracy: 0.9165\n" 319 | ], 320 | "name": "stdout" 321 | } 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": { 327 | "id": "MhfeOw1FgLPE" 328 | }, 329 | "source": [ 330 | "## Compare results\r\n", 331 | "\r\n", 332 | "In this example we are further interested in also comparing the results" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "metadata": { 338 | "colab": { 339 | "base_uri": "https://localhost:8080/" 340 | }, 341 | "id": "M-PfvMpCdhm6", 342 | "outputId": "16ed65f5-d958-45cc-bd61-c09dcd6e9c1d" 343 | }, 344 | "source": [ 345 | "#Compare Results\r\n", 346 | "from tabulate import tabulate\r\n", 347 | "\r\n", 348 | "data = [[\"Model without gctf:\",sum(time_callback_no_gctf.times),history_no_gctf.history['accuracy'][-1],history_no_gctf.history['loss'][-1]],\r\n", 349 | " [\"Model with gctf\",sum(time_callback_gctf.times),history_gctf.history['accuracy'][-1],history_gctf.history['loss'][-1]]] \r\n", 350 | "\r\n", 351 | "print(tabulate(data, headers=[\"Type\",\"Execution time\", \"Accuracy\", \"Loss\"]))" 352 | ], 353 | "execution_count": 9, 354 | "outputs": [ 355 | { 356 | "output_type": "stream", 357 | "text": [ 358 | "Type Execution time Accuracy Loss\n", 359 | "------------------- ---------------- ---------- --------\n", 360 | "Model without gctf: 20.183 0.887617 0.310299\n", 361 | "Model with gctf 18.464 0.916467 0.22555\n" 362 | ], 363 | "name": "stdout" 364 | } 365 | ] 366 | } 367 | ] 368 | } 369 | -------------------------------------------------------------------------------- /gctf/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .centralized_gradients import get_centralized_gradients 4 | from .centralized_gradients import centralized_gradients_for_optimizer 5 | from .optimizers import * 6 | -------------------------------------------------------------------------------- /gctf/centralized_gradients.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import keras.backend as K 3 | 4 | 5 | def get_centralized_gradients(optimizer, loss, params): 6 | """Compute the centralized gradients. 7 | 8 | This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you 9 | could point `get_gradients` to this function. This is a modified version of 10 | `tf.keras.optimizers.Optimizer.get_gradients`. 11 | 12 | # Arguments: 13 | optimizer: a `tf.keras.optimizers.Optimizer object`. The optimizer you are using. 14 | loss: Scalar tensor to minimize. 15 | params: List of variables. 16 | 17 | # Returns: 18 | A gradients tensor. 19 | 20 | # Reference: 21 | [Yong et al., 2020](https://arxiv.org/abs/2004.01461) 22 | """ 23 | 24 | # We here just provide a modified get_gradients() function since we are trying to just compute the centralized 25 | # gradients at this stage which can be used in other optimizers. 26 | grads = [] 27 | for grad in K.gradients(loss, params): 28 | grad_len = len(grad.shape) 29 | if grad_len > 1: 30 | axis = list(range(grad_len - 1)) 31 | grad -= tf.reduce_mean(grad, 32 | axis=axis, 33 | keep_dims=True) 34 | grads.append(grad) 35 | 36 | if None in grads: 37 | raise ValueError('An operation has `None` for gradient. ' 38 | 'Please make sure that all of your ops have a ' 39 | 'gradient defined (i.e. are differentiable). ' 40 | 'Common ops without gradient: ' 41 | 'K.argmax, K.round, K.eval.') 42 | if hasattr(optimizer, 'clipnorm') and optimizer.clipnorm > 0: 43 | norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) 44 | grads = [ 45 | tf.keras.optimizers.clip_norm( 46 | g, 47 | optimizer.clipnorm, 48 | norm) for g in grads] 49 | if hasattr(optimizer, 'clipvalue') and optimizer.clipvalue > 0: 50 | grads = [K.clip(g, -optimizer.clipvalue, optimizer.clipvalue) 51 | for g in grads] 52 | return grads 53 | 54 | 55 | def centralized_gradients_for_optimizer(optimizer): 56 | """Create a centralized gradients functions for a specified optimizer. 57 | 58 | # Arguments: 59 | optimizer: a `tf.keras.optimizers.Optimizer object`. The optimizer you are using. 60 | 61 | # Usage: 62 | 63 | ```py 64 | >>> opt = tf.keras.optimizers.Adam(learning_rate=0.1) 65 | >>> opt.get_gradients = gctf.centralized_gradients_for_optimizer(opt) 66 | >>> model.compile(optimizer = opt, ...) 67 | ``` 68 | """ 69 | 70 | def get_centralized_gradients_for_optimizer(loss, params): 71 | return get_centralized_gradients(optimizer, loss, params) 72 | 73 | return get_centralized_gradients_for_optimizer 74 | -------------------------------------------------------------------------------- /gctf/optimizers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .centralized_gradients import centralized_gradients_for_optimizer 4 | 5 | 6 | def update_optimizer(optimizer): 7 | optimizer.get_gradients = centralized_gradients_for_optimizer(optimizer) 8 | return optimizer 9 | 10 | 11 | def adagrad(learning_rate=0.001, initial_accumulator_value=0.1, epsilon=1e-07): 12 | optimizer = tf.keras.optimizers.Adagrad( 13 | learning_rate=learning_rate, 14 | initial_accumulator_value=initial_accumulator_value, 15 | epsilon=epsilon) 16 | return update_optimizer(optimizer) 17 | 18 | 19 | def adadelta(learning_rate=0.001, rho=0.95, epsilon=1e-07): 20 | optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate, 21 | rho=rho, 22 | epsilon=epsilon) 23 | return update_optimizer(optimizer) 24 | 25 | 26 | def adam( 27 | learning_rate=0.001, 28 | beta_1=0.9, 29 | beta_2=0.999, 30 | epsilon=1e-7, 31 | amsgrad=False): 32 | optimizer = tf.keras.optimizers.Adam( 33 | learning_rate=learning_rate, 34 | beta_1=beta_1, 35 | beta_2=beta_2, 36 | epsilon=epsilon, 37 | amsgrad=amsgrad) 38 | return update_optimizer(optimizer) 39 | 40 | 41 | def adamax(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07): 42 | optimizer = tf.keras.optimizers.Adamax(learning_rate=learning_rate, 43 | beta_1=beta_1, 44 | beta_2=beta_2, 45 | epsilon=epsilon) 46 | return update_optimizer(optimizer) 47 | 48 | 49 | def ftrl( 50 | learning_rate=0.001, 51 | learning_rate_power=-0.5, 52 | initial_accumulator_value=0.1, 53 | l1_regularization_strength=0.0, 54 | l2_regularization_strength=0.0, 55 | l2_shrinkage_regularization_strength=0.0, 56 | beta=0.0): 57 | optimizer = tf.keras.optimizers.Adamax( 58 | learning_rate=learning_rate, 59 | learning_rate_power=learning_rate_power, 60 | initial_accumulator_value=initial_accumulator_value, 61 | l1_regularization_strength=l1_regularization_strength, 62 | l2_regularization_strength=l2_regularization_strength, 63 | l2_shrinkage_regularization_strength=l2_shrinkage_regularization_strength, 64 | beta=beta) 65 | return update_optimizer(optimizer) 66 | 67 | 68 | def nadam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07): 69 | optimizer = tf.keras.optimizers.Nadam( 70 | learning_rate=learning_rate, 71 | beta_1=beta_1, 72 | beta_2=beta_2, 73 | epsilon=epsilon) 74 | return update_optimizer(optimizer) 75 | 76 | 77 | def rmsprop( 78 | learning_rate=0.001, 79 | rho=0.9, 80 | momentum=0.0, 81 | epsilon=1e-07, 82 | centered=False): 83 | optimizer = tf.keras.optimizers.RMSprop( 84 | learning_rate=learning_rate, 85 | rho=rho, 86 | momentum=momentum, 87 | epsilon=epsilon, 88 | centered=centered) 89 | return update_optimizer(optimizer) 90 | 91 | 92 | def sgd(learning_rate=0.01, momentum=0.0, nesterov=False): 93 | optimizer = tf.keras.optimizers.SGD( 94 | learning_rate=learning_rate, 95 | momentum=momentum, 96 | nesterov=nesterov) 97 | return update_optimizer(optimizer) 98 | -------------------------------------------------------------------------------- /gctf/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.3' 2 | -------------------------------------------------------------------------------- /images/gctf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rishit-dagli/Gradient-Centralization-TensorFlow/6fbb2e3f049665724a3cc87b7c9dde07830da6f1/images/gctf.png -------------------------------------------------------------------------------- /images/illutstration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rishit-dagli/Gradient-Centralization-TensorFlow/6fbb2e3f049665724a3cc87b7c9dde07830da6f1/images/illutstration.png -------------------------------------------------------------------------------- /images/projected_grad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rishit-dagli/Gradient-Centralization-TensorFlow/6fbb2e3f049665724a3cc87b7c9dde07830da6f1/images/projected_grad.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | exec(open('gctf/version.py').read()) 4 | 5 | with open("README.md", "r") as fh: 6 | long_description = fh.read() 7 | 8 | setup( 9 | name="gradient-centralization-tf", 10 | version="0.0.3", 11 | description="Implement Gradient Centralization in TensorFlow", 12 | packages=["gctf"], 13 | 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | 17 | classifiers=[ 18 | "Development Status :: 4 - Beta", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3 :: Only", 21 | "Programming Language :: Python :: 3.7", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Intended Audience :: Developers", 26 | "Intended Audience :: Education", 27 | "Intended Audience :: Science/Research", 28 | "Topic :: Scientific/Engineering", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | "Topic :: Software Development", 31 | "Topic :: Software Development :: Libraries", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | "Topic :: Scientific/Engineering :: Mathematics" 34 | ], 35 | 36 | url="https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow", 37 | author="Rishit Dagli", 38 | author_email="rishit.dagli@gmail.com", 39 | 40 | install_requires=[ 41 | "tensorflow >= 2.2.0", 42 | "keras ~= 2.4.0", 43 | ], 44 | 45 | extras_require={ 46 | "dev": [ 47 | "check-manifest", 48 | "twine", 49 | "numpy" 50 | ], 51 | }, 52 | ) 53 | --------------------------------------------------------------------------------