├── .flake8 ├── .gitattributes ├── .gitignore ├── ACKNOWLEDGMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── aff.png ├── architecture.png ├── clusten ├── __init__.py ├── clusten.py ├── src │ ├── clustenav_cuda.cpp │ ├── clustenav_cuda_kernel.cu │ ├── clustenqk_cuda.cpp │ ├── clustenqk_cuda_kernel.cu │ ├── clustenwf_cuda.cpp │ ├── clustenwf_cuda_kernel.cu │ └── setup.py ├── test_av_kernel.py ├── test_qk_kernel.py └── test_wf_kernel.py ├── config.py ├── configs ├── aff_base_22k.yaml ├── aff_base_22kto1k.yaml ├── aff_base_22kto1k_384.yaml ├── aff_mini.yaml ├── aff_mini_1_5th.yaml ├── aff_small.yaml ├── aff_small_1_5th.yaml ├── aff_tiny.yaml └── aff_tiny_1_5th.yaml ├── create_env.sh ├── data ├── __init__.py ├── build.py └── samplers.py ├── demo1.png ├── demo2.png ├── logger.py ├── lr_scheduler.py ├── main.py ├── models ├── __init__.py ├── aff_transformer.py ├── build.py ├── point_utils.py └── test_cluster.py ├── optimizer.py ├── run_aff.sh └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | max-line-length = 120 4 | # C408 ignored because we like the dict keyword argument syntax 5 | # E501 is not flexible enough, we're using B950 instead 6 | ignore = 7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E303,E226, 8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 9 | # to line this up with executable bit 10 | EXE001, 11 | # these ignores are from flake8-bugbear; please fix! 12 | B007,B008, 13 | # these ignores are from flake8-comprehensions; please fix! 14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415, 15 | # for "unable to detect undefined names" 16 | F403, 17 | # for "Too many leading '#' for block comment (E266)" 18 | E266, 19 | # for "E731 do not assign a lambda expression, use a def" 20 | E731, 21 | # for "future feature annotations is not defined" 22 | F407, 23 | # do not use bare 'except' 24 | E722, 25 | per-file-ignores = __init__.py: F401 26 | optional-ascii-coding = True 27 | exclude = 28 | ./.git, 29 | ./docs, 30 | ./scripts, 31 | ./test 32 | ./third_party, 33 | ./venv, 34 | *.pyi 35 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.svg 2 | .nfs* 3 | .DS_Store 4 | __pycache__/ 5 | *swp* 6 | output/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this AutoFocusFormer Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | 7 | Microsoft (Swin Transformer) 8 | MIT License 9 | 10 | Copyright (c) Microsoft Corporation. 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE 29 | 30 | SHI Lab (Neighborhood-Attention-Transformer) 31 | MIT License 32 | 33 | Copyright (c) 2022-2023 SHI Lab 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy 36 | of this software and associated documentation files (the "Software"), to deal 37 | in the Software without restriction, including without limitation the rights 38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 39 | copies of the Software, and to permit persons to whom the Software is 40 | furnished to do so, subject to the following conditions: 41 | 42 | The above copyright notice and this permission notice shall be included in all 43 | copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 51 | SOFTWARE. 52 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 125 | [https://www.contributor-covenant.org/translations][translations]. 126 | 127 | [homepage]: https://www.contributor-covenant.org 128 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 129 | [Mozilla CoC]: https://github.com/mozilla/diversity 130 | [FAQ]: https://www.contributor-covenant.org/faq 131 | [translations]: https://www.contributor-covenant.org/translations 132 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2023 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoFocusFormer 2 | 3 | [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md) 4 | [![CLUSTEN](https://img.shields.io/badge/CUDA%20Extension-CLUSTEN-red)](clusten/) 5 | 6 | AFF-Base: [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/autofocusformer-image-segmentation-off-the/instance-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/instance-segmentation-on-cityscapes-val?p=autofocusformer-image-segmentation-off-the) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/autofocusformer-image-segmentation-off-the/panoptic-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/panoptic-segmentation-on-cityscapes-val?p=autofocusformer-image-segmentation-off-the) 7 | 8 | This software project accompanies the research paper, *AutoFocusFormer: Image Segmentation off the Grid* (CVPR 2023). 9 | 10 | [Chen Ziwen](https://www.chenziwe.com), Kaushik Patnaik, [Shuangfei Zhai](https://scholar.google.com/citations?user=G6vdBYsAAAAJ&hl=en), [Alvin Wan](http://alvinwan.com), [Zhile Ren](https://jrenzhile.com), [Alex Schwing](https://alexander-schwing.de/), [Alex Colburn](https://www.colburn.org), [Li Fuxin](https://web.engr.oregonstate.edu/~lif/) 11 | 12 | [arXiv](https://arxiv.org/abs/2304.12406) | [video narration](https://youtu.be/i1mZtk70yGY) | [AFF-Classification (this repo)](https://github.com/apple/ml-autofocusformer) | [AFF-Segmentation](https://github.com/apple/ml-autofocusformer-segmentation) 13 | 14 | ## Introduction 15 | 16 | AutoFocusFormer (AFF) is the first **adaptive**-downsampling network capable of **dense** prediction tasks such as semantic/instance segmentation. 17 | 18 | AFF abandons the traditional grid structure of image feature maps, and automatically learns to retain the most important pixels with respect to the task goal. 19 | 20 |
21 | 22 |

23 | 24 | AFF consists of a local-attention transformer backbone and a task-specific head. The backbone consists of four stages, each stage containing three modules: balanced clustering, local-attention transformer blocks, and adaptive downsampling. 25 | 26 |
27 | 28 |

29 | 30 | AFF demonstrates significant savings on FLOPs (see our models with 1/5 downsampling rate), and significant improvement on recognition of small objects. 31 | 32 | Notably, AFF-Small achieves **44.0** instance segmentation AP and **66.9** panoptic segmentation PQ on Cityscapes val with a backbone of only **42.6M** parameters, a performance on par with Swin-Large, a backbone with **197M** params (saving **78%**!). 33 | 34 |
35 | 36 |

37 | 38 |
39 | 40 |

41 | 42 | ## Main Results on ImageNet with Pretrained Models 43 | 44 | | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model | 45 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 46 | | AFF-Mini | ImageNet-1K | 224x224 | 78.2 | 93.6 | 6.75M | 1.08G | 1337 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_mini.pth) | 47 | | AFF-Mini-1/5 | ImageNet-1K | 224x224 | 77.5 | 93.3 | 6.75M | 0.72G | 1678 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_mini_1_5th.pth) | 48 | | AFF-Tiny | ImageNet-1K | 224x224 | 83.0 | 96.3 | 27M | 4G | 528 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_tiny.pth) | 49 | | AFF-Tiny-1/5 | ImageNet-1K | 224x224 | 82.4 | 95.9 | 27M | 2.74G | 682 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_tiny_1_5th.pth) | 50 | | AFF-Small | ImageNet-1K | 224x224 | 83.5 | 96.6 | 42.6M | 8.16G | 321 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_small.pth) | 51 | | AFF-Small-1/5 | ImageNet-1K | 224x224 | 83.4 | 96.5 | 42.6M | 5.69G | 424 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_small_1_5th.pth) | 52 | 53 | FPS is obtained on a single V100 GPU. 54 | 55 | We train with a total batch size 4096. 56 | 57 | | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model | 58 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 59 | | AFF-Base | ImageNet-22K | 384x384 | 86.2 | 98.0 | 75.34M | 42.54G | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_base_22k.pth) | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_base_22kto1k_384.pth) | 60 | 61 | ## Getting Started 62 | 63 | ### Clone this repo 64 | 65 | ```bash 66 | git clone git@github.com:apple/ml-autofocusformer.git 67 | cd ml-autofocusformer 68 | ``` 69 | One can download pre-trained checkpoints through the links in the table above. 70 | 71 | ### Create environment and install requirements 72 | 73 | ```bash 74 | sh create_env.sh 75 | ``` 76 | 77 | See further documentation inside the script file. 78 | 79 | Our experiments are run with `CUDA==11.6` and `pytorch==1.12`. 80 | 81 | ### Prepare data 82 | 83 | We use standard ImageNet dataset, which can be downloaded from http://image-net.org/. 84 | 85 | For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: 86 | ```bash 87 | $ tree imagenet 88 | imagenet/ 89 | ├── training 90 | │ ├── class1 91 | │ │ ├── img1.jpeg 92 | │ │ ├── img2.jpeg 93 | │ │ └── ... 94 | │ ├── class2 95 | │ │ ├── img3.jpeg 96 | │ │ └── ... 97 | │ └── ... 98 | └── validation 99 | ├── class1 100 | │ ├── img4.jpeg 101 | │ ├── img5.jpeg 102 | │ └── ... 103 | ├── class2 104 | │ ├── img6.jpeg 105 | │ └── ... 106 | └── ... 107 | 108 | ``` 109 | 110 | ### Train and evaluate 111 | 112 | Modify the arguments in script `run_aff.sh` (e.g., path to dataset) and run 113 | ```bash 114 | sh run_aff.sh 115 | ``` 116 | for training or evaluation. 117 | 118 | Run `python main.py -h` to see full documentation of the args. 119 | 120 | One can also directly modify the config files in `configs/`. 121 | 122 | ## Citing AutoFocusFormer 123 | 124 | ```BibTeX 125 | @inproceedings{autofocusformer, 126 | title = {AutoFocusFormer: Image Segmentation off the Grid}, 127 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 128 | author = {Ziwen, Chen and Patnaik, Kaushik and Zhai, Shuangfei and Wan, Alvin and Ren, Zhile and Schwing, Alex and Colburn, Alex and Fuxin, Li}, 129 | year = {2023}, 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /aff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/aff.png -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/architecture.png -------------------------------------------------------------------------------- /clusten/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from .clusten import CLUSTENQKFunction, CLUSTENAVFunction, CLUSTENWFFunction 7 | -------------------------------------------------------------------------------- /clusten/clusten.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from torch.autograd import Function 7 | 8 | try: 9 | import clustenqk_cuda 10 | import clustenav_cuda 11 | import clustenwf_cuda 12 | except ImportError: 13 | raise RuntimeError("Could not load CLUSTEN CUDA extension. " + 14 | "Please make sure your device has CUDA, the CUDA toolkit for PyTorch is installed, and that you've compiled CLUSTEN correctly.") 15 | 16 | 17 | class CLUSTENQKFunction(Function): 18 | """ 19 | query times key function 20 | """ 21 | @staticmethod 22 | def forward(ctx, query, key, nbhd_idx): 23 | query = query.contiguous() 24 | key = key.contiguous() 25 | if key.dtype != query.dtype: 26 | key = key.to(query.dtype) 27 | nbhd_idx = nbhd_idx.contiguous() 28 | attn = clustenqk_cuda.forward( 29 | query, 30 | key.permute(0, 1, 3, 2).contiguous(), 31 | nbhd_idx) 32 | ctx.save_for_backward(query, key, nbhd_idx) 33 | return attn 34 | 35 | @staticmethod 36 | def backward(ctx, grad_attn): 37 | outputs = clustenqk_cuda.backward( 38 | grad_attn.contiguous(), *ctx.saved_tensors) 39 | d_query, d_key = outputs 40 | return d_query, d_key, None 41 | 42 | 43 | class CLUSTENAVFunction(Function): 44 | """ 45 | attention times value function 46 | """ 47 | @staticmethod 48 | def forward(ctx, attn, v, nbhd_idx): 49 | attn = attn.contiguous() 50 | v = v.contiguous() 51 | nbhd_idx = nbhd_idx.contiguous() 52 | if attn.dtype != v.dtype: 53 | v = v.to(attn.dtype) 54 | feat = clustenav_cuda.forward( 55 | attn, 56 | v, 57 | nbhd_idx) 58 | ctx.save_for_backward(attn, v, nbhd_idx) 59 | return feat 60 | 61 | @staticmethod 62 | def backward(ctx, grad_feat): 63 | outputs = clustenav_cuda.backward( 64 | grad_feat.contiguous(), *ctx.saved_tensors) 65 | d_attn, d_v = outputs 66 | return d_attn, d_v, None 67 | 68 | 69 | class CLUSTENWFFunction(Function): 70 | """ 71 | weight times feature function 72 | """ 73 | @staticmethod 74 | def forward(ctx, weights, feat, nbhd_idx): 75 | weights = weights.contiguous() 76 | feat = feat.contiguous() 77 | nbhd_idx = nbhd_idx.contiguous() 78 | if feat.dtype != weights.dtype: 79 | feat = feat.to(weights.dtype) 80 | feat_new = clustenwf_cuda.forward( 81 | weights, 82 | feat, 83 | nbhd_idx) 84 | ctx.save_for_backward(weights, feat, nbhd_idx) 85 | return feat_new 86 | 87 | @staticmethod 88 | def backward(ctx, grad_feat_new): 89 | outputs = clustenwf_cuda.backward( 90 | grad_feat_new.contiguous(), *ctx.saved_tensors) 91 | d_weights, d_feat = outputs 92 | return d_weights, d_feat, None 93 | -------------------------------------------------------------------------------- /clusten/src/clustenav_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | #include 8 | 9 | torch::Tensor clusten_av_cuda_forward( 10 | const torch::Tensor &attn, // b x h x n x m 11 | const torch::Tensor &v, // b x h x n x c 12 | const torch::Tensor &nbhd_idx); // b x n x m 13 | 14 | std::vector clusten_av_cuda_backward( 15 | const torch::Tensor &d_feat, 16 | const torch::Tensor &attn, 17 | const torch::Tensor &v, 18 | const torch::Tensor &nbhd_idx); 19 | 20 | // C++ interface 21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | torch::Tensor clusten_av_forward( 26 | const torch::Tensor &attn, 27 | const torch::Tensor &v, 28 | const torch::Tensor &nbhd_idx) { 29 | CHECK_INPUT(attn); 30 | CHECK_INPUT(v); 31 | CHECK_INPUT(nbhd_idx); 32 | return clusten_av_cuda_forward(attn, v, nbhd_idx); 33 | } 34 | 35 | std::vector clusten_av_backward( 36 | const torch::Tensor &d_feat, 37 | const torch::Tensor &attn, 38 | const torch::Tensor &v, 39 | const torch::Tensor &nbhd_idx) { 40 | CHECK_INPUT(d_feat); 41 | CHECK_INPUT(attn); 42 | CHECK_INPUT(v); 43 | CHECK_INPUT(nbhd_idx); 44 | return clusten_av_cuda_backward(d_feat, attn, v, nbhd_idx); 45 | } 46 | 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &clusten_av_forward, "CLUSTENAV forward (CUDA)"); 50 | m.def("backward", &clusten_av_backward, "CLUSTENAV backward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /clusten/src/clustenav_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define CUDA_NUM_THREADS 1024 17 | 18 | template 19 | __global__ void clusten_av_cuda_forward_kernel( 20 | const torch::PackedTensorAccessor32 attn, // b x h x n x m 21 | const torch::PackedTensorAccessor32 v, // b x h x n x c 22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n x m 23 | torch::PackedTensorAccessor32 feat, // b x n x c 24 | const int length, // n 25 | const int batch_size, // b 26 | const int heads, // h 27 | const int nbhd_size, // m 28 | const int dim) { // c 29 | 30 | const int z = blockIdx.z * blockDim.z + threadIdx.z; 31 | if (z < batch_size * heads){ 32 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 33 | if (i < length){ 34 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 35 | if (c < dim){ 36 | const int b = z / heads; 37 | const int h = z - b * heads; 38 | int64_t nbi; 39 | // calculate a@v 40 | scalar_t updt = scalar_t(0); 41 | #pragma unroll 42 | for (unsigned int ni=0; ni < nbhd_size; ++ni) { 43 | nbi = nbhd_idx[b][i][ni]; 44 | updt += attn[b][h][i][ni] * v[b][h][nbi][c]; 45 | } 46 | feat[b][h][i][c] = updt; 47 | } 48 | } 49 | } 50 | } 51 | 52 | 53 | torch::Tensor clusten_av_cuda_forward( 54 | const torch::Tensor &attn, 55 | const torch::Tensor &v, 56 | const torch::Tensor &nbhd_idx) { 57 | 58 | int64_t batch_size = attn.size(0); 59 | int64_t heads = attn.size(1); 60 | int64_t length = attn.size(2); 61 | int64_t dim = v.size(3); 62 | int64_t nbhd_size = nbhd_idx.size(2); 63 | int zsize = batch_size * heads; 64 | 65 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim); 66 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length); 67 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS)); 68 | 69 | auto feat = torch::zeros( 70 | {batch_size, heads, length, dim}, v.options()); 71 | 72 | const auto stream = c10::cuda::getCurrentCUDAStream(); 73 | const dim3 blocks( 74 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS, 75 | (length + TOKENTHREADS - 1) / TOKENTHREADS, 76 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS); 77 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(attn.scalar_type(), "clusten_av_cuda_forward", ([&] { 80 | const auto attn_a = attn.packed_accessor32(); 81 | const auto v_a = v.packed_accessor32(); 82 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 83 | auto feat_a = feat.packed_accessor32(); 84 | 85 | clusten_av_cuda_forward_kernel<<>>( 86 | attn_a, v_a, nbhd_idx_a, feat_a, 87 | length, batch_size, heads, nbhd_size, dim); 88 | })); 89 | return feat; 90 | } 91 | 92 | 93 | template 94 | __global__ void clusten_av_cuda_backward_kernel( 95 | const torch::PackedTensorAccessor32 d_feat, 96 | const torch::PackedTensorAccessor32 attn, 97 | const torch::PackedTensorAccessor32 nbhd_idx, 98 | torch::PackedTensorAccessor32 d_v, 99 | const int length, 100 | const int batch_size, 101 | const int heads, 102 | const int nbhd_size, 103 | const int dim, 104 | const size_t d_v_numel) { 105 | 106 | const int z = blockIdx.z * blockDim.z + threadIdx.z; 107 | if (z < batch_size * heads){ 108 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 109 | if (i < length){ 110 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 111 | if (c < dim){ 112 | const int b = z / heads; 113 | const int h = z - b * heads; 114 | int64_t nbi; 115 | size_t index; 116 | #pragma unroll 117 | for (unsigned int ni=0; ni < nbhd_size; ++ni) { 118 | nbi = nbhd_idx[b][i][ni]; 119 | // calculate d_v = att * d_feat 120 | index = b*d_v.stride(0) + h*d_v.stride(1) + nbi*d_v.stride(2) + c; 121 | at::native::fastAtomicAdd(d_v.data(), index, d_v_numel, d_feat[b][h][i][c] * attn[b][h][i][ni], true); 122 | // atomicAdd(&(d_v[b][h][nbi][c]), d_feat[b][h][i][c] * attn[b][h][i][ni]); // avoid race condition 123 | } 124 | } 125 | } 126 | } 127 | } 128 | 129 | template 130 | __global__ void clusten_av_attn_cuda_backward_kernel( 131 | const torch::PackedTensorAccessor32 d_feat, 132 | const torch::PackedTensorAccessor32 v, 133 | const torch::PackedTensorAccessor32 nbhd_idx, 134 | torch::PackedTensorAccessor32 d_attn, 135 | const int length, 136 | const int batch_size, 137 | const int heads, 138 | const int nbhd_size, 139 | const int dim) { 140 | 141 | const int z = blockIdx.z * blockDim.z + threadIdx.z; 142 | if (z < batch_size * heads){ 143 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 144 | if (i < length){ 145 | const int ni = blockIdx.x * blockDim.x + threadIdx.x; 146 | if (ni < nbhd_size){ 147 | const int b = z / heads; 148 | const int h = z - b * heads; 149 | int64_t nbi = nbhd_idx[b][i][ni]; 150 | scalar_t updt = scalar_t(0); 151 | #pragma unroll 152 | for (unsigned int c=0; c < dim; ++c) { 153 | // calculate d_attn = v * d_feat 154 | updt += v[b][h][nbi][c] * d_feat[b][h][i][c]; 155 | } 156 | d_attn[b][h][i][ni] = updt; 157 | } 158 | } 159 | } 160 | } 161 | 162 | std::vector clusten_av_cuda_backward( 163 | const torch::Tensor &d_feat, 164 | const torch::Tensor &attn, 165 | const torch::Tensor &v, 166 | const torch::Tensor &nbhd_idx) { 167 | 168 | int64_t batch_size = attn.size(0); 169 | int64_t heads = attn.size(1); 170 | int64_t length = attn.size(2); 171 | int64_t dim = v.size(3); 172 | int64_t nbhd_size = nbhd_idx.size(2); 173 | int zsize = batch_size * heads; 174 | 175 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim); 176 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length); 177 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS* CHANNELTHREADS)); 178 | 179 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), nbhd_size); 180 | int TOKENTHREADS_NB = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length); 181 | int BATCHTHREADS_NB = max(1, CUDA_NUM_THREADS / (TOKENTHREADS_NB* NBHDTHREADS)); 182 | 183 | auto d_attn = torch::zeros_like(attn); 184 | auto d_v = torch::zeros_like(v); 185 | 186 | const auto stream = c10::cuda::getCurrentCUDAStream(); 187 | 188 | const dim3 blocks( 189 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS, 190 | (length + TOKENTHREADS - 1) / TOKENTHREADS, 191 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS); 192 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS); 193 | 194 | const dim3 blocks_nb( 195 | (nbhd_size + NBHDTHREADS - 1) / NBHDTHREADS, 196 | (length + TOKENTHREADS_NB - 1) / TOKENTHREADS_NB, 197 | (zsize + BATCHTHREADS_NB - 1) / BATCHTHREADS_NB); 198 | const dim3 threads_nb(NBHDTHREADS, TOKENTHREADS_NB, BATCHTHREADS_NB); 199 | 200 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(attn.scalar_type(), "clusten_av_cuda_backward", ([&] { 201 | const auto d_feat_a = d_feat.packed_accessor32(); 202 | const auto attn_a = attn.packed_accessor32(); 203 | const auto v_a = v.packed_accessor32(); 204 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 205 | auto d_attn_a = d_attn.packed_accessor32(); 206 | auto d_v_a = d_v.packed_accessor32(); 207 | 208 | const size_t d_v_numel = d_v.numel(); 209 | clusten_av_cuda_backward_kernel<<>>( 210 | d_feat_a, attn_a, nbhd_idx_a, d_v_a, 211 | length, batch_size, heads, nbhd_size, dim, d_v_numel); 212 | clusten_av_attn_cuda_backward_kernel<<>>( 213 | d_feat_a, v_a, nbhd_idx_a, d_attn_a, 214 | length, batch_size, heads, nbhd_size, dim); 215 | })); 216 | 217 | return {d_attn, d_v.to(v.dtype())}; 218 | } 219 | -------------------------------------------------------------------------------- /clusten/src/clustenqk_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | #include 8 | 9 | torch::Tensor clusten_qk_cuda_forward( 10 | const torch::Tensor &query, // b x h x n x c 11 | const torch::Tensor &key, // b x h x n x c 12 | const torch::Tensor &nbhd_idx); // b x n x m 13 | 14 | std::vector clusten_qk_cuda_backward( 15 | const torch::Tensor &d_attn, 16 | const torch::Tensor &query, 17 | const torch::Tensor &key, 18 | const torch::Tensor &nbhd_idx); 19 | 20 | // C++ interface 21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | torch::Tensor clusten_qk_forward( 26 | const torch::Tensor &query, 27 | const torch::Tensor &key, 28 | const torch::Tensor &nbhd_idx) { 29 | CHECK_INPUT(query); 30 | CHECK_INPUT(key); 31 | CHECK_INPUT(nbhd_idx); 32 | return clusten_qk_cuda_forward(query, key, nbhd_idx); 33 | } 34 | 35 | std::vector clusten_qk_backward( 36 | const torch::Tensor &d_attn, 37 | const torch::Tensor &query, 38 | const torch::Tensor &key, 39 | const torch::Tensor &nbhd_idx) { 40 | CHECK_INPUT(d_attn); 41 | CHECK_INPUT(query); 42 | CHECK_INPUT(key); 43 | CHECK_INPUT(nbhd_idx); 44 | return clusten_qk_cuda_backward(d_attn, query, key, nbhd_idx); 45 | } 46 | 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &clusten_qk_forward, "CLUSTENQK forward (CUDA)"); 50 | m.def("backward", &clusten_qk_backward, "CLUSTENQK backward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /clusten/src/clustenqk_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define CUDA_NUM_THREADS 1024 17 | 18 | template 19 | __global__ void clusten_qk_cuda_forward_kernel( 20 | const torch::PackedTensorAccessor32 query, // b x h x n x c 21 | const torch::PackedTensorAccessor32 key, // b x h x c x n (reordered by cluster) 22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n x m 23 | torch::PackedTensorAccessor32 attn, // b x h x n x m 24 | const int length, // n 25 | const int batch_size, // b 26 | const int heads, // h 27 | const int nbhd_size, // m 28 | const int dim) { // c 29 | 30 | const int z = blockIdx.z * blockDim.z + threadIdx.z; 31 | if (z < batch_size * heads){ 32 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 33 | if (i < length){ 34 | const int ni = blockIdx.x * blockDim.x + threadIdx.x; 35 | if (ni < nbhd_size){ 36 | const int b = z / heads; 37 | const int h = z - b * heads; 38 | int64_t nbi = nbhd_idx[b][i][ni]; 39 | // calculate q@k 40 | scalar_t updt = scalar_t(0); 41 | #pragma unroll 42 | for (unsigned int c=0; c < dim; ++c) { 43 | updt += query[b][h][i][c] * key[b][h][c][nbi]; 44 | } 45 | attn[b][h][i][ni] = updt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | 52 | torch::Tensor clusten_qk_cuda_forward( 53 | const torch::Tensor &query, 54 | const torch::Tensor &key, 55 | const torch::Tensor &nbhd_idx) { 56 | 57 | int64_t batch_size = query.size(0); 58 | int64_t heads = query.size(1); 59 | int64_t length = query.size(2); 60 | int64_t dim = query.size(3); 61 | int64_t nbhd_size = nbhd_idx.size(2); 62 | int zsize = batch_size * heads; 63 | 64 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), nbhd_size); 65 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length); 66 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * NBHDTHREADS)); 67 | 68 | auto attn = torch::zeros( 69 | {batch_size, heads, length, nbhd_size}, query.options()); 70 | 71 | const auto stream = c10::cuda::getCurrentCUDAStream(); 72 | const dim3 blocks( 73 | (dim + NBHDTHREADS - 1) / NBHDTHREADS, 74 | (length + TOKENTHREADS - 1) / TOKENTHREADS, 75 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS); 76 | const dim3 threads(NBHDTHREADS, TOKENTHREADS, BATCHTHREADS); 77 | 78 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(query.scalar_type(), "clusten_qk_cuda_forward", ([&] { 79 | const auto query_a = query.packed_accessor32(); 80 | const auto key_a = key.packed_accessor32(); 81 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 82 | auto attn_a = attn.packed_accessor32(); 83 | 84 | clusten_qk_cuda_forward_kernel<<>>( 85 | query_a, key_a, nbhd_idx_a, attn_a, 86 | length, batch_size, heads, nbhd_size, dim); 87 | })); 88 | return attn; 89 | } 90 | 91 | template 92 | __global__ void clusten_qk_cuda_backward_kernel( 93 | const torch::PackedTensorAccessor32 d_attn, 94 | const torch::PackedTensorAccessor32 query, 95 | const torch::PackedTensorAccessor32 key, 96 | const torch::PackedTensorAccessor32 nbhd_idx, 97 | torch::PackedTensorAccessor32 d_query, 98 | torch::PackedTensorAccessor32 d_key, 99 | const int length, 100 | const int batch_size, 101 | const int heads, 102 | const int nbhd_size, 103 | const int dim, 104 | const size_t d_key_numel) { 105 | 106 | const int z = blockIdx.z * blockDim.z + threadIdx.z; 107 | if (z < batch_size * heads){ 108 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 109 | if (i < length){ 110 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 111 | if (c < dim){ 112 | const int b = z / heads; 113 | const int h = z - b * heads; 114 | size_t index; 115 | scalar_t dq_update = scalar_t(0); 116 | scalar_t d_attn_tmp; 117 | #pragma unroll 118 | for (unsigned int ni=0; ni < nbhd_size; ++ni) { 119 | const int64_t nbi = nbhd_idx[b][i][ni]; 120 | // calculate d_query = key * d_att 121 | // calculate d_key = query * d_att 122 | d_attn_tmp = d_attn[b][h][i][ni]; 123 | dq_update += key[b][h][nbi][c] * d_attn_tmp; 124 | index = b*d_key.stride(0) + h*d_key.stride(1) + nbi*d_key.stride(2) + c; 125 | at::native::fastAtomicAdd(d_key.data(), index, d_key_numel, query[b][h][i][c] * d_attn_tmp, true); 126 | //atomicAdd(&(d_key[b][h][nbi][c]), query[b][h][i][c] * d_attn_tmp); // avoid race condition 127 | } 128 | d_query[b][h][i][c] = dq_update; 129 | } 130 | } 131 | } 132 | } 133 | 134 | std::vector clusten_qk_cuda_backward( 135 | const torch::Tensor &d_attn, 136 | const torch::Tensor &query, 137 | const torch::Tensor &key, 138 | const torch::Tensor &nbhd_idx) { 139 | 140 | int64_t batch_size = query.size(0); 141 | int64_t heads = query.size(1); 142 | int64_t length = query.size(2); 143 | int64_t dim = query.size(3); 144 | int64_t nbhd_size = nbhd_idx.size(2); 145 | int zsize = batch_size * heads; 146 | 147 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim); 148 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length); 149 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS)); 150 | 151 | auto d_query = torch::zeros_like(query); 152 | auto d_key = torch::zeros_like(key); 153 | 154 | const auto stream = c10::cuda::getCurrentCUDAStream(); 155 | 156 | const dim3 blocks( 157 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS, 158 | (length + TOKENTHREADS - 1) / TOKENTHREADS, 159 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS); 160 | 161 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS); 162 | 163 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(query.scalar_type(), "clusten_qk_cuda_backward", ([&] { 164 | const auto d_attn_a = d_attn.packed_accessor32(); 165 | const auto query_a = query.packed_accessor32(); 166 | const auto key_a = key.packed_accessor32(); 167 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 168 | auto d_query_a = d_query.packed_accessor32(); 169 | auto d_key_a = d_key.packed_accessor32(); 170 | 171 | const size_t d_key_numel = d_key.numel(); 172 | clusten_qk_cuda_backward_kernel<<>>( 173 | d_attn_a, query_a, key_a, nbhd_idx_a, d_query_a, d_key_a, 174 | length, batch_size, heads, nbhd_size, dim, d_key_numel); 175 | })); 176 | 177 | return {d_query, d_key.to(key.dtype())}; 178 | } 179 | -------------------------------------------------------------------------------- /clusten/src/clustenwf_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | #include 8 | 9 | torch::Tensor clusten_wf_cuda_forward( 10 | const torch::Tensor &weights, // b x n_ x m x ic 11 | const torch::Tensor &feat, // b x n x c 12 | const torch::Tensor &nbhd_idx); // b x n_ x m 13 | 14 | std::vector clusten_wf_cuda_backward( 15 | const torch::Tensor &d_feat_new, 16 | const torch::Tensor &weights, 17 | const torch::Tensor &feat, 18 | const torch::Tensor &nbhd_idx); 19 | 20 | // C++ interface 21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | torch::Tensor clusten_wf_forward( 26 | const torch::Tensor &weights, 27 | const torch::Tensor &feat, 28 | const torch::Tensor &nbhd_idx) { 29 | CHECK_INPUT(weights); 30 | CHECK_INPUT(feat); 31 | CHECK_INPUT(nbhd_idx); 32 | return clusten_wf_cuda_forward(weights, feat, nbhd_idx); 33 | } 34 | 35 | std::vector clusten_wf_backward( 36 | const torch::Tensor &d_feat_new, 37 | const torch::Tensor &weights, 38 | const torch::Tensor &feat, 39 | const torch::Tensor &nbhd_idx) { 40 | CHECK_INPUT(d_feat_new); 41 | CHECK_INPUT(weights); 42 | CHECK_INPUT(feat); 43 | CHECK_INPUT(nbhd_idx); 44 | return clusten_wf_cuda_backward(d_feat_new, weights, feat, nbhd_idx); 45 | } 46 | 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &clusten_wf_forward, "CLUSTENWF forward (CUDA)"); 50 | m.def("backward", &clusten_wf_backward, "CLUSTENWF backward (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /clusten/src/clustenwf_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * For licensing see accompanying LICENSE file. 3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | */ 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define CUDA_NUM_THREADS 1024 17 | 18 | template 19 | __global__ void clusten_wf_cuda_forward_kernel( 20 | const torch::PackedTensorAccessor32 weights, // b x n_ x m x ic 21 | const torch::PackedTensorAccessor32 feat, // b x n x c 22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n_ x m 23 | torch::PackedTensorAccessor32 feat_new, // b x n_ x ic x c 24 | const int length, // n 25 | const int length_out, // n_ 26 | const int batch_size, // b 27 | const int nbhd_size, // m 28 | const int dim, // c 29 | const int dim_inner) { // ic 30 | 31 | const int b = blockIdx.z * blockDim.z + threadIdx.z; 32 | if (b < batch_size){ 33 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 34 | if (i < length_out){ 35 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 36 | if (c < dim){ 37 | int64_t nbi; 38 | // calculate weights@feat 39 | scalar_t updt; 40 | #pragma unroll 41 | for (unsigned int ic=0; ic < dim_inner; ++ic) { 42 | updt = scalar_t(0); 43 | #pragma unroll 44 | for (unsigned int ni=0; ni < nbhd_size; ++ni) { 45 | nbi = nbhd_idx[b][i][ni]; 46 | updt += weights[b][i][ni][ic] * feat[b][nbi][c]; 47 | } 48 | feat_new[b][i][ic][c] = updt; 49 | } 50 | } 51 | } 52 | } 53 | } 54 | 55 | 56 | torch::Tensor clusten_wf_cuda_forward( 57 | const torch::Tensor &weights, 58 | const torch::Tensor &feat, 59 | const torch::Tensor &nbhd_idx) { 60 | 61 | int64_t batch_size = weights.size(0); 62 | int64_t length_out = weights.size(1); 63 | int64_t nbhd_size = weights.size(2); 64 | int64_t dim_inner = weights.size(3); 65 | int64_t length = feat.size(1); 66 | int64_t dim = feat.size(2); 67 | 68 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim); 69 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length_out); 70 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS)); 71 | 72 | auto feat_new = torch::zeros( 73 | {batch_size, length_out, dim_inner, dim}, weights.options()); 74 | 75 | const auto stream = c10::cuda::getCurrentCUDAStream(); 76 | const dim3 blocks( 77 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS, 78 | (length_out + TOKENTHREADS - 1) / TOKENTHREADS, 79 | (batch_size + BATCHTHREADS - 1) / BATCHTHREADS); 80 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS); 81 | 82 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(weights.scalar_type(), "clusten_wf_cuda_forward", ([&] { 83 | const auto weights_a = weights.packed_accessor32(); 84 | const auto feat_a = feat.packed_accessor32(); 85 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 86 | auto feat_new_a = feat_new.packed_accessor32(); 87 | 88 | clusten_wf_cuda_forward_kernel<<>>( 89 | weights_a, feat_a, nbhd_idx_a, feat_new_a, 90 | length, length_out, batch_size, nbhd_size, dim, dim_inner); 91 | })); 92 | return feat_new; 93 | } 94 | 95 | 96 | template 97 | __global__ void clusten_wf_cuda_backward_kernel( 98 | const torch::PackedTensorAccessor32 d_feat_new, 99 | const torch::PackedTensorAccessor32 weights, 100 | const torch::PackedTensorAccessor32 nbhd_idx, 101 | torch::PackedTensorAccessor32 d_feat, 102 | const int length, // n 103 | const int length_out, // n_ 104 | const int batch_size, // b 105 | const int nbhd_size, // m 106 | const int dim, // c 107 | const int dim_inner, // ic 108 | const size_t d_feat_numel) { 109 | 110 | const int b = blockIdx.z * blockDim.z + threadIdx.z; 111 | if (b < batch_size){ 112 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 113 | if (i < length_out){ 114 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 115 | if (c < dim){ 116 | int64_t nbi; 117 | size_t index; 118 | scalar_t updt; 119 | #pragma unroll 120 | for (unsigned int ni=0; ni < nbhd_size; ++ni) { 121 | nbi = nbhd_idx[b][i][ni]; 122 | updt = scalar_t(0); 123 | // calculate d_feat = weights * d_feat_new 124 | #pragma unroll 125 | for (unsigned int ic=0; ic < dim_inner; ++ic) { 126 | updt += d_feat_new[b][i][ic][c] * weights[b][i][ni][ic]; 127 | } 128 | index = b*d_feat.stride(0) + nbi*d_feat.stride(1) + c; 129 | at::native::fastAtomicAdd(d_feat.data(), index, d_feat_numel, updt, true); 130 | // atomicAdd(&(d_feat[b][nbi][c]), updt); // avoid race condition 131 | } 132 | } 133 | } 134 | } 135 | } 136 | 137 | template 138 | __global__ void clusten_wf_weights_cuda_backward_kernel( 139 | const torch::PackedTensorAccessor32 d_feat_new, 140 | const torch::PackedTensorAccessor32 feat, 141 | const torch::PackedTensorAccessor32 nbhd_idx, 142 | torch::PackedTensorAccessor32 d_weights, 143 | const int length, // n 144 | const int length_out, // n_ 145 | const int batch_size, // b 146 | const int nbhd_size, // m 147 | const int dim, // c 148 | const int dim_inner){ // ic 149 | 150 | const int b = blockIdx.z * blockDim.z + threadIdx.z; 151 | if (b < batch_size){ 152 | const int i = blockIdx.y * blockDim.y + threadIdx.y; 153 | if (i < length_out){ 154 | const int z = blockIdx.x * blockDim.x + threadIdx.x; 155 | if (z < nbhd_size * dim_inner){ 156 | const int ni = z / dim_inner; 157 | const int ic = z - ni * dim_inner; 158 | int64_t nbi = nbhd_idx[b][i][ni]; 159 | scalar_t updt = scalar_t(0); 160 | #pragma unroll 161 | for (unsigned int c=0; c < dim; ++c) { 162 | // calculate d_weights = feat * d_feat_new 163 | updt += feat[b][nbi][c] * d_feat_new[b][i][ic][c]; 164 | } 165 | d_weights[b][i][ni][ic] = updt; 166 | } 167 | } 168 | } 169 | } 170 | 171 | std::vector clusten_wf_cuda_backward( 172 | const torch::Tensor &d_feat_new, 173 | const torch::Tensor &weights, 174 | const torch::Tensor &feat, 175 | const torch::Tensor &nbhd_idx) { 176 | 177 | int64_t batch_size = weights.size(0); 178 | int64_t length_out = weights.size(1); 179 | int64_t nbhd_size = weights.size(2); 180 | int64_t dim_inner = weights.size(3); 181 | int64_t length = feat.size(1); 182 | int64_t dim = feat.size(2); 183 | 184 | int64_t zsize = nbhd_size * dim_inner; 185 | 186 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim); 187 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length_out); 188 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS* CHANNELTHREADS)); 189 | 190 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), zsize); 191 | int TOKENTHREADS_NB = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length_out); 192 | int BATCHTHREADS_NB = max(1, CUDA_NUM_THREADS / (TOKENTHREADS_NB* NBHDTHREADS)); 193 | 194 | auto d_weights = torch::zeros_like(weights); 195 | auto d_feat = torch::zeros_like(feat); 196 | 197 | const auto stream = c10::cuda::getCurrentCUDAStream(); 198 | 199 | const dim3 blocks( 200 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS, 201 | (length_out + TOKENTHREADS - 1) / TOKENTHREADS, 202 | (batch_size + BATCHTHREADS - 1) / BATCHTHREADS); 203 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS); 204 | 205 | const dim3 blocks_nb( 206 | (zsize + NBHDTHREADS - 1) / NBHDTHREADS, 207 | (length_out + TOKENTHREADS_NB - 1) / TOKENTHREADS_NB, 208 | (batch_size + BATCHTHREADS_NB - 1) / BATCHTHREADS_NB); 209 | const dim3 threads_nb(NBHDTHREADS, TOKENTHREADS_NB, BATCHTHREADS_NB); 210 | 211 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(weights.scalar_type(), "clusten_wf_cuda_backward", ([&] { 212 | const auto d_feat_new_a = d_feat_new.packed_accessor32(); 213 | const auto weights_a = weights.packed_accessor32(); 214 | const auto feat_a = feat.packed_accessor32(); 215 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32(); 216 | auto d_weights_a = d_weights.packed_accessor32(); 217 | auto d_feat_a = d_feat.packed_accessor32(); 218 | 219 | const size_t d_feat_numel = d_feat.numel(); 220 | clusten_wf_cuda_backward_kernel<<>>( 221 | d_feat_new_a, weights_a, nbhd_idx_a, d_feat_a, 222 | length, length_out, batch_size, nbhd_size, dim, dim_inner, d_feat_numel); 223 | clusten_wf_weights_cuda_backward_kernel<<>>( 224 | d_feat_new_a, feat_a, nbhd_idx_a, d_weights_a, 225 | length, length_out, batch_size, nbhd_size, dim, dim_inner); 226 | })); 227 | 228 | return {d_weights, d_feat.to(feat.dtype())}; 229 | } 230 | -------------------------------------------------------------------------------- /clusten/src/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | 9 | setup( 10 | name='clustencuda', 11 | version='0.1', 12 | author='Ziwen Chen', 13 | author_email='chenziw@oregonstate.edu', 14 | description='Cluster Attention CUDA Kernel', 15 | ext_modules=[ 16 | CUDAExtension('clustenqk_cuda', [ 17 | 'clustenqk_cuda.cpp', 18 | 'clustenqk_cuda_kernel.cu', 19 | ]), 20 | CUDAExtension('clustenav_cuda', [ 21 | 'clustenav_cuda.cpp', 22 | 'clustenav_cuda_kernel.cu', 23 | ]), 24 | CUDAExtension('clustenwf_cuda', [ 25 | 'clustenwf_cuda.cpp', 26 | 'clustenwf_cuda_kernel.cu', 27 | ]), 28 | ], 29 | cmdclass={ 30 | 'build_ext': BuildExtension 31 | }) 32 | -------------------------------------------------------------------------------- /clusten/test_av_kernel.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn 8 | from clusten import CLUSTENAVFunction 9 | 10 | """ 11 | Test the correctness of AV custom kernel 12 | """ 13 | 14 | b = 256 15 | h = 4 16 | n = 196 17 | m = 48 18 | c = 32 19 | 20 | # dummy data 21 | attn = nn.Parameter(torch.randn(b, h, n, m)).cuda() 22 | attn.retain_grad() 23 | val = nn.Parameter(torch.randn(b, h, n, c)).cuda() 24 | val.retain_grad() 25 | nn_idx = torch.randint(n, (b, n, m)).cuda() 26 | 27 | # use the custom kernel 28 | feat = CLUSTENAVFunction.apply(attn, val, nn_idx) 29 | feat.mean().backward() 30 | grad_attn = attn.grad.clone().detach() 31 | attn.grad.data.zero_() 32 | grad_val = val.grad.clone().detach() 33 | val.grad.data.zero_() 34 | 35 | # use the pytorch equivalent 36 | ''' 37 | feat2 = CLUSTENAVFunction.apply(attn,val,nn_idx) 38 | ''' 39 | val_gather = val.gather(index=nn_idx.reshape(b, 1, -1, 1).expand(-1, h, -1, c), dim=2).reshape(b, h, n, m, c) 40 | feat2 = (attn.unsqueeze(4) * val_gather).sum(3) 41 | feat2.mean().backward() 42 | grad_attn2 = attn.grad.clone().detach() 43 | attn.grad.data.zero_() 44 | grad_val2 = val.grad.clone().detach() 45 | val.grad.data.zero_() 46 | 47 | 48 | print('diff of forward: ', torch.linalg.norm(feat2 - feat)) 49 | print('diff of grad attn: ', torch.linalg.norm(grad_attn2 - grad_attn)) 50 | print('diff of grad val: ', torch.linalg.norm(grad_val2 - grad_val)) 51 | -------------------------------------------------------------------------------- /clusten/test_qk_kernel.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn 8 | from clusten import CLUSTENQKFunction 9 | 10 | """ 11 | Test the correctness of QK custom kernel 12 | """ 13 | 14 | b = 256 15 | h = 4 16 | n = 196 17 | m = 48 18 | c = 32 19 | 20 | # dummy data 21 | query = nn.Parameter(torch.randn(b, h, n, c)).cuda() 22 | query.retain_grad() 23 | key = nn.Parameter(torch.randn(b, h, n, c)).cuda() 24 | key.retain_grad() 25 | nn_idx = torch.randint(n, (b, n, m)).cuda() 26 | 27 | # use the custom kernel 28 | attn = CLUSTENQKFunction.apply(query, key, nn_idx) 29 | attn.mean().backward() 30 | grad_query = query.grad.clone().detach() 31 | query.grad.data.zero_() 32 | grad_key = key.grad.clone().detach() 33 | key.grad.data.zero_() 34 | 35 | # use the pytorch equivalent 36 | ''' 37 | attn2 = CLUSTENQKFunction.apply(query, key, nn_idx) 38 | ''' 39 | key_gather = key.gather(index=nn_idx.reshape(b, 1, -1, 1).expand(-1, h, -1, c), dim=2).reshape(b, h, n, m, c) 40 | attn2 = (query.unsqueeze(3) * key_gather).sum(-1) 41 | attn2.mean().backward() 42 | grad_query2 = query.grad.clone().detach() 43 | query.grad.data.zero_() 44 | grad_key2 = key.grad.clone().detach() 45 | key.grad.data.zero_() 46 | 47 | 48 | print('diff of forward: ', torch.linalg.norm(attn2 - attn)) 49 | print('diff of grad query: ', torch.linalg.norm(grad_query2 - grad_query)) 50 | print('diff of grad key: ', torch.linalg.norm(grad_key2 - grad_key)) 51 | -------------------------------------------------------------------------------- /clusten/test_wf_kernel.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | from torch import nn 8 | from clusten import CLUSTENWFFunction 9 | 10 | """ 11 | Test the correctness of WF custom kernel 12 | """ 13 | 14 | b = 256 15 | n_ = 64 16 | n = 196 17 | m = 48 18 | c = 32 19 | ic = 4 20 | 21 | # dummy data 22 | weights = nn.Parameter(torch.randn(b, n_, m, ic)).cuda() 23 | weights.retain_grad() 24 | feat = nn.Parameter(torch.randn(b, n, c)).cuda() 25 | feat.retain_grad() 26 | nn_idx = torch.randint(n, (b, n_, m)).cuda() 27 | 28 | # use the custom kernel 29 | feat_new = CLUSTENWFFunction.apply(weights, feat, nn_idx) 30 | feat_new.mean().backward() 31 | grad_weights = weights.grad.clone().detach() 32 | weights.grad.data.zero_() 33 | grad_feat = feat.grad.clone().detach() 34 | feat.grad.data.zero_() 35 | 36 | # use the pytorch equivalent 37 | ''' 38 | feat_new2 = CLUSTENWFFunction.apply(weights,feat,nn_idx) 39 | ''' 40 | feat_gather = feat.gather(index=nn_idx.reshape(b, -1, 1).expand(-1, -1, c), dim=1).reshape(b, n_, m, c) 41 | feat_new2 = weights.transpose(-1, -2) @ feat_gather 42 | feat_new2.mean().backward() 43 | grad_weights2 = weights.grad.clone().detach() 44 | weights.grad.data.zero_() 45 | grad_feat2 = feat.grad.clone().detach() 46 | feat.grad.data.zero_() 47 | 48 | 49 | print('diff of forward: ', torch.linalg.norm(feat_new2 - feat_new)) 50 | print('diff of grad weights: ', torch.linalg.norm(grad_weights2 - grad_weights)) 51 | print('diff of grad feat: ', torch.linalg.norm(grad_feat2 - grad_feat)) 52 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | import os 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | _C = CN() 14 | 15 | # Base config files 16 | _C.BASE = [''] 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Data settings 20 | # ----------------------------------------------------------------------------- 21 | _C.DATA = CN() 22 | # Batch size for a single GPU, could be overwritten by command line argument 23 | _C.DATA.BATCH_SIZE = 128 24 | # Path to dataset, could be overwritten by command line argument 25 | _C.DATA.DATA_PATH = 'imagenet' 26 | # Dataset name 27 | _C.DATA.DATASET = 'imagenet' 28 | # Input image size 29 | _C.DATA.IMG_SIZE = 224 30 | # Input channels 31 | _C.DATA.IN_CHANS = 3 32 | # Interpolation to resize image (random, bilinear, bicubic) 33 | _C.DATA.INTERPOLATION = 'bicubic' 34 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 35 | _C.DATA.PIN_MEMORY = True 36 | # Number of data loading threads 37 | _C.DATA.NUM_WORKERS = 8 38 | 39 | # ----------------------------------------------------------------------------- 40 | # Model settings 41 | # ----------------------------------------------------------------------------- 42 | _C.MODEL = CN() 43 | # Model type 44 | _C.MODEL.TYPE = 'aff' 45 | # Model name 46 | _C.MODEL.NAME = 'aff_mini_1_4th' 47 | # Checkpoint to resume, could be overwritten by command line argument 48 | _C.MODEL.RESUME = '' 49 | # Number of classes, overwritten in data preparation 50 | _C.MODEL.NUM_CLASSES = 1000 51 | # Dropout rate 52 | _C.MODEL.DROP_RATE = 0.0 53 | # Drop path rate 54 | _C.MODEL.DROP_PATH_RATE = 0.0 55 | # Label Smoothing 56 | _C.MODEL.LABEL_SMOOTHING = 0.1 57 | 58 | # AFF parameters 59 | _C.MODEL.AFF = CN() 60 | _C.MODEL.AFF.DEPTHS = [2, 2, 6, 2] 61 | _C.MODEL.AFF.NUM_HEADS = [2, 4, 8, 16] 62 | _C.MODEL.AFF.EMBED_DIM = [32, 128, 256, 384] 63 | _C.MODEL.AFF.MLP_RATIO = 2. 64 | _C.MODEL.AFF.PATCH_NORM = True 65 | 66 | _C.MODEL.AFF.CLUSTER_SIZE = 8 67 | _C.MODEL.AFF.NBHD_SIZE = [48, 48, 48, 49] 68 | _C.MODEL.AFF.ALPHA = 4.0 69 | _C.MODEL.AFF.DS_RATE = 0.25 70 | _C.MODEL.AFF.LAYER_SCALE = 0.0 71 | _C.MODEL.AFF.RESERVE = True 72 | 73 | # ----------------------------------------------------------------------------- 74 | # Training settings 75 | # ----------------------------------------------------------------------------- 76 | _C.TRAIN = CN() 77 | _C.TRAIN.START_EPOCH = 0 78 | _C.TRAIN.EPOCHS = 300 79 | _C.TRAIN.WARMUP_EPOCHS = 20 80 | _C.TRAIN.COOLDOWN_EPOCHS = 0 81 | _C.TRAIN.WEIGHT_DECAY = 0.05 82 | _C.TRAIN.BASE_LR = 5e-4 83 | _C.TRAIN.WARMUP_LR = 5e-7 84 | _C.TRAIN.MIN_LR = 5e-6 85 | # EMA 86 | _C.TRAIN.USE_EMA = False 87 | _C.TRAIN.EMA_DECAY = 0.9998 88 | 89 | # Clip gradient norm 90 | _C.TRAIN.CLIP_GRAD = 5.0 91 | # Auto resume from latest checkpoint 92 | _C.TRAIN.AUTO_RESUME = True 93 | # Gradient accumulation steps 94 | # could be overwritten by command line argument 95 | _C.TRAIN.ACCUMULATION_STEPS = 0 96 | 97 | # LR scheduler 98 | _C.TRAIN.LR_SCHEDULER = CN() 99 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 100 | # Epoch interval to decay LR, used in StepLRScheduler 101 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 102 | # LR decay rate, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 104 | 105 | # Optimizer 106 | _C.TRAIN.OPTIMIZER = CN() 107 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 108 | # Optimizer Epsilon 109 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 110 | # Optimizer Betas 111 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 112 | # SGD momentum 113 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 114 | 115 | # ----------------------------------------------------------------------------- 116 | # Augmentation settings 117 | # ----------------------------------------------------------------------------- 118 | _C.AUG = CN() 119 | # Color jitter factor 120 | _C.AUG.COLOR_JITTER = 0.4 121 | # Use AutoAugment policy. "v0" or "original" 122 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 123 | # Random erase prob 124 | _C.AUG.REPROB = 0.25 125 | # Random erase mode 126 | _C.AUG.REMODE = 'pixel' 127 | # Random erase count 128 | _C.AUG.RECOUNT = 1 129 | # Mixup alpha, mixup enabled if > 0 130 | _C.AUG.MIXUP = 0.0 # 0.8 131 | # Cutmix alpha, cutmix enabled if > 0 132 | _C.AUG.CUTMIX = 0.0 # 1.0 133 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 134 | _C.AUG.CUTMIX_MINMAX = None 135 | # Probability of performing mixup or cutmix when either/both is enabled 136 | _C.AUG.MIXUP_PROB = 1.0 137 | # Probability of switching to cutmix when both mixup and cutmix enabled 138 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 139 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 140 | _C.AUG.MIXUP_MODE = 'batch' 141 | 142 | # ----------------------------------------------------------------------------- 143 | # Testing settings 144 | # ----------------------------------------------------------------------------- 145 | _C.TEST = CN() 146 | # Whether to use center crop when testing 147 | _C.TEST.CROP = True 148 | 149 | # ----------------------------------------------------------------------------- 150 | # Misc 151 | # ----------------------------------------------------------------------------- 152 | # Pytorch native amp, overwritten by command line argument 153 | _C.AMP_ENABLE = False 154 | # Path to output folder, overwritten by command line argument 155 | _C.OUTPUT = '' 156 | # Tag of experiment, overwritten by command line argument 157 | _C.TAG = 'default' 158 | # Frequency to save checkpoint (epochs) 159 | _C.SAVE_FREQ = 1 160 | # Frequency to logging info 161 | _C.PRINT_FREQ = 10 162 | # Frequency to validate (epochs) 163 | _C.EVAL_FREQ = 1 164 | # Fixed random seed 165 | _C.SEED = 0 166 | # Perform evaluation only, overwritten by command line argument 167 | _C.EVAL_MODE = False 168 | # Test throughput only, overwritten by command line argument 169 | _C.THROUGHPUT_MODE = False 170 | # local rank for DistributedDataParallel, given by command line argument 171 | _C.LOCAL_RANK = 0 172 | 173 | 174 | def _update_config_from_file(config, cfg_file): 175 | config.defrost() 176 | with open(cfg_file, 'r') as f: 177 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 178 | 179 | for cfg in yaml_cfg.setdefault('BASE', ['']): 180 | if cfg: 181 | _update_config_from_file( 182 | config, os.path.join(os.path.dirname(cfg_file), cfg) 183 | ) 184 | print('=> merge config from {}'.format(cfg_file)) 185 | config.merge_from_file(cfg_file) 186 | config.freeze() 187 | 188 | 189 | def update_config(config, args): 190 | _update_config_from_file(config, args.cfg) 191 | 192 | config.defrost() 193 | if args.opts: 194 | config.merge_from_list(args.opts) 195 | 196 | # merge from specific arguments 197 | if args.batch_size: 198 | config.DATA.BATCH_SIZE = args.batch_size 199 | if args.data_path: 200 | config.DATA.DATA_PATH = args.data_path 201 | if args.blr: 202 | config.TRAIN.BASE_LR = args.blr 203 | if args.resume: 204 | config.MODEL.RESUME = args.resume 205 | if args.accumulation_steps: 206 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 207 | if args.output: 208 | config.OUTPUT = args.output 209 | if args.tag: 210 | config.TAG = args.tag 211 | if args.eval: 212 | config.EVAL_MODE = True 213 | if args.throughput: 214 | config.THROUGHPUT_MODE = True 215 | if args.epochs: 216 | config.TRAIN.EPOCHS = args.epochs 217 | 218 | # set local rank for distributed training 219 | config.LOCAL_RANK = args.local_rank 220 | 221 | # output folder 222 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 223 | 224 | config.freeze() 225 | 226 | 227 | def get_config(args): 228 | """Get a yacs CfgNode object with default values.""" 229 | # Return a clone so that the defaults will not be altered 230 | # This is for the "local variable" use pattern 231 | config = _C.clone() 232 | update_config(config, args) 233 | 234 | return config 235 | -------------------------------------------------------------------------------- /configs/aff_base_22k.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_base_22k 4 | DROP_PATH_RATE: 0.2 5 | AFF: 6 | DEPTHS: [3,4,18,2] 7 | NUM_HEADS: [4,8,16,32] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [128, 256, 512, 1024] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | LAYER_SCALE: 1e-5 # turned off if 0.0 13 | ALPHA: 4.0 14 | DS_RATE: 0.25 15 | DATA: 16 | DATASET: imagenet22k 17 | IMG_SIZE: 224 18 | BATCH_SIZE: 64 19 | DATA_PATH: path/to/22k 20 | TRAIN: 21 | EPOCHS: 90 22 | WARMUP_EPOCHS: 5 23 | WEIGHT_DECAY: 0.05 24 | BASE_LR: 5e-4 25 | MIN_LR: 1.25e-6 26 | USE_EMA: False 27 | EMA_DECAY: 0.9998 28 | WARMUP_LR: 1.25e-7 29 | ACCUMULATION_STEPS: 1 30 | AUG: 31 | MIXUP: 0.8 32 | CUTMIX: 1.0 33 | -------------------------------------------------------------------------------- /configs/aff_base_22kto1k.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_base_22kto1k 4 | DROP_PATH_RATE: 0.2 5 | AFF: 6 | DEPTHS: [3,4,18,2] 7 | NUM_HEADS: [4,8,16,32] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [128, 256, 512, 1024] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | LAYER_SCALE: 1e-5 # turned off if 0.0 13 | ALPHA: 4.0 14 | DS_RATE: 0.25 15 | PRETRAINED: aff_base_22k.pth 16 | DATA: 17 | DATASET: imagenet 18 | BATCH_SIZE: 64 19 | TRAIN: 20 | EPOCHS: 30 21 | WARMUP_EPOCHS: 5 22 | WEIGHT_DECAY: 1e-8 23 | BASE_LR: 2e-05 24 | MIN_LR: 2e-07 25 | USE_EMA: False 26 | EMA_DECAY: 0.9998 27 | WARMUP_LR: 2e-08 28 | AUG: 29 | MIXUP: 0.8 30 | CUTMIX: 1.0 31 | -------------------------------------------------------------------------------- /configs/aff_base_22kto1k_384.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_base_22kto1k_384 4 | DROP_PATH_RATE: 0.2 5 | AFF: 6 | DEPTHS: [3,4,18,2] 7 | NUM_HEADS: [4,8,16,32] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [128, 256, 512, 1024] 10 | CLUSTER_SIZE: 24 11 | NBHD_SIZE: [144,144,144,144] 12 | LAYER_SCALE: 1e-5 # turned off if 0.0 13 | ALPHA: 4.0 14 | DS_RATE: 0.25 15 | PRETRAINED: aff_base_22k.pth 16 | DATA: 17 | DATASET: imagenet 18 | IMG_SIZE: 384 19 | BATCH_SIZE: 16 20 | TRAIN: 21 | EPOCHS: 30 22 | WARMUP_EPOCHS: 5 23 | WEIGHT_DECAY: 1e-8 24 | BASE_LR: 2e-05 25 | MIN_LR: 2e-07 26 | USE_EMA: False 27 | EMA_DECAY: 0.9998 28 | WARMUP_LR: 2e-08 29 | ACCUMULATION_STEPS: 4 30 | AUG: 31 | MIXUP: 0.8 32 | CUTMIX: 1.0 33 | TEST: 34 | CROP: False 35 | -------------------------------------------------------------------------------- /configs/aff_mini.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_mini_1_4th 4 | DROP_PATH_RATE: 0.0 5 | AFF: 6 | DEPTHS: [ 2, 2, 6, 2] 7 | NUM_HEADS: [ 2, 4, 8, 16 ] 8 | MLP_RATIO: 2. 9 | EMBED_DIM: [32,128,256,384] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | ALPHA: 4.0 13 | DS_RATE: 0.25 14 | DATA: 15 | DATASET: imagenet 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 1024 18 | TRAIN: 19 | EPOCHS: 300 20 | BASE_LR: 5e-4 21 | MIN_LR: 5e-6 22 | WARMUP_LR: 5e-7 23 | AUG: 24 | MIXUP: 0.0 25 | CUTMIX: 0.0 26 | -------------------------------------------------------------------------------- /configs/aff_mini_1_5th.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_mini_1_5th 4 | DROP_PATH_RATE: 0.0 5 | AFF: 6 | DEPTHS: [ 2, 2, 6, 2] 7 | NUM_HEADS: [ 2, 4, 8, 16 ] 8 | MLP_RATIO: 2. 9 | EMBED_DIM: [32,128,256,384] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | ALPHA: 4.0 13 | DS_RATE: 0.2 14 | DATA: 15 | DATASET: imagenet 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 1024 18 | TRAIN: 19 | EPOCHS: 300 20 | BASE_LR: 5e-4 21 | MIN_LR: 5e-6 22 | WARMUP_LR: 5e-7 23 | AUG: 24 | MIXUP: 0.0 25 | CUTMIX: 0.0 26 | -------------------------------------------------------------------------------- /configs/aff_small.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_small_1_4th 4 | DROP_PATH_RATE: 0.3 5 | AFF: 6 | DEPTHS: [3,4,18,2] 7 | NUM_HEADS: [3,6,12,24] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [96,192,384,768] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | LAYER_SCALE: 1e-5 # turned off if 0.0 13 | ALPHA: 4.0 14 | DS_RATE: 0.25 15 | DATA: 16 | DATASET: imagenet 17 | IMG_SIZE: 224 18 | BATCH_SIZE: 256 19 | TRAIN: 20 | EPOCHS: 300 21 | BASE_LR: 5e-4 22 | MIN_LR: 5e-6 23 | WARMUP_LR: 5e-7 24 | AUG: 25 | MIXUP: 0.8 26 | CUTMIX: 1.0 27 | -------------------------------------------------------------------------------- /configs/aff_small_1_5th.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_small_1_5th 4 | DROP_PATH_RATE: 0.3 5 | AFF: 6 | DEPTHS: [3,4,18,2] 7 | NUM_HEADS: [3,6,12,24] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [96,192,384,768] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | LAYER_SCALE: 1e-5 # turned off if 0.0 13 | ALPHA: 4.0 14 | DS_RATE: 0.2 15 | DATA: 16 | DATASET: imagenet 17 | IMG_SIZE: 224 18 | BATCH_SIZE: 256 19 | TRAIN: 20 | EPOCHS: 300 21 | BASE_LR: 5e-4 22 | MIN_LR: 5e-6 23 | WARMUP_LR: 5e-7 24 | AUG: 25 | MIXUP: 0.8 26 | CUTMIX: 1.0 27 | -------------------------------------------------------------------------------- /configs/aff_tiny.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_tiny_1_4th 4 | DROP_PATH_RATE: 0.2 5 | AFF: 6 | DEPTHS: [3,4,18,5] 7 | NUM_HEADS: [2,4,8,16] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [64,128,256,512] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | ALPHA: 4.0 13 | DS_RATE: 0.25 14 | DATA: 15 | DATASET: imagenet 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 256 18 | TRAIN: 19 | EPOCHS: 300 20 | BASE_LR: 5e-4 21 | MIN_LR: 5e-6 22 | WARMUP_LR: 5e-7 23 | AUG: 24 | MIXUP: 0.8 25 | CUTMIX: 1.0 26 | -------------------------------------------------------------------------------- /configs/aff_tiny_1_5th.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: aff 3 | NAME: aff_tiny_1_5th 4 | DROP_PATH_RATE: 0.2 5 | AFF: 6 | DEPTHS: [3,4,18,5] 7 | NUM_HEADS: [2,4,8,16] 8 | MLP_RATIO: 3. 9 | EMBED_DIM: [64,128,256,512] 10 | CLUSTER_SIZE: 8 11 | NBHD_SIZE: [48,48,48,49] 12 | ALPHA: 4.0 13 | DS_RATE: 0.2 14 | DATA: 15 | DATASET: imagenet 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 256 18 | TRAIN: 19 | EPOCHS: 300 20 | BASE_LR: 5e-4 21 | MIN_LR: 5e-6 22 | WARMUP_LR: 5e-7 23 | AUG: 24 | MIXUP: 0.8 25 | CUTMIX: 1.0 26 | -------------------------------------------------------------------------------- /create_env.sh: -------------------------------------------------------------------------------- 1 | # Create a conda virtual environment and activate it 2 | conda create -n aff python=3.8 3 | conda activate aff 4 | 5 | # Install requirements 6 | pip install \ 7 | yacs==0.1.8 \ 8 | termcolor==2.2.0 \ 9 | timm==0.6.12 \ 10 | pykeops==2.1.1 \ 11 | ptflops==0.6.9 \ 12 | numpy==1.22.4 13 | conda install -c conda-forge opencv 14 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.6 -c pytorch -c conda-forge 15 | 16 | # Install the custom CUDA kernels for AFF 17 | cd clusten/src/ && python setup.py install 18 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .build import build_loader 9 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | import os 10 | import torch 11 | import numpy as np 12 | import utils 13 | import torch.distributed as dist 14 | from torchvision import datasets, transforms 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.data import Mixup 17 | from timm.data import create_transform 18 | from timm.data.transforms import _str_to_pil_interpolation 19 | 20 | from .samplers import SubsetRandomSampler 21 | 22 | 23 | def build_loader(config): 24 | config.defrost() 25 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 26 | config.freeze() 27 | print(f"local rank {config.LOCAL_RANK} / global rank {utils.get_rank()} successfully build train dataset") 28 | dataset_val, _ = build_dataset(is_train=False, config=config) 29 | print(f"local rank {config.LOCAL_RANK} / global rank {utils.get_rank()} successfully build val dataset") 30 | 31 | num_tasks = dist.get_world_size() 32 | global_rank = utils.get_rank() 33 | sampler_train = torch.utils.data.DistributedSampler( 34 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 35 | ) 36 | 37 | indices = np.arange(utils.get_rank(), len(dataset_val), dist.get_world_size()) 38 | sampler_val = SubsetRandomSampler(indices) 39 | 40 | data_loader_train = torch.utils.data.DataLoader( 41 | dataset_train, sampler=sampler_train, 42 | batch_size=config.DATA.BATCH_SIZE, 43 | num_workers=config.DATA.NUM_WORKERS, 44 | pin_memory=config.DATA.PIN_MEMORY, 45 | drop_last=True, 46 | ) 47 | 48 | data_loader_val = torch.utils.data.DataLoader( 49 | dataset_val, sampler=sampler_val, 50 | batch_size=config.DATA.BATCH_SIZE, 51 | shuffle=False, 52 | num_workers=config.DATA.NUM_WORKERS, 53 | pin_memory=config.DATA.PIN_MEMORY, 54 | drop_last=False 55 | ) 56 | 57 | # setup mixup / cutmix 58 | mixup_fn = None 59 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 60 | if mixup_active: 61 | mixup_fn = Mixup( 62 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 63 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 64 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 65 | 66 | return data_loader_train, data_loader_val, mixup_fn 67 | 68 | 69 | def build_dataset(is_train, config): 70 | transform = build_transform(is_train, config) 71 | if config.DATA.DATASET == 'imagenet': 72 | prefix = 'training' if is_train else 'validation' 73 | root = os.path.join(config.DATA.DATA_PATH, prefix) 74 | dataset = datasets.ImageFolder(root, transform=transform) 75 | nb_classes = 1000 76 | else: 77 | raise NotImplementedError("We only support ImageNet now.") 78 | 79 | return dataset, nb_classes 80 | 81 | 82 | def build_transform_imagenet(is_train, config): 83 | resize_im = config.DATA.IMG_SIZE > 32 84 | if is_train: 85 | # this should always dispatch to transforms_imagenet_train 86 | transform = create_transform( 87 | input_size=config.DATA.IMG_SIZE, 88 | is_training=True, 89 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 90 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 91 | re_prob=config.AUG.REPROB, 92 | re_mode=config.AUG.REMODE, 93 | re_count=config.AUG.RECOUNT, 94 | interpolation=config.DATA.INTERPOLATION, 95 | ) 96 | if not resize_im: 97 | # replace RandomResizedCropAndInterpolation with 98 | # RandomCrop 99 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 100 | return transform 101 | 102 | t = [] 103 | if resize_im: 104 | if config.TEST.CROP: 105 | size = int((256 / 224) * config.DATA.IMG_SIZE) 106 | t.append( 107 | transforms.Resize(size, interpolation=_str_to_pil_interpolation[config.DATA.INTERPOLATION]), 108 | # to maintain same ratio w.r.t. 224 images 109 | ) 110 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 111 | else: 112 | t.append( 113 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 114 | interpolation=_str_to_pil_interpolation[config.DATA.INTERPOLATION]) 115 | ) 116 | 117 | t.append(transforms.ToTensor()) 118 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 119 | return transforms.Compose(t) 120 | 121 | 122 | def build_transform(is_train, config): 123 | if config.DATA.DATASET == 'imagenet': 124 | return build_transform_imagenet(is_train, config) 125 | else: 126 | raise NotImplementedError("We only support ImageNet now.") 127 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/demo1.png -------------------------------------------------------------------------------- /demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/demo2.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | import torch 10 | from timm.scheduler.cosine_lr import CosineLRScheduler 11 | from timm.scheduler.step_lr import StepLRScheduler 12 | from timm.scheduler.scheduler import Scheduler 13 | 14 | 15 | def build_scheduler(config, optimizer, n_iter_per_epoch): 16 | """ 17 | Options for scheduler: 18 | cosine - set learning rating using a cosine annealing schedule, as proposed in 19 | "SGDR: Stochastic Gradient Descent with Warm Restarts" 20 | linear - decays the learning rate by a linearly changing factor until total number of steps is reached 21 | step - after every decay_steps, the learning rate is updated to be lr * decay_rate 22 | """ 23 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 24 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 25 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 26 | 27 | num_epochs = config.TRAIN.EPOCHS 28 | 29 | lr_scheduler = None 30 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 31 | lr_scheduler = CosineLRScheduler( 32 | optimizer, 33 | t_initial=num_steps, 34 | # t_mul=1., 35 | lr_min=config.TRAIN.MIN_LR, 36 | warmup_lr_init=config.TRAIN.WARMUP_LR, 37 | warmup_t=warmup_steps, 38 | cycle_limit=1, 39 | t_in_epochs=False, 40 | ) 41 | cycle_length = lr_scheduler.get_cycle_length() // n_iter_per_epoch 42 | num_epochs = cycle_length + config.TRAIN.COOLDOWN_EPOCHS 43 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 44 | lr_scheduler = LinearLRScheduler( 45 | optimizer, 46 | t_initial=num_steps, 47 | lr_min_rate=0.01, 48 | warmup_lr_init=config.TRAIN.WARMUP_LR, 49 | warmup_t=warmup_steps, 50 | t_in_epochs=False, 51 | ) 52 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 53 | lr_scheduler = StepLRScheduler( 54 | optimizer, 55 | decay_t=decay_steps, 56 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 57 | warmup_lr_init=config.TRAIN.WARMUP_LR, 58 | warmup_t=warmup_steps, 59 | t_in_epochs=False, 60 | ) 61 | 62 | return lr_scheduler 63 | 64 | 65 | class LinearLRScheduler(Scheduler): 66 | def __init__(self, 67 | optimizer: torch.optim.Optimizer, 68 | t_initial: int, 69 | lr_min_rate: float, 70 | warmup_t=0, 71 | warmup_lr_init=0., 72 | t_in_epochs=True, 73 | noise_range_t=None, 74 | noise_pct=0.67, 75 | noise_std=1.0, 76 | noise_seed=42, 77 | initialize=True, 78 | ) -> None: 79 | super().__init__( 80 | optimizer, param_group_field="lr", 81 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 82 | initialize=initialize) 83 | 84 | self.t_initial = t_initial 85 | self.lr_min_rate = lr_min_rate 86 | self.warmup_t = warmup_t 87 | self.warmup_lr_init = warmup_lr_init 88 | self.t_in_epochs = t_in_epochs 89 | if self.warmup_t: 90 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 91 | super().update_groups(self.warmup_lr_init) 92 | else: 93 | self.warmup_steps = [1 for _ in self.base_values] 94 | 95 | def _get_lr(self, t): 96 | if t < self.warmup_t: 97 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 98 | else: 99 | t = t - self.warmup_t 100 | total_t = self.t_initial - self.warmup_t 101 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 102 | return lrs 103 | 104 | def get_epoch_values(self, epoch: int): 105 | if self.t_in_epochs: 106 | return self._get_lr(epoch) 107 | else: 108 | return None 109 | 110 | def get_update_values(self, num_updates: int): 111 | if not self.t_in_epochs: 112 | return self._get_lr(num_updates) 113 | else: 114 | return None 115 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | import random 15 | import copy 16 | 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | 20 | from timm.loss import SoftTargetCrossEntropy 21 | from timm.utils import accuracy, AverageMeter, ModelEmaV2 22 | 23 | from config import get_config 24 | from models import build_model 25 | from data import build_loader 26 | from lr_scheduler import build_scheduler 27 | from optimizer import build_optimizer 28 | from logger import create_logger 29 | from utils import load_checkpoint, save_checkpoint, auto_resume_helper, reduce_tensor, get_rank, init_distributed_mode, get_local_rank, get_world_size, NativeScalerWithGradNormCount 30 | 31 | torch.backends.cuda.matmul.allow_tf32 = True 32 | 33 | os.environ['TORCH_DISTRIBUTED_DEBUG'] = "INFO" 34 | 35 | 36 | def parse_option(): 37 | parser = argparse.ArgumentParser('AutoFocusFormer training and evaluation script', add_help=True) 38 | parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', ) 39 | parser.add_argument( 40 | "--opts", 41 | help="Modify config options by adding 'KEY VALUE' pairs. ", 42 | default=None, 43 | nargs='+', 44 | ) 45 | 46 | # easy config modification 47 | parser.add_argument('--batch-size', type=int, help="batch size per GPU") 48 | parser.add_argument('--epochs', type=int, help="number of epochs") 49 | parser.add_argument('--blr', type=float, help='base learning rate: absolute_lr = base_lr * total_batch_size / 512') 50 | parser.add_argument('--data-path', type=str, help='path to dataset') 51 | parser.add_argument('--resume', help='resume from checkpoint') 52 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 53 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 54 | help='root of output folder, the full path is // (default: output)') 55 | parser.add_argument('--tag', help='tag of experiment') 56 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 57 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 58 | 59 | # distributed training 60 | parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') 61 | 62 | args, unparsed = parser.parse_known_args() 63 | 64 | return args 65 | 66 | 67 | def main(config, logger): 68 | """ 69 | Initializes all components needed for training, validates the resume checkpoint, 70 | and trains the model 71 | Args: 72 | config: CfgNode object, containing training and model configs 73 | logger: logger object for logging 74 | """ 75 | 76 | # build dataloader 77 | data_loader_train, data_loader_val, mixup_fn = build_loader(config) 78 | 79 | # build model 80 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 81 | print(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 82 | model = build_model(config) 83 | model.cuda() 84 | logger.info(str(model)) 85 | 86 | # build loss scaler 87 | loss_scaler = NativeScalerWithGradNormCount(config) 88 | 89 | # build optimizer 90 | optimizer = build_optimizer(config, model) 91 | 92 | # build distributed model 93 | model_without_ddp = model 94 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) # , find_unused_parameters=True) 95 | 96 | # print model param number and flop count 97 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 98 | logger.info(f"number of params: {n_parameters}") 99 | 100 | from ptflops import get_model_complexity_info 101 | with torch.no_grad(): 102 | macs, params = get_model_complexity_info(copy.deepcopy(model_without_ddp), (config.DATA.IN_CHANS, config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), as_strings=True, print_per_layer_stat=False, verbose=True) 103 | logger.info(f"macs {macs}, params {params}") 104 | 105 | # test model throughput 106 | with torch.no_grad(): 107 | throughput(config, data_loader_val, model, logger) 108 | torch.cuda.synchronize() 109 | if config.THROUGHPUT_MODE: 110 | return 111 | 112 | # build scheduler 113 | if config.TRAIN.ACCUMULATION_STEPS > 1: 114 | lr_scheduler build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) 115 | else: 116 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 117 | 118 | # build criterion 119 | if config.AUG.MIXUP > 0.: 120 | # smoothing is handled with mixup label transform 121 | criterion = SoftTargetCrossEntropy() 122 | else: 123 | criterion = torch.nn.CrossEntropyLoss(label_smoothing=config.MODEL.LABEL_SMOOTHING) 124 | 125 | # resume from checkpoint (if applicable) 126 | max_accuracy = 0.0 127 | if config.TRAIN.AUTO_RESUME: 128 | resume_file = auto_resume_helper(config.OUTPUT) 129 | if resume_file: 130 | if config.MODEL.RESUME: 131 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 132 | config.defrost() 133 | config.MODEL.RESUME = resume_file 134 | config.freeze() 135 | logger.info(f'auto resuming from {resume_file}') 136 | else: 137 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 138 | 139 | if config.MODEL.RESUME: 140 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) 141 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 142 | logger.info(f"Accuracy of the network: {acc1:.1f}%") 143 | if config.EVAL_MODE: 144 | return 145 | # EMA 146 | model_ema = None 147 | if config.TRAIN.USE_EMA: 148 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 149 | model_ema = ModelEmaV2( 150 | model_without_ddp, decay=config.TRAIN.EMA_DECAY, device=None) 151 | if config.MODEL.RESUME: 152 | load_checkpoint(config, model_ema, None, None, None, logger, use_ema=True) 153 | logger.info("Validating EMA checkpoint...") 154 | acc1, acc5, loss = validate(config, data_loader_val, model_ema.module, logger) 155 | logger.info(f"Accuracy of model ema: {acc1:.1f}%") 156 | 157 | # start training 158 | num_epochs = config.TRAIN.EPOCHS + config.TRAIN.COOLDOWN_EPOCHS 159 | logger.info("Start training") 160 | start_time = time.time() 161 | for epoch in range(config.TRAIN.START_EPOCH, num_epochs): 162 | data_loader_train.sampler.set_epoch(epoch) 163 | 164 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, logger, model_ema=model_ema, total_epochs=num_epochs) 165 | if get_rank() == 0 and ((epoch+1) % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1) or epoch == 0): 166 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema=model_ema, total_epochs=num_epochs) 167 | torch.cuda.synchronize() 168 | 169 | if (epoch % config.EVAL_FREQ == 0 or epoch == (num_epochs - 1)): 170 | acc1, acc5, loss = validate(config, data_loader_val, model, logger) 171 | logger.info(f"Accuracy of the network: {acc1:.1f}%") 172 | max_accuracy = max(max_accuracy, acc1) 173 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 174 | if model_ema is not None: 175 | ema_acc1, ema_acc5, ema_loss = validate(config, data_loader_val, model_ema.module, logger) 176 | logger.info(f"Accuracy of model ema: {ema_acc1:.1f}%") 177 | else: 178 | logger.info("Not at eval epoch yet!") 179 | 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | logger.info('Training time {}'.format(total_time_str)) 183 | 184 | 185 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, logger, model_ema=None, total_epochs=None): 186 | """ 187 | Trains the model for one epoch 188 | Args: 189 | config: CfgNode object, containing training and model configs 190 | model: the model to train 191 | criterion: the criterion for computing loss 192 | data_loader: torch.utils.data.DataLoader object 193 | optimizer: optimizer for training 194 | epoch: int, current epoch 195 | mixup_fn: mixup function used for mixup augmentation 196 | lr_scheduler: learning rate scheduler 197 | loss_scaler: loss scaler, used during mixed-precision training 198 | logger: logger object for logging 199 | model_ema (optional): EMA version of the model 200 | total_epochs (optional): int, total number of epochs 201 | """ 202 | if total_epochs is None: 203 | total_epochs = config.TRAIN.EPOCHS 204 | model.train() 205 | optimizer.zero_grad() 206 | 207 | num_steps = len(data_loader) 208 | batch_time = AverageMeter() 209 | loss_meter = AverageMeter() 210 | norm_meter = AverageMeter() 211 | scaler_meter = AverageMeter() 212 | 213 | start = time.time() 214 | end = time.time() 215 | 216 | for idx, (samples, targets) in enumerate(data_loader): 217 | if mixup_fn is not None: 218 | samples, targets = mixup_fn(samples, targets) 219 | samples = samples.cuda() 220 | targets = targets.cuda() 221 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 222 | outputs = model(samples) 223 | 224 | if config.TRAIN.ACCUMULATION_STEPS <= 1: 225 | ACCUMULATION_STEPS = 1 226 | else: 227 | ACCUMULATION_STEPS = config.TRAIN.ACCUMULATION_STEPS 228 | loss = criterion(outputs, targets) 229 | loss = loss / ACCUMULATION_STEPS 230 | total_loss = loss 231 | grad_norm = loss_scaler(total_loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, 232 | parameters=model.parameters(), create_graph=False, 233 | update_grad=(idx + 1) % ACCUMULATION_STEPS == 0) 234 | if (idx + 1) % ACCUMULATION_STEPS == 0: 235 | optimizer.zero_grad() 236 | lr_scheduler.step_update((epoch * num_steps + idx) // ACCUMULATION_STEPS) 237 | if model_ema is not None: 238 | model_ema.update(model) 239 | if loss_scaler.is_enabled(): 240 | loss_scale_value = loss_scaler.state_dict()["scale"] 241 | else: 242 | loss_scale_value = 1.0 243 | 244 | torch.cuda.synchronize() 245 | 246 | loss_meter.update(loss.item() * ACCUMULATION_STEPS, targets.size(0)) 247 | if grad_norm is not None: # loss_scaler return None if not update 248 | norm_meter.update(grad_norm) 249 | scaler_meter.update(loss_scale_value) 250 | batch_time.update(time.time() - end) 251 | end = time.time() 252 | 253 | if idx % (config.PRINT_FREQ * ACCUMULATION_STEPS) == 0: 254 | lr = optimizer.param_groups[0]['lr'] 255 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 256 | etas = batch_time.avg * (num_steps - idx) 257 | logger.info( 258 | f'Train: [{epoch}/{total_epochs}][{idx}/{num_steps}]\t' 259 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 260 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 261 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 262 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 263 | f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' 264 | f'mem {memory_used:.0f}MB') 265 | del total_loss, outputs 266 | torch.cuda.empty_cache() 267 | epoch_time = time.time() - start 268 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 269 | 270 | 271 | @torch.no_grad() 272 | def validate(config, data_loader, model, logger): 273 | """ 274 | Validates the accuracy of a model loaded with pre-trained checkpoint 275 | Args: 276 | config: CfgNode object, containing training and model configs 277 | data_loader: torch.utils.data.DataLoader object 278 | model: the model to validate 279 | logger: logger object for logging 280 | """ 281 | criterion = torch.nn.CrossEntropyLoss() 282 | model.eval() 283 | 284 | batch_time = AverageMeter() 285 | loss_meter = AverageMeter() 286 | acc1_meter = AverageMeter() 287 | acc5_meter = AverageMeter() 288 | 289 | end = time.time() 290 | for idx, (images, target) in enumerate(data_loader): 291 | 292 | images = images.cuda() 293 | target = target.cuda() 294 | 295 | # compute output 296 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 297 | output = model(images) 298 | 299 | # measure accuracy and record loss 300 | loss = criterion(output, target) 301 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 302 | 303 | acc1 = reduce_tensor(acc1) 304 | acc5 = reduce_tensor(acc5) 305 | loss = reduce_tensor(loss) 306 | 307 | loss_meter.update(loss.item(), target.size(0)) 308 | acc1_meter.update(acc1.item(), target.size(0)) 309 | acc5_meter.update(acc5.item(), target.size(0)) 310 | 311 | # measure elapsed time 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if idx % config.PRINT_FREQ == 0: 316 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 317 | logger.info( 318 | f'Test: [{idx}/{len(data_loader)}]\t' 319 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 320 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 321 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 322 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 323 | f'Mem {memory_used:.0f}MB') 324 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 325 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 326 | 327 | 328 | @torch.no_grad() 329 | def throughput(config, data_loader, model, logger): 330 | """ 331 | Computes the throughput of the model averaging over 30 steps 332 | Args: 333 | config: CfgNode object, containing training and model configs 334 | data_loader: torch.utils.data.DataLoader object 335 | model: the model to test 336 | logger: logger object for logging 337 | """ 338 | model.eval() 339 | 340 | for idx, (images, _) in enumerate(data_loader): 341 | images = images.cuda(non_blocking=True) 342 | batch_size = images.shape[0] 343 | for i in range(50): 344 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 345 | model(images) 346 | torch.cuda.synchronize() 347 | logger.info("throughput averaged with 30 times") 348 | tic1 = time.time() 349 | for i in range(30): 350 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): 351 | model(images) 352 | torch.cuda.synchronize() 353 | tic2 = time.time() 354 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 355 | return 356 | 357 | 358 | def run_all(config): 359 | """ 360 | Run main() on all parallel gpus 361 | """ 362 | 363 | # initialize distributed training and get the current GPU 364 | init_distributed_mode() 365 | config.defrost() 366 | config.LOCAL_RANK = get_local_rank() 367 | config.freeze() 368 | 369 | seed = config.SEED + get_rank() 370 | print('Finished init distributed') 371 | torch.manual_seed(seed) 372 | torch.random.manual_seed(seed) 373 | np.random.seed(seed) 374 | random.seed(seed) 375 | cudnn.benchmark = True 376 | world_size = get_world_size() 377 | 378 | # linear scale the learning rate according to total batch size, may not be optimal 379 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * world_size / 512.0 380 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * world_size / 512.0 381 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * world_size / 512.0 382 | # gradient accumulation also need to scale the learning rate 383 | if config.TRAIN.ACCUMULATION_STEPS > 1: 384 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 385 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 386 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 387 | config.defrost() 388 | config.TRAIN.BASE_LR = linear_scaled_lr 389 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 390 | config.TRAIN.MIN_LR = linear_scaled_min_lr 391 | config.freeze() 392 | 393 | # create output folder 394 | os.makedirs(config.OUTPUT, exist_ok=True) 395 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=get_rank(), name=f"{config.MODEL.NAME}") 396 | print('Logger created') 397 | if get_rank() == 0: 398 | path = os.path.join(config.OUTPUT, "config.json") 399 | with open(path, "w") as f: 400 | f.write(config.dump()) 401 | logger.info(f"Full config saved to {path}") 402 | 403 | # print config 404 | logger.info(config.dump()) 405 | 406 | import pykeops 407 | import tempfile 408 | with tempfile.TemporaryDirectory() as dirname: 409 | pykeops.set_build_folder(dirname) 410 | main(config, logger) 411 | 412 | 413 | if __name__ == "__main__": 414 | args = parse_option() 415 | config = get_config(args) 416 | run_all(config) 417 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .build import build_model 9 | -------------------------------------------------------------------------------- /models/aff_transformer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from timm.models.layers import DropPath, trunc_normal_ 10 | from .point_utils import knn_keops, space_filling_cluster 11 | from clusten import CLUSTENQKFunction, CLUSTENAVFunction, CLUSTENWFFunction 12 | 13 | 14 | def build_pe_lookup(img_size): 15 | """ 16 | Pre-compute lookup table of relative positions for position embedding 17 | each entry: (delta x, delta y, distance, sin, cos) 18 | """ 19 | global rel_pos_width, table_width, pre_table 20 | rel_pos_width = img_size // 4 - 1 # 55 for input img 224 x 224, after stride-4 downsampling, max delta is 55 21 | table_width = 2 * rel_pos_width + 1 22 | 23 | pre_hs = torch.arange(table_width).float()-rel_pos_width 24 | pre_ws = torch.arange(table_width).float()-rel_pos_width 25 | pre_ys, pre_xs = torch.meshgrid(pre_hs, pre_ws) # 111 x 111 26 | 27 | dis_table = (pre_ys**2 + pre_xs**2) ** 0.5 28 | sin_table = pre_ys / dis_table 29 | cos_table = pre_xs / dis_table 30 | pre_table = torch.stack([pre_xs, pre_ys, dis_table, sin_table, cos_table], dim=2) # 111 x 111 x 5 31 | pre_table[torch.bitwise_or(pre_table.isnan(), pre_table.isinf()).nonzero(as_tuple=True)] = 0 32 | pre_table = pre_table.reshape(-1, 5) 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class ClusterAttention(nn.Module): 55 | """ 56 | Performs local attention on nearest clusters 57 | 58 | Args: 59 | dim (int): Number of input channels. 60 | num_heads (int): Number of attention heads. 61 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 62 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 63 | """ 64 | 65 | def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.): 66 | 67 | super().__init__() 68 | self.dim = dim 69 | self.pos_dim = 2 70 | self.num_heads = num_heads 71 | 72 | head_dim = dim // num_heads 73 | self.scale = head_dim ** -0.5 74 | self.q = nn.Linear(dim, dim) 75 | self.kv = nn.Linear(dim, 2*dim) 76 | self.softmax = nn.Softmax(dim=-1) 77 | 78 | self.blank_k = nn.Parameter(torch.randn(dim)) 79 | self.blank_v = nn.Parameter(torch.randn(dim)) 80 | 81 | self.pos_embed = nn.Linear(self.pos_dim+3, num_heads) 82 | 83 | self.attn_drop = nn.Dropout(attn_drop) 84 | self.proj = nn.Linear(dim, dim) 85 | self.proj_drop = nn.Dropout(proj_drop) 86 | 87 | def forward(self, feat, member_idx, cluster_mask, pe_idx, global_attn): 88 | """ 89 | Args: 90 | feat - b x n x c, token features 91 | member_idx - b x n x nbhd, token idx in each local nbhd 92 | cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid) 93 | pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table 94 | global_attn - bool, whether to perform global attention 95 | """ 96 | 97 | b, n, c = feat.shape 98 | c_ = c // self.num_heads 99 | assert c == self.dim, "dim does not accord to input" 100 | h = self.num_heads 101 | 102 | # get qkv 103 | q = self.q(feat) # b x n x c 104 | q = q * self.scale 105 | kv = self.kv(feat) # b x n x 2c 106 | 107 | # get attention 108 | if global_attn: 109 | q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3) # b x h x n x c_ 110 | kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_ 111 | key, v = kv[0], kv[1] 112 | attn = q @ key.transpose(-1, -2) # b x h x n x n 113 | mask = None 114 | else: 115 | nbhd_size = member_idx.shape[-1] 116 | m = nbhd_size 117 | q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3) 118 | kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_ 119 | key, v = kv[0], kv[1] 120 | attn = CLUSTENQKFunction.apply(q, key, member_idx) # b x h x n x m 121 | mask = cluster_mask 122 | if mask is not None: 123 | mask = mask.reshape(b, 1, n, m) 124 | 125 | # position embedding 126 | global pre_table 127 | if not pre_table.is_cuda: 128 | pre_table = pre_table.to(pe_idx.device) 129 | pe_table = self.pos_embed(pre_table) # 111 x 111 x h for img_size 224x224 130 | 131 | pe_shape = pe_idx.shape 132 | pos_embed = pe_table.gather(index=pe_idx.view(-1, 1).expand(-1, h), dim=0).reshape(*(pe_shape), h).permute(0, 3, 1, 2) 133 | 134 | attn = attn + pos_embed 135 | 136 | if mask is not None: 137 | attn = attn + (1-mask)*(-100) 138 | 139 | # blank token 140 | blank_attn = (q * self.blank_k.reshape(1, h, 1, c_)).sum(-1, keepdim=True) # b x h x n x 1 141 | attn = torch.cat([attn, blank_attn], dim=-1) 142 | attn = self.softmax(attn) 143 | attn = self.attn_drop(attn) 144 | 145 | blank_attn = attn[..., -1:] 146 | attn = attn[..., :-1] 147 | blank_v = blank_attn * self.blank_v.reshape(1, h, 1, c_) # b x h x n x c_ 148 | 149 | # aggregate v 150 | if global_attn: 151 | feat = (attn @ v).permute(0, 2, 1, 3).reshape(b, n, c) 152 | feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c) 153 | else: 154 | feat = CLUSTENAVFunction.apply(attn, v, member_idx).permute(0, 2, 1, 3).reshape(b, n, c) 155 | feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c) 156 | 157 | feat = self.proj(feat) 158 | feat = self.proj_drop(feat) 159 | 160 | return feat 161 | 162 | def extra_repr(self) -> str: 163 | return f'dim={self.dim}, num_heads={self.num_heads}' 164 | 165 | 166 | class ClusterTransformerBlock(nn.Module): 167 | r""" Cluster Transformer Block. 168 | 169 | Args: 170 | dim (int): Number of input channels. 171 | num_heads (int): Number of attention heads. 172 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 173 | drop (float, optional): Dropout rate. Default: 0.0 174 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 175 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 176 | layer_scale (float, optional): Layer scale initial parameter. Default: 0.0 177 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 178 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 179 | """ 180 | 181 | def __init__(self, dim, num_heads, 182 | mlp_ratio=2., drop=0., attn_drop=0., drop_path=0., layer_scale=0.0, 183 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 184 | super().__init__() 185 | self.dim = dim 186 | self.num_heads = num_heads 187 | self.mlp_ratio = mlp_ratio 188 | 189 | self.norm1 = norm_layer(dim) 190 | self.attn = ClusterAttention( 191 | dim, num_heads=num_heads, 192 | attn_drop=attn_drop, proj_drop=drop) 193 | 194 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 195 | self.norm2 = norm_layer(dim) 196 | mlp_hidden_dim = int(dim * mlp_ratio) 197 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 198 | 199 | # layer_scale code copied from https://github.com/SHI-Labs/Neighborhood-Attention-Transformer/blob/a2cfef599fffd36d058a5a4cfdbd81c008e1c349/classification/nat.py 200 | self.layer_scale = False 201 | if layer_scale is not None and type(layer_scale) in [int, float] and layer_scale > 0: 202 | self.layer_scale = True 203 | self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) 204 | self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) 205 | 206 | def forward(self, feat, member_idx, cluster_mask, pe_idx, global_attn): 207 | """ 208 | Args: 209 | feat - b x n x c, token features 210 | member_idx - b x n x nbhd, token idx in each local nbhd 211 | cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid) 212 | pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table 213 | global_attn - bool, whether to perform global attention 214 | """ 215 | 216 | b, n, c = feat.shape 217 | assert c == self.dim, "dim does not accord to input" 218 | 219 | shortcut = feat 220 | feat = self.norm1(feat) 221 | 222 | # cluster attention 223 | feat = self.attn(feat=feat, 224 | member_idx=member_idx, 225 | cluster_mask=cluster_mask, 226 | pe_idx=pe_idx, 227 | global_attn=global_attn) 228 | 229 | # FFN 230 | if not self.layer_scale: 231 | feat = shortcut + self.drop_path(feat) 232 | feat_mlp = self.mlp(self.norm2(feat)) 233 | feat = feat + self.drop_path(feat_mlp) 234 | else: 235 | feat = shortcut + self.drop_path(self.gamma1 * feat) 236 | feat_mlp = self.mlp(self.norm2(feat)) 237 | feat = feat + self.drop_path(self.gamma2 * feat_mlp) 238 | 239 | return feat 240 | 241 | def extra_repr(self) -> str: 242 | return f"dim={self.dim}, num_heads={self.num_heads}, " \ 243 | f"mlp_ratio={self.mlp_ratio}" 244 | 245 | 246 | class ClusterMerging(nn.Module): 247 | r""" Adaptive Downsampling. 248 | 249 | Args: 250 | dim (int): Number of input channels. 251 | out_dim (int): Number of output channels. 252 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 253 | alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0 254 | ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25 255 | reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True 256 | """ 257 | 258 | def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm, alpha=4.0, ds_rate=0.25, reserve_on=True): 259 | super().__init__() 260 | self.dim = dim 261 | self.pos_dim = 2 262 | self.alpha = alpha 263 | self.ds_rate = ds_rate 264 | self.reserve_on = reserve_on 265 | 266 | # pointconv 267 | inner_ch = 4 268 | self.weight_net = nn.Sequential( 269 | nn.Linear(self.pos_dim+3, inner_ch, bias=True), 270 | nn.LayerNorm(inner_ch), 271 | nn.GELU() 272 | ) 273 | self.norm = norm_layer(inner_ch*dim) 274 | self.linear = nn.Linear(dim*inner_ch, out_dim) 275 | 276 | def forward(self, pos, feat, member_idx, cluster_mask, learned_prob, stride, pe_idx, reserve_num): 277 | """ 278 | Args: 279 | pos - b x n x 2, token positions 280 | feat - b x n x c, token features 281 | member_idx - b x n x nbhd, token idx in each local nbhd 282 | cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid) 283 | learned_prob - b x n x 1, learned importance scores 284 | stride - int, "stride" of the current feature map, 2,4,8 for the 3 stages respectively 285 | pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table 286 | reserve_num - int, number of tokens to be reserved 287 | """ 288 | 289 | b, n, c = feat.shape 290 | d = pos.shape[2] 291 | 292 | keep_num = int(n*self.ds_rate) 293 | 294 | # grid prior 295 | if stride == 2: # no ada ds yet, no need ada grid 296 | grid_prob = ((pos % stride).sum(-1) == 0).float() # b x n 297 | else: 298 | _, min_dist = knn_keops(pos, pos, 2, return_dist=True) # b x n x 2 299 | min_dist = min_dist[:, :, 1] # b x n 300 | ada_stride = 2**(min_dist.log2().ceil()+1) # b x n 301 | grid_prob = ((pos.long() % ada_stride.unsqueeze(2).long()).sum(-1) == 0).float() # b x n 302 | 303 | final_prob = grid_prob 304 | 305 | # add importance score 306 | if learned_prob is not None: 307 | lp = learned_prob.detach().view(b, n) 308 | lp = lp * self.alpha 309 | final_prob = final_prob + lp 310 | 311 | # reserve points on a coarse grid 312 | if self.reserve_on: 313 | reserve_mask = ((pos % (stride*2)).sum(-1) == 0).float() # b x n 314 | final_prob = final_prob + (reserve_mask*(-100)) 315 | sample_num = keep_num - reserve_num 316 | else: 317 | sample_num = keep_num 318 | 319 | # select topk tokens as merging centers 320 | sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1] # b x n_ 321 | 322 | if self.reserve_on: 323 | reserve_idx = reserve_mask.nonzero(as_tuple=True)[1].reshape(b, reserve_num) 324 | idx = torch.cat([sample_idx, reserve_idx], dim=-1).unsqueeze(2) # b x n_ x 1 325 | else: 326 | idx = sample_idx.unsqueeze(2) 327 | 328 | n = idx.shape[1] 329 | assert n == keep_num, "n not equal to keep num!" 330 | 331 | # gather pos, nbhd, nbhd position embedding, nbhd importance scores for topk merging locations 332 | pos = pos.gather(index=idx.expand(-1, -1, d), dim=1) # b x n' x d 333 | 334 | nbhd_size = member_idx.shape[-1] 335 | member_idx = member_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m 336 | pe_idx = pe_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m 337 | if cluster_mask is not None: 338 | cluster_mask = cluster_mask.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m 339 | if learned_prob is not None: 340 | lp = learned_prob.gather(index=member_idx.view(b, -1, 1), dim=1).reshape(b, n, nbhd_size, 1) # b x n x m x 1 341 | 342 | # pointconv weights 343 | global pre_table 344 | if not pre_table.is_cuda: 345 | pre_table = pre_table.to(pe_idx.device) 346 | weights_table = self.weight_net(pre_table) # 111 x 111 x ic 347 | 348 | weight_shape = pe_idx.shape 349 | inner_ch = weights_table.shape[-1] 350 | weights = weights_table.gather(index=pe_idx.view(-1, 1).expand(-1, inner_ch), dim=0).reshape(*(weight_shape), inner_ch) 351 | 352 | if learned_prob is not None: 353 | if cluster_mask is not None: 354 | lp = lp * cluster_mask.unsqueeze(3) 355 | weights = weights * lp 356 | else: 357 | if cluster_mask is not None: 358 | weights = weights * cluster_mask.unsqueeze(3) 359 | 360 | # merge features 361 | feat = CLUSTENWFFunction.apply(weights, feat, member_idx.view(b, n, -1)).reshape(b, n, -1) # b x n x ic*c 362 | feat = self.norm(feat) 363 | feat = self.linear(feat) # b x n x 2c 364 | 365 | return pos, feat 366 | 367 | 368 | class BasicLayer(nn.Module): 369 | """ AutoFocusFormer layer for one stage. 370 | 371 | Args: 372 | dim (int): Number of input channels. 373 | out_dim (int): Number of output channels. 374 | cluster_size (int): Cluster size. 375 | nbhd_size (int): Neighbor size. If larger than or equal to number of tokens, perform global attention; 376 | otherwise, rounded to the nearest multiples of cluster_size. 377 | depth (int): Number of blocks. 378 | num_heads (int): Number of attention heads. 379 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 380 | alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0 381 | ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25 382 | reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True 383 | drop (float, optional): Dropout rate. Default: 0.0 384 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 385 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 386 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 387 | layer_scale (float, optional): Layer scale initial parameter. Default: 0.0 388 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 389 | """ 390 | 391 | def __init__(self, dim, out_dim, cluster_size, nbhd_size, 392 | depth, num_heads, mlp_ratio, 393 | alpha=4.0, ds_rate=0.25, reserve_on=True, 394 | drop=0., attn_drop=0., 395 | drop_path=0., norm_layer=nn.LayerNorm, 396 | layer_scale=0.0, downsample=None): 397 | 398 | super().__init__() 399 | self.dim = dim 400 | self.nbhd_size = nbhd_size 401 | self.cluster_size = cluster_size 402 | self.depth = depth 403 | 404 | # build blocks 405 | self.blocks = nn.ModuleList([ 406 | ClusterTransformerBlock(dim=dim, 407 | num_heads=num_heads, 408 | mlp_ratio=mlp_ratio, 409 | drop=drop, attn_drop=attn_drop, 410 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 411 | layer_scale=layer_scale, 412 | norm_layer=norm_layer) 413 | for i in range(depth)]) 414 | 415 | # merging layer 416 | if downsample is not None: 417 | self.downsample = downsample(dim=dim, out_dim=out_dim, norm_layer=norm_layer, alpha=alpha, ds_rate=ds_rate, reserve_on=reserve_on) 418 | else: 419 | self.downsample = None 420 | 421 | # cache the clustering result for the first feature map since it is on grid 422 | self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = None, None, None, None, None 423 | 424 | # fc for importance scores 425 | if downsample is not None: 426 | self.prob_net = nn.Linear(dim, 1) 427 | 428 | def forward(self, pos, feat, h, w, on_grid, stride): 429 | """ 430 | Args: 431 | pos - b x n x 2, token positions 432 | feat - b x n x c, token features 433 | h,w - max height and width of token positions 434 | on_grid - bool, whether the tokens are still on grid; True for the first feature map 435 | stride - int, "stride" of the current token set; starts with 2, then doubles in each stage 436 | """ 437 | b, n, d = pos.shape 438 | c = feat.shape[2] 439 | assert self.cluster_size > 0, 'self.cluster_size must be positive' 440 | 441 | if self.nbhd_size >= n: 442 | global_attn = True 443 | member_idx, cluster_mask = None, None 444 | else: 445 | global_attn = False 446 | k = int(math.ceil(n / float(self.cluster_size))) # number of clusters 447 | nnc = min(int(round(self.nbhd_size / float(self.cluster_size))), k) # number of nearest clusters 448 | nbhd_size = self.cluster_size * nnc 449 | self.nbhd_size = nbhd_size # if not global attention, then nbhd size is rounded to nearest multiples of cluster 450 | 451 | if global_attn: 452 | rel_pos = (pos[:, None, :, :]+rel_pos_width) - pos[:, :, None, :] # b x n x n x d 453 | else: 454 | if k == n: 455 | # if number of clusters equal to number of tokens 456 | cluster_mean_pos = pos 457 | member_idx = torch.arange(n, device=feat.device).long().reshape(1, n, 1).expand(b, -1, -1) # b x n x 1 458 | cluster_mask = None 459 | else: 460 | # perform clustering 461 | if on_grid: 462 | if self.cluster_mean_pos is None: 463 | self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False) 464 | pos, cluster_mean_pos, member_idx, cluster_mask = self.pos[:b], self.cluster_mean_pos[:b], self.member_idx[:b], self.cluster_mask 465 | # reorder the tokens so that tokens in same cluster are stored together 466 | feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), self.reorder[:b].view(-1)].reshape(b, n, c) 467 | if cluster_mask is not None: 468 | cluster_mask = cluster_mask[:b] 469 | else: 470 | pos, cluster_mean_pos, member_idx, cluster_mask, reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False) 471 | # reorder the tokens so that tokens in same cluster are stored together 472 | feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), reorder.view(-1)].reshape(b, n, c) 473 | 474 | assert member_idx.shape[1] == k and member_idx.shape[2] == self.cluster_size, "member_idx shape incorrect!" 475 | 476 | nearest_cluster = knn_keops(pos, cluster_mean_pos, nnc) # b x n x nnc 477 | 478 | # collect neighbor indices from nearest clusters 479 | m = self.cluster_size 480 | member_idx = member_idx.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size) # b x n x nnc*m 481 | if cluster_mask is not None: 482 | cluster_mask = cluster_mask.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size) 483 | pos_ = pos.gather(index=member_idx.view(b, -1, 1).expand(-1, -1, d), dim=1).reshape(b, n, nbhd_size, d) 484 | rel_pos = pos_ - (pos.unsqueeze(2)-rel_pos_width) # b x n x nbhd_size x d 485 | 486 | # compute indices in the position embedding lookup table 487 | pe_idx = (rel_pos[..., 1] * table_width + rel_pos[..., 0]).long() 488 | 489 | for i_blk in range(len(self.blocks)): 490 | blk = self.blocks[i_blk] 491 | feat = blk(feat=feat, 492 | member_idx=member_idx, 493 | cluster_mask=cluster_mask, 494 | pe_idx=pe_idx, 495 | global_attn=global_attn) 496 | 497 | if self.downsample is not None: 498 | learned_prob = self.prob_net(feat).sigmoid() # b x n x 1 499 | reserve_num = math.ceil(h/(stride*2)) * math.ceil(w/(stride*2)) 500 | pos, feat = self.downsample(pos=pos, feat=feat, 501 | member_idx=member_idx, cluster_mask=cluster_mask, 502 | learned_prob=learned_prob, stride=stride, 503 | pe_idx=pe_idx, reserve_num=reserve_num) 504 | 505 | return pos, feat 506 | 507 | def extra_repr(self) -> str: 508 | return f"dim={self.dim}, depth={self.depth}" 509 | 510 | 511 | class PatchEmbed(nn.Module): 512 | r""" Image to Patch Embedding 513 | 514 | Args: 515 | in_chans (int): Number of input image channels. Default: 3. 516 | embed_dim (int): Number of channels. Default: 32. 517 | norm_layer (nn.Module, optional): Normalization layer. Default: None 518 | """ 519 | 520 | def __init__(self, in_chans=3, embed_dim=32, norm_layer=None): 521 | super().__init__() 522 | 523 | self.in_chans = in_chans 524 | self.embed_dim = embed_dim 525 | 526 | self.proj1 = nn.Conv2d(in_chans, embed_dim//2, kernel_size=3, stride=2, padding=1) 527 | self.bn = nn.BatchNorm2d(embed_dim//2) 528 | self.act1 = nn.GELU() 529 | self.proj2 = nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1) 530 | 531 | if norm_layer is not None: 532 | self.norm = norm_layer(embed_dim) 533 | else: 534 | self.norm = None 535 | 536 | def forward(self, x): 537 | """ 538 | Args: 539 | x - b x c x h x w, input imgs 540 | """ 541 | 542 | x = self.proj2(self.act1(self.bn(self.proj1(x)))) 543 | b, c, h, w = x.shape 544 | x = x.flatten(2).transpose(1, 2) # b x n x c 545 | if self.norm is not None: 546 | x = self.norm(x) 547 | 548 | hs = torch.arange(0, h, device=x.device) 549 | ws = torch.arange(0, w, device=x.device) 550 | ys, xs = torch.meshgrid(hs, ws) 551 | pos = torch.stack([xs, ys], dim=2).unsqueeze(0).expand(b, -1, -1, -1).reshape(b, -1, 2).to(x.dtype) 552 | 553 | return pos, x, h, w 554 | 555 | 556 | class AutoFocusFormer(nn.Module): 557 | """ 558 | 559 | Args: 560 | in_chans (int): Number of input image channels. Default: 3 561 | num_classes (int): Number of classes for classification head. Default: 1000 562 | embed_dim (tuple(int)): Feature dimension of each stage. Default: [32,128,256,512] 563 | cluster_size (int): Cluster size. Default: 8 564 | nbhd_size (tuple(int)): Neighborhood size of local attention of each stage. Default: [48,48,48,49] 565 | alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0 566 | ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25 567 | reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True 568 | depths (tuple(int)): Depth of each AFF layer. 569 | num_heads (tuple(int)): Number of attention heads in different layers. 570 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2.0 571 | drop_rate (float): Dropout rate. Default: 0 572 | attn_drop_rate (float): Attention dropout rate. Default: 0 573 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 574 | norm_layer (nn.Module): Normalization layer. 575 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 576 | layer_scale (float, optional): Layer scale initial parameter; turned off if 0.0. Default: 0.0 577 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. 578 | """ 579 | 580 | def __init__(self, in_chans=3, num_classes=1000, embed_dim=[32, 128, 256, 512], 581 | cluster_size=8, nbhd_size=[48, 48, 48, 49], 582 | alpha=4.0, ds_rate=0.25, reserve_on=True, 583 | depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 584 | mlp_ratio=2., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 585 | norm_layer=nn.LayerNorm, patch_norm=True, 586 | layer_scale=0.0, 587 | downsample=ClusterMerging, 588 | img_size=224, 589 | **kwargs): 590 | super().__init__() 591 | 592 | self.num_classes = num_classes 593 | self.num_layers = len(depths) 594 | self.embed_dim = embed_dim 595 | self.patch_norm = patch_norm 596 | self.num_features = embed_dim[-1] 597 | self.mlp_ratio = mlp_ratio 598 | 599 | self.patch_embed = PatchEmbed( 600 | in_chans=in_chans, embed_dim=embed_dim[0], 601 | norm_layer=norm_layer if self.patch_norm else None) 602 | 603 | self.pos_drop = nn.Dropout(p=drop_rate) 604 | 605 | build_pe_lookup(img_size) 606 | 607 | # stochastic depth 608 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 609 | 610 | # build layers 611 | self.layers = nn.ModuleList() 612 | for i_layer in range(self.num_layers): 613 | layer = BasicLayer(dim=int(embed_dim[i_layer]), 614 | out_dim=int(embed_dim[i_layer+1]) if (i_layer < self.num_layers - 1) else None, 615 | cluster_size=cluster_size, 616 | nbhd_size=nbhd_size[i_layer], 617 | depth=depths[i_layer], 618 | num_heads=num_heads[i_layer], 619 | mlp_ratio=self.mlp_ratio, 620 | alpha=alpha, 621 | ds_rate=ds_rate, 622 | reserve_on=reserve_on, 623 | drop=drop_rate, attn_drop=attn_drop_rate, 624 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 625 | norm_layer=norm_layer, 626 | downsample=downsample if (i_layer < self.num_layers - 1) else None, 627 | layer_scale=layer_scale) 628 | self.layers.append(layer) 629 | 630 | self.norm = norm_layer(self.num_features) 631 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 632 | 633 | self.apply(self._init_weights) 634 | 635 | def _init_weights(self, m): 636 | if isinstance(m, nn.Linear): 637 | trunc_normal_(m.weight, std=.02) 638 | if isinstance(m, nn.Linear) and m.bias is not None: 639 | nn.init.constant_(m.bias, 0) 640 | elif isinstance(m, nn.LayerNorm): 641 | nn.init.constant_(m.bias, 0) 642 | nn.init.constant_(m.weight, 1.0) 643 | 644 | @torch.jit.ignore 645 | def no_weight_decay(self): 646 | return {} 647 | 648 | @torch.jit.ignore 649 | def no_weight_decay_keywords(self): 650 | return {} 651 | 652 | def forward_features(self, x): 653 | ''' 654 | x - b x c x h x w 655 | ''' 656 | pos, x, h, w = self.patch_embed(x) # b x n x c, b x n x d 657 | x = self.pos_drop(x) 658 | 659 | for i_layer in range(len(self.layers)): 660 | layer = self.layers[i_layer] 661 | pos, x = layer(pos, x, h=h, w=w, on_grid=i_layer == 0, stride=2**(i_layer+1)) 662 | 663 | x = self.norm(x) # b x n x c 664 | x = x.mean(1) 665 | return x 666 | 667 | def forward(self, x): 668 | x = self.forward_features(x) 669 | x = self.head(x) 670 | return x 671 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | from .aff_transformer import AutoFocusFormer 10 | 11 | 12 | def build_model(config): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'aff': 15 | model = AutoFocusFormer(in_chans=config.DATA.IN_CHANS, 16 | num_classes=config.MODEL.NUM_CLASSES, 17 | embed_dim=config.MODEL.AFF.EMBED_DIM, 18 | cluster_size=config.MODEL.AFF.CLUSTER_SIZE, 19 | nbhd_size=config.MODEL.AFF.NBHD_SIZE, 20 | alpha=config.MODEL.AFF.ALPHA, 21 | ds_rate=config.MODEL.AFF.DS_RATE, 22 | reserve_on=config.MODEL.AFF.RESERVE, 23 | depths=config.MODEL.AFF.DEPTHS, 24 | num_heads=config.MODEL.AFF.NUM_HEADS, 25 | mlp_ratio=config.MODEL.AFF.MLP_RATIO, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | patch_norm=config.MODEL.AFF.PATCH_NORM, 29 | layer_scale=config.MODEL.AFF.LAYER_SCALE, 30 | img_size=config.DATA.IMG_SIZE) 31 | else: 32 | raise NotImplementedError(f"Unkown model: {model_type}") 33 | 34 | return model 35 | -------------------------------------------------------------------------------- /models/point_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import math 7 | import torch 8 | 9 | 10 | def points2img(pos, pixel, h, w): 11 | """ 12 | Scatter tokens onto a canvas of size h x w 13 | Args: 14 | pos - b x n x 2, position of tokens, should be valid indices in the canvas 15 | pixel - b x n x c, feature of tokens 16 | h,w - int, height and width of the canvas 17 | Returns: 18 | img - b x c x h x w, the resulting grid img; blank spots filled with 0 19 | """ 20 | b, n, c = pixel.shape 21 | img = torch.zeros(b, h*w, c, device=pos.device).to(pixel.dtype) 22 | idx = (pos[:, :, 1]*w+pos[:, :, 0]).long().unsqueeze(2).expand(-1, -1, c) # b x n x c 23 | img = img.scatter(src=pixel, index=idx, dim=1) 24 | return img.permute(0, 2, 1).reshape(b, c, h, w) 25 | 26 | 27 | def knn_keops(query, database, k, return_dist=False): 28 | """ 29 | Compute k-nearest neighbors using the Keops library 30 | Backward pass turned off; Keops does not provide backward pass for distance 31 | Args: 32 | query - b x n_ x c, the position of tokens looking for knn 33 | database - b x n x c, the candidate tokens for knn 34 | k - int, the nunmber of neighbors to be found 35 | return_dist - bool, whether to return distance to the neighbors 36 | Returns: 37 | nn_dix - b x n x k, the indices of the knn 38 | nn_dist - b x n x k, if return_dist, the distance to the knn 39 | """ 40 | b, n, c = database.shape 41 | with torch.no_grad(): 42 | query = query.detach() 43 | database = database.detach() 44 | # Keops does not support half precision 45 | if query.dtype != torch.float32: 46 | query = query.to(torch.float32) 47 | if database.dtype != torch.float32: 48 | database = database.to(torch.float32) 49 | from pykeops.torch import LazyTensor 50 | query_ = LazyTensor(query[:, None, :, :]) 51 | database_ = LazyTensor(database[:, :, None, :]) 52 | dist = ((query_-database_) ** 2).sum(-1) ** 0.5 # b x n x n_ 53 | if return_dist: 54 | nn_dist, nn_idx = dist.Kmin_argKmin(k, dim=1) # b x n_ x k 55 | return nn_idx, nn_dist 56 | else: 57 | nn_idx = dist.argKmin(k, dim=1) # b x n_ x k 58 | return nn_idx 59 | 60 | 61 | def space_filling_cluster(pos, m, h, w, no_reorder=False, sf_type='', use_anchor=True): 62 | """ 63 | The balanced clustering algorithm based on space-filling curves 64 | In the case where number of tokens not divisible by cluster size, 65 | the last cluster will have a few blank spots, indicated by the mask returned 66 | Args: 67 | pos - b x n x 2, positions of tokens 68 | m - int, target size of the clusters 69 | h,w - int, height and width 70 | no_reorder - bool, if True, return the clustering based on the original order of tokens; 71 | otherwise, reorder the tokens so that the same cluster stays together 72 | sf_type - str, can be 'peano' or 'hilbert', or otherwise, horizontal scanlines w/ alternating 73 | direction in each row by default 74 | use_anchor - bool, whether to use space-fiiling anchors or not; if False, directly compute 75 | space-filling curves on the token positions 76 | Returns: 77 | pos - b x n x 2, returned only if no_reorder is False; the reordered position of tokens 78 | cluster_mean_pos - b x k x 2, the clustering centers 79 | member_idx - b x k x m, the indices of tokens in each cluster 80 | cluster_mask - b x k x m, the binary mask indicating the paddings in last cluster (0 if padding) 81 | pos_ranking - b x n x 1, returned only if no_reorder is False; i-th entry is the idx of the token 82 | rank i in the new order 83 | """ 84 | with torch.no_grad(): 85 | pos = pos.detach() 86 | 87 | if pos.dtype != torch.float: 88 | pos = pos.to(torch.float) 89 | b, n, d = pos.shape 90 | 91 | k = int(math.ceil(n/m)) 92 | 93 | if use_anchor: 94 | patch_len = (h*w/k)**0.5 95 | num_patch_h = int(round(h / patch_len)) 96 | num_patch_w = int(round(w / patch_len)) 97 | patch_len_h, patch_len_w = h / num_patch_h, w / num_patch_w 98 | if sf_type == 'peano': 99 | num_patch_h = max(3, int(3**round(math.log(num_patch_h, 3)))) 100 | patch_len_h = h / num_patch_h 101 | num_patch_w = int(round(w / h * 3) * (num_patch_h / 3)) 102 | patch_len_w = w / num_patch_w 103 | elif sf_type == 'hilbert': 104 | num_patch_h = max(2, int(2**round(math.log(num_patch_h, 2)))) 105 | patch_len_h = h / num_patch_h 106 | num_patch_w = int(round(w / h * 2) * (num_patch_h / 2)) 107 | patch_len_w = w / num_patch_w 108 | hs = torch.arange(0, num_patch_h, device=pos.device) 109 | ws = torch.arange(0, num_patch_w, device=pos.device) 110 | ys, xs = torch.meshgrid(hs, ws) 111 | grid_pos = torch.stack([xs, ys], dim=2) # h x w x 2 112 | grid_pos = grid_pos.reshape(-1, 2) 113 | 114 | # sort the grid centers to one line 115 | if sf_type == 'peano': 116 | order_grid_idx, order_idx = calculate_peano_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0)) 117 | order_grid_idx = order_grid_idx[0] 118 | order_idx = order_idx[0] 119 | elif sf_type == 'hilbert': 120 | order_grid_idx, order_idx = calculate_hilbert_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0)) 121 | order_grid_idx = order_grid_idx[0] 122 | order_idx = order_idx[0] 123 | else: 124 | order_mask = torch.ones_like(ys) # h x w 125 | order_mask[1::2] = -1 126 | order_mask = order_mask * xs 127 | order_mask = order_mask + ys*w 128 | order_mask[1::2] += (w-1) 129 | order_mask = order_mask.reshape(-1) 130 | order_idx = order_mask.sort()[1] 131 | order_idx_src = torch.arange(len(order_idx)).to(pos.device) 132 | order_grid_idx = torch.zeros_like(order_idx_src) 133 | order_grid_idx.scatter_(index=order_idx, dim=0, src=order_idx_src) 134 | 135 | ordered_grid = grid_pos[order_idx] 136 | patch_len_hw = torch.Tensor([patch_len_w, patch_len_h]).to(pos.device) 137 | 138 | init_pos_means = ordered_grid * patch_len_hw + patch_len_hw/2 - 0.5 139 | nump = ordered_grid.shape[0] 140 | 141 | prev_means = torch.zeros_like(init_pos_means) 142 | prev_means[1:] = init_pos_means[:nump-1].clone() 143 | prev_means[0] = prev_means[1] - (prev_means[2]-prev_means[1]) # float('inf') 144 | next_means = torch.zeros_like(init_pos_means) 145 | next_means[:nump-1] = init_pos_means[1:].clone() 146 | next_means[-1] = next_means[-2] + (next_means[-2]-next_means[-3]) # float('inf') 147 | 148 | mean_assignment = (pos / patch_len_hw).floor() 149 | mean_assignment = mean_assignment[..., 0] + mean_assignment[..., 1] * num_patch_w 150 | mean_assignment = order_grid_idx.unsqueeze(0).expand(b, -1).gather(index=mean_assignment.long(), dim=1).unsqueeze(2) # b x n x 1 151 | 152 | prev_mean_assign = prev_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d 153 | next_mean_assign = next_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d 154 | dist_prev = (pos-prev_mean_assign).pow(2).sum(-1) # b x n 155 | dist_next = (pos-next_mean_assign).pow(2).sum(-1) 156 | dist_ratio = dist_prev / (dist_next + 1e-5) 157 | 158 | pos_ranking = mean_assignment * (dist_ratio.max()+1) + dist_ratio.unsqueeze(2) 159 | pos_ranking = pos_ranking.sort(dim=1)[1] # b x n x 1 160 | 161 | else: 162 | if sf_type == 'peano': 163 | _, pos_ranking = calculate_peano_order(h, w, pos) 164 | elif sf_type == 'hilbert': 165 | _, pos_ranking = calculate_hilbert_order(h, w, pos) 166 | else: 167 | hs = torch.arange(0, h, device=pos.device) 168 | ws = torch.arange(0, w, device=pos.device) 169 | ys, xs = torch.meshgrid(hs, ws) 170 | order_mask = torch.ones_like(ys) # h x w 171 | order_mask[1::2] = -1 172 | order_mask = order_mask * xs 173 | order_mask = order_mask + ys*w 174 | order_mask[1::2] += (w-1) 175 | order_mask = order_mask.reshape(-1) 176 | pos_idx = pos[..., 0] + pos[..., 1] * w 177 | order_mask = order_mask.gather(index=pos_idx.long().reshape(-1), dim=0).reshape(b, n) 178 | pos_ranking = order_mask.sort()[1] 179 | pos_ranking = pos_ranking.unsqueeze(2) 180 | 181 | pos = pos.gather(index=pos_ranking.expand(-1, -1, d), dim=1) # b x n x d 182 | 183 | if k*m == n: 184 | cluster_mask = None 185 | cluster_mean_pos = pos.reshape(b, k, -1, d).mean(2) 186 | else: 187 | pos_pad = torch.zeros(b, k*m, d, dtype=pos.dtype, device=pos.device) 188 | pos_pad[:, :n] = pos.clone() 189 | cluster_mask = torch.zeros(b, k*m, device=pos.device).long() 190 | cluster_mask[:, :n] = 1 191 | cluster_mask = cluster_mask.reshape(b, k, m) 192 | cluster_mean_pos = pos_pad.reshape(b, k, -1, d).sum(2) / cluster_mask.sum(2, keepdim=True) 193 | 194 | if no_reorder: 195 | if k*m == n: 196 | member_idx = pos_ranking.reshape(b, k, m) 197 | else: 198 | member_idx = torch.zeros(b, k*m, device=pos.device, dtype=torch.int64) 199 | member_idx[:, :n] = pos_ranking.squeeze(2) 200 | member_idx = member_idx.reshape(b, k, m) 201 | return cluster_mean_pos, member_idx, cluster_mask 202 | else: 203 | member_idx = torch.arange(k*m, device=pos.device) 204 | member_idx[n:] = 0 205 | member_idx = member_idx.unsqueeze(0).expand(b, -1) # b x k*m 206 | member_idx = member_idx.reshape(b, k, m) 207 | 208 | return pos, cluster_mean_pos, member_idx, cluster_mask, pos_ranking 209 | 210 | 211 | def calculate_peano_order(h, w, pos): 212 | """ 213 | Given height and width of the canvas and position of tokens, 214 | calculate the peano curve order of the tokens 215 | Args: 216 | h,w - int, height and width 217 | pos - b x n x 2, positions of tokens 218 | Returns: 219 | final_order_ - b x n, i-th entry is the rank of i-th token in the new order 220 | final_order_index - b x n, i-th entry is the idx of the token rank i in the new order 221 | """ 222 | b, n, _ = pos.shape 223 | num_levels = math.ceil(math.log(h, 3)) 224 | assert num_levels >= 1, "h too short" 225 | first_w = None 226 | if h != w: 227 | first_w = round(3 * (w/h)) 228 | if first_w == 3: 229 | first_w = None 230 | init_dict = torch.Tensor([[2, 3, 8], [1, 4, 7], [0, 5, 6]]).to(pos.device) 231 | inverse_dict = torch.Tensor([[[1, 1], [1, -1], [1, 1]], [[-1, 1], [-1, -1], [-1, 1]], [[1, 1], [1, -1], [1, 1]]]).to(pos.device) 232 | if first_w is not None: 233 | init_dict_flip = init_dict.flip(dims=[0]) 234 | init_dict_f = torch.cat([init_dict, init_dict_flip], dim=1) # 3 x 6 235 | init_dict_f = init_dict_f.repeat(1, math.ceil(first_w/6)) 236 | init_dict_f = init_dict_f[:, :first_w] # 3 x fw 237 | w_index = torch.arange(math.ceil(first_w/3)).to(pos.device).repeat_interleave(3)[:first_w] * 9 # fw 238 | init_dict_f = init_dict_f + w_index 239 | init_dict_f = init_dict_f.reshape(-1) # 3*fw 240 | inverse_dict_f = inverse_dict[:, :2].repeat(1, math.ceil(first_w/2), 1)[:, :first_w] # 3 x fw x 2 241 | inverse_dict_f = inverse_dict_f.reshape(-1, 2) 242 | init_dict = init_dict.reshape(-1) # 9 243 | inverse_dict = inverse_dict.reshape(-1, 2) # 9 x 2 244 | last_h = h 245 | rem_pos = pos 246 | levels_pos = [] 247 | for le in range(num_levels): 248 | cur_h = last_h / 3 249 | level_pos = (rem_pos / cur_h).floor() 250 | levels_pos.append(level_pos) 251 | rem_pos = rem_pos % cur_h 252 | last_h = cur_h 253 | orders = [] 254 | for i in range(len(levels_pos)): 255 | inverse = torch.ones_like(pos) # b x n x 2 256 | for j in range(i): 257 | cur_level_pos = levels_pos[i-j-1] 258 | if i-j-1 == 0 and first_w is not None: 259 | cur_level_pos_index = cur_level_pos[..., 0] + cur_level_pos[..., 1] * first_w # b x n 260 | cur_inverse = inverse_dict_f.gather(index=cur_level_pos_index.long().view(-1, 1).expand(-1, 2), dim=0).reshape(b, n, 2) 261 | else: 262 | cur_level_pos_index = cur_level_pos[..., 0] + cur_level_pos[..., 1] * 3 # b x n 263 | cur_inverse = inverse_dict.gather(index=cur_level_pos_index.long().view(-1, 1).expand(-1, 2), dim=0).reshape(b, n, 2) 264 | inverse = cur_inverse * inverse 265 | level_pos = levels_pos[i] 266 | inversed_pos = torch.where(inverse > 0, level_pos, 2-level_pos) 267 | if i == 0 and first_w is not None: 268 | inversed_pos_index = inversed_pos[..., 0] + inversed_pos[..., 1] * first_w # b x n 269 | cur_order = init_dict_f.gather(index=inversed_pos_index.long().view(-1), dim=0).reshape(b, n) 270 | else: 271 | inversed_pos_index = inversed_pos[..., 0] + inversed_pos[..., 1] * 3 # b x n 272 | cur_order = init_dict.gather(index=inversed_pos_index.long().view(-1), dim=0).reshape(b, n) 273 | orders.append(cur_order) 274 | final_order = orders[-1] 275 | for i in range(len(orders)-1): 276 | cur_order = orders[i] 277 | final_order = final_order + cur_order * (9**(num_levels-i-1)) 278 | final_order_index = final_order.sort(dim=1)[1] 279 | order_src = torch.arange(n).to(pos.device).unsqueeze(0).expand(b, -1) # b x n 280 | final_order_ = torch.zeros_like(order_src) 281 | final_order_.scatter_(index=final_order_index, src=order_src, dim=1) 282 | return final_order_, final_order_index 283 | 284 | 285 | def calculate_hilbert_order(h, w, pos): 286 | """ 287 | Given height and width of the canvas and position of tokens, 288 | calculate the hilber curve order of the tokens 289 | Args: 290 | h,w - int, height and width 291 | pos - b x n x 2, positions of tokens 292 | Returns: 293 | final_order_ - b x n, i-th entry is the rank of i-th token in the new order 294 | final_order_index - b x n, i-th entry is the idx of the token rank i in the new order 295 | """ 296 | b, n, _ = pos.shape 297 | num_levels = math.ceil(math.log(h, 2)) 298 | assert num_levels >= 1, "h too short" 299 | first_w = None 300 | if h != w: 301 | first_w = round(2 * (w/h)) 302 | if first_w == 2: 303 | first_w = None 304 | rotate_dict = torch.Tensor([[[-1, 1], [0, 0]], [[0, -1], [0, 1]], [[1, 0], [-1, 0]]]).to(pos.device) # 3 x 2 x 2 -1 means left, 1 means right 305 | if first_w is not None: 306 | rotate_dict_f = rotate_dict[0].repeat(1, math.ceil(first_w/2))[:, :first_w] # 2 x fw 307 | rotate_dict_f = rotate_dict_f.reshape(-1) # 2*fw 308 | rotate_dict = rotate_dict.reshape(3, -1) # 3 x 4 309 | rot_res_dict = torch.Tensor([[0, 3, 1, 2], [2, 3, 1, 0], [2, 1, 3, 0], [0, 1, 3, 2]]).to(pos.device) # 4 x 4 310 | last_h = h 311 | rem_pos = pos 312 | levels_pos = [] 313 | for le in range(num_levels): 314 | cur_h = last_h / 2 315 | level_pos = (rem_pos / cur_h).floor() 316 | levels_pos.append(level_pos) 317 | rem_pos = rem_pos % cur_h 318 | last_h = cur_h 319 | orders = [] 320 | for i in range(len(levels_pos)): 321 | level_pos = levels_pos[i] 322 | if i == 0 and first_w is not None: 323 | level_pos_index = level_pos[..., 0] + level_pos[..., 1] * first_w # b x n 324 | else: 325 | level_pos_index = level_pos[..., 0] + level_pos[..., 1] * 2 # b x n 326 | rotate = torch.zeros_like(pos[..., 0]) 327 | for j in range(i): 328 | cur_level_pos = levels_pos[j] 329 | if j == 0 and first_w is not None: 330 | cur_level_pos_index = cur_level_pos[..., 0] + cur_level_pos[..., 1] * first_w # b x n 331 | cur_rotate = rotate_dict_f.gather(index=cur_level_pos_index.long().view(-1), dim=0).reshape(b, n) 332 | else: 333 | rotate_d = rotate_dict.gather(index=(rotate % 3).long().view(-1, 1).expand(-1, 4), dim=0).reshape(b, n, 4) 334 | cur_level_pos_index = cur_level_pos[..., 0] + cur_level_pos[..., 1] * 2 # b x n 335 | cur_rotate = rotate_d.gather(index=cur_level_pos_index.long().unsqueeze(2), dim=2).reshape(b, n) 336 | rotate = cur_rotate + rotate 337 | rotate = rotate % 4 338 | rotate_res = rot_res_dict.gather(index=rotate.long().view(-1, 1).expand(-1, 4), dim=0).reshape(b, n, 4) 339 | rotate_res = rotate_res.gather(index=level_pos_index.long().unsqueeze(2), dim=2).squeeze(2) # b x n 340 | orders.append(rotate_res) 341 | final_order = orders[-1] 342 | for i in range(len(orders)-1): 343 | cur_order = orders[i] 344 | final_order = final_order + cur_order * (4**(num_levels-i-1)) 345 | final_order_index = final_order.sort(dim=1)[1] 346 | order_src = torch.arange(n).to(pos.device).unsqueeze(0).expand(b, -1) # b x n 347 | final_order_ = torch.zeros_like(order_src) 348 | final_order_.scatter_(index=final_order_index, src=order_src, dim=1) 349 | return final_order_, final_order_index 350 | -------------------------------------------------------------------------------- /models/test_cluster.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | import numpy as np 8 | from point_utils import space_filling_cluster 9 | import cv2 10 | 11 | """ 12 | Test the correctness of the space_filling_cluster function 13 | """ 14 | 15 | 16 | def display_img(img): 17 | cv2.imshow('img', img) 18 | cv2.waitKey(0) 19 | cv2.destroyAllWindows() 20 | 21 | 22 | h, w = 100, 100 # canvas size 23 | n = 2499 # number of tokens 24 | m = 20 # cluster size 25 | show_center = True 26 | 27 | hs = torch.arange(0, h) 28 | ws = torch.arange(0, w) 29 | ys, xs = torch.meshgrid(hs, ws) 30 | pos = torch.stack([xs, ys], dim=2).reshape(1, -1, 2) # 1 x hw x 2 31 | 32 | # random point cloud 33 | pos = pos[:, torch.randperm(h*w)[:n]] # 1 x n x 2 34 | 35 | # cluster_mean_pos, member_idx, cluster_mask = space_filling_cluster(pos, m, h, w, no_reorder=True) 36 | pos, cluster_mean_pos, member_idx, cluster_mask, _ = space_filling_cluster(pos, m, h, w, no_reorder=False) 37 | if show_center: 38 | cluster_mean_pos = cluster_mean_pos.round().long() 39 | k = member_idx.shape[1] # number of clusters 40 | print("n,k,m", n, k, m) 41 | if cluster_mask is not None: 42 | cluster_mask = cluster_mask.reshape(1, -1, 1) 43 | print("cluster_mask invalid indices", (cluster_mask[0, :, 0] == 0).nonzero()) 44 | 45 | cluster_idx = torch.arange(k).view(1, -1, 1).expand(-1, -1, m).reshape(1, -1, 1) # 1 x km x 1 46 | mean_assignment = torch.zeros(1, n, 1, dtype=cluster_idx.dtype) 47 | mean_assignment.scatter_(index=member_idx.reshape(1, -1, 1)[:, :n], dim=1, src=cluster_idx) 48 | 49 | colors = torch.Tensor(np.random.uniform(size=(k, 3))) 50 | ca = colors.gather(index=mean_assignment.reshape(-1, 1).expand(-1, 3), dim=0) # n x 3 51 | c = ca.shape[-1] 52 | 53 | img = torch.zeros(h*w, c) 54 | pos = pos[0] 55 | idx = (pos[:, 1]*w+pos[:, 0]).long() # n 56 | idx = idx.unsqueeze(1).expand(-1, c) 57 | img.scatter_(src=ca, index=idx, dim=0) 58 | if show_center: 59 | cluster_mean_pos = cluster_mean_pos[0] 60 | center_idx = cluster_mean_pos[:, 1]*w+cluster_mean_pos[:, 0] 61 | img[center_idx] = torch.Tensor([0, 0, 1]) # cluster centers shown as red dots 62 | 63 | img = img.reshape(h, w, c).numpy() 64 | img = img.repeat(4, axis=0).repeat(4, axis=1) 65 | # img = 1.0-img 66 | display_img(img) 67 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from torch import optim as optim 9 | 10 | 11 | def build_optimizer(config, model, ignore=[]): 12 | """ 13 | Build optimizer, set weight decay of normalization to 0 by default. 14 | """ 15 | skip = {} 16 | skip_keywords = {} 17 | if hasattr(model, 'no_weight_decay'): 18 | skip = model.no_weight_decay() 19 | if hasattr(model, 'no_weight_decay_keywords'): 20 | skip_keywords = model.no_weight_decay_keywords() 21 | 22 | parameters = set_weight_decay(model, skip, skip_keywords, ignore=ignore) 23 | 24 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 25 | optimizer = None 26 | if opt_lower == 'sgd': 27 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 28 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 29 | elif opt_lower == 'adamw': 30 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 31 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 32 | 33 | return optimizer 34 | 35 | 36 | def set_weight_decay(model, skip_list=(), skip_keywords=(), ignore=[]): 37 | has_decay = [] 38 | no_decay = [] 39 | 40 | for name, param in model.named_parameters(): 41 | if not param.requires_grad or name in ignore: 42 | continue # frozen weights 43 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 44 | check_keywords_in_name(name, skip_keywords): 45 | no_decay.append(param) 46 | # print(f"{name} has no weight decay") 47 | else: 48 | has_decay.append(param) 49 | return [{'params': has_decay}, 50 | {'params': no_decay, 'weight_decay': 0.}] 51 | 52 | 53 | def check_keywords_in_name(name, keywords=()): 54 | isin = False 55 | for keyword in keywords: 56 | if keyword in name: 57 | isin = True 58 | return isin 59 | -------------------------------------------------------------------------------- /run_aff.sh: -------------------------------------------------------------------------------- 1 | # number of gpus for data parallel 2 | GPUS=2 3 | 4 | # dataset path 5 | DATA=imagenet/ 6 | 7 | # config file path 8 | CONFIG=configs/aff_small.yaml 9 | 10 | # checkpoint path for resume 11 | RESUME=checkpoints/aff_small.pth 12 | 13 | python -m torch.distributed.launch --nproc_per_node $GPUS --master_port 12345 main.py \ 14 | --data-path $DATA \ 15 | --cfg $CONFIG \ 16 | --eval \ 17 | --resume $RESUME \ 18 | 19 | # Comment out '--eval' and '--resume' to start training from fresh. 20 | # To enlarge the effective batch size, use '--accumulation-steps'. For example, '--accumulation-steps 2' doubles the effective total batch size. 21 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Adapted for AutoFocusFormer by Ziwen 2023 8 | 9 | import builtins 10 | import datetime 11 | import os 12 | import torch 13 | import numpy as np 14 | import random 15 | import torch.distributed as dist 16 | 17 | 18 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger, use_ema=False): 19 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 20 | if config.MODEL.RESUME.startswith('https'): 21 | checkpoint = torch.hub.load_state_dict_from_url( 22 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 23 | else: 24 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 25 | if use_ema: 26 | msg = model.load_state_dict(checkpoint['model_ema'], strict=False) 27 | logger.info(msg) 28 | del checkpoint 29 | torch.cuda.empty_cache() 30 | return 31 | msg = model.load_state_dict(checkpoint['model'], strict=False) 32 | logger.info(msg) 33 | max_accuracy = 0.0 34 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 35 | optimizer.load_state_dict(checkpoint['optimizer']) 36 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 37 | loss_scaler.load_state_dict(checkpoint['loss_scaler']) 38 | config.defrost() 39 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 40 | config.freeze() 41 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 42 | if 'max_accuracy' in checkpoint: 43 | max_accuracy = checkpoint['max_accuracy'] 44 | if 'rng' in checkpoint: 45 | np.random.set_state(checkpoint['np_rng']) 46 | torch.set_rng_state(checkpoint['rng']) 47 | torch.random.set_rng_state(checkpoint['random']) 48 | random.setstate(checkpoint['prng']) 49 | 50 | del checkpoint 51 | torch.cuda.empty_cache() 52 | return max_accuracy 53 | 54 | 55 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema=None, total_epochs=None): 56 | if total_epochs is None: 57 | total_epochs = config.TRAIN.EPOCHS 58 | save_state = {'model': model.state_dict(), 59 | 'optimizer': optimizer.state_dict(), 60 | 'lr_scheduler': lr_scheduler.state_dict(), 61 | 'loss_scaler': loss_scaler.state_dict(), 62 | 'max_accuracy': max_accuracy, 63 | 'epoch': epoch, 64 | 'rng': torch.get_rng_state(), 65 | 'random': torch.random.get_rng_state(), 66 | 'np_rng': np.random.get_state(), 67 | 'prng': random.getstate()} 68 | if model_ema is not None: 69 | save_state['model_ema'] = model_ema.state_dict() 70 | 71 | save_path = os.path.join(config.OUTPUT, 'ckpt_epoch.pth') 72 | logger.info(f"{save_path} saving......") 73 | torch.save(save_state, save_path) 74 | logger.info(f"{save_path} saved !!!") 75 | if ((epoch+1) % config.SAVE_FREQ == 0 or epoch == (total_epochs - 1) or epoch == 0): 76 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 77 | torch.save(save_state, save_path) 78 | 79 | 80 | def get_grad_norm(parameters, norm_type=2): 81 | if isinstance(parameters, torch.Tensor): 82 | parameters = [parameters] 83 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 84 | norm_type = float(norm_type) 85 | total_norm = 0 86 | for p in parameters: 87 | param_norm = p.grad.data.norm(norm_type) 88 | total_norm += param_norm.item() ** norm_type 89 | total_norm = total_norm ** (1. / norm_type) 90 | return total_norm 91 | 92 | 93 | def auto_resume_helper(output_dir): 94 | checkpoints = os.listdir(output_dir) 95 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 96 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 97 | if len(checkpoints) > 0: 98 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 99 | print(f"The latest checkpoint founded: {latest_checkpoint}") 100 | resume_file = latest_checkpoint 101 | else: 102 | resume_file = None 103 | return resume_file 104 | 105 | 106 | def get_rank(): 107 | if not is_dist_avail_and_initialized(): 108 | return 0 109 | return dist.get_rank() 110 | 111 | 112 | def get_local_rank(): 113 | if not is_dist_avail_and_initialized(): 114 | return 0 115 | return dist.get_rank() 116 | 117 | 118 | def get_world_size(): 119 | if not is_dist_avail_and_initialized(): 120 | return 1 121 | return dist.get_world_size() 122 | 123 | 124 | def reduce_tensor(tensor): 125 | rt = tensor.clone() 126 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 127 | rt /= dist.get_world_size() 128 | return rt 129 | 130 | 131 | def init_distributed_mode(): 132 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 133 | rank = int(os.environ["RANK"]) 134 | world_size = int(os.environ['WORLD_SIZE']) 135 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 136 | else: 137 | rank = -1 138 | world_size = -1 139 | torch.cuda.set_device(rank) 140 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 141 | torch.distributed.barrier() 142 | 143 | 144 | def setup_for_distributed(is_master): 145 | """ 146 | This function disables printing when not in master process 147 | """ 148 | builtin_print = builtins.print 149 | 150 | def print(*args, **kwargs): 151 | force = kwargs.pop('force', False) 152 | force = force or (get_world_size() > 8) 153 | if is_master or force: 154 | now = datetime.datetime.now().time() 155 | builtin_print('[{}] '.format(now), end='') # print with time stamp 156 | builtin_print(*args, **kwargs) 157 | 158 | builtins.print = print 159 | 160 | 161 | def is_dist_avail_and_initialized(): 162 | if not dist.is_available(): 163 | return False 164 | if not dist.is_initialized(): 165 | return False 166 | return True 167 | 168 | 169 | class NativeScalerWithGradNormCount: 170 | def __init__(self, config): 171 | self._scaler = torch.cuda.amp.GradScaler(enabled=config.AMP_ENABLE) 172 | 173 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 174 | 175 | self._scaler.scale(loss).backward(create_graph=create_graph) 176 | 177 | if update_grad: 178 | if clip_grad is not None: 179 | assert parameters is not None 180 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 181 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad, error_if_nonfinite=False) 182 | else: 183 | self._scaler.unscale_(optimizer) 184 | norm = ampscaler_get_grad_norm(parameters) 185 | self._scaler.step(optimizer) 186 | self._scaler.update() 187 | else: 188 | norm = None 189 | return norm 190 | 191 | def state_dict(self): 192 | return self._scaler.state_dict() 193 | 194 | def load_state_dict(self, state_dict): 195 | self._scaler.load_state_dict(state_dict) 196 | 197 | def is_enabled(self): 198 | return self._scaler.is_enabled() 199 | --------------------------------------------------------------------------------