├── .flake8
├── .gitattributes
├── .gitignore
├── ACKNOWLEDGMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── aff.png
├── architecture.png
├── clusten
├── __init__.py
├── clusten.py
├── src
│ ├── clustenav_cuda.cpp
│ ├── clustenav_cuda_kernel.cu
│ ├── clustenqk_cuda.cpp
│ ├── clustenqk_cuda_kernel.cu
│ ├── clustenwf_cuda.cpp
│ ├── clustenwf_cuda_kernel.cu
│ └── setup.py
├── test_av_kernel.py
├── test_qk_kernel.py
└── test_wf_kernel.py
├── config.py
├── configs
├── aff_base_22k.yaml
├── aff_base_22kto1k.yaml
├── aff_base_22kto1k_384.yaml
├── aff_mini.yaml
├── aff_mini_1_5th.yaml
├── aff_small.yaml
├── aff_small_1_5th.yaml
├── aff_tiny.yaml
└── aff_tiny_1_5th.yaml
├── create_env.sh
├── data
├── __init__.py
├── build.py
└── samplers.py
├── demo1.png
├── demo2.png
├── logger.py
├── lr_scheduler.py
├── main.py
├── models
├── __init__.py
├── aff_transformer.py
├── build.py
├── point_utils.py
└── test_cluster.py
├── optimizer.py
├── run_aff.sh
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = B,C,E,F,P,T4,W,B9
3 | max-line-length = 120
4 | # C408 ignored because we like the dict keyword argument syntax
5 | # E501 is not flexible enough, we're using B950 instead
6 | ignore =
7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E303,E226,
8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying
9 | # to line this up with executable bit
10 | EXE001,
11 | # these ignores are from flake8-bugbear; please fix!
12 | B007,B008,
13 | # these ignores are from flake8-comprehensions; please fix!
14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415,
15 | # for "unable to detect undefined names"
16 | F403,
17 | # for "Too many leading '#' for block comment (E266)"
18 | E266,
19 | # for "E731 do not assign a lambda expression, use a def"
20 | E731,
21 | # for "future feature annotations is not defined"
22 | F407,
23 | # do not use bare 'except'
24 | E722,
25 | per-file-ignores = __init__.py: F401
26 | optional-ascii-coding = True
27 | exclude =
28 | ./.git,
29 | ./docs,
30 | ./scripts,
31 | ./test
32 | ./third_party,
33 | ./venv,
34 | *.pyi
35 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.pth filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.svg
2 | .nfs*
3 | .DS_Store
4 | __pycache__/
5 | *swp*
6 | output/
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
--------------------------------------------------------------------------------
/ACKNOWLEDGMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this AutoFocusFormer Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | _____________________
6 |
7 | Microsoft (Swin Transformer)
8 | MIT License
9 |
10 | Copyright (c) Microsoft Corporation.
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE
29 |
30 | SHI Lab (Neighborhood-Attention-Transformer)
31 | MIT License
32 |
33 | Copyright (c) 2022-2023 SHI Lab
34 |
35 | Permission is hereby granted, free of charge, to any person obtaining a copy
36 | of this software and associated documentation files (the "Software"), to deal
37 | in the Software without restriction, including without limitation the rights
38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
39 | copies of the Software, and to permit persons to whom the Software is
40 | furnished to do so, subject to the following conditions:
41 |
42 | The above copyright notice and this permission notice shall be included in all
43 | copies or substantial portions of the Software.
44 |
45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
51 | SOFTWARE.
52 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual
10 | identity and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the overall
26 | community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or advances of
31 | any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email address,
35 | without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
63 | All complaints will be reviewed and investigated promptly and fairly.
64 |
65 | All community leaders are obligated to respect the privacy and security of the
66 | reporter of any incident.
67 |
68 | ## Enforcement Guidelines
69 |
70 | Community leaders will follow these Community Impact Guidelines in determining
71 | the consequences for any action they deem in violation of this Code of Conduct:
72 |
73 | ### 1. Correction
74 |
75 | **Community Impact**: Use of inappropriate language or other behavior deemed
76 | unprofessional or unwelcome in the community.
77 |
78 | **Consequence**: A private, written warning from community leaders, providing
79 | clarity around the nature of the violation and an explanation of why the
80 | behavior was inappropriate. A public apology may be requested.
81 |
82 | ### 2. Warning
83 |
84 | **Community Impact**: A violation through a single incident or series of
85 | actions.
86 |
87 | **Consequence**: A warning with consequences for continued behavior. No
88 | interaction with the people involved, including unsolicited interaction with
89 | those enforcing the Code of Conduct, for a specified period of time. This
90 | includes avoiding interactions in community spaces as well as external channels
91 | like social media. Violating these terms may lead to a temporary or permanent
92 | ban.
93 |
94 | ### 3. Temporary Ban
95 |
96 | **Community Impact**: A serious violation of community standards, including
97 | sustained inappropriate behavior.
98 |
99 | **Consequence**: A temporary ban from any sort of interaction or public
100 | communication with the community for a specified period of time. No public or
101 | private interaction with the people involved, including unsolicited interaction
102 | with those enforcing the Code of Conduct, is allowed during this period.
103 | Violating these terms may lead to a permanent ban.
104 |
105 | ### 4. Permanent Ban
106 |
107 | **Community Impact**: Demonstrating a pattern of violation of community
108 | standards, including sustained inappropriate behavior, harassment of an
109 | individual, or aggression toward or disparagement of classes of individuals.
110 |
111 | **Consequence**: A permanent ban from any sort of public interaction within the
112 | community.
113 |
114 | ## Attribution
115 |
116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
117 | version 2.1, available at
118 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
119 |
120 | Community Impact Guidelines were inspired by
121 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
122 |
123 | For answers to common questions about this code of conduct, see the FAQ at
124 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
125 | [https://www.contributor-covenant.org/translations][translations].
126 |
127 | [homepage]: https://www.contributor-covenant.org
128 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
129 | [Mozilla CoC]: https://github.com/mozilla/diversity
130 | [FAQ]: https://www.contributor-covenant.org/faq
131 | [translations]: https://www.contributor-covenant.org/translations
132 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2023 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AutoFocusFormer
2 |
3 | [](CODE_OF_CONDUCT.md)
4 | [](clusten/)
5 |
6 | AFF-Base: [](https://paperswithcode.com/sota/instance-segmentation-on-cityscapes-val?p=autofocusformer-image-segmentation-off-the) [](https://paperswithcode.com/sota/panoptic-segmentation-on-cityscapes-val?p=autofocusformer-image-segmentation-off-the)
7 |
8 | This software project accompanies the research paper, *AutoFocusFormer: Image Segmentation off the Grid* (CVPR 2023).
9 |
10 | [Chen Ziwen](https://www.chenziwe.com), Kaushik Patnaik, [Shuangfei Zhai](https://scholar.google.com/citations?user=G6vdBYsAAAAJ&hl=en), [Alvin Wan](http://alvinwan.com), [Zhile Ren](https://jrenzhile.com), [Alex Schwing](https://alexander-schwing.de/), [Alex Colburn](https://www.colburn.org), [Li Fuxin](https://web.engr.oregonstate.edu/~lif/)
11 |
12 | [arXiv](https://arxiv.org/abs/2304.12406) | [video narration](https://youtu.be/i1mZtk70yGY) | [AFF-Classification (this repo)](https://github.com/apple/ml-autofocusformer) | [AFF-Segmentation](https://github.com/apple/ml-autofocusformer-segmentation)
13 |
14 | ## Introduction
15 |
16 | AutoFocusFormer (AFF) is the first **adaptive**-downsampling network capable of **dense** prediction tasks such as semantic/instance segmentation.
17 |
18 | AFF abandons the traditional grid structure of image feature maps, and automatically learns to retain the most important pixels with respect to the task goal.
19 |
20 |
21 |

22 |
23 |
24 | AFF consists of a local-attention transformer backbone and a task-specific head. The backbone consists of four stages, each stage containing three modules: balanced clustering, local-attention transformer blocks, and adaptive downsampling.
25 |
26 |
27 |

28 |
29 |
30 | AFF demonstrates significant savings on FLOPs (see our models with 1/5 downsampling rate), and significant improvement on recognition of small objects.
31 |
32 | Notably, AFF-Small achieves **44.0** instance segmentation AP and **66.9** panoptic segmentation PQ on Cityscapes val with a backbone of only **42.6M** parameters, a performance on par with Swin-Large, a backbone with **197M** params (saving **78%**!).
33 |
34 |
35 |

36 |
37 |
38 |
39 |

40 |
41 |
42 | ## Main Results on ImageNet with Pretrained Models
43 |
44 | | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS | 1K model |
45 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
46 | | AFF-Mini | ImageNet-1K | 224x224 | 78.2 | 93.6 | 6.75M | 1.08G | 1337 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_mini.pth) |
47 | | AFF-Mini-1/5 | ImageNet-1K | 224x224 | 77.5 | 93.3 | 6.75M | 0.72G | 1678 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_mini_1_5th.pth) |
48 | | AFF-Tiny | ImageNet-1K | 224x224 | 83.0 | 96.3 | 27M | 4G | 528 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_tiny.pth) |
49 | | AFF-Tiny-1/5 | ImageNet-1K | 224x224 | 82.4 | 95.9 | 27M | 2.74G | 682 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_tiny_1_5th.pth) |
50 | | AFF-Small | ImageNet-1K | 224x224 | 83.5 | 96.6 | 42.6M | 8.16G | 321 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_small.pth) |
51 | | AFF-Small-1/5 | ImageNet-1K | 224x224 | 83.4 | 96.5 | 42.6M | 5.69G | 424 | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_small_1_5th.pth) |
52 |
53 | FPS is obtained on a single V100 GPU.
54 |
55 | We train with a total batch size 4096.
56 |
57 | | name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | 22K model | 1K model |
58 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
59 | | AFF-Base | ImageNet-22K | 384x384 | 86.2 | 98.0 | 75.34M | 42.54G | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_base_22k.pth) | [Apple ML](https://docs-assets.developer.apple.com/ml-research/models/aff/classification/aff_base_22kto1k_384.pth) |
60 |
61 | ## Getting Started
62 |
63 | ### Clone this repo
64 |
65 | ```bash
66 | git clone git@github.com:apple/ml-autofocusformer.git
67 | cd ml-autofocusformer
68 | ```
69 | One can download pre-trained checkpoints through the links in the table above.
70 |
71 | ### Create environment and install requirements
72 |
73 | ```bash
74 | sh create_env.sh
75 | ```
76 |
77 | See further documentation inside the script file.
78 |
79 | Our experiments are run with `CUDA==11.6` and `pytorch==1.12`.
80 |
81 | ### Prepare data
82 |
83 | We use standard ImageNet dataset, which can be downloaded from http://image-net.org/.
84 |
85 | For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
86 | ```bash
87 | $ tree imagenet
88 | imagenet/
89 | ├── training
90 | │ ├── class1
91 | │ │ ├── img1.jpeg
92 | │ │ ├── img2.jpeg
93 | │ │ └── ...
94 | │ ├── class2
95 | │ │ ├── img3.jpeg
96 | │ │ └── ...
97 | │ └── ...
98 | └── validation
99 | ├── class1
100 | │ ├── img4.jpeg
101 | │ ├── img5.jpeg
102 | │ └── ...
103 | ├── class2
104 | │ ├── img6.jpeg
105 | │ └── ...
106 | └── ...
107 |
108 | ```
109 |
110 | ### Train and evaluate
111 |
112 | Modify the arguments in script `run_aff.sh` (e.g., path to dataset) and run
113 | ```bash
114 | sh run_aff.sh
115 | ```
116 | for training or evaluation.
117 |
118 | Run `python main.py -h` to see full documentation of the args.
119 |
120 | One can also directly modify the config files in `configs/`.
121 |
122 | ## Citing AutoFocusFormer
123 |
124 | ```BibTeX
125 | @inproceedings{autofocusformer,
126 | title = {AutoFocusFormer: Image Segmentation off the Grid},
127 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
128 | author = {Ziwen, Chen and Patnaik, Kaushik and Zhai, Shuangfei and Wan, Alvin and Ren, Zhile and Schwing, Alex and Colburn, Alex and Fuxin, Li},
129 | year = {2023},
130 | }
131 | ```
132 |
--------------------------------------------------------------------------------
/aff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/aff.png
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/architecture.png
--------------------------------------------------------------------------------
/clusten/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from .clusten import CLUSTENQKFunction, CLUSTENAVFunction, CLUSTENWFFunction
7 |
--------------------------------------------------------------------------------
/clusten/clusten.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from torch.autograd import Function
7 |
8 | try:
9 | import clustenqk_cuda
10 | import clustenav_cuda
11 | import clustenwf_cuda
12 | except ImportError:
13 | raise RuntimeError("Could not load CLUSTEN CUDA extension. " +
14 | "Please make sure your device has CUDA, the CUDA toolkit for PyTorch is installed, and that you've compiled CLUSTEN correctly.")
15 |
16 |
17 | class CLUSTENQKFunction(Function):
18 | """
19 | query times key function
20 | """
21 | @staticmethod
22 | def forward(ctx, query, key, nbhd_idx):
23 | query = query.contiguous()
24 | key = key.contiguous()
25 | if key.dtype != query.dtype:
26 | key = key.to(query.dtype)
27 | nbhd_idx = nbhd_idx.contiguous()
28 | attn = clustenqk_cuda.forward(
29 | query,
30 | key.permute(0, 1, 3, 2).contiguous(),
31 | nbhd_idx)
32 | ctx.save_for_backward(query, key, nbhd_idx)
33 | return attn
34 |
35 | @staticmethod
36 | def backward(ctx, grad_attn):
37 | outputs = clustenqk_cuda.backward(
38 | grad_attn.contiguous(), *ctx.saved_tensors)
39 | d_query, d_key = outputs
40 | return d_query, d_key, None
41 |
42 |
43 | class CLUSTENAVFunction(Function):
44 | """
45 | attention times value function
46 | """
47 | @staticmethod
48 | def forward(ctx, attn, v, nbhd_idx):
49 | attn = attn.contiguous()
50 | v = v.contiguous()
51 | nbhd_idx = nbhd_idx.contiguous()
52 | if attn.dtype != v.dtype:
53 | v = v.to(attn.dtype)
54 | feat = clustenav_cuda.forward(
55 | attn,
56 | v,
57 | nbhd_idx)
58 | ctx.save_for_backward(attn, v, nbhd_idx)
59 | return feat
60 |
61 | @staticmethod
62 | def backward(ctx, grad_feat):
63 | outputs = clustenav_cuda.backward(
64 | grad_feat.contiguous(), *ctx.saved_tensors)
65 | d_attn, d_v = outputs
66 | return d_attn, d_v, None
67 |
68 |
69 | class CLUSTENWFFunction(Function):
70 | """
71 | weight times feature function
72 | """
73 | @staticmethod
74 | def forward(ctx, weights, feat, nbhd_idx):
75 | weights = weights.contiguous()
76 | feat = feat.contiguous()
77 | nbhd_idx = nbhd_idx.contiguous()
78 | if feat.dtype != weights.dtype:
79 | feat = feat.to(weights.dtype)
80 | feat_new = clustenwf_cuda.forward(
81 | weights,
82 | feat,
83 | nbhd_idx)
84 | ctx.save_for_backward(weights, feat, nbhd_idx)
85 | return feat_new
86 |
87 | @staticmethod
88 | def backward(ctx, grad_feat_new):
89 | outputs = clustenwf_cuda.backward(
90 | grad_feat_new.contiguous(), *ctx.saved_tensors)
91 | d_weights, d_feat = outputs
92 | return d_weights, d_feat, None
93 |
--------------------------------------------------------------------------------
/clusten/src/clustenav_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 | #include
8 |
9 | torch::Tensor clusten_av_cuda_forward(
10 | const torch::Tensor &attn, // b x h x n x m
11 | const torch::Tensor &v, // b x h x n x c
12 | const torch::Tensor &nbhd_idx); // b x n x m
13 |
14 | std::vector clusten_av_cuda_backward(
15 | const torch::Tensor &d_feat,
16 | const torch::Tensor &attn,
17 | const torch::Tensor &v,
18 | const torch::Tensor &nbhd_idx);
19 |
20 | // C++ interface
21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
24 |
25 | torch::Tensor clusten_av_forward(
26 | const torch::Tensor &attn,
27 | const torch::Tensor &v,
28 | const torch::Tensor &nbhd_idx) {
29 | CHECK_INPUT(attn);
30 | CHECK_INPUT(v);
31 | CHECK_INPUT(nbhd_idx);
32 | return clusten_av_cuda_forward(attn, v, nbhd_idx);
33 | }
34 |
35 | std::vector clusten_av_backward(
36 | const torch::Tensor &d_feat,
37 | const torch::Tensor &attn,
38 | const torch::Tensor &v,
39 | const torch::Tensor &nbhd_idx) {
40 | CHECK_INPUT(d_feat);
41 | CHECK_INPUT(attn);
42 | CHECK_INPUT(v);
43 | CHECK_INPUT(nbhd_idx);
44 | return clusten_av_cuda_backward(d_feat, attn, v, nbhd_idx);
45 | }
46 |
47 |
48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
49 | m.def("forward", &clusten_av_forward, "CLUSTENAV forward (CUDA)");
50 | m.def("backward", &clusten_av_backward, "CLUSTENAV backward (CUDA)");
51 | }
52 |
--------------------------------------------------------------------------------
/clusten/src/clustenav_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #define CUDA_NUM_THREADS 1024
17 |
18 | template
19 | __global__ void clusten_av_cuda_forward_kernel(
20 | const torch::PackedTensorAccessor32 attn, // b x h x n x m
21 | const torch::PackedTensorAccessor32 v, // b x h x n x c
22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n x m
23 | torch::PackedTensorAccessor32 feat, // b x n x c
24 | const int length, // n
25 | const int batch_size, // b
26 | const int heads, // h
27 | const int nbhd_size, // m
28 | const int dim) { // c
29 |
30 | const int z = blockIdx.z * blockDim.z + threadIdx.z;
31 | if (z < batch_size * heads){
32 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
33 | if (i < length){
34 | const int c = blockIdx.x * blockDim.x + threadIdx.x;
35 | if (c < dim){
36 | const int b = z / heads;
37 | const int h = z - b * heads;
38 | int64_t nbi;
39 | // calculate a@v
40 | scalar_t updt = scalar_t(0);
41 | #pragma unroll
42 | for (unsigned int ni=0; ni < nbhd_size; ++ni) {
43 | nbi = nbhd_idx[b][i][ni];
44 | updt += attn[b][h][i][ni] * v[b][h][nbi][c];
45 | }
46 | feat[b][h][i][c] = updt;
47 | }
48 | }
49 | }
50 | }
51 |
52 |
53 | torch::Tensor clusten_av_cuda_forward(
54 | const torch::Tensor &attn,
55 | const torch::Tensor &v,
56 | const torch::Tensor &nbhd_idx) {
57 |
58 | int64_t batch_size = attn.size(0);
59 | int64_t heads = attn.size(1);
60 | int64_t length = attn.size(2);
61 | int64_t dim = v.size(3);
62 | int64_t nbhd_size = nbhd_idx.size(2);
63 | int zsize = batch_size * heads;
64 |
65 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim);
66 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length);
67 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS));
68 |
69 | auto feat = torch::zeros(
70 | {batch_size, heads, length, dim}, v.options());
71 |
72 | const auto stream = c10::cuda::getCurrentCUDAStream();
73 | const dim3 blocks(
74 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS,
75 | (length + TOKENTHREADS - 1) / TOKENTHREADS,
76 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS);
77 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(attn.scalar_type(), "clusten_av_cuda_forward", ([&] {
80 | const auto attn_a = attn.packed_accessor32();
81 | const auto v_a = v.packed_accessor32();
82 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
83 | auto feat_a = feat.packed_accessor32();
84 |
85 | clusten_av_cuda_forward_kernel<<>>(
86 | attn_a, v_a, nbhd_idx_a, feat_a,
87 | length, batch_size, heads, nbhd_size, dim);
88 | }));
89 | return feat;
90 | }
91 |
92 |
93 | template
94 | __global__ void clusten_av_cuda_backward_kernel(
95 | const torch::PackedTensorAccessor32 d_feat,
96 | const torch::PackedTensorAccessor32 attn,
97 | const torch::PackedTensorAccessor32 nbhd_idx,
98 | torch::PackedTensorAccessor32 d_v,
99 | const int length,
100 | const int batch_size,
101 | const int heads,
102 | const int nbhd_size,
103 | const int dim,
104 | const size_t d_v_numel) {
105 |
106 | const int z = blockIdx.z * blockDim.z + threadIdx.z;
107 | if (z < batch_size * heads){
108 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
109 | if (i < length){
110 | const int c = blockIdx.x * blockDim.x + threadIdx.x;
111 | if (c < dim){
112 | const int b = z / heads;
113 | const int h = z - b * heads;
114 | int64_t nbi;
115 | size_t index;
116 | #pragma unroll
117 | for (unsigned int ni=0; ni < nbhd_size; ++ni) {
118 | nbi = nbhd_idx[b][i][ni];
119 | // calculate d_v = att * d_feat
120 | index = b*d_v.stride(0) + h*d_v.stride(1) + nbi*d_v.stride(2) + c;
121 | at::native::fastAtomicAdd(d_v.data(), index, d_v_numel, d_feat[b][h][i][c] * attn[b][h][i][ni], true);
122 | // atomicAdd(&(d_v[b][h][nbi][c]), d_feat[b][h][i][c] * attn[b][h][i][ni]); // avoid race condition
123 | }
124 | }
125 | }
126 | }
127 | }
128 |
129 | template
130 | __global__ void clusten_av_attn_cuda_backward_kernel(
131 | const torch::PackedTensorAccessor32 d_feat,
132 | const torch::PackedTensorAccessor32 v,
133 | const torch::PackedTensorAccessor32 nbhd_idx,
134 | torch::PackedTensorAccessor32 d_attn,
135 | const int length,
136 | const int batch_size,
137 | const int heads,
138 | const int nbhd_size,
139 | const int dim) {
140 |
141 | const int z = blockIdx.z * blockDim.z + threadIdx.z;
142 | if (z < batch_size * heads){
143 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
144 | if (i < length){
145 | const int ni = blockIdx.x * blockDim.x + threadIdx.x;
146 | if (ni < nbhd_size){
147 | const int b = z / heads;
148 | const int h = z - b * heads;
149 | int64_t nbi = nbhd_idx[b][i][ni];
150 | scalar_t updt = scalar_t(0);
151 | #pragma unroll
152 | for (unsigned int c=0; c < dim; ++c) {
153 | // calculate d_attn = v * d_feat
154 | updt += v[b][h][nbi][c] * d_feat[b][h][i][c];
155 | }
156 | d_attn[b][h][i][ni] = updt;
157 | }
158 | }
159 | }
160 | }
161 |
162 | std::vector clusten_av_cuda_backward(
163 | const torch::Tensor &d_feat,
164 | const torch::Tensor &attn,
165 | const torch::Tensor &v,
166 | const torch::Tensor &nbhd_idx) {
167 |
168 | int64_t batch_size = attn.size(0);
169 | int64_t heads = attn.size(1);
170 | int64_t length = attn.size(2);
171 | int64_t dim = v.size(3);
172 | int64_t nbhd_size = nbhd_idx.size(2);
173 | int zsize = batch_size * heads;
174 |
175 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim);
176 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length);
177 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS* CHANNELTHREADS));
178 |
179 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), nbhd_size);
180 | int TOKENTHREADS_NB = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length);
181 | int BATCHTHREADS_NB = max(1, CUDA_NUM_THREADS / (TOKENTHREADS_NB* NBHDTHREADS));
182 |
183 | auto d_attn = torch::zeros_like(attn);
184 | auto d_v = torch::zeros_like(v);
185 |
186 | const auto stream = c10::cuda::getCurrentCUDAStream();
187 |
188 | const dim3 blocks(
189 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS,
190 | (length + TOKENTHREADS - 1) / TOKENTHREADS,
191 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS);
192 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS);
193 |
194 | const dim3 blocks_nb(
195 | (nbhd_size + NBHDTHREADS - 1) / NBHDTHREADS,
196 | (length + TOKENTHREADS_NB - 1) / TOKENTHREADS_NB,
197 | (zsize + BATCHTHREADS_NB - 1) / BATCHTHREADS_NB);
198 | const dim3 threads_nb(NBHDTHREADS, TOKENTHREADS_NB, BATCHTHREADS_NB);
199 |
200 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(attn.scalar_type(), "clusten_av_cuda_backward", ([&] {
201 | const auto d_feat_a = d_feat.packed_accessor32();
202 | const auto attn_a = attn.packed_accessor32();
203 | const auto v_a = v.packed_accessor32();
204 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
205 | auto d_attn_a = d_attn.packed_accessor32();
206 | auto d_v_a = d_v.packed_accessor32();
207 |
208 | const size_t d_v_numel = d_v.numel();
209 | clusten_av_cuda_backward_kernel<<>>(
210 | d_feat_a, attn_a, nbhd_idx_a, d_v_a,
211 | length, batch_size, heads, nbhd_size, dim, d_v_numel);
212 | clusten_av_attn_cuda_backward_kernel<<>>(
213 | d_feat_a, v_a, nbhd_idx_a, d_attn_a,
214 | length, batch_size, heads, nbhd_size, dim);
215 | }));
216 |
217 | return {d_attn, d_v.to(v.dtype())};
218 | }
219 |
--------------------------------------------------------------------------------
/clusten/src/clustenqk_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 | #include
8 |
9 | torch::Tensor clusten_qk_cuda_forward(
10 | const torch::Tensor &query, // b x h x n x c
11 | const torch::Tensor &key, // b x h x n x c
12 | const torch::Tensor &nbhd_idx); // b x n x m
13 |
14 | std::vector clusten_qk_cuda_backward(
15 | const torch::Tensor &d_attn,
16 | const torch::Tensor &query,
17 | const torch::Tensor &key,
18 | const torch::Tensor &nbhd_idx);
19 |
20 | // C++ interface
21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
24 |
25 | torch::Tensor clusten_qk_forward(
26 | const torch::Tensor &query,
27 | const torch::Tensor &key,
28 | const torch::Tensor &nbhd_idx) {
29 | CHECK_INPUT(query);
30 | CHECK_INPUT(key);
31 | CHECK_INPUT(nbhd_idx);
32 | return clusten_qk_cuda_forward(query, key, nbhd_idx);
33 | }
34 |
35 | std::vector clusten_qk_backward(
36 | const torch::Tensor &d_attn,
37 | const torch::Tensor &query,
38 | const torch::Tensor &key,
39 | const torch::Tensor &nbhd_idx) {
40 | CHECK_INPUT(d_attn);
41 | CHECK_INPUT(query);
42 | CHECK_INPUT(key);
43 | CHECK_INPUT(nbhd_idx);
44 | return clusten_qk_cuda_backward(d_attn, query, key, nbhd_idx);
45 | }
46 |
47 |
48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
49 | m.def("forward", &clusten_qk_forward, "CLUSTENQK forward (CUDA)");
50 | m.def("backward", &clusten_qk_backward, "CLUSTENQK backward (CUDA)");
51 | }
52 |
--------------------------------------------------------------------------------
/clusten/src/clustenqk_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #define CUDA_NUM_THREADS 1024
17 |
18 | template
19 | __global__ void clusten_qk_cuda_forward_kernel(
20 | const torch::PackedTensorAccessor32 query, // b x h x n x c
21 | const torch::PackedTensorAccessor32 key, // b x h x c x n (reordered by cluster)
22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n x m
23 | torch::PackedTensorAccessor32 attn, // b x h x n x m
24 | const int length, // n
25 | const int batch_size, // b
26 | const int heads, // h
27 | const int nbhd_size, // m
28 | const int dim) { // c
29 |
30 | const int z = blockIdx.z * blockDim.z + threadIdx.z;
31 | if (z < batch_size * heads){
32 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
33 | if (i < length){
34 | const int ni = blockIdx.x * blockDim.x + threadIdx.x;
35 | if (ni < nbhd_size){
36 | const int b = z / heads;
37 | const int h = z - b * heads;
38 | int64_t nbi = nbhd_idx[b][i][ni];
39 | // calculate q@k
40 | scalar_t updt = scalar_t(0);
41 | #pragma unroll
42 | for (unsigned int c=0; c < dim; ++c) {
43 | updt += query[b][h][i][c] * key[b][h][c][nbi];
44 | }
45 | attn[b][h][i][ni] = updt;
46 | }
47 | }
48 | }
49 | }
50 |
51 |
52 | torch::Tensor clusten_qk_cuda_forward(
53 | const torch::Tensor &query,
54 | const torch::Tensor &key,
55 | const torch::Tensor &nbhd_idx) {
56 |
57 | int64_t batch_size = query.size(0);
58 | int64_t heads = query.size(1);
59 | int64_t length = query.size(2);
60 | int64_t dim = query.size(3);
61 | int64_t nbhd_size = nbhd_idx.size(2);
62 | int zsize = batch_size * heads;
63 |
64 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), nbhd_size);
65 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length);
66 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * NBHDTHREADS));
67 |
68 | auto attn = torch::zeros(
69 | {batch_size, heads, length, nbhd_size}, query.options());
70 |
71 | const auto stream = c10::cuda::getCurrentCUDAStream();
72 | const dim3 blocks(
73 | (dim + NBHDTHREADS - 1) / NBHDTHREADS,
74 | (length + TOKENTHREADS - 1) / TOKENTHREADS,
75 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS);
76 | const dim3 threads(NBHDTHREADS, TOKENTHREADS, BATCHTHREADS);
77 |
78 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(query.scalar_type(), "clusten_qk_cuda_forward", ([&] {
79 | const auto query_a = query.packed_accessor32();
80 | const auto key_a = key.packed_accessor32();
81 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
82 | auto attn_a = attn.packed_accessor32();
83 |
84 | clusten_qk_cuda_forward_kernel<<>>(
85 | query_a, key_a, nbhd_idx_a, attn_a,
86 | length, batch_size, heads, nbhd_size, dim);
87 | }));
88 | return attn;
89 | }
90 |
91 | template
92 | __global__ void clusten_qk_cuda_backward_kernel(
93 | const torch::PackedTensorAccessor32 d_attn,
94 | const torch::PackedTensorAccessor32 query,
95 | const torch::PackedTensorAccessor32 key,
96 | const torch::PackedTensorAccessor32 nbhd_idx,
97 | torch::PackedTensorAccessor32 d_query,
98 | torch::PackedTensorAccessor32 d_key,
99 | const int length,
100 | const int batch_size,
101 | const int heads,
102 | const int nbhd_size,
103 | const int dim,
104 | const size_t d_key_numel) {
105 |
106 | const int z = blockIdx.z * blockDim.z + threadIdx.z;
107 | if (z < batch_size * heads){
108 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
109 | if (i < length){
110 | const int c = blockIdx.x * blockDim.x + threadIdx.x;
111 | if (c < dim){
112 | const int b = z / heads;
113 | const int h = z - b * heads;
114 | size_t index;
115 | scalar_t dq_update = scalar_t(0);
116 | scalar_t d_attn_tmp;
117 | #pragma unroll
118 | for (unsigned int ni=0; ni < nbhd_size; ++ni) {
119 | const int64_t nbi = nbhd_idx[b][i][ni];
120 | // calculate d_query = key * d_att
121 | // calculate d_key = query * d_att
122 | d_attn_tmp = d_attn[b][h][i][ni];
123 | dq_update += key[b][h][nbi][c] * d_attn_tmp;
124 | index = b*d_key.stride(0) + h*d_key.stride(1) + nbi*d_key.stride(2) + c;
125 | at::native::fastAtomicAdd(d_key.data(), index, d_key_numel, query[b][h][i][c] * d_attn_tmp, true);
126 | //atomicAdd(&(d_key[b][h][nbi][c]), query[b][h][i][c] * d_attn_tmp); // avoid race condition
127 | }
128 | d_query[b][h][i][c] = dq_update;
129 | }
130 | }
131 | }
132 | }
133 |
134 | std::vector clusten_qk_cuda_backward(
135 | const torch::Tensor &d_attn,
136 | const torch::Tensor &query,
137 | const torch::Tensor &key,
138 | const torch::Tensor &nbhd_idx) {
139 |
140 | int64_t batch_size = query.size(0);
141 | int64_t heads = query.size(1);
142 | int64_t length = query.size(2);
143 | int64_t dim = query.size(3);
144 | int64_t nbhd_size = nbhd_idx.size(2);
145 | int zsize = batch_size * heads;
146 |
147 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim);
148 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length);
149 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS));
150 |
151 | auto d_query = torch::zeros_like(query);
152 | auto d_key = torch::zeros_like(key);
153 |
154 | const auto stream = c10::cuda::getCurrentCUDAStream();
155 |
156 | const dim3 blocks(
157 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS,
158 | (length + TOKENTHREADS - 1) / TOKENTHREADS,
159 | (zsize + BATCHTHREADS - 1) / BATCHTHREADS);
160 |
161 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS);
162 |
163 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(query.scalar_type(), "clusten_qk_cuda_backward", ([&] {
164 | const auto d_attn_a = d_attn.packed_accessor32();
165 | const auto query_a = query.packed_accessor32();
166 | const auto key_a = key.packed_accessor32();
167 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
168 | auto d_query_a = d_query.packed_accessor32();
169 | auto d_key_a = d_key.packed_accessor32();
170 |
171 | const size_t d_key_numel = d_key.numel();
172 | clusten_qk_cuda_backward_kernel<<>>(
173 | d_attn_a, query_a, key_a, nbhd_idx_a, d_query_a, d_key_a,
174 | length, batch_size, heads, nbhd_size, dim, d_key_numel);
175 | }));
176 |
177 | return {d_query, d_key.to(key.dtype())};
178 | }
179 |
--------------------------------------------------------------------------------
/clusten/src/clustenwf_cuda.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 | #include
8 |
9 | torch::Tensor clusten_wf_cuda_forward(
10 | const torch::Tensor &weights, // b x n_ x m x ic
11 | const torch::Tensor &feat, // b x n x c
12 | const torch::Tensor &nbhd_idx); // b x n_ x m
13 |
14 | std::vector clusten_wf_cuda_backward(
15 | const torch::Tensor &d_feat_new,
16 | const torch::Tensor &weights,
17 | const torch::Tensor &feat,
18 | const torch::Tensor &nbhd_idx);
19 |
20 | // C++ interface
21 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
24 |
25 | torch::Tensor clusten_wf_forward(
26 | const torch::Tensor &weights,
27 | const torch::Tensor &feat,
28 | const torch::Tensor &nbhd_idx) {
29 | CHECK_INPUT(weights);
30 | CHECK_INPUT(feat);
31 | CHECK_INPUT(nbhd_idx);
32 | return clusten_wf_cuda_forward(weights, feat, nbhd_idx);
33 | }
34 |
35 | std::vector clusten_wf_backward(
36 | const torch::Tensor &d_feat_new,
37 | const torch::Tensor &weights,
38 | const torch::Tensor &feat,
39 | const torch::Tensor &nbhd_idx) {
40 | CHECK_INPUT(d_feat_new);
41 | CHECK_INPUT(weights);
42 | CHECK_INPUT(feat);
43 | CHECK_INPUT(nbhd_idx);
44 | return clusten_wf_cuda_backward(d_feat_new, weights, feat, nbhd_idx);
45 | }
46 |
47 |
48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
49 | m.def("forward", &clusten_wf_forward, "CLUSTENWF forward (CUDA)");
50 | m.def("backward", &clusten_wf_backward, "CLUSTENWF backward (CUDA)");
51 | }
52 |
--------------------------------------------------------------------------------
/clusten/src/clustenwf_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * For licensing see accompanying LICENSE file.
3 | * Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | */
5 |
6 | #include
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #define CUDA_NUM_THREADS 1024
17 |
18 | template
19 | __global__ void clusten_wf_cuda_forward_kernel(
20 | const torch::PackedTensorAccessor32 weights, // b x n_ x m x ic
21 | const torch::PackedTensorAccessor32 feat, // b x n x c
22 | const torch::PackedTensorAccessor32 nbhd_idx, // b x n_ x m
23 | torch::PackedTensorAccessor32 feat_new, // b x n_ x ic x c
24 | const int length, // n
25 | const int length_out, // n_
26 | const int batch_size, // b
27 | const int nbhd_size, // m
28 | const int dim, // c
29 | const int dim_inner) { // ic
30 |
31 | const int b = blockIdx.z * blockDim.z + threadIdx.z;
32 | if (b < batch_size){
33 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
34 | if (i < length_out){
35 | const int c = blockIdx.x * blockDim.x + threadIdx.x;
36 | if (c < dim){
37 | int64_t nbi;
38 | // calculate weights@feat
39 | scalar_t updt;
40 | #pragma unroll
41 | for (unsigned int ic=0; ic < dim_inner; ++ic) {
42 | updt = scalar_t(0);
43 | #pragma unroll
44 | for (unsigned int ni=0; ni < nbhd_size; ++ni) {
45 | nbi = nbhd_idx[b][i][ni];
46 | updt += weights[b][i][ni][ic] * feat[b][nbi][c];
47 | }
48 | feat_new[b][i][ic][c] = updt;
49 | }
50 | }
51 | }
52 | }
53 | }
54 |
55 |
56 | torch::Tensor clusten_wf_cuda_forward(
57 | const torch::Tensor &weights,
58 | const torch::Tensor &feat,
59 | const torch::Tensor &nbhd_idx) {
60 |
61 | int64_t batch_size = weights.size(0);
62 | int64_t length_out = weights.size(1);
63 | int64_t nbhd_size = weights.size(2);
64 | int64_t dim_inner = weights.size(3);
65 | int64_t length = feat.size(1);
66 | int64_t dim = feat.size(2);
67 |
68 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim);
69 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length_out);
70 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS * CHANNELTHREADS));
71 |
72 | auto feat_new = torch::zeros(
73 | {batch_size, length_out, dim_inner, dim}, weights.options());
74 |
75 | const auto stream = c10::cuda::getCurrentCUDAStream();
76 | const dim3 blocks(
77 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS,
78 | (length_out + TOKENTHREADS - 1) / TOKENTHREADS,
79 | (batch_size + BATCHTHREADS - 1) / BATCHTHREADS);
80 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS);
81 |
82 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(weights.scalar_type(), "clusten_wf_cuda_forward", ([&] {
83 | const auto weights_a = weights.packed_accessor32();
84 | const auto feat_a = feat.packed_accessor32();
85 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
86 | auto feat_new_a = feat_new.packed_accessor32();
87 |
88 | clusten_wf_cuda_forward_kernel<<>>(
89 | weights_a, feat_a, nbhd_idx_a, feat_new_a,
90 | length, length_out, batch_size, nbhd_size, dim, dim_inner);
91 | }));
92 | return feat_new;
93 | }
94 |
95 |
96 | template
97 | __global__ void clusten_wf_cuda_backward_kernel(
98 | const torch::PackedTensorAccessor32 d_feat_new,
99 | const torch::PackedTensorAccessor32 weights,
100 | const torch::PackedTensorAccessor32 nbhd_idx,
101 | torch::PackedTensorAccessor32 d_feat,
102 | const int length, // n
103 | const int length_out, // n_
104 | const int batch_size, // b
105 | const int nbhd_size, // m
106 | const int dim, // c
107 | const int dim_inner, // ic
108 | const size_t d_feat_numel) {
109 |
110 | const int b = blockIdx.z * blockDim.z + threadIdx.z;
111 | if (b < batch_size){
112 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
113 | if (i < length_out){
114 | const int c = blockIdx.x * blockDim.x + threadIdx.x;
115 | if (c < dim){
116 | int64_t nbi;
117 | size_t index;
118 | scalar_t updt;
119 | #pragma unroll
120 | for (unsigned int ni=0; ni < nbhd_size; ++ni) {
121 | nbi = nbhd_idx[b][i][ni];
122 | updt = scalar_t(0);
123 | // calculate d_feat = weights * d_feat_new
124 | #pragma unroll
125 | for (unsigned int ic=0; ic < dim_inner; ++ic) {
126 | updt += d_feat_new[b][i][ic][c] * weights[b][i][ni][ic];
127 | }
128 | index = b*d_feat.stride(0) + nbi*d_feat.stride(1) + c;
129 | at::native::fastAtomicAdd(d_feat.data(), index, d_feat_numel, updt, true);
130 | // atomicAdd(&(d_feat[b][nbi][c]), updt); // avoid race condition
131 | }
132 | }
133 | }
134 | }
135 | }
136 |
137 | template
138 | __global__ void clusten_wf_weights_cuda_backward_kernel(
139 | const torch::PackedTensorAccessor32 d_feat_new,
140 | const torch::PackedTensorAccessor32 feat,
141 | const torch::PackedTensorAccessor32 nbhd_idx,
142 | torch::PackedTensorAccessor32 d_weights,
143 | const int length, // n
144 | const int length_out, // n_
145 | const int batch_size, // b
146 | const int nbhd_size, // m
147 | const int dim, // c
148 | const int dim_inner){ // ic
149 |
150 | const int b = blockIdx.z * blockDim.z + threadIdx.z;
151 | if (b < batch_size){
152 | const int i = blockIdx.y * blockDim.y + threadIdx.y;
153 | if (i < length_out){
154 | const int z = blockIdx.x * blockDim.x + threadIdx.x;
155 | if (z < nbhd_size * dim_inner){
156 | const int ni = z / dim_inner;
157 | const int ic = z - ni * dim_inner;
158 | int64_t nbi = nbhd_idx[b][i][ni];
159 | scalar_t updt = scalar_t(0);
160 | #pragma unroll
161 | for (unsigned int c=0; c < dim; ++c) {
162 | // calculate d_weights = feat * d_feat_new
163 | updt += feat[b][nbi][c] * d_feat_new[b][i][ic][c];
164 | }
165 | d_weights[b][i][ni][ic] = updt;
166 | }
167 | }
168 | }
169 | }
170 |
171 | std::vector clusten_wf_cuda_backward(
172 | const torch::Tensor &d_feat_new,
173 | const torch::Tensor &weights,
174 | const torch::Tensor &feat,
175 | const torch::Tensor &nbhd_idx) {
176 |
177 | int64_t batch_size = weights.size(0);
178 | int64_t length_out = weights.size(1);
179 | int64_t nbhd_size = weights.size(2);
180 | int64_t dim_inner = weights.size(3);
181 | int64_t length = feat.size(1);
182 | int64_t dim = feat.size(2);
183 |
184 | int64_t zsize = nbhd_size * dim_inner;
185 |
186 | int CHANNELTHREADS = min(int64_t(CUDA_NUM_THREADS), dim);
187 | int TOKENTHREADS = min(int64_t(CUDA_NUM_THREADS / CHANNELTHREADS), length_out);
188 | int BATCHTHREADS = max(1, CUDA_NUM_THREADS / (TOKENTHREADS* CHANNELTHREADS));
189 |
190 | int NBHDTHREADS = min(int64_t(CUDA_NUM_THREADS), zsize);
191 | int TOKENTHREADS_NB = min(int64_t(CUDA_NUM_THREADS / NBHDTHREADS), length_out);
192 | int BATCHTHREADS_NB = max(1, CUDA_NUM_THREADS / (TOKENTHREADS_NB* NBHDTHREADS));
193 |
194 | auto d_weights = torch::zeros_like(weights);
195 | auto d_feat = torch::zeros_like(feat);
196 |
197 | const auto stream = c10::cuda::getCurrentCUDAStream();
198 |
199 | const dim3 blocks(
200 | (dim + CHANNELTHREADS - 1) / CHANNELTHREADS,
201 | (length_out + TOKENTHREADS - 1) / TOKENTHREADS,
202 | (batch_size + BATCHTHREADS - 1) / BATCHTHREADS);
203 | const dim3 threads(CHANNELTHREADS, TOKENTHREADS, BATCHTHREADS);
204 |
205 | const dim3 blocks_nb(
206 | (zsize + NBHDTHREADS - 1) / NBHDTHREADS,
207 | (length_out + TOKENTHREADS_NB - 1) / TOKENTHREADS_NB,
208 | (batch_size + BATCHTHREADS_NB - 1) / BATCHTHREADS_NB);
209 | const dim3 threads_nb(NBHDTHREADS, TOKENTHREADS_NB, BATCHTHREADS_NB);
210 |
211 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(weights.scalar_type(), "clusten_wf_cuda_backward", ([&] {
212 | const auto d_feat_new_a = d_feat_new.packed_accessor32();
213 | const auto weights_a = weights.packed_accessor32();
214 | const auto feat_a = feat.packed_accessor32();
215 | const auto nbhd_idx_a = nbhd_idx.packed_accessor32();
216 | auto d_weights_a = d_weights.packed_accessor32();
217 | auto d_feat_a = d_feat.packed_accessor32();
218 |
219 | const size_t d_feat_numel = d_feat.numel();
220 | clusten_wf_cuda_backward_kernel<<>>(
221 | d_feat_new_a, weights_a, nbhd_idx_a, d_feat_a,
222 | length, length_out, batch_size, nbhd_size, dim, dim_inner, d_feat_numel);
223 | clusten_wf_weights_cuda_backward_kernel<<>>(
224 | d_feat_new_a, feat_a, nbhd_idx_a, d_weights_a,
225 | length, length_out, batch_size, nbhd_size, dim, dim_inner);
226 | }));
227 |
228 | return {d_weights, d_feat.to(feat.dtype())};
229 | }
230 |
--------------------------------------------------------------------------------
/clusten/src/setup.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from setuptools import setup
7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8 |
9 | setup(
10 | name='clustencuda',
11 | version='0.1',
12 | author='Ziwen Chen',
13 | author_email='chenziw@oregonstate.edu',
14 | description='Cluster Attention CUDA Kernel',
15 | ext_modules=[
16 | CUDAExtension('clustenqk_cuda', [
17 | 'clustenqk_cuda.cpp',
18 | 'clustenqk_cuda_kernel.cu',
19 | ]),
20 | CUDAExtension('clustenav_cuda', [
21 | 'clustenav_cuda.cpp',
22 | 'clustenav_cuda_kernel.cu',
23 | ]),
24 | CUDAExtension('clustenwf_cuda', [
25 | 'clustenwf_cuda.cpp',
26 | 'clustenwf_cuda_kernel.cu',
27 | ]),
28 | ],
29 | cmdclass={
30 | 'build_ext': BuildExtension
31 | })
32 |
--------------------------------------------------------------------------------
/clusten/test_av_kernel.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import torch
7 | from torch import nn
8 | from clusten import CLUSTENAVFunction
9 |
10 | """
11 | Test the correctness of AV custom kernel
12 | """
13 |
14 | b = 256
15 | h = 4
16 | n = 196
17 | m = 48
18 | c = 32
19 |
20 | # dummy data
21 | attn = nn.Parameter(torch.randn(b, h, n, m)).cuda()
22 | attn.retain_grad()
23 | val = nn.Parameter(torch.randn(b, h, n, c)).cuda()
24 | val.retain_grad()
25 | nn_idx = torch.randint(n, (b, n, m)).cuda()
26 |
27 | # use the custom kernel
28 | feat = CLUSTENAVFunction.apply(attn, val, nn_idx)
29 | feat.mean().backward()
30 | grad_attn = attn.grad.clone().detach()
31 | attn.grad.data.zero_()
32 | grad_val = val.grad.clone().detach()
33 | val.grad.data.zero_()
34 |
35 | # use the pytorch equivalent
36 | '''
37 | feat2 = CLUSTENAVFunction.apply(attn,val,nn_idx)
38 | '''
39 | val_gather = val.gather(index=nn_idx.reshape(b, 1, -1, 1).expand(-1, h, -1, c), dim=2).reshape(b, h, n, m, c)
40 | feat2 = (attn.unsqueeze(4) * val_gather).sum(3)
41 | feat2.mean().backward()
42 | grad_attn2 = attn.grad.clone().detach()
43 | attn.grad.data.zero_()
44 | grad_val2 = val.grad.clone().detach()
45 | val.grad.data.zero_()
46 |
47 |
48 | print('diff of forward: ', torch.linalg.norm(feat2 - feat))
49 | print('diff of grad attn: ', torch.linalg.norm(grad_attn2 - grad_attn))
50 | print('diff of grad val: ', torch.linalg.norm(grad_val2 - grad_val))
51 |
--------------------------------------------------------------------------------
/clusten/test_qk_kernel.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import torch
7 | from torch import nn
8 | from clusten import CLUSTENQKFunction
9 |
10 | """
11 | Test the correctness of QK custom kernel
12 | """
13 |
14 | b = 256
15 | h = 4
16 | n = 196
17 | m = 48
18 | c = 32
19 |
20 | # dummy data
21 | query = nn.Parameter(torch.randn(b, h, n, c)).cuda()
22 | query.retain_grad()
23 | key = nn.Parameter(torch.randn(b, h, n, c)).cuda()
24 | key.retain_grad()
25 | nn_idx = torch.randint(n, (b, n, m)).cuda()
26 |
27 | # use the custom kernel
28 | attn = CLUSTENQKFunction.apply(query, key, nn_idx)
29 | attn.mean().backward()
30 | grad_query = query.grad.clone().detach()
31 | query.grad.data.zero_()
32 | grad_key = key.grad.clone().detach()
33 | key.grad.data.zero_()
34 |
35 | # use the pytorch equivalent
36 | '''
37 | attn2 = CLUSTENQKFunction.apply(query, key, nn_idx)
38 | '''
39 | key_gather = key.gather(index=nn_idx.reshape(b, 1, -1, 1).expand(-1, h, -1, c), dim=2).reshape(b, h, n, m, c)
40 | attn2 = (query.unsqueeze(3) * key_gather).sum(-1)
41 | attn2.mean().backward()
42 | grad_query2 = query.grad.clone().detach()
43 | query.grad.data.zero_()
44 | grad_key2 = key.grad.clone().detach()
45 | key.grad.data.zero_()
46 |
47 |
48 | print('diff of forward: ', torch.linalg.norm(attn2 - attn))
49 | print('diff of grad query: ', torch.linalg.norm(grad_query2 - grad_query))
50 | print('diff of grad key: ', torch.linalg.norm(grad_key2 - grad_key))
51 |
--------------------------------------------------------------------------------
/clusten/test_wf_kernel.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import torch
7 | from torch import nn
8 | from clusten import CLUSTENWFFunction
9 |
10 | """
11 | Test the correctness of WF custom kernel
12 | """
13 |
14 | b = 256
15 | n_ = 64
16 | n = 196
17 | m = 48
18 | c = 32
19 | ic = 4
20 |
21 | # dummy data
22 | weights = nn.Parameter(torch.randn(b, n_, m, ic)).cuda()
23 | weights.retain_grad()
24 | feat = nn.Parameter(torch.randn(b, n, c)).cuda()
25 | feat.retain_grad()
26 | nn_idx = torch.randint(n, (b, n_, m)).cuda()
27 |
28 | # use the custom kernel
29 | feat_new = CLUSTENWFFunction.apply(weights, feat, nn_idx)
30 | feat_new.mean().backward()
31 | grad_weights = weights.grad.clone().detach()
32 | weights.grad.data.zero_()
33 | grad_feat = feat.grad.clone().detach()
34 | feat.grad.data.zero_()
35 |
36 | # use the pytorch equivalent
37 | '''
38 | feat_new2 = CLUSTENWFFunction.apply(weights,feat,nn_idx)
39 | '''
40 | feat_gather = feat.gather(index=nn_idx.reshape(b, -1, 1).expand(-1, -1, c), dim=1).reshape(b, n_, m, c)
41 | feat_new2 = weights.transpose(-1, -2) @ feat_gather
42 | feat_new2.mean().backward()
43 | grad_weights2 = weights.grad.clone().detach()
44 | weights.grad.data.zero_()
45 | grad_feat2 = feat.grad.clone().detach()
46 | feat.grad.data.zero_()
47 |
48 |
49 | print('diff of forward: ', torch.linalg.norm(feat_new2 - feat_new))
50 | print('diff of grad weights: ', torch.linalg.norm(grad_weights2 - grad_weights))
51 | print('diff of grad feat: ', torch.linalg.norm(grad_feat2 - grad_feat))
52 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Adapted for AutoFocusFormer by Ziwen 2023
8 |
9 | import os
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 |
13 | _C = CN()
14 |
15 | # Base config files
16 | _C.BASE = ['']
17 |
18 | # -----------------------------------------------------------------------------
19 | # Data settings
20 | # -----------------------------------------------------------------------------
21 | _C.DATA = CN()
22 | # Batch size for a single GPU, could be overwritten by command line argument
23 | _C.DATA.BATCH_SIZE = 128
24 | # Path to dataset, could be overwritten by command line argument
25 | _C.DATA.DATA_PATH = 'imagenet'
26 | # Dataset name
27 | _C.DATA.DATASET = 'imagenet'
28 | # Input image size
29 | _C.DATA.IMG_SIZE = 224
30 | # Input channels
31 | _C.DATA.IN_CHANS = 3
32 | # Interpolation to resize image (random, bilinear, bicubic)
33 | _C.DATA.INTERPOLATION = 'bicubic'
34 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
35 | _C.DATA.PIN_MEMORY = True
36 | # Number of data loading threads
37 | _C.DATA.NUM_WORKERS = 8
38 |
39 | # -----------------------------------------------------------------------------
40 | # Model settings
41 | # -----------------------------------------------------------------------------
42 | _C.MODEL = CN()
43 | # Model type
44 | _C.MODEL.TYPE = 'aff'
45 | # Model name
46 | _C.MODEL.NAME = 'aff_mini_1_4th'
47 | # Checkpoint to resume, could be overwritten by command line argument
48 | _C.MODEL.RESUME = ''
49 | # Number of classes, overwritten in data preparation
50 | _C.MODEL.NUM_CLASSES = 1000
51 | # Dropout rate
52 | _C.MODEL.DROP_RATE = 0.0
53 | # Drop path rate
54 | _C.MODEL.DROP_PATH_RATE = 0.0
55 | # Label Smoothing
56 | _C.MODEL.LABEL_SMOOTHING = 0.1
57 |
58 | # AFF parameters
59 | _C.MODEL.AFF = CN()
60 | _C.MODEL.AFF.DEPTHS = [2, 2, 6, 2]
61 | _C.MODEL.AFF.NUM_HEADS = [2, 4, 8, 16]
62 | _C.MODEL.AFF.EMBED_DIM = [32, 128, 256, 384]
63 | _C.MODEL.AFF.MLP_RATIO = 2.
64 | _C.MODEL.AFF.PATCH_NORM = True
65 |
66 | _C.MODEL.AFF.CLUSTER_SIZE = 8
67 | _C.MODEL.AFF.NBHD_SIZE = [48, 48, 48, 49]
68 | _C.MODEL.AFF.ALPHA = 4.0
69 | _C.MODEL.AFF.DS_RATE = 0.25
70 | _C.MODEL.AFF.LAYER_SCALE = 0.0
71 | _C.MODEL.AFF.RESERVE = True
72 |
73 | # -----------------------------------------------------------------------------
74 | # Training settings
75 | # -----------------------------------------------------------------------------
76 | _C.TRAIN = CN()
77 | _C.TRAIN.START_EPOCH = 0
78 | _C.TRAIN.EPOCHS = 300
79 | _C.TRAIN.WARMUP_EPOCHS = 20
80 | _C.TRAIN.COOLDOWN_EPOCHS = 0
81 | _C.TRAIN.WEIGHT_DECAY = 0.05
82 | _C.TRAIN.BASE_LR = 5e-4
83 | _C.TRAIN.WARMUP_LR = 5e-7
84 | _C.TRAIN.MIN_LR = 5e-6
85 | # EMA
86 | _C.TRAIN.USE_EMA = False
87 | _C.TRAIN.EMA_DECAY = 0.9998
88 |
89 | # Clip gradient norm
90 | _C.TRAIN.CLIP_GRAD = 5.0
91 | # Auto resume from latest checkpoint
92 | _C.TRAIN.AUTO_RESUME = True
93 | # Gradient accumulation steps
94 | # could be overwritten by command line argument
95 | _C.TRAIN.ACCUMULATION_STEPS = 0
96 |
97 | # LR scheduler
98 | _C.TRAIN.LR_SCHEDULER = CN()
99 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
100 | # Epoch interval to decay LR, used in StepLRScheduler
101 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
102 | # LR decay rate, used in StepLRScheduler
103 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
104 |
105 | # Optimizer
106 | _C.TRAIN.OPTIMIZER = CN()
107 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
108 | # Optimizer Epsilon
109 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
110 | # Optimizer Betas
111 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
112 | # SGD momentum
113 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
114 |
115 | # -----------------------------------------------------------------------------
116 | # Augmentation settings
117 | # -----------------------------------------------------------------------------
118 | _C.AUG = CN()
119 | # Color jitter factor
120 | _C.AUG.COLOR_JITTER = 0.4
121 | # Use AutoAugment policy. "v0" or "original"
122 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
123 | # Random erase prob
124 | _C.AUG.REPROB = 0.25
125 | # Random erase mode
126 | _C.AUG.REMODE = 'pixel'
127 | # Random erase count
128 | _C.AUG.RECOUNT = 1
129 | # Mixup alpha, mixup enabled if > 0
130 | _C.AUG.MIXUP = 0.0 # 0.8
131 | # Cutmix alpha, cutmix enabled if > 0
132 | _C.AUG.CUTMIX = 0.0 # 1.0
133 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
134 | _C.AUG.CUTMIX_MINMAX = None
135 | # Probability of performing mixup or cutmix when either/both is enabled
136 | _C.AUG.MIXUP_PROB = 1.0
137 | # Probability of switching to cutmix when both mixup and cutmix enabled
138 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
139 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
140 | _C.AUG.MIXUP_MODE = 'batch'
141 |
142 | # -----------------------------------------------------------------------------
143 | # Testing settings
144 | # -----------------------------------------------------------------------------
145 | _C.TEST = CN()
146 | # Whether to use center crop when testing
147 | _C.TEST.CROP = True
148 |
149 | # -----------------------------------------------------------------------------
150 | # Misc
151 | # -----------------------------------------------------------------------------
152 | # Pytorch native amp, overwritten by command line argument
153 | _C.AMP_ENABLE = False
154 | # Path to output folder, overwritten by command line argument
155 | _C.OUTPUT = ''
156 | # Tag of experiment, overwritten by command line argument
157 | _C.TAG = 'default'
158 | # Frequency to save checkpoint (epochs)
159 | _C.SAVE_FREQ = 1
160 | # Frequency to logging info
161 | _C.PRINT_FREQ = 10
162 | # Frequency to validate (epochs)
163 | _C.EVAL_FREQ = 1
164 | # Fixed random seed
165 | _C.SEED = 0
166 | # Perform evaluation only, overwritten by command line argument
167 | _C.EVAL_MODE = False
168 | # Test throughput only, overwritten by command line argument
169 | _C.THROUGHPUT_MODE = False
170 | # local rank for DistributedDataParallel, given by command line argument
171 | _C.LOCAL_RANK = 0
172 |
173 |
174 | def _update_config_from_file(config, cfg_file):
175 | config.defrost()
176 | with open(cfg_file, 'r') as f:
177 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
178 |
179 | for cfg in yaml_cfg.setdefault('BASE', ['']):
180 | if cfg:
181 | _update_config_from_file(
182 | config, os.path.join(os.path.dirname(cfg_file), cfg)
183 | )
184 | print('=> merge config from {}'.format(cfg_file))
185 | config.merge_from_file(cfg_file)
186 | config.freeze()
187 |
188 |
189 | def update_config(config, args):
190 | _update_config_from_file(config, args.cfg)
191 |
192 | config.defrost()
193 | if args.opts:
194 | config.merge_from_list(args.opts)
195 |
196 | # merge from specific arguments
197 | if args.batch_size:
198 | config.DATA.BATCH_SIZE = args.batch_size
199 | if args.data_path:
200 | config.DATA.DATA_PATH = args.data_path
201 | if args.blr:
202 | config.TRAIN.BASE_LR = args.blr
203 | if args.resume:
204 | config.MODEL.RESUME = args.resume
205 | if args.accumulation_steps:
206 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
207 | if args.output:
208 | config.OUTPUT = args.output
209 | if args.tag:
210 | config.TAG = args.tag
211 | if args.eval:
212 | config.EVAL_MODE = True
213 | if args.throughput:
214 | config.THROUGHPUT_MODE = True
215 | if args.epochs:
216 | config.TRAIN.EPOCHS = args.epochs
217 |
218 | # set local rank for distributed training
219 | config.LOCAL_RANK = args.local_rank
220 |
221 | # output folder
222 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
223 |
224 | config.freeze()
225 |
226 |
227 | def get_config(args):
228 | """Get a yacs CfgNode object with default values."""
229 | # Return a clone so that the defaults will not be altered
230 | # This is for the "local variable" use pattern
231 | config = _C.clone()
232 | update_config(config, args)
233 |
234 | return config
235 |
--------------------------------------------------------------------------------
/configs/aff_base_22k.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_base_22k
4 | DROP_PATH_RATE: 0.2
5 | AFF:
6 | DEPTHS: [3,4,18,2]
7 | NUM_HEADS: [4,8,16,32]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [128, 256, 512, 1024]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | LAYER_SCALE: 1e-5 # turned off if 0.0
13 | ALPHA: 4.0
14 | DS_RATE: 0.25
15 | DATA:
16 | DATASET: imagenet22k
17 | IMG_SIZE: 224
18 | BATCH_SIZE: 64
19 | DATA_PATH: path/to/22k
20 | TRAIN:
21 | EPOCHS: 90
22 | WARMUP_EPOCHS: 5
23 | WEIGHT_DECAY: 0.05
24 | BASE_LR: 5e-4
25 | MIN_LR: 1.25e-6
26 | USE_EMA: False
27 | EMA_DECAY: 0.9998
28 | WARMUP_LR: 1.25e-7
29 | ACCUMULATION_STEPS: 1
30 | AUG:
31 | MIXUP: 0.8
32 | CUTMIX: 1.0
33 |
--------------------------------------------------------------------------------
/configs/aff_base_22kto1k.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_base_22kto1k
4 | DROP_PATH_RATE: 0.2
5 | AFF:
6 | DEPTHS: [3,4,18,2]
7 | NUM_HEADS: [4,8,16,32]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [128, 256, 512, 1024]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | LAYER_SCALE: 1e-5 # turned off if 0.0
13 | ALPHA: 4.0
14 | DS_RATE: 0.25
15 | PRETRAINED: aff_base_22k.pth
16 | DATA:
17 | DATASET: imagenet
18 | BATCH_SIZE: 64
19 | TRAIN:
20 | EPOCHS: 30
21 | WARMUP_EPOCHS: 5
22 | WEIGHT_DECAY: 1e-8
23 | BASE_LR: 2e-05
24 | MIN_LR: 2e-07
25 | USE_EMA: False
26 | EMA_DECAY: 0.9998
27 | WARMUP_LR: 2e-08
28 | AUG:
29 | MIXUP: 0.8
30 | CUTMIX: 1.0
31 |
--------------------------------------------------------------------------------
/configs/aff_base_22kto1k_384.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_base_22kto1k_384
4 | DROP_PATH_RATE: 0.2
5 | AFF:
6 | DEPTHS: [3,4,18,2]
7 | NUM_HEADS: [4,8,16,32]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [128, 256, 512, 1024]
10 | CLUSTER_SIZE: 24
11 | NBHD_SIZE: [144,144,144,144]
12 | LAYER_SCALE: 1e-5 # turned off if 0.0
13 | ALPHA: 4.0
14 | DS_RATE: 0.25
15 | PRETRAINED: aff_base_22k.pth
16 | DATA:
17 | DATASET: imagenet
18 | IMG_SIZE: 384
19 | BATCH_SIZE: 16
20 | TRAIN:
21 | EPOCHS: 30
22 | WARMUP_EPOCHS: 5
23 | WEIGHT_DECAY: 1e-8
24 | BASE_LR: 2e-05
25 | MIN_LR: 2e-07
26 | USE_EMA: False
27 | EMA_DECAY: 0.9998
28 | WARMUP_LR: 2e-08
29 | ACCUMULATION_STEPS: 4
30 | AUG:
31 | MIXUP: 0.8
32 | CUTMIX: 1.0
33 | TEST:
34 | CROP: False
35 |
--------------------------------------------------------------------------------
/configs/aff_mini.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_mini_1_4th
4 | DROP_PATH_RATE: 0.0
5 | AFF:
6 | DEPTHS: [ 2, 2, 6, 2]
7 | NUM_HEADS: [ 2, 4, 8, 16 ]
8 | MLP_RATIO: 2.
9 | EMBED_DIM: [32,128,256,384]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | ALPHA: 4.0
13 | DS_RATE: 0.25
14 | DATA:
15 | DATASET: imagenet
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 1024
18 | TRAIN:
19 | EPOCHS: 300
20 | BASE_LR: 5e-4
21 | MIN_LR: 5e-6
22 | WARMUP_LR: 5e-7
23 | AUG:
24 | MIXUP: 0.0
25 | CUTMIX: 0.0
26 |
--------------------------------------------------------------------------------
/configs/aff_mini_1_5th.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_mini_1_5th
4 | DROP_PATH_RATE: 0.0
5 | AFF:
6 | DEPTHS: [ 2, 2, 6, 2]
7 | NUM_HEADS: [ 2, 4, 8, 16 ]
8 | MLP_RATIO: 2.
9 | EMBED_DIM: [32,128,256,384]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | ALPHA: 4.0
13 | DS_RATE: 0.2
14 | DATA:
15 | DATASET: imagenet
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 1024
18 | TRAIN:
19 | EPOCHS: 300
20 | BASE_LR: 5e-4
21 | MIN_LR: 5e-6
22 | WARMUP_LR: 5e-7
23 | AUG:
24 | MIXUP: 0.0
25 | CUTMIX: 0.0
26 |
--------------------------------------------------------------------------------
/configs/aff_small.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_small_1_4th
4 | DROP_PATH_RATE: 0.3
5 | AFF:
6 | DEPTHS: [3,4,18,2]
7 | NUM_HEADS: [3,6,12,24]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [96,192,384,768]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | LAYER_SCALE: 1e-5 # turned off if 0.0
13 | ALPHA: 4.0
14 | DS_RATE: 0.25
15 | DATA:
16 | DATASET: imagenet
17 | IMG_SIZE: 224
18 | BATCH_SIZE: 256
19 | TRAIN:
20 | EPOCHS: 300
21 | BASE_LR: 5e-4
22 | MIN_LR: 5e-6
23 | WARMUP_LR: 5e-7
24 | AUG:
25 | MIXUP: 0.8
26 | CUTMIX: 1.0
27 |
--------------------------------------------------------------------------------
/configs/aff_small_1_5th.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_small_1_5th
4 | DROP_PATH_RATE: 0.3
5 | AFF:
6 | DEPTHS: [3,4,18,2]
7 | NUM_HEADS: [3,6,12,24]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [96,192,384,768]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | LAYER_SCALE: 1e-5 # turned off if 0.0
13 | ALPHA: 4.0
14 | DS_RATE: 0.2
15 | DATA:
16 | DATASET: imagenet
17 | IMG_SIZE: 224
18 | BATCH_SIZE: 256
19 | TRAIN:
20 | EPOCHS: 300
21 | BASE_LR: 5e-4
22 | MIN_LR: 5e-6
23 | WARMUP_LR: 5e-7
24 | AUG:
25 | MIXUP: 0.8
26 | CUTMIX: 1.0
27 |
--------------------------------------------------------------------------------
/configs/aff_tiny.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_tiny_1_4th
4 | DROP_PATH_RATE: 0.2
5 | AFF:
6 | DEPTHS: [3,4,18,5]
7 | NUM_HEADS: [2,4,8,16]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [64,128,256,512]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | ALPHA: 4.0
13 | DS_RATE: 0.25
14 | DATA:
15 | DATASET: imagenet
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 256
18 | TRAIN:
19 | EPOCHS: 300
20 | BASE_LR: 5e-4
21 | MIN_LR: 5e-6
22 | WARMUP_LR: 5e-7
23 | AUG:
24 | MIXUP: 0.8
25 | CUTMIX: 1.0
26 |
--------------------------------------------------------------------------------
/configs/aff_tiny_1_5th.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: aff
3 | NAME: aff_tiny_1_5th
4 | DROP_PATH_RATE: 0.2
5 | AFF:
6 | DEPTHS: [3,4,18,5]
7 | NUM_HEADS: [2,4,8,16]
8 | MLP_RATIO: 3.
9 | EMBED_DIM: [64,128,256,512]
10 | CLUSTER_SIZE: 8
11 | NBHD_SIZE: [48,48,48,49]
12 | ALPHA: 4.0
13 | DS_RATE: 0.2
14 | DATA:
15 | DATASET: imagenet
16 | IMG_SIZE: 224
17 | BATCH_SIZE: 256
18 | TRAIN:
19 | EPOCHS: 300
20 | BASE_LR: 5e-4
21 | MIN_LR: 5e-6
22 | WARMUP_LR: 5e-7
23 | AUG:
24 | MIXUP: 0.8
25 | CUTMIX: 1.0
26 |
--------------------------------------------------------------------------------
/create_env.sh:
--------------------------------------------------------------------------------
1 | # Create a conda virtual environment and activate it
2 | conda create -n aff python=3.8
3 | conda activate aff
4 |
5 | # Install requirements
6 | pip install \
7 | yacs==0.1.8 \
8 | termcolor==2.2.0 \
9 | timm==0.6.12 \
10 | pykeops==2.1.1 \
11 | ptflops==0.6.9 \
12 | numpy==1.22.4
13 | conda install -c conda-forge opencv
14 | conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.6 -c pytorch -c conda-forge
15 |
16 | # Install the custom CUDA kernels for AFF
17 | cd clusten/src/ && python setup.py install
18 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | from .build import build_loader
9 |
--------------------------------------------------------------------------------
/data/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Adapted for AutoFocusFormer by Ziwen 2023
8 |
9 | import os
10 | import torch
11 | import numpy as np
12 | import utils
13 | import torch.distributed as dist
14 | from torchvision import datasets, transforms
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from timm.data import Mixup
17 | from timm.data import create_transform
18 | from timm.data.transforms import _str_to_pil_interpolation
19 |
20 | from .samplers import SubsetRandomSampler
21 |
22 |
23 | def build_loader(config):
24 | config.defrost()
25 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
26 | config.freeze()
27 | print(f"local rank {config.LOCAL_RANK} / global rank {utils.get_rank()} successfully build train dataset")
28 | dataset_val, _ = build_dataset(is_train=False, config=config)
29 | print(f"local rank {config.LOCAL_RANK} / global rank {utils.get_rank()} successfully build val dataset")
30 |
31 | num_tasks = dist.get_world_size()
32 | global_rank = utils.get_rank()
33 | sampler_train = torch.utils.data.DistributedSampler(
34 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
35 | )
36 |
37 | indices = np.arange(utils.get_rank(), len(dataset_val), dist.get_world_size())
38 | sampler_val = SubsetRandomSampler(indices)
39 |
40 | data_loader_train = torch.utils.data.DataLoader(
41 | dataset_train, sampler=sampler_train,
42 | batch_size=config.DATA.BATCH_SIZE,
43 | num_workers=config.DATA.NUM_WORKERS,
44 | pin_memory=config.DATA.PIN_MEMORY,
45 | drop_last=True,
46 | )
47 |
48 | data_loader_val = torch.utils.data.DataLoader(
49 | dataset_val, sampler=sampler_val,
50 | batch_size=config.DATA.BATCH_SIZE,
51 | shuffle=False,
52 | num_workers=config.DATA.NUM_WORKERS,
53 | pin_memory=config.DATA.PIN_MEMORY,
54 | drop_last=False
55 | )
56 |
57 | # setup mixup / cutmix
58 | mixup_fn = None
59 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
60 | if mixup_active:
61 | mixup_fn = Mixup(
62 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
63 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
64 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
65 |
66 | return data_loader_train, data_loader_val, mixup_fn
67 |
68 |
69 | def build_dataset(is_train, config):
70 | transform = build_transform(is_train, config)
71 | if config.DATA.DATASET == 'imagenet':
72 | prefix = 'training' if is_train else 'validation'
73 | root = os.path.join(config.DATA.DATA_PATH, prefix)
74 | dataset = datasets.ImageFolder(root, transform=transform)
75 | nb_classes = 1000
76 | else:
77 | raise NotImplementedError("We only support ImageNet now.")
78 |
79 | return dataset, nb_classes
80 |
81 |
82 | def build_transform_imagenet(is_train, config):
83 | resize_im = config.DATA.IMG_SIZE > 32
84 | if is_train:
85 | # this should always dispatch to transforms_imagenet_train
86 | transform = create_transform(
87 | input_size=config.DATA.IMG_SIZE,
88 | is_training=True,
89 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
90 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
91 | re_prob=config.AUG.REPROB,
92 | re_mode=config.AUG.REMODE,
93 | re_count=config.AUG.RECOUNT,
94 | interpolation=config.DATA.INTERPOLATION,
95 | )
96 | if not resize_im:
97 | # replace RandomResizedCropAndInterpolation with
98 | # RandomCrop
99 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
100 | return transform
101 |
102 | t = []
103 | if resize_im:
104 | if config.TEST.CROP:
105 | size = int((256 / 224) * config.DATA.IMG_SIZE)
106 | t.append(
107 | transforms.Resize(size, interpolation=_str_to_pil_interpolation[config.DATA.INTERPOLATION]),
108 | # to maintain same ratio w.r.t. 224 images
109 | )
110 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
111 | else:
112 | t.append(
113 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
114 | interpolation=_str_to_pil_interpolation[config.DATA.INTERPOLATION])
115 | )
116 |
117 | t.append(transforms.ToTensor())
118 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
119 | return transforms.Compose(t)
120 |
121 |
122 | def build_transform(is_train, config):
123 | if config.DATA.DATASET == 'imagenet':
124 | return build_transform_imagenet(is_train, config)
125 | else:
126 | raise NotImplementedError("We only support ImageNet now.")
127 |
--------------------------------------------------------------------------------
/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 |
11 | class SubsetRandomSampler(torch.utils.data.Sampler):
12 | r"""Samples elements randomly from a given list of indices, without replacement.
13 |
14 | Arguments:
15 | indices (sequence): a sequence of indices
16 | """
17 |
18 | def __init__(self, indices):
19 | self.epoch = 0
20 | self.indices = indices
21 |
22 | def __iter__(self):
23 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
24 |
25 | def __len__(self):
26 | return len(self.indices)
27 |
28 | def set_epoch(self, epoch):
29 | self.epoch = epoch
30 |
--------------------------------------------------------------------------------
/demo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/demo1.png
--------------------------------------------------------------------------------
/demo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-autofocusformer/9a687eae0649685d998db854a02dad9ba6f8d120/demo2.png
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import sys
10 | import logging
11 | import functools
12 | from termcolor import colored
13 |
14 |
15 | @functools.lru_cache()
16 | def create_logger(output_dir, dist_rank=0, name=''):
17 | # create logger
18 | logger = logging.getLogger(name)
19 | logger.setLevel(logging.DEBUG)
20 | logger.propagate = False
21 |
22 | # create formatter
23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26 |
27 | # create console handlers for master process
28 | if dist_rank == 0:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setLevel(logging.DEBUG)
31 | console_handler.setFormatter(
32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33 | logger.addHandler(console_handler)
34 |
35 | # create file handlers
36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37 | file_handler.setLevel(logging.DEBUG)
38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39 | logger.addHandler(file_handler)
40 |
41 | return logger
42 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Adapted for AutoFocusFormer by Ziwen 2023
8 |
9 | import torch
10 | from timm.scheduler.cosine_lr import CosineLRScheduler
11 | from timm.scheduler.step_lr import StepLRScheduler
12 | from timm.scheduler.scheduler import Scheduler
13 |
14 |
15 | def build_scheduler(config, optimizer, n_iter_per_epoch):
16 | """
17 | Options for scheduler:
18 | cosine - set learning rating using a cosine annealing schedule, as proposed in
19 | "SGDR: Stochastic Gradient Descent with Warm Restarts"
20 | linear - decays the learning rate by a linearly changing factor until total number of steps is reached
21 | step - after every decay_steps, the learning rate is updated to be lr * decay_rate
22 | """
23 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
24 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
25 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
26 |
27 | num_epochs = config.TRAIN.EPOCHS
28 |
29 | lr_scheduler = None
30 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
31 | lr_scheduler = CosineLRScheduler(
32 | optimizer,
33 | t_initial=num_steps,
34 | # t_mul=1.,
35 | lr_min=config.TRAIN.MIN_LR,
36 | warmup_lr_init=config.TRAIN.WARMUP_LR,
37 | warmup_t=warmup_steps,
38 | cycle_limit=1,
39 | t_in_epochs=False,
40 | )
41 | cycle_length = lr_scheduler.get_cycle_length() // n_iter_per_epoch
42 | num_epochs = cycle_length + config.TRAIN.COOLDOWN_EPOCHS
43 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
44 | lr_scheduler = LinearLRScheduler(
45 | optimizer,
46 | t_initial=num_steps,
47 | lr_min_rate=0.01,
48 | warmup_lr_init=config.TRAIN.WARMUP_LR,
49 | warmup_t=warmup_steps,
50 | t_in_epochs=False,
51 | )
52 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
53 | lr_scheduler = StepLRScheduler(
54 | optimizer,
55 | decay_t=decay_steps,
56 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
57 | warmup_lr_init=config.TRAIN.WARMUP_LR,
58 | warmup_t=warmup_steps,
59 | t_in_epochs=False,
60 | )
61 |
62 | return lr_scheduler
63 |
64 |
65 | class LinearLRScheduler(Scheduler):
66 | def __init__(self,
67 | optimizer: torch.optim.Optimizer,
68 | t_initial: int,
69 | lr_min_rate: float,
70 | warmup_t=0,
71 | warmup_lr_init=0.,
72 | t_in_epochs=True,
73 | noise_range_t=None,
74 | noise_pct=0.67,
75 | noise_std=1.0,
76 | noise_seed=42,
77 | initialize=True,
78 | ) -> None:
79 | super().__init__(
80 | optimizer, param_group_field="lr",
81 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
82 | initialize=initialize)
83 |
84 | self.t_initial = t_initial
85 | self.lr_min_rate = lr_min_rate
86 | self.warmup_t = warmup_t
87 | self.warmup_lr_init = warmup_lr_init
88 | self.t_in_epochs = t_in_epochs
89 | if self.warmup_t:
90 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
91 | super().update_groups(self.warmup_lr_init)
92 | else:
93 | self.warmup_steps = [1 for _ in self.base_values]
94 |
95 | def _get_lr(self, t):
96 | if t < self.warmup_t:
97 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
98 | else:
99 | t = t - self.warmup_t
100 | total_t = self.t_initial - self.warmup_t
101 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
102 | return lrs
103 |
104 | def get_epoch_values(self, epoch: int):
105 | if self.t_in_epochs:
106 | return self._get_lr(epoch)
107 | else:
108 | return None
109 |
110 | def get_update_values(self, num_updates: int):
111 | if not self.t_in_epochs:
112 | return self._get_lr(num_updates)
113 | else:
114 | return None
115 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Adapted for AutoFocusFormer by Ziwen 2023
8 |
9 | import os
10 | import time
11 | import argparse
12 | import datetime
13 | import numpy as np
14 | import random
15 | import copy
16 |
17 | import torch
18 | import torch.backends.cudnn as cudnn
19 |
20 | from timm.loss import SoftTargetCrossEntropy
21 | from timm.utils import accuracy, AverageMeter, ModelEmaV2
22 |
23 | from config import get_config
24 | from models import build_model
25 | from data import build_loader
26 | from lr_scheduler import build_scheduler
27 | from optimizer import build_optimizer
28 | from logger import create_logger
29 | from utils import load_checkpoint, save_checkpoint, auto_resume_helper, reduce_tensor, get_rank, init_distributed_mode, get_local_rank, get_world_size, NativeScalerWithGradNormCount
30 |
31 | torch.backends.cuda.matmul.allow_tf32 = True
32 |
33 | os.environ['TORCH_DISTRIBUTED_DEBUG'] = "INFO"
34 |
35 |
36 | def parse_option():
37 | parser = argparse.ArgumentParser('AutoFocusFormer training and evaluation script', add_help=True)
38 | parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
39 | parser.add_argument(
40 | "--opts",
41 | help="Modify config options by adding 'KEY VALUE' pairs. ",
42 | default=None,
43 | nargs='+',
44 | )
45 |
46 | # easy config modification
47 | parser.add_argument('--batch-size', type=int, help="batch size per GPU")
48 | parser.add_argument('--epochs', type=int, help="number of epochs")
49 | parser.add_argument('--blr', type=float, help='base learning rate: absolute_lr = base_lr * total_batch_size / 512')
50 | parser.add_argument('--data-path', type=str, help='path to dataset')
51 | parser.add_argument('--resume', help='resume from checkpoint')
52 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
53 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
54 | help='root of output folder, the full path is