├── .gitattributes ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── ConvMixer.yaml ├── WFWS-ConvMixer.yaml └── WS-ConvMixer.yaml ├── models ├── convmixer_helper.py ├── fuse_weights.py ├── image_model_builder.py └── weight_sharing_helper.py ├── pretrained ├── ConvMixer_768_32_14_3-Stripped.pyth ├── WF-Mean-WS-ConvMixer_768_32_14_3-Fused-Stripped.pyth └── WS-ConvMixer_768_32_14_3-Stripped.pyth ├── run.sh ├── setup.sh └── spin-slowfast.patch /.gitattributes: -------------------------------------------------------------------------------- 1 | pretrained/*.pyth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this SPIN Software may utilize the following copyrighted material, the use of which is hereby acknowledged. 3 | 4 | Facebook, Inc. and its affiliates (facebookresearch/SlowFast) 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright 2019, Facebook, Inc 194 | 195 | Licensed under the Apache License, Version 2.0 (the "License"); 196 | you may not use this file except in compliance with the License. 197 | You may obtain a copy of the License at 198 | 199 | http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | Unless required by applicable law or agreed to in writing, software 202 | distributed under the License is distributed on an "AS IS" BASIS, 203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | See the License for the specific language governing permissions and 205 | limitations under the License. 206 | 207 | Adam Paszke and Soumith Chintala and Ronan Collobert and Koray Kavukcuoglu and Clement Farabet and Ronan Collobert and Leon Bottou and Iain Melvin and Jason Weston and Samy Bengio and Johnny Mariethoz (PyTorch) 208 | From PyTorch: 209 | 210 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 211 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 212 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 213 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 214 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 215 | Copyright (c) 2011-2013 NYU (Clement Farabet) 216 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 217 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 218 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 219 | 220 | From Caffe2: 221 | 222 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 223 | 224 | All contributions by Facebook: 225 | Copyright (c) 2016 Facebook Inc. 226 | 227 | All contributions by Google: 228 | Copyright (c) 2015 Google Inc. 229 | All rights reserved. 230 | 231 | All contributions by Yangqing Jia: 232 | Copyright (c) 2015 Yangqing Jia 233 | All rights reserved. 234 | 235 | All contributions by Kakao Brain: 236 | Copyright 2019-2020 Kakao Brain 237 | 238 | All contributions by Cruise LLC: 239 | Copyright (c) 2022 Cruise LLC. 240 | All rights reserved. 241 | 242 | All contributions from Caffe: 243 | Copyright(c) 2013, 2014, 2015, the respective contributors 244 | All rights reserved. 245 | 246 | All other contributions: 247 | Copyright(c) 2015, 2016 the respective contributors 248 | All rights reserved. 249 | 250 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 251 | copyright over their contributions to Caffe2. The project versioning records 252 | all such contribution and copyright details. If a contributor wants to further 253 | mark their specific copyright on a particular contribution, they should 254 | indicate their copyright solely in the commit message of the change when it is 255 | committed. 256 | 257 | All rights reserved. 258 | 259 | Redistribution and use in source and binary forms, with or without 260 | modification, are permitted provided that the following conditions are met: 261 | 262 | 1. Redistributions of source code must retain the above copyright 263 | notice, this list of conditions and the following disclaimer. 264 | 265 | 2. Redistributions in binary form must reproduce the above copyright 266 | notice, this list of conditions and the following disclaimer in the 267 | documentation and/or other materials provided with the distribution. 268 | 269 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 270 | and IDIAP Research Institute nor the names of its contributors may be 271 | used to endorse or promote products derived from this software without 272 | specific prior written permission. 273 | 274 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 275 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 276 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 277 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 278 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 279 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 280 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 281 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 282 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 283 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 284 | POSSIBILITY OF SUCH DAMAGE. 285 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH SlowFast: 43 | 44 | The SlowFast software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | 48 | ------------------------------------------------------------------------------- 49 | SOFTWARE DISTRIBUTED WITH PyTorch: 50 | 51 | The PyTorch software includes a number of subcomponents with separate 52 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 53 | ------------------------------------------------------------------------------- 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPIN 2 | This repository contains the official implementation for the ECCV'22 paper, ["SPIN: An Empirical Evaluation on Sharing Parameters of Isotropic Networks"](https://arxiv.org/abs/2207.10237). 3 | 4 | ## Code Overview 5 | We provide the implementation of weight sharing version of the [ConvMixer](https://openreview.net/pdf?id=TVHS5Y4dNvM) model. The main code for the implementation are in the `models` directory. The model can be configured by the files in `configs`. We provide three example configs. 6 | * `configs/ConvMixer.yaml` for vanilla ConvMixer model. 7 | * `configs/WS-ConvMixer.yaml` for Weight-shared ConvMixer (WS-ConvMixer) model. 8 | * `configs/WFWS-ConvMixer.yaml` for Weight-fusion Weight-shared ConvMixer (WFWS-ConvMixer) model. 9 | 10 | Note that in order to run the model `configs/WF-WSConvMixer.yaml`, you must have a corresponding pretrained ConvMixer model. Please refer to our paper for each technique. 11 | 12 | ## Installation 13 | First, clone this repo with 14 | ``` 15 | git clone https://github.com/apple/ml-spin.git 16 | ``` 17 | The implementation of SPIN reuses the infrastructure of Meta Research's open source project [SlowFast](https://github.com/facebookresearch/SlowFast). Our modification to the SlowFast code is stored in the `spin-slowfast.patch`. To download the SlowFast code and apply our changes, run 18 | ``` 19 | bash setup.sh 20 | ``` 21 | After getting the codebase ready, follow this [link](https://github.com/facebookresearch/SlowFast/blob/main/INSTALL.md) from SlowFast repo to setup your environment and install other dependencies. 22 | 23 | ## Training 24 | After the environment is set up, you can run the following example training script to train a weight sharing ConvMixer model. The script assumes you have a machine with 4-GPUs. 25 | ``` 26 | bash run.sh 27 | ``` 28 | ### Pre-trained ConvMixer Models on ImageNet1K 29 | We provide our pretrained models of ConvMixer, WS-ConvMixer and WFWS-ConvMixer in the following table. For the WFWS-ConvMixer, we first initialized the model using the proposed weight fusion technique with mean strategy, and then run the `models/fuse_weights.py` to export the fused model after training. In order to re-run the model, please use the WS-ConvMixer configuration. Please note we did a light hyperparameter tunning so the accuracy is slightly higher than the numbers reported in the paper. 30 | | C/D/P/K | Weight Sharing? | Weight Fusion? | Sharing Rate | Share Distribution | Sharing Mapping | Accuracy | Model Size | 31 | | ------- | --------------- | -------------- | ------------ | ------------- | -------------------- | -------- | ---------- | 32 | | 768/32/14/3 | No | No | - | - | - | 76.32% | [79MB](pretrained/ConvMixer_768_32_14_3-Stripped.pyth) 33 | | 768/32/14/3 | Yes | No | 2 | Uniform | Sequential | 74.27% | [43MB](pretrained/WS-ConvMixer_768_32_14_3-Stripped.pyth) | 34 | | 768/32/14/3 | Yes | Mean | 2 | Uniform | Sequential | 75.21% | [43MB](pretrained/WF-Mean-WS-ConvMixer_768_32_14_3-Fused-Stripped.pyth) | 35 | 36 | ## Citation 37 | If you find our code or paper helps, please consider citing: 38 | ``` 39 | @article{spin_eccv22, 40 | author = {Lin, Chien-Yu and Prabhu, Anish and Merth, Thomas and Mehta, Sachin and Ranjan, Anurag and Horton, Maxwell and Rastegari, Mohammad} 41 | title = {SPIN: An Empirical Evaluation on Sharing Parameters of Isotropic Networks}, 42 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, 43 | year = {2022} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /configs/ConvMixer.yaml: -------------------------------------------------------------------------------- 1 | # Configs for regular ConvMixer 2 | TRAIN: 3 | ENABLE: True 4 | DATASET: imagenet 5 | BATCH_SIZE: 64 6 | EVAL_PERIOD: 10 7 | CHECKPOINT_PERIOD: 10 8 | AUTO_RESUME: True 9 | MIXED_PRECISION: True 10 | DATA: 11 | PATH_TO_DATA_DIR: 12 | MEAN: [0.485, 0.456, 0.406] 13 | STD: [0.229, 0.224, 0.225] 14 | NUM_FRAMES: 1 15 | TRAIN_CROP_SIZE: 224 16 | TEST_CROP_SIZE: 224 17 | INPUT_CHANNEL_NUM: [3] 18 | TRAIN_JITTER_SCALES_RELATIVE: [0.08, 1.0] 19 | TRAIN_JITTER_ASPECT_RELATIVE: [0.75, 1.3333333333333333] 20 | CONVMIXER: 21 | CHANNEL: 768 22 | PATCH_KERNEL: 14 23 | PATCH_STRIDE: 14 24 | PATCH_PADDING: 0 25 | KERNEL: 3 26 | DEPTH: 32 27 | ACT_FUNC: RELU 28 | AUG: 29 | ENABLE: True 30 | COLOR_JITTER: 0.4 31 | AA_TYPE: rand-m9-mstd0.5-inc1 32 | INTERPOLATION: bicubic 33 | RE_PROB: 0.25 34 | RE_MODE: pixel 35 | RE_COUNT: 1 36 | RE_SPLIT: False 37 | MIXUP: 38 | ENABLE: True 39 | ALPHA: 0.5 40 | CUTMIX_ALPHA: 0.5 41 | PROB: 1.0 42 | SWITCH_PROB: 0.5 43 | LABEL_SMOOTH_VALUE: 0.1 44 | SOLVER: 45 | BASE_LR: 0.01 46 | BASE_LR_SCALE_NUM_SHARDS: False 47 | LR_POLICY: onecycle 48 | MAX_EPOCH: 300 49 | OPTIMIZING_METHOD: adamw 50 | WEIGHT_DECAY: 2e-5 51 | MOMENTUM: 0.9 52 | WARMUP_EPOCHS: 0.0 53 | CLIP_GRAD_L2NORM: 1.0 54 | MODEL: 55 | NUM_CLASSES: 1000 56 | ARCH: 2d 57 | MODEL_NAME: WSConvMixer 58 | LOSS_FUNC: soft_cross_entropy 59 | DROPOUT_RATE: 0.5 60 | TEST: 61 | ENABLE: False 62 | DATASET: imagenet 63 | BATCH_SIZE: 256 64 | CHECKPOINT_FILE_PATH: 65 | DATA_LOADER: 66 | NUM_WORKERS: 1 67 | PIN_MEMORY: True 68 | NUM_GPUS: 4 69 | NUM_SHARDS: 1 70 | RNG_SEED: 0 71 | OUTPUT_DIR: ./ 72 | -------------------------------------------------------------------------------- /configs/WFWS-ConvMixer.yaml: -------------------------------------------------------------------------------- 1 | # Configs for Weight Fusion Weight Sharing ConvMixer 2 | TRAIN: 3 | ENABLE: True 4 | DATASET: imagenet 5 | BATCH_SIZE: 64 6 | EVAL_PERIOD: 10 7 | CHECKPOINT_PERIOD: 10 8 | CHECKPOINT_EPOCH_RESET: True 9 | CHECKPOINT_FILE_PATH: 10 | AUTO_RESUME: True 11 | MIXED_PRECISION: True 12 | DATA: 13 | PATH_TO_DATA_DIR: 14 | MEAN: [0.485, 0.456, 0.406] 15 | STD: [0.229, 0.224, 0.225] 16 | NUM_FRAMES: 1 17 | TRAIN_CROP_SIZE: 224 18 | TEST_CROP_SIZE: 224 19 | INPUT_CHANNEL_NUM: [3] 20 | TRAIN_JITTER_SCALES_RELATIVE: [0.08, 1.0] 21 | TRAIN_JITTER_ASPECT_RELATIVE: [0.75, 1.3333333333333333] 22 | CONVMIXER: 23 | CHANNEL: 768 24 | PATCH_KERNEL: 14 25 | PATCH_STRIDE: 14 26 | PATCH_PADDING: 0 27 | KERNEL: 3 28 | DEPTH: 32 29 | ACT_FUNC: RELU 30 | WEIGHT_SHARE: 31 | ENABLE: True 32 | SHARING_DISTRIBUTION: uniform 33 | SHARING_MAPPING: sequential 34 | SHARE_RATE: 2 35 | REDUCTION_FN: mean 36 | AUG: 37 | ENABLE: True 38 | COLOR_JITTER: 0.4 39 | AA_TYPE: rand-m9-mstd0.5-inc1 40 | INTERPOLATION: bicubic 41 | RE_PROB: 0.25 42 | RE_MODE: pixel 43 | RE_COUNT: 1 44 | RE_SPLIT: False 45 | MIXUP: 46 | ENABLE: True 47 | ALPHA: 0.5 48 | CUTMIX_ALPHA: 0.5 49 | PROB: 1.0 50 | SWITCH_PROB: 0.5 51 | LABEL_SMOOTH_VALUE: 0.1 52 | SOLVER: 53 | BASE_LR: 0.01 54 | BASE_LR_SCALE_NUM_SHARDS: False 55 | LR_POLICY: onecycle 56 | MAX_EPOCH: 300 57 | OPTIMIZING_METHOD: adamw 58 | WEIGHT_DECAY: 2e-5 59 | MOMENTUM: 0.9 60 | WARMUP_EPOCHS: 0.0 61 | CLIP_GRAD_L2NORM: 1.0 62 | MODEL: 63 | NUM_CLASSES: 1000 64 | ARCH: 2d 65 | MODEL_NAME: WSConvMixer 66 | LOSS_FUNC: soft_cross_entropy 67 | DROPOUT_RATE: 0.5 68 | TEST: 69 | ENABLE: False 70 | DATASET: imagenet 71 | BATCH_SIZE: 256 72 | CHECKPOINT_FILE_PATH: 73 | DATA_LOADER: 74 | NUM_WORKERS: 1 75 | PIN_MEMORY: True 76 | NUM_GPUS: 4 77 | NUM_SHARDS: 1 78 | RNG_SEED: 0 79 | OUTPUT_DIR: ./ 80 | -------------------------------------------------------------------------------- /configs/WS-ConvMixer.yaml: -------------------------------------------------------------------------------- 1 | # Configs for Weight Sharing ConvMixer 2 | TRAIN: 3 | ENABLE: True 4 | DATASET: imagenet 5 | BATCH_SIZE: 64 6 | EVAL_PERIOD: 10 7 | CHECKPOINT_PERIOD: 10 8 | AUTO_RESUME: True 9 | MIXED_PRECISION: True 10 | DATA: 11 | PATH_TO_DATA_DIR: 12 | MEAN: [0.485, 0.456, 0.406] 13 | STD: [0.229, 0.224, 0.225] 14 | NUM_FRAMES: 1 15 | TRAIN_CROP_SIZE: 224 16 | TEST_CROP_SIZE: 224 17 | INPUT_CHANNEL_NUM: [3] 18 | TRAIN_JITTER_SCALES_RELATIVE: [0.08, 1.0] 19 | TRAIN_JITTER_ASPECT_RELATIVE: [0.75, 1.3333333333333333] 20 | CONVMIXER: 21 | CHANNEL: 768 22 | PATCH_KERNEL: 14 23 | PATCH_STRIDE: 14 24 | PATCH_PADDING: 0 25 | KERNEL: 3 26 | DEPTH: 32 27 | ACT_FUNC: RELU 28 | WEIGHT_SHARE: 29 | ENABLE: True 30 | SHARING_DISTRIBUTION: uniform 31 | SHARING_MAPPING: sequential 32 | SHARE_RATE: 2 33 | AUG: 34 | ENABLE: True 35 | COLOR_JITTER: 0.4 36 | AA_TYPE: rand-m9-mstd0.5-inc1 37 | INTERPOLATION: bicubic 38 | RE_PROB: 0.25 39 | RE_MODE: pixel 40 | RE_COUNT: 1 41 | RE_SPLIT: False 42 | MIXUP: 43 | ENABLE: True 44 | ALPHA: 0.5 45 | CUTMIX_ALPHA: 0.5 46 | PROB: 1.0 47 | SWITCH_PROB: 0.5 48 | LABEL_SMOOTH_VALUE: 0.1 49 | SOLVER: 50 | BASE_LR: 0.01 51 | BASE_LR_SCALE_NUM_SHARDS: False 52 | LR_POLICY: onecycle 53 | MAX_EPOCH: 300 54 | OPTIMIZING_METHOD: adamw 55 | WEIGHT_DECAY: 2e-5 56 | MOMENTUM: 0.9 57 | WARMUP_EPOCHS: 0.0 58 | CLIP_GRAD_L2NORM: 1.0 59 | MODEL: 60 | NUM_CLASSES: 1000 61 | ARCH: 2d 62 | MODEL_NAME: WSConvMixer 63 | LOSS_FUNC: soft_cross_entropy 64 | DROPOUT_RATE: 0.5 65 | TEST: 66 | ENABLE: False 67 | DATASET: imagenet 68 | BATCH_SIZE: 256 69 | CHECKPOINT_FILE_PATH: 70 | DATA_LOADER: 71 | NUM_WORKERS: 1 72 | PIN_MEMORY: True 73 | NUM_GPUS: 4 74 | NUM_SHARDS: 1 75 | RNG_SEED: 0 76 | OUTPUT_DIR: ./ 77 | -------------------------------------------------------------------------------- /models/convmixer_helper.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConvMixerPatchEmbed(nn.Module): 11 | """ConvMixer 2D PatchEmbeding.""" 12 | 13 | def __init__( 14 | self, dim_in=3, dim_out=768, kernel=7, stride=7, padding=0, activation=nn.GELU 15 | ): 16 | super().__init__() 17 | self.conv = nn.Conv2d( 18 | dim_in, dim_out, kernel_size=kernel, stride=stride, padding=padding 19 | ) 20 | self.act = activation() 21 | self.bn = nn.BatchNorm2d(dim_out) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | x = self.act(x) 26 | x = self.bn(x) 27 | 28 | return x 29 | 30 | 31 | class ConvMixerHead(nn.Module): 32 | """ConvMixer Head""" 33 | 34 | def __init__(self, dim=768, dropout_rate=0.0, classes=400): 35 | super().__init__() 36 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 37 | if dropout_rate > 0.0: 38 | self.dropout = nn.Dropout(dropout_rate) 39 | self.fc = nn.Linear(dim, classes, bias=True) 40 | 41 | def forward(self, x): 42 | x = self.pool(x) 43 | x = x.flatten(1) 44 | if hasattr(self, "dropout"): 45 | x = self.dropout(x) 46 | x = self.fc(x) 47 | 48 | return x 49 | 50 | 51 | class WSConvMixerBlock(nn.Module): 52 | """Main block of WSConvMixer""" 53 | 54 | def __init__(self, dim=768, kernel_size=7, padding=1, activation=nn.GELU): 55 | super().__init__() 56 | self.dwise_blk = nn.Sequential( 57 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding=padding), 58 | activation(), 59 | nn.BatchNorm2d(dim), 60 | ) 61 | self.act2 = activation() 62 | self.bn2 = nn.BatchNorm2d(dim) 63 | 64 | def pwise_blk(self, x, pwise_w, pwise_b): 65 | x = F.conv2d(x, pwise_w, pwise_b) 66 | x = self.act2(x) 67 | x = self.bn2(x) 68 | return x 69 | 70 | def forward(self, x, pwise_w, pwise_b): 71 | x = x + self.dwise_blk(x) 72 | x = x + self.pwise_blk(x, pwise_w, pwise_b) 73 | return x 74 | -------------------------------------------------------------------------------- /models/fuse_weights.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import slowfast.utils.checkpoint as cu 7 | import torch 8 | import torch.nn as nn 9 | from slowfast.config.defaults import assert_and_infer_cfg 10 | from slowfast.models import build_model 11 | from slowfast.utils.misc import launch_job 12 | from slowfast.utils.parser import load_config, parse_args 13 | 14 | 15 | def fuse_weights(cfg): 16 | model = build_model(cfg) 17 | checkpoint = torch.load(cfg.TRAIN.CHECKPOINT_FILE_PATH) 18 | model.load_state_dict(checkpoint["model_state"]) 19 | 20 | fused_pwise_w, fused_pwise_b = model.weight_fusion() 21 | num_fused_lay = len(fused_pwise_w) 22 | 23 | new_pwise_w = nn.ParameterList( 24 | [nn.Parameter(data=fused_pwise_w[i]) for i in range(num_fused_lay)] 25 | ) 26 | new_pwise_b = nn.ParameterList( 27 | [nn.Parameter(data=fused_pwise_b[i]) for i in range(num_fused_lay)] 28 | ) 29 | 30 | model.pwise_w = new_pwise_w 31 | model.pwise_b = new_pwise_b 32 | model.reduction_fn = None 33 | 34 | cu.save_checkpoint("./pretrained", model, None, 0, cfg, None) 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | cfg = load_config(args) 40 | cfg = assert_and_infer_cfg(cfg) 41 | 42 | launch_job(cfg=cfg, init_method=args.init_method, func=fuse_weights) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /models/image_model_builder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import slowfast.utils.weight_init_helper as init_helper 7 | import torch 8 | import torch.nn as nn 9 | from convmixer_helper import ConvMixerHead, ConvMixerPatchEmbed, WSConvMixerBlock 10 | from slowfast.models.build import MODEL_REGISTRY 11 | from weight_sharing_helper import WeightSharingTopology 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class WSConvMixer(nn.Module): 16 | """The main class of the weight sharing ConvMixer 17 | 18 | We support non-weight-sharing, weight-sharing and weight-sharing with pretrained weights 19 | all in this single module. The mode is configured by the cfg file feed to __init__(). 20 | """ 21 | 22 | def __init__(self, cfg): 23 | """The `__init__` method of the weight sharing ConvMixer 24 | 25 | Args: 26 | cfg (CfgNode): model building configs, details are in the comments of the config file. 27 | """ 28 | super().__init__() 29 | channel = cfg.CONVMIXER.CHANNEL 30 | patch_kernel = cfg.CONVMIXER.PATCH_KERNEL 31 | patch_stride = cfg.CONVMIXER.PATCH_STRIDE 32 | patch_pad = cfg.CONVMIXER.PATCH_PADDING 33 | kernel = cfg.CONVMIXER.KERNEL 34 | act_func = cfg.CONVMIXER.ACT_FUNC 35 | classes = cfg.MODEL.NUM_CLASSES 36 | pad = self._cal_pad(kernel) 37 | self.depth = cfg.CONVMIXER.DEPTH 38 | 39 | # weight sharing module related parameters 40 | self.share_weight = cfg.CONVMIXER.WEIGHT_SHARE.ENABLE 41 | self.share_rate = cfg.CONVMIXER.WEIGHT_SHARE.SHARE_RATE 42 | self.num_share = int(self.depth // self.share_rate) 43 | self.share_dist = cfg.CONVMIXER.WEIGHT_SHARE.SHARING_DISTRIBUTION 44 | self.share_map = cfg.CONVMIXER.WEIGHT_SHARE.SHARING_MAPPING 45 | self.reduction_fn = cfg.CONVMIXER.WEIGHT_SHARE.REDUCTION_FN 46 | 47 | if act_func == "GELU": 48 | self.activation = nn.GELU 49 | elif act_func == "RELU": 50 | self.activation = nn.ReLU 51 | else: 52 | raise NotImplementedError( 53 | "{} is not supported as an activation function".format(act) 54 | ) 55 | 56 | _SUPPORTED_REDUCTION_FNS = [ 57 | "mean", 58 | "choose_first", 59 | "scalar_weighted_mean", 60 | "channel_weighted_mean", 61 | ] 62 | if self.reduction_fn not in _SUPPORTED_REDUCTION_FNS: 63 | raise NotImplementedError( 64 | "Only [mean, choose_first, channel_weighted_mean, scalar_weighted_mean] weight fusion strategies are supported." 65 | ) 66 | 67 | if self.share_weight and self.reduction_fn == None: 68 | num_pwise_w = self.num_share 69 | else: 70 | num_pwise_w = self.depth 71 | 72 | self.pwise_w = nn.ParameterList( 73 | [ 74 | nn.Parameter(torch.zeros(channel, channel, 1, 1)) 75 | for _ in range(num_pwise_w) 76 | ] 77 | ) 78 | self.pwise_b = nn.ParameterList( 79 | [nn.Parameter(torch.zeros(channel)) for _ in range(num_pwise_w)] 80 | ) 81 | 82 | if self.reduction_fn == "scalar_weighted_mean": 83 | self.reduction_w = nn.ParameterList( 84 | [nn.Parameter(data=torch.ones(1)) for _ in range(self.depth)] 85 | ) 86 | elif self.reduction_fn == "channel_weighted_mean": 87 | self.reduction_w = nn.ParameterList( 88 | [nn.Parameter(data=torch.ones(channel)) for _ in range(self.depth)] 89 | ) 90 | 91 | # build sharing mapping 92 | self.pwise_map = self.build_layer_mapping() 93 | 94 | self.patch_embed = ConvMixerPatchEmbed( 95 | dim_in=3, 96 | dim_out=channel, 97 | kernel=patch_kernel, 98 | stride=patch_stride, 99 | activation=self.activation, 100 | padding=patch_pad, 101 | ) 102 | 103 | self.blocks = nn.ModuleList() 104 | for _ in range(self.depth): 105 | self.blocks.append( 106 | WSConvMixerBlock( 107 | dim=channel, 108 | kernel_size=kernel, 109 | padding=pad, 110 | activation=self.activation, 111 | ) 112 | ) 113 | 114 | self.head = ConvMixerHead( 115 | dim=channel, dropout_rate=cfg.MODEL.DROPOUT_RATE, classes=classes 116 | ) 117 | 118 | # init weights 119 | for w in self.pwise_w: 120 | nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu") 121 | init_helper.init_weights(self) 122 | 123 | def _cal_pad(self, kernel, stride=1): 124 | """Helper function to calculate padding size.""" 125 | if kernel > stride: 126 | raise ValueError("size of kernel should be larger than stride") 127 | pad = int((kernel - stride) / 2) 128 | return pad 129 | 130 | def build_layer_mapping(self): 131 | """Function to define the sharing mapping.""" 132 | if not self.share_weight: 133 | mapping = [i for i in range(self.depth)] 134 | else: 135 | mapping = getattr( 136 | WeightSharingTopology, f"{self.share_dist}_{self.share_map}" 137 | )(self.share_rate, self.num_share) 138 | assert self.depth == len(mapping), "Length of Mapping doesn't match model depth" 139 | return mapping 140 | 141 | def weight_fusion(self): 142 | """Generate weights from pretrained weights during forward.""" 143 | w = self.pwise_w 144 | b = self.pwise_b 145 | 146 | reduced_w = [] 147 | for i in range(self.num_share): 148 | if self.reduction_fn == "mean": 149 | p = torch.mean( 150 | torch.stack( 151 | [w[i * self.share_rate + j] for j in range(self.share_rate)] 152 | ), 153 | dim=0, 154 | ) 155 | elif self.reduction_fn == "choose_first": 156 | p = w[i * self.share_rate] 157 | elif ( 158 | self.reduction_fn == "channel_weighted_mean" 159 | or self.reduction_fn == "scalar_weighted_mean" 160 | ): 161 | tf_p = [ 162 | self.reduction_w[i * self.share_rate + j] 163 | for j in range(self.share_rate) 164 | ] 165 | p = [w[i * self.share_rate + j] for j in range(self.share_rate)] 166 | p = [t.view(-1, 1, 1, 1) * x for t, x in zip(tf_p, p)] 167 | p = torch.mean(torch.stack(p), dim=0) 168 | reduced_w.append(p) 169 | 170 | reduced_b = [] 171 | for i in range(self.num_share): 172 | if self.reduction_fn == "choose_first": 173 | p = b[i * self.share_rate] 174 | elif ( 175 | self.reduction_fn == "mean" 176 | or self.reduction_fn == "scalar_weighted_mean" 177 | or self.reduction_fn == "channel_weighted_mean" 178 | ): 179 | p = torch.mean( 180 | torch.stack( 181 | [b[i * self.share_rate + j] for j in range(self.share_rate)] 182 | ), 183 | dim=0, 184 | ) 185 | reduced_b.append(p) 186 | 187 | return reduced_w, reduced_b 188 | 189 | def forward(self, x): 190 | """The forward method of WSConvMixer module.""" 191 | x = x[0] 192 | x = self.patch_embed(x) 193 | 194 | if self.reduction_fn is not None: 195 | w, b = self.weight_fusion() 196 | else: 197 | w, b = self.pwise_w, self.pwise_b 198 | 199 | for i in range(self.depth): 200 | ptr = self.pwise_map[i] 201 | x = self.blocks[i](x, w[ptr], b[ptr]) 202 | 203 | x = self.head(x) 204 | return x 205 | -------------------------------------------------------------------------------- /models/weight_sharing_helper.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import random 7 | 8 | 9 | # class of functions of different weight sharing mappings 10 | class WeightSharingTopology: 11 | """Class of predefined functions of different sharing structures.""" 12 | 13 | def uniform_sequential(share_rate, num_share): 14 | mapping = [] 15 | for i in range(num_share): 16 | for _ in range(share_rate): 17 | mapping.append(i) 18 | return mapping 19 | 20 | def uniform_strided(share_rate, num_share): 21 | mapping = [] 22 | for _ in range(share_rate): 23 | for i in range(num_share): 24 | mapping.append(i) 25 | return mapping 26 | 27 | def uniform_random(share_rate, num_share): 28 | mapping = WeightSharingMapping.scatter_seq(share_rate, num_share) 29 | random.shuffle(mapping) 30 | return mapping 31 | -------------------------------------------------------------------------------- /pretrained/ConvMixer_768_32_14_3-Stripped.pyth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3cea4e26b756452a47348fc86abbb6e268450e1be893d5d3a19758304837442c 3 | size 82432896 4 | -------------------------------------------------------------------------------- /pretrained/WF-Mean-WS-ConvMixer_768_32_14_3-Fused-Stripped.pyth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e6e766509ca6ef21ebeba50c3c7c6073faac928233feea806e2a115c486ebd84 3 | size 44624480 4 | -------------------------------------------------------------------------------- /pretrained/WS-ConvMixer_768_32_14_3-Stripped.pyth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b60730182330a23947bcec455170f8f047db79840c33f98176daf03ba3cebbe4 3 | size 44624480 4 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | PL_TORCH_DISTRIBUTED_BACKEND=gloo python SlowFast/tools/run_net.py \ 7 | --cfg configs/WS-ConvMixer.yaml \ 8 | DATA.PATH_TO_DATA_DIR /PATH/TO/IMAGENET/DATA/ \ 9 | NUM_GPUS 4 \ 10 | DATA_LOADER.NUM_WORKERS 4 \ 11 | TRAIN.BATCH_SIZE 256 \ 12 | TENSORBOARD.ENABLE True \ 13 | TRAIN.ENABLE True \ 14 | TEST.ENABLE False 15 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | git clone https://github.com/facebookresearch/SlowFast.git 7 | cd SlowFast 8 | git checkout 5bf7e7 9 | git am ../spin-slowfast.patch 10 | -------------------------------------------------------------------------------- /spin-slowfast.patch: -------------------------------------------------------------------------------- 1 | From e7da51cf65760421cc5d86b405dc2ec0d7e240f8 Mon Sep 17 00:00:00 2001 2 | From: Chien-Yu Lin 3 | Date: Tue, 14 Jun 2022 23:06:30 +0000 4 | Subject: [PATCH 1/2] update slowfast for spin 5 | 6 | --- 7 | setup.py | 2 +- 8 | slowfast/config/defaults.py | 46 ++++++++++++++++++++++++++++ 9 | slowfast/datasets/imagenet.py | 2 +- 10 | slowfast/models/__init__.py | 7 +++++ 11 | slowfast/models/build.py | 3 +- 12 | slowfast/utils/lr_policy.py | 10 ++++++ 13 | slowfast/utils/misc.py | 6 +++- 14 | slowfast/utils/weight_init_helper.py | 4 +-- 15 | tools/train_net.py | 4 +-- 16 | 9 files changed, 75 insertions(+), 9 deletions(-) 17 | 18 | diff --git a/setup.py b/setup.py 19 | index afeab49..9daae52 100644 20 | --- a/setup.py 21 | +++ b/setup.py 22 | @@ -23,7 +23,7 @@ setup( 23 | "opencv-python", 24 | "pandas", 25 | "torchvision>=0.4.2", 26 | - "PIL", 27 | + "Pillow", 28 | "sklearn", 29 | "tensorboard", 30 | ], 31 | diff --git a/slowfast/config/defaults.py b/slowfast/config/defaults.py 32 | index e20ef02..eead08c 100644 33 | --- a/slowfast/config/defaults.py 34 | +++ b/slowfast/config/defaults.py 35 | @@ -378,6 +378,52 @@ _C.MVIT.SEP_POS_EMBED = False 36 | _C.MVIT.DROPOUT_RATE = 0.0 37 | 38 | 39 | +# ----------------------------------------------------------------------------- 40 | +# ConvMixer options 41 | +# ----------------------------------------------------------------------------- 42 | +_C.CONVMIXER = CfgNode() 43 | + 44 | +# Number of weight layers. 45 | +_C.CONVMIXER.DEPTH = 32 46 | + 47 | +# Dim of the hidden features 48 | +_C.CONVMIXER.CHANNEL = 768 49 | + 50 | +# Size of the input patch kernel 51 | +_C.CONVMIXER.PATCH_KERNEL = 7 52 | + 53 | +# Size of the input patch kernel stride 54 | +_C.CONVMIXER.PATCH_STRIDE = 7 55 | + 56 | +# Size of the input patch kernel stride 57 | +_C.CONVMIXER.PATCH_PADDING = 0 58 | + 59 | +# Size of the conv kernel 60 | +_C.CONVMIXER.KERNEL = 3 61 | + 62 | +# Size of the conv kernel 63 | +_C.CONVMIXER.ACT_FUNC = "RELU" 64 | + 65 | +# Create a new node for configure WeightSharedConvMixer block 66 | +_C.CONVMIXER.WEIGHT_SHARE = CfgNode() 67 | + 68 | +# Whether to use WeightSharedConvMixer 69 | +_C.CONVMIXER.WEIGHT_SHARE.ENABLE = False 70 | + 71 | +# The overall sharing structure. 72 | +_C.CONVMIXER.WEIGHT_SHARE.SHARING_DISTRIBUTION = "uniform" 73 | + 74 | +# The sharing mapping scheme. Choose from ["sequential", "strided", "random"] 75 | +_C.CONVMIXER.WEIGHT_SHARE.SHARING_MAPPING = "sequential" 76 | + 77 | +# Sharing rate 78 | +_C.CONVMIXER.WEIGHT_SHARE.SHARE_RATE = 1 79 | + 80 | +# Wegith fusion methods. Choose from ["mean", "choose_first", "scalar_weighted_mean", "channel_weighted_mean"] 81 | +# No weight fusion will be applied when it's set to None 82 | +_C.CONVMIXER.WEIGHT_SHARE.REDUCTION_FN = None 83 | + 84 | + 85 | # ----------------------------------------------------------------------------- 86 | # SlowFast options 87 | # ----------------------------------------------------------------------------- 88 | diff --git a/slowfast/datasets/imagenet.py b/slowfast/datasets/imagenet.py 89 | index af96403..ef6fb8d 100644 90 | --- a/slowfast/datasets/imagenet.py 91 | +++ b/slowfast/datasets/imagenet.py 92 | @@ -110,7 +110,7 @@ class Imagenet(torch.utils.data.Dataset): 93 | else: 94 | # For testing use scale and center crop 95 | im, _ = transform.uniform_crop( 96 | - im, test_size, spatial_idx=1, scale_size=train_size 97 | + im, test_size, spatial_idx=1, scale_size=test_size 98 | ) 99 | # For training and testing use color normalization 100 | im = transform.color_normalization( 101 | diff --git a/slowfast/models/__init__.py b/slowfast/models/__init__.py 102 | index ce97190..906b230 100644 103 | --- a/slowfast/models/__init__.py 104 | +++ b/slowfast/models/__init__.py 105 | @@ -5,6 +5,13 @@ from .build import MODEL_REGISTRY, build_model # noqa 106 | from .custom_video_model_builder import * # noqa 107 | from .video_model_builder import ResNet, SlowFast # noqa 108 | 109 | +# import our model from outside directory 110 | +import os, sys 111 | +dir_path = os.path.dirname(os.path.realpath(__file__)) 112 | +dir_path = os.path.join(dir_path, '..', '..', '..', 'models') 113 | +sys.path.insert(1, dir_path) 114 | +from image_model_builder import * 115 | + 116 | try: 117 | from .ptv_model_builder import ( 118 | PTVCSN, 119 | diff --git a/slowfast/models/build.py b/slowfast/models/build.py 120 | index a88eb51..a1424ff 100644 121 | --- a/slowfast/models/build.py 122 | +++ b/slowfast/models/build.py 123 | @@ -48,6 +48,7 @@ def build_model(cfg, gpu_id=None): 124 | if cfg.NUM_GPUS > 1: 125 | # Make model replica operate on the current device 126 | model = torch.nn.parallel.DistributedDataParallel( 127 | - module=model, device_ids=[cur_device], output_device=cur_device 128 | + module=model, device_ids=[cur_device], output_device=cur_device, 129 | + find_unused_parameters=True 130 | ) 131 | return model 132 | diff --git a/slowfast/utils/lr_policy.py b/slowfast/utils/lr_policy.py 133 | index 80cfdd4..1c4e1ac 100644 134 | --- a/slowfast/utils/lr_policy.py 135 | +++ b/slowfast/utils/lr_policy.py 136 | @@ -4,6 +4,7 @@ 137 | """Learning rate policy.""" 138 | 139 | import math 140 | +import numpy as np 141 | 142 | 143 | def get_lr_at_epoch(cfg, cur_epoch): 144 | @@ -27,6 +28,15 @@ def get_lr_at_epoch(cfg, cur_epoch): 145 | return lr 146 | 147 | 148 | +def lr_func_onecycle(cfg, cur_epoch): 149 | + t_initial = cfg.SOLVER.MAX_EPOCH 150 | + t = cur_epoch 151 | + lr_max = cfg.SOLVER.BASE_LR 152 | + lr = np.interp([t], [0, t_initial*2//5, t_initial*4//5, t_initial], 153 | + [0, lr_max, lr_max/20.0, 0])[0] 154 | + return lr 155 | + 156 | + 157 | def lr_func_cosine(cfg, cur_epoch): 158 | """ 159 | Retrieve the learning rate to specified values at specified epoch with the 160 | diff --git a/slowfast/utils/misc.py b/slowfast/utils/misc.py 161 | index d7217c2..901e07d 100644 162 | --- a/slowfast/utils/misc.py 163 | +++ b/slowfast/utils/misc.py 164 | @@ -163,7 +163,11 @@ def get_model_stats(model, cfg, mode, use_train_input): 165 | model_mode = model.training 166 | model.eval() 167 | inputs = _get_model_analysis_input(cfg, use_train_input) 168 | - count_dict, *_ = model_stats_fun(model, inputs) 169 | + # modification for ImageNet 170 | + if cfg.TRAIN.DATASET == "imagenet": 171 | + count_dict, *_ = model_stats_fun(model, inputs[0]) 172 | + else: 173 | + count_dict, *_ = model_stats_fun(model, inputs) 174 | count = sum(count_dict.values()) 175 | model.train(model_mode) 176 | return count 177 | diff --git a/slowfast/utils/weight_init_helper.py b/slowfast/utils/weight_init_helper.py 178 | index 0a2d65d..2be56a4 100644 179 | --- a/slowfast/utils/weight_init_helper.py 180 | +++ b/slowfast/utils/weight_init_helper.py 181 | @@ -16,7 +16,7 @@ def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 182 | every bottleneck. 183 | """ 184 | for m in model.modules(): 185 | - if isinstance(m, nn.Conv3d): 186 | + if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d): 187 | """ 188 | Follow the initialization method proposed in: 189 | {He, Kaiming, et al. 190 | @@ -25,7 +25,7 @@ def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 191 | arXiv preprint arXiv:1502.01852 (2015)} 192 | """ 193 | c2_msra_fill(m) 194 | - elif isinstance(m, nn.BatchNorm3d): 195 | + elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d): 196 | if ( 197 | hasattr(m, "transform_final_bn") 198 | and m.transform_final_bn 199 | diff --git a/tools/train_net.py b/tools/train_net.py 200 | index 6b6ff22..1c64f25 100644 201 | --- a/tools/train_net.py 202 | +++ b/tools/train_net.py 203 | @@ -205,7 +205,6 @@ def train_epoch( 204 | train_meter.log_epoch_stats(cur_epoch) 205 | train_meter.reset() 206 | 207 | - 208 | @torch.no_grad() 209 | def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): 210 | """ 211 | @@ -220,14 +219,13 @@ def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): 212 | writer (TensorboardWriter, optional): TensorboardWriter object 213 | to writer Tensorboard log. 214 | """ 215 | - 216 | # Evaluation mode enabled. The running stats would not be updated. 217 | model.eval() 218 | val_meter.iter_tic() 219 | 220 | for cur_iter, (inputs, labels, _, meta) in enumerate(val_loader): 221 | if cfg.NUM_GPUS: 222 | - # Transferthe data to the current GPU device. 223 | + # Transfer the data to the current GPU device. 224 | if isinstance(inputs, (list,)): 225 | for i in range(len(inputs)): 226 | inputs[i] = inputs[i].cuda(non_blocking=True) 227 | -- 228 | 2.17.1 229 | 230 | 231 | From f588dcfdd04e70331bd137477851605778a45a12 Mon Sep 17 00:00:00 2001 232 | From: Chien-Yu Lin 233 | Date: Fri, 17 Jun 2022 23:29:10 +0000 234 | Subject: [PATCH 2/2] update checkpoint 235 | 236 | --- 237 | slowfast/utils/checkpoint.py | 6 ++++-- 238 | 1 file changed, 4 insertions(+), 2 deletions(-) 239 | 240 | diff --git a/slowfast/utils/checkpoint.py b/slowfast/utils/checkpoint.py 241 | index 227a657..abf1de7 100644 242 | --- a/slowfast/utils/checkpoint.py 243 | +++ b/slowfast/utils/checkpoint.py 244 | @@ -127,9 +127,11 @@ def save_checkpoint(path_to_job, model, optimizer, epoch, cfg, scaler=None): 245 | checkpoint = { 246 | "epoch": epoch, 247 | "model_state": normalized_sd, 248 | - "optimizer_state": optimizer.state_dict(), 249 | + # "optimizer_state": optimizer.state_dict(), 250 | "cfg": cfg.dump(), 251 | } 252 | + if optimizer is not None: 253 | + checkpoint["optimizer_state"] = optimizer.state_dict() 254 | if scaler is not None: 255 | checkpoint["scaler_state"] = scaler.state_dict() 256 | # Write the checkpoint. 257 | @@ -491,7 +493,7 @@ def load_test_checkpoint(cfg, model): 258 | ) 259 | 260 | 261 | -def load_train_checkpoint(cfg, model, optimizer, scaler=None): 262 | +def load_train_checkpoint(cfg, model, optimizer=None, scaler=None): 263 | """ 264 | Loading checkpoint logic for training. 265 | """ 266 | -- 267 | 2.17.1 268 | 269 | --------------------------------------------------------------------------------