├── .coveragerc ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── mobileone_pytorch ├── __init__.py ├── _depthwise_convolution.py ├── _mobileone_block.py ├── _mobileone_block_down.py ├── _mobileone_block_up.py ├── _mobileone_component.py ├── _mobileone_getters.py ├── _mobileone_network.py ├── _pointwise_convolution.py ├── _reparametrizable_module.py └── _reparametrize.py ├── pymarkdown.cfg ├── pyproject.toml ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── _tools ├── __init__.py └── _count_parameters.py ├── test_depthwise_convolution.py ├── test_mobileone_block.py ├── test_mobileone_block_down.py ├── test_mobileone_block_up.py ├── test_mobileone_getters.py ├── test_mobileone_network.py ├── test_pointwise_convolution.py └── test_reparametrize.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = 3 | mobileone_pytorch 4 | tests 5 | 6 | [report] 7 | exclude_lines = 8 | pragma: no cover 9 | 10 | def __repr__ 11 | if self\.debug 12 | 13 | raise AssertionError 14 | raise NotImplementedError 15 | 16 | if 0: 17 | if __name__ == .__main__.: 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | \#*\# 2 | *\# 3 | *~ 4 | __pycache__ 5 | .pytest_cache 6 | *.swp 7 | .coverage 8 | cov_html/ 9 | .mypy_cache 10 | .#* 11 | .idea/ 12 | .vscode/ 13 | .DS_Store 14 | .python-version 15 | dist 16 | *.egg-info 17 | .pytype 18 | build/ 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | fail_fast: true 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v2.5.0 7 | hooks: 8 | - id: check-added-large-files 9 | - id: check-yaml 10 | - id: check-ast 11 | - id: check-case-conflict 12 | stages: [push] 13 | - repo: local 14 | hooks: 15 | - id: isort 16 | name: isort 17 | stages: [commit] 18 | language: system 19 | entry: pipenv run isort 20 | types: [python] 21 | 22 | - id: black 23 | name: black 24 | stages: [commit] 25 | language: system 26 | entry: pipenv run black 27 | types: [python] 28 | 29 | - id: flake8 30 | name: flake8 31 | stages: [commit] 32 | language: system 33 | entry: pipenv run flake8 34 | types: [python] 35 | exclude: setup.py 36 | 37 | - id: mypy 38 | name: mypy 39 | stages: [commit] 40 | language: system 41 | entry: pipenv run mypy 42 | types: [python] 43 | pass_filenames: false 44 | 45 | - id: pytype 46 | name: pytype 47 | stages: [push] 48 | language: system 49 | entry: pipenv run pytype 50 | types: [python] 51 | pass_filenames: false 52 | args: [-j, auto] 53 | 54 | - id: pymarkdown 55 | name: pymarkdown 56 | stages: [commit] 57 | language: system 58 | entry: pipenv run pymarkdown 59 | args: [--config, pymarkdown.cfg, scan, README.md] 60 | pass_filenames: false 61 | 62 | - id: pytest-cov 63 | name: pytest-cov 64 | stages: [push] 65 | language: system 66 | entry: pipenv run pytest --cov --cov-fail-under=100 67 | types: [python] 68 | pass_filenames: false 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | black = "==22.3.0" 8 | flake8 = "==4.0.1" 9 | mypy = "==v0.942" 10 | isort = "==5.10.1" 11 | pytype = "==2021.11.12" 12 | pytest = "*" 13 | pytest-cov = "*" 14 | pre-commit = "*" 15 | types-toml = "*" 16 | types-pyyaml = "*" 17 | pymarkdownlnt = "==0.9.2" 18 | virtualenv = "==20.13.0" 19 | 20 | [packages] 21 | torch = "*" 22 | 23 | [requires] 24 | python_version = "3.8" 25 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "cfd20183fdb87b3e801707c074feec10a88829e3c31dac6e1dd1fee7126483a8" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.8" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "torch": { 20 | "hashes": [ 21 | "sha256:0399746f83b4541bcb5b219a18dbe8cade760aba1c660d2748a38c6dc338ebc7", 22 | "sha256:0986685f2ec8b7c4d3593e8cfe96be85d462943f1a8f54112fc48d4d9fbbe903", 23 | "sha256:13c7cca6b2ea3704d775444f02af53c5f072d145247e17b8cd7813ac57869f03", 24 | "sha256:201abf43a99bb4980cc827dd4b38ac28f35e4dddac7832718be3d5479cafd2c1", 25 | "sha256:2143d5fe192fd908b70b494349de5b1ac02854a8a902bd5f47d13d85b410e430", 26 | "sha256:2568f011dddeb5990d8698cc375d237f14568ffa8489854e3b94113b4b6b7c8b", 27 | "sha256:3322d33a06e440d715bb214334bd41314c94632d9a2f07d22006bf21da3a2be4", 28 | "sha256:349ea3ba0c0e789e0507876c023181f13b35307aebc2e771efd0e045b8e03e84", 29 | "sha256:44a3804e9bb189574f5d02ccc2dc6e32e26a81b3e095463b7067b786048c6072", 30 | "sha256:5ed69d5af232c5c3287d44cef998880dadcc9721cd020e9ae02f42e56b79c2e4", 31 | "sha256:60d06ee2abfa85f10582d205404d52889d69bcbb71f7e211cfc37e3957ac19ca", 32 | "sha256:63341f96840a223f277e498d2737b39da30d9f57c7a1ef88857b920096317739", 33 | "sha256:72207b8733523388c49d43ffcc4416d1d8cd64c40f7826332e714605ace9b1d2", 34 | "sha256:7ddb167827170c4e3ff6a27157414a00b9fef93dea175da04caf92a0619b7aee", 35 | "sha256:844f1db41173b53fe40c44b3e04fcca23a6ce00ac328b7099f2800e611766845", 36 | "sha256:a1325c9c28823af497cbf443369bddac9ac59f67f1e600f8ab9b754958e55b76", 37 | "sha256:abbdc5483359b9495dc76e3bd7911ccd2ddc57706c117f8316832e31590af871", 38 | "sha256:c0313438bc36448ffd209f5fb4e5f325b3af158cdf61c8829b8ddaf128c57816", 39 | "sha256:e3e8348edca3e3cee5a67a2b452b85c57712efe1cc3ffdb87c128b3dde54534e", 40 | "sha256:fb47291596677570246d723ee6abbcbac07eeba89d8f83de31e3954f21f44879" 41 | ], 42 | "index": "pypi", 43 | "version": "==1.12.0" 44 | }, 45 | "typing-extensions": { 46 | "hashes": [ 47 | "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02", 48 | "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6" 49 | ], 50 | "markers": "python_full_version >= '3.7.0'", 51 | "version": "==4.3.0" 52 | } 53 | }, 54 | "develop": { 55 | "application-properties": { 56 | "hashes": [ 57 | "sha256:4babf595ccc6142d0678cc2f4ffaae9369754c309c6182ec9a800c47d5c86876", 58 | "sha256:c020c1b0da868ffd46070c258d2af9cccd992d3a6e940d75c2e07186d3ab274b" 59 | ], 60 | "markers": "python_version >= '3.8'", 61 | "version": "==0.5.0" 62 | }, 63 | "attrs": { 64 | "hashes": [ 65 | "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4", 66 | "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd" 67 | ], 68 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", 69 | "version": "==21.4.0" 70 | }, 71 | "black": { 72 | "hashes": [ 73 | "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b", 74 | "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176", 75 | "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09", 76 | "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a", 77 | "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015", 78 | "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79", 79 | "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb", 80 | "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20", 81 | "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464", 82 | "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968", 83 | "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82", 84 | "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21", 85 | "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0", 86 | "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265", 87 | "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b", 88 | "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a", 89 | "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72", 90 | "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce", 91 | "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0", 92 | "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a", 93 | "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163", 94 | "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad", 95 | "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d" 96 | ], 97 | "index": "pypi", 98 | "version": "==22.3.0" 99 | }, 100 | "cfgv": { 101 | "hashes": [ 102 | "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426", 103 | "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736" 104 | ], 105 | "markers": "python_full_version >= '3.6.1'", 106 | "version": "==3.3.1" 107 | }, 108 | "click": { 109 | "hashes": [ 110 | "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e", 111 | "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48" 112 | ], 113 | "markers": "python_version >= '3.7'", 114 | "version": "==8.1.3" 115 | }, 116 | "columnar": { 117 | "hashes": [ 118 | "sha256:3fda53a5c8858e5103a647bd23dc9ad11ed107888a409aeca1f3a89327abd0db", 119 | "sha256:b7c3bf4d35e7b66db63077343bed15cbe5298f5215b3b3302543d9cf221440cc" 120 | ], 121 | "version": "==1.3.1" 122 | }, 123 | "coverage": { 124 | "extras": [ 125 | "toml" 126 | ], 127 | "hashes": [ 128 | "sha256:01c5615d13f3dd3aa8543afc069e5319cfa0c7d712f6e04b920431e5c564a749", 129 | "sha256:106c16dfe494de3193ec55cac9640dd039b66e196e4641fa8ac396181578b982", 130 | "sha256:129cd05ba6f0d08a766d942a9ed4b29283aff7b2cccf5b7ce279d50796860bb3", 131 | "sha256:145f296d00441ca703a659e8f3eb48ae39fb083baba2d7ce4482fb2723e050d9", 132 | "sha256:1480ff858b4113db2718848d7b2d1b75bc79895a9c22e76a221b9d8d62496428", 133 | "sha256:269eaa2c20a13a5bf17558d4dc91a8d078c4fa1872f25303dddcbba3a813085e", 134 | "sha256:26dff09fb0d82693ba9e6231248641d60ba606150d02ed45110f9ec26404ed1c", 135 | "sha256:2bd9a6fc18aab8d2e18f89b7ff91c0f34ff4d5e0ba0b33e989b3cd4194c81fd9", 136 | "sha256:309ce4a522ed5fca432af4ebe0f32b21d6d7ccbb0f5fcc99290e71feba67c264", 137 | "sha256:3384f2a3652cef289e38100f2d037956194a837221edd520a7ee5b42d00cc605", 138 | "sha256:342d4aefd1c3e7f620a13f4fe563154d808b69cccef415415aece4c786665397", 139 | "sha256:39ee53946bf009788108b4dd2894bf1349b4e0ca18c2016ffa7d26ce46b8f10d", 140 | "sha256:4321f075095a096e70aff1d002030ee612b65a205a0a0f5b815280d5dc58100c", 141 | "sha256:4803e7ccf93230accb928f3a68f00ffa80a88213af98ed338a57ad021ef06815", 142 | "sha256:4ce1b258493cbf8aec43e9b50d89982346b98e9ffdfaae8ae5793bc112fb0068", 143 | "sha256:664a47ce62fe4bef9e2d2c430306e1428ecea207ffd68649e3b942fa8ea83b0b", 144 | "sha256:75ab269400706fab15981fd4bd5080c56bd5cc07c3bccb86aab5e1d5a88dc8f4", 145 | "sha256:83c4e737f60c6936460c5be330d296dd5b48b3963f48634c53b3f7deb0f34ec4", 146 | "sha256:84631e81dd053e8a0d4967cedab6db94345f1c36107c71698f746cb2636c63e3", 147 | "sha256:84e65ef149028516c6d64461b95a8dbcfce95cfd5b9eb634320596173332ea84", 148 | "sha256:865d69ae811a392f4d06bde506d531f6a28a00af36f5c8649684a9e5e4a85c83", 149 | "sha256:87f4f3df85aa39da00fd3ec4b5abeb7407e82b68c7c5ad181308b0e2526da5d4", 150 | "sha256:8c08da0bd238f2970230c2a0d28ff0e99961598cb2e810245d7fc5afcf1254e8", 151 | "sha256:961e2fb0680b4f5ad63234e0bf55dfb90d302740ae9c7ed0120677a94a1590cb", 152 | "sha256:9b3e07152b4563722be523e8cd0b209e0d1a373022cfbde395ebb6575bf6790d", 153 | "sha256:a7f3049243783df2e6cc6deafc49ea123522b59f464831476d3d1448e30d72df", 154 | "sha256:bf5601c33213d3cb19d17a796f8a14a9eaa5e87629a53979a5981e3e3ae166f6", 155 | "sha256:cec3a0f75c8f1031825e19cd86ee787e87cf03e4fd2865c79c057092e69e3a3b", 156 | "sha256:d42c549a8f41dc103a8004b9f0c433e2086add8a719da00e246e17cbe4056f72", 157 | "sha256:d67d44996140af8b84284e5e7d398e589574b376fb4de8ccd28d82ad8e3bea13", 158 | "sha256:d9c80df769f5ec05ad21ea34be7458d1dc51ff1fb4b2219e77fe24edf462d6df", 159 | "sha256:e57816f8ffe46b1df8f12e1b348f06d164fd5219beba7d9433ba79608ef011cc", 160 | "sha256:ee2ddcac99b2d2aec413e36d7a429ae9ebcadf912946b13ffa88e7d4c9b712d6", 161 | "sha256:f02cbbf8119db68455b9d763f2f8737bb7db7e43720afa07d8eb1604e5c5ae28", 162 | "sha256:f1d5aa2703e1dab4ae6cf416eb0095304f49d004c39e9db1d86f57924f43006b", 163 | "sha256:f5b66caa62922531059bc5ac04f836860412f7f88d38a476eda0a6f11d4724f4", 164 | "sha256:f69718750eaae75efe506406c490d6fc5a6161d047206cc63ce25527e8a3adad", 165 | "sha256:fb73e0011b8793c053bfa85e53129ba5f0250fdc0392c1591fd35d915ec75c46", 166 | "sha256:fd180ed867e289964404051a958f7cccabdeed423f91a899829264bb7974d3d3", 167 | "sha256:fdb6f7bd51c2d1714cea40718f6149ad9be6a2ee7d93b19e9f00934c0f2a74d9", 168 | "sha256:ffa9297c3a453fba4717d06df579af42ab9a28022444cae7fa605af4df612d54" 169 | ], 170 | "markers": "python_version >= '3.7'", 171 | "version": "==6.4.1" 172 | }, 173 | "distlib": { 174 | "hashes": [ 175 | "sha256:6564fe0a8f51e734df6333d08b8b94d4ea8ee6b99b5ed50613f731fd4089f34b", 176 | "sha256:e4b58818180336dc9c529bfb9a0b58728ffc09ad92027a3f30b7cd91e3458579" 177 | ], 178 | "version": "==0.3.4" 179 | }, 180 | "filelock": { 181 | "hashes": [ 182 | "sha256:37def7b658813cda163b56fc564cdc75e86d338246458c4c28ae84cabefa2404", 183 | "sha256:3a0fd85166ad9dbab54c9aec96737b744106dc5f15c0b09a6744a445299fcf04" 184 | ], 185 | "markers": "python_version >= '3.7'", 186 | "version": "==3.7.1" 187 | }, 188 | "flake8": { 189 | "hashes": [ 190 | "sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d", 191 | "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d" 192 | ], 193 | "index": "pypi", 194 | "version": "==4.0.1" 195 | }, 196 | "identify": { 197 | "hashes": [ 198 | "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa", 199 | "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82" 200 | ], 201 | "markers": "python_version >= '3.7'", 202 | "version": "==2.5.1" 203 | }, 204 | "importlab": { 205 | "hashes": [ 206 | "sha256:744bd75d4410744962d203bd1eb71a950b19e8fb8eb5f0b805461dc0a2da329b" 207 | ], 208 | "markers": "python_version >= '3.6'", 209 | "version": "==0.7" 210 | }, 211 | "iniconfig": { 212 | "hashes": [ 213 | "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3", 214 | "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32" 215 | ], 216 | "version": "==1.1.1" 217 | }, 218 | "isort": { 219 | "hashes": [ 220 | "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7", 221 | "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951" 222 | ], 223 | "index": "pypi", 224 | "version": "==5.10.1" 225 | }, 226 | "libcst": { 227 | "hashes": [ 228 | "sha256:00f72f86915a32e68ffffce4d45ef89af54f34113fdb3b000e262dc28dd1d8c2", 229 | "sha256:05d2bdb8d2bed5e3c55066804084e0e26a982aed2b357d3503a214aa66e21348", 230 | "sha256:0d3c50eeee2d6003e7ca0098f519e8f531e0c973bfbbe274d241ed3f617e905b", 231 | "sha256:1685b99ceaff8296e5f2395d4f8dbea26baaba730f49d73f608dfcd77e5d86ef", 232 | "sha256:298756f36604dee1bbc10f39ee6ef592047c0a4faca4c7b3a863ff48934af9fa", 233 | "sha256:3116776ea9e56a48da98ae4c94c2c5c8be4c45c7247855d9461eedae0cffae4f", 234 | "sha256:46e40161d44c2fe043722d07584a2b0b85672570b2fca673a9d8760e3678f276", 235 | "sha256:4e66d7349d02b947199621c159327b9a19767280308556c1ab54613a64e3eee2", 236 | "sha256:5012ccfed435c6cd00f506e5705cc7d325b585562a23b84a78395094ad342ddc", 237 | "sha256:5e123209fb81c958cace1b9a1750cf443515265bec231d52f84ed93347463a16", 238 | "sha256:72b52e74c9c3950e43424e706e01991a83671fa7b140191875dc1fb810ef41c1", 239 | "sha256:8b9e2108235cd5ad67c9a344267fba6c29ecb3bbeff01d3de4701f31f18fbeb3", 240 | "sha256:9bff5b970d52eaf723d2232ebb2db2cfeb9bf3d68978316ee698f49ba869b310", 241 | "sha256:a17442b62a22bef6ce0734ff33801378575ab8a9f9a33dbafe236270cdbcdb3c", 242 | "sha256:a17d7512b996af17ce649363963332405d4566b0d54a404e019918caf78c1dee", 243 | "sha256:a28ec30e3bcd8ce15311dd68a707f608e4c4bde282ea2657ecb13c24b2bf206f", 244 | "sha256:ac24e390532687ceca3d5d1ff5cca36abe3947d4da235de47f2803c65f78eda1", 245 | "sha256:b1e178a57a294afba329f02d0a7af4f06792a65d7e7bf36157519311c1d9a849", 246 | "sha256:bb3248b6eb50df73a712b517a643d0d20b2afd8ebb405d76d0ab1af2fbb2a2a6", 247 | "sha256:bdac3fcd49329659dceef8ddeb88a0f565235f7a2e0d953c27bc91316574f49d", 248 | "sha256:c9d8948ab7dc2dfd1c5b078f29ba3987d4107075ad4aa011db42bc805154ec9f", 249 | "sha256:d7dddb5f8894aa5637fa0cc3b4a0f7b069d209d1bd91437719f3bd62a358998a", 250 | "sha256:e73a119c3b9d0f225d3f6416805fd87245fe5441028c2adfefb949c7728b6bbe", 251 | "sha256:fc41cf393be128d8c2023c1871f9bafb39ddd9875ff53aef9cb9dda03ca0ab5d" 252 | ], 253 | "markers": "python_version >= '3.7'", 254 | "version": "==0.4.5" 255 | }, 256 | "mccabe": { 257 | "hashes": [ 258 | "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", 259 | "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" 260 | ], 261 | "version": "==0.6.1" 262 | }, 263 | "mypy": { 264 | "hashes": [ 265 | "sha256:0e2dd88410937423fba18e57147dd07cd8381291b93d5b1984626f173a26543e", 266 | "sha256:10daab80bc40f84e3f087d896cdb53dc811a9f04eae4b3f95779c26edee89d16", 267 | "sha256:17e44649fec92e9f82102b48a3bf7b4a5510ad0cd22fa21a104826b5db4903e2", 268 | "sha256:1a0459c333f00e6a11cbf6b468b870c2b99a906cb72d6eadf3d1d95d38c9352c", 269 | "sha256:246e1aa127d5b78488a4a0594bd95f6d6fb9d63cf08a66dafbff8595d8891f67", 270 | "sha256:2b184db8c618c43c3a31b32ff00cd28195d39e9c24e7c3b401f3db7f6e5767f5", 271 | "sha256:2bc249409a7168d37c658e062e1ab5173300984a2dada2589638568ddc1db02b", 272 | "sha256:3841b5433ff936bff2f4dc8d54cf2cdbfea5d8e88cedfac45c161368e5770ba6", 273 | "sha256:4c3e497588afccfa4334a9986b56f703e75793133c4be3a02d06a3df16b67a58", 274 | "sha256:5bf44840fb43ac4074636fd47ee476d73f0039f4f54e86d7265077dc199be24d", 275 | "sha256:64235137edc16bee6f095aba73be5334677d6f6bdb7fa03cfab90164fa294a17", 276 | "sha256:6776e5fa22381cc761df53e7496a805801c1a751b27b99a9ff2f0ca848c7eca0", 277 | "sha256:6ce34a118d1a898f47def970a2042b8af6bdcc01546454726c7dd2171aa6dfca", 278 | "sha256:6f6ad963172152e112b87cc7ec103ba0f2db2f1cd8997237827c052a3903eaa6", 279 | "sha256:6f7106cbf9cc2f403693bf50ed7c9fa5bb3dfa9007b240db3c910929abe2a322", 280 | "sha256:7742d2c4e46bb5017b51c810283a6a389296cda03df805a4f7869a6f41246534", 281 | "sha256:9521c1265ccaaa1791d2c13582f06facf815f426cd8b07c3a485f486a8ffc1f3", 282 | "sha256:a1b383fe99678d7402754fe90448d4037f9512ce70c21f8aee3b8bf48ffc51db", 283 | "sha256:b840cfe89c4ab6386c40300689cd8645fc8d2d5f20101c7f8bd23d15fca14904", 284 | "sha256:d8d3ba77e56b84cd47a8ee45b62c84b6d80d32383928fe2548c9a124ea0a725c", 285 | "sha256:dcd955f36e0180258a96f880348fbca54ce092b40fbb4b37372ae3b25a0b0a46", 286 | "sha256:e865fec858d75b78b4d63266c9aff770ecb6a39dfb6d6b56c47f7f8aba6baba8", 287 | "sha256:edf7237137a1a9330046dbb14796963d734dd740a98d5e144a3eb1d267f5f9ee" 288 | ], 289 | "index": "pypi", 290 | "version": "==v0.942" 291 | }, 292 | "mypy-extensions": { 293 | "hashes": [ 294 | "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d", 295 | "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8" 296 | ], 297 | "version": "==0.4.3" 298 | }, 299 | "networkx": { 300 | "hashes": [ 301 | "sha256:5e53f027c0d567cf1f884dbb283224df525644e43afd1145d64c9d88a3584762", 302 | "sha256:6933b9b3174a0bdf03c911bb4a1ee43a86ce3edeb813e37e1d4c553b3f4a2c4f" 303 | ], 304 | "markers": "python_version >= '3.8'", 305 | "version": "==2.8.4" 306 | }, 307 | "ninja": { 308 | "hashes": [ 309 | "sha256:0560eea57199e41e86ac2c1af0108b63ae77c3ca4d05a9425a750e908135935a", 310 | "sha256:21a1d84d4c7df5881bfd86c25cce4cf7af44ba2b8b255c57bc1c434ec30a2dfc", 311 | "sha256:279836285975e3519392c93c26e75755e8a8a7fafec9f4ecbb0293119ee0f9c6", 312 | "sha256:29570a18d697fc84d361e7e6330f0021f34603ae0fcb0ef67ae781e9814aae8d", 313 | "sha256:5ea785bf6a15727040835256577239fa3cf5da0d60e618c307aa5efc31a1f0ce", 314 | "sha256:688167841b088b6802e006f911d911ffa925e078c73e8ef2f88286107d3204f8", 315 | "sha256:6bd76a025f26b9ae507cf8b2b01bb25bb0031df54ed685d85fc559c411c86cf4", 316 | "sha256:740d61fefb4ca13573704ee8fe89b973d40b8dc2a51aaa4e9e68367233743bb6", 317 | "sha256:840a0b042d43a8552c4004966e18271ec726e5996578f28345d9ce78e225b67e", 318 | "sha256:84be6f9ec49f635dc40d4b871319a49fa49b8d55f1d9eae7cd50d8e57ddf7a85", 319 | "sha256:9ca8dbece144366d5f575ffc657af03eb11c58251268405bc8519d11cf42f113", 320 | "sha256:cc8b31b5509a2129e4d12a35fc21238c157038022560aaf22e49ef0a77039086", 321 | "sha256:d5e0275d28997a750a4f445c00bdd357b35cc334c13cdff13edf30e544704fbd", 322 | "sha256:e1b86ad50d4e681a7dbdff05fc23bb52cb773edb90bc428efba33fa027738408" 323 | ], 324 | "version": "==1.10.2.3" 325 | }, 326 | "nodeenv": { 327 | "hashes": [ 328 | "sha256:27083a7b96a25f2f5e1d8cb4b6317ee8aeda3bdd121394e5ac54e498028a042e", 329 | "sha256:e0e7f7dfb85fc5394c6fe1e8fa98131a2473e04311a45afb6508f7cf1836fa2b" 330 | ], 331 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6'", 332 | "version": "==1.7.0" 333 | }, 334 | "packaging": { 335 | "hashes": [ 336 | "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb", 337 | "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522" 338 | ], 339 | "markers": "python_version >= '3.6'", 340 | "version": "==21.3" 341 | }, 342 | "pathspec": { 343 | "hashes": [ 344 | "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a", 345 | "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1" 346 | ], 347 | "version": "==0.9.0" 348 | }, 349 | "platformdirs": { 350 | "hashes": [ 351 | "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788", 352 | "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19" 353 | ], 354 | "markers": "python_version >= '3.7'", 355 | "version": "==2.5.2" 356 | }, 357 | "pluggy": { 358 | "hashes": [ 359 | "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159", 360 | "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3" 361 | ], 362 | "markers": "python_version >= '3.6'", 363 | "version": "==1.0.0" 364 | }, 365 | "pre-commit": { 366 | "hashes": [ 367 | "sha256:10c62741aa5704faea2ad69cb550ca78082efe5697d6f04e5710c3c229afdd10", 368 | "sha256:4233a1e38621c87d9dda9808c6606d7e7ba0e087cd56d3fe03202a01d2919615" 369 | ], 370 | "index": "pypi", 371 | "version": "==2.19.0" 372 | }, 373 | "py": { 374 | "hashes": [ 375 | "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", 376 | "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378" 377 | ], 378 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", 379 | "version": "==1.11.0" 380 | }, 381 | "pycodestyle": { 382 | "hashes": [ 383 | "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20", 384 | "sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f" 385 | ], 386 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", 387 | "version": "==2.8.0" 388 | }, 389 | "pyflakes": { 390 | "hashes": [ 391 | "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c", 392 | "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e" 393 | ], 394 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", 395 | "version": "==2.4.0" 396 | }, 397 | "pymarkdownlnt": { 398 | "hashes": [ 399 | "sha256:273f406c0ed271dce663d3444e336fa03436ebc981bf27cbb2a596642df3b658", 400 | "sha256:425dfa56f01449d670de561d0239251da925fe6386e42bbcb22ca1c4e6bfb61b" 401 | ], 402 | "index": "pypi", 403 | "version": "==0.9.2" 404 | }, 405 | "pyparsing": { 406 | "hashes": [ 407 | "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb", 408 | "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc" 409 | ], 410 | "markers": "python_full_version >= '3.6.8'", 411 | "version": "==3.0.9" 412 | }, 413 | "pytest": { 414 | "hashes": [ 415 | "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c", 416 | "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45" 417 | ], 418 | "index": "pypi", 419 | "version": "==7.1.2" 420 | }, 421 | "pytest-cov": { 422 | "hashes": [ 423 | "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6", 424 | "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470" 425 | ], 426 | "index": "pypi", 427 | "version": "==3.0.0" 428 | }, 429 | "pytype": { 430 | "hashes": [ 431 | "sha256:05cfb5f8681edeeefb5ddce6048044465b4d6d4dde8f2963557430631e484e4e", 432 | "sha256:1833cb074bc21d91fa5f7f451447630f1d66b6e3219163f8c90021a7e4c02d75", 433 | "sha256:3b75272ba0e10973ec8c5a30a82ba42d63dcc6fb3bed0106879b539ede17c6cb", 434 | "sha256:62553824000096407aa3f901c13b21b845365c8626248f2b216ab1feeac7efc6", 435 | "sha256:66e8e4a198e17895da4137fbb9e0cdd5494d1e905b90fd9d77952e49577e8a43", 436 | "sha256:79fa4e3715cc72fbc2fc170347467565ec50a13adfb81c5b96b6221752dc742a", 437 | "sha256:82c8952b2394096b61447804d20e08a4cf3219a4671970c6f8dd332dc142663c", 438 | "sha256:8d06b1349228730f0cb84a1cfcf32c25d67203a7d570114743b024fd52ebc09e", 439 | "sha256:9f4e0cda5d7e30d20e2c17cf5a3e4ee4992244631188ec20f801f5c703baf44e", 440 | "sha256:a8d2289a3b35e109e888c81fdf586188fdcaefca247cb441215a171a8668a48f", 441 | "sha256:cca706b2110eb5e62fca4c2d79739cd7ec13cc65c7b050d6a49b8aaa5c6bffd4", 442 | "sha256:e5a6111e9ae4e5e7e791510d74de0d831ef480f38039e5ce5a2b009bcb47b5fa", 443 | "sha256:e6951e35668e70f22cfaadddbea9a2b989df2cf3e6a1ab3158df920c9caa8136" 444 | ], 445 | "index": "pypi", 446 | "version": "==2021.11.12" 447 | }, 448 | "pyyaml": { 449 | "hashes": [ 450 | "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293", 451 | "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b", 452 | "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57", 453 | "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b", 454 | "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4", 455 | "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07", 456 | "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba", 457 | "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9", 458 | "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287", 459 | "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513", 460 | "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0", 461 | "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0", 462 | "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92", 463 | "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f", 464 | "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2", 465 | "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc", 466 | "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c", 467 | "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86", 468 | "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4", 469 | "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c", 470 | "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34", 471 | "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b", 472 | "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c", 473 | "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb", 474 | "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737", 475 | "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3", 476 | "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d", 477 | "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53", 478 | "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78", 479 | "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803", 480 | "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a", 481 | "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174", 482 | "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5" 483 | ], 484 | "markers": "python_version >= '3.6'", 485 | "version": "==6.0" 486 | }, 487 | "setuptools": { 488 | "hashes": [ 489 | "sha256:990a4f7861b31532871ab72331e755b5f14efbe52d336ea7f6118144dd478741", 490 | "sha256:c1848f654aea2e3526d17fc3ce6aeaa5e7e24e66e645b5be2171f3f6b4e5a178" 491 | ], 492 | "markers": "python_version >= '3.7'", 493 | "version": "==62.6.0" 494 | }, 495 | "six": { 496 | "hashes": [ 497 | "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", 498 | "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254" 499 | ], 500 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'", 501 | "version": "==1.16.0" 502 | }, 503 | "tabulate": { 504 | "hashes": [ 505 | "sha256:0ba055423dbaa164b9e456abe7920c5e8ed33fcc16f6d1b2f2d152c8e1e8b4fc", 506 | "sha256:436f1c768b424654fce8597290d2764def1eea6a77cfa5c33be00b1bc0f4f63d", 507 | "sha256:6c57f3f3dd7ac2782770155f3adb2db0b1a269637e42f27599925e64b114f519" 508 | ], 509 | "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'", 510 | "version": "==0.8.10" 511 | }, 512 | "toml": { 513 | "hashes": [ 514 | "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", 515 | "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f" 516 | ], 517 | "markers": "python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'", 518 | "version": "==0.10.2" 519 | }, 520 | "tomli": { 521 | "hashes": [ 522 | "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", 523 | "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f" 524 | ], 525 | "markers": "python_version >= '3.7'", 526 | "version": "==2.0.1" 527 | }, 528 | "toolz": { 529 | "hashes": [ 530 | "sha256:6b312d5e15138552f1bda8a4e66c30e236c831b612b2bf0005f8a1df10a4bc33", 531 | "sha256:a5700ce83414c64514d82d60bcda8aabfde092d1c1a8663f9200c07fdcc6da8f" 532 | ], 533 | "markers": "python_version >= '3.5'", 534 | "version": "==0.11.2" 535 | }, 536 | "typed-ast": { 537 | "hashes": [ 538 | "sha256:0261195c2062caf107831e92a76764c81227dae162c4f75192c0d489faf751a2", 539 | "sha256:0fdbcf2fef0ca421a3f5912555804296f0b0960f0418c440f5d6d3abb549f3e1", 540 | "sha256:183afdf0ec5b1b211724dfef3d2cad2d767cbefac291f24d69b00546c1837fb6", 541 | "sha256:211260621ab1cd7324e0798d6be953d00b74e0428382991adfddb352252f1d62", 542 | "sha256:267e3f78697a6c00c689c03db4876dd1efdfea2f251a5ad6555e82a26847b4ac", 543 | "sha256:2efae9db7a8c05ad5547d522e7dbe62c83d838d3906a3716d1478b6c1d61388d", 544 | "sha256:370788a63915e82fd6f212865a596a0fefcbb7d408bbbb13dea723d971ed8bdc", 545 | "sha256:39e21ceb7388e4bb37f4c679d72707ed46c2fbf2a5609b8b8ebc4b067d977df2", 546 | "sha256:3e123d878ba170397916557d31c8f589951e353cc95fb7f24f6bb69adc1a8a97", 547 | "sha256:4879da6c9b73443f97e731b617184a596ac1235fe91f98d279a7af36c796da35", 548 | "sha256:4e964b4ff86550a7a7d56345c7864b18f403f5bd7380edf44a3c1fb4ee7ac6c6", 549 | "sha256:639c5f0b21776605dd6c9dbe592d5228f021404dafd377e2b7ac046b0349b1a1", 550 | "sha256:669dd0c4167f6f2cd9f57041e03c3c2ebf9063d0757dc89f79ba1daa2bfca9d4", 551 | "sha256:6778e1b2f81dfc7bc58e4b259363b83d2e509a65198e85d5700dfae4c6c8ff1c", 552 | "sha256:683407d92dc953c8a7347119596f0b0e6c55eb98ebebd9b23437501b28dcbb8e", 553 | "sha256:79b1e0869db7c830ba6a981d58711c88b6677506e648496b1f64ac7d15633aec", 554 | "sha256:7d5d014b7daa8b0bf2eaef684295acae12b036d79f54178b92a2b6a56f92278f", 555 | "sha256:98f80dee3c03455e92796b58b98ff6ca0b2a6f652120c263efdba4d6c5e58f72", 556 | "sha256:a94d55d142c9265f4ea46fab70977a1944ecae359ae867397757d836ea5a3f47", 557 | "sha256:a9916d2bb8865f973824fb47436fa45e1ebf2efd920f2b9f99342cb7fab93f72", 558 | "sha256:c542eeda69212fa10a7ada75e668876fdec5f856cd3d06829e6aa64ad17c8dfe", 559 | "sha256:cf4afcfac006ece570e32d6fa90ab74a17245b83dfd6655a6f68568098345ff6", 560 | "sha256:ebd9d7f80ccf7a82ac5f88c521115cc55d84e35bf8b446fcd7836eb6b98929a3", 561 | "sha256:ed855bbe3eb3715fca349c80174cfcfd699c2f9de574d40527b8429acae23a66" 562 | ], 563 | "markers": "python_version >= '3.6'", 564 | "version": "==1.5.4" 565 | }, 566 | "types-pyyaml": { 567 | "hashes": [ 568 | "sha256:33ae75c84b8f61fddf0c63e9c7e557db9db1694ad3c2ee8628ec5efebb5a5e9b", 569 | "sha256:b738e9ef120da0af8c235ba49d3b72510f56ef9bcc308fc8e7357100ff122284" 570 | ], 571 | "index": "pypi", 572 | "version": "==6.0.9" 573 | }, 574 | "types-toml": { 575 | "hashes": [ 576 | "sha256:05a8da4bfde2f1ee60e90c7071c063b461f74c63a9c3c1099470c08d6fa58615", 577 | "sha256:a567fe2614b177d537ad99a661adc9bfc8c55a46f95e66370a4ed2dd171335f9" 578 | ], 579 | "index": "pypi", 580 | "version": "==0.10.7" 581 | }, 582 | "typing-extensions": { 583 | "hashes": [ 584 | "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02", 585 | "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6" 586 | ], 587 | "markers": "python_full_version >= '3.7.0'", 588 | "version": "==4.3.0" 589 | }, 590 | "typing-inspect": { 591 | "hashes": [ 592 | "sha256:047d4097d9b17f46531bf6f014356111a1b6fb821a24fe7ac909853ca2a782aa", 593 | "sha256:3cd7d4563e997719a710a3bfe7ffb544c6b72069b6812a02e9b414a8fa3aaa6b", 594 | "sha256:b1f56c0783ef0f25fb064a01be6e5407e54cf4a4bf4f3ba3fe51e0bd6dcea9e5" 595 | ], 596 | "version": "==0.7.1" 597 | }, 598 | "virtualenv": { 599 | "hashes": [ 600 | "sha256:339f16c4a86b44240ba7223d0f93a7887c3ca04b5f9c8129da7958447d079b09", 601 | "sha256:d8458cf8d59d0ea495ad9b34c2599487f8a7772d796f9910858376d1600dd2dd" 602 | ], 603 | "index": "pypi", 604 | "version": "==20.13.0" 605 | }, 606 | "wcwidth": { 607 | "hashes": [ 608 | "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784", 609 | "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83" 610 | ], 611 | "version": "==0.2.5" 612 | } 613 | } 614 | } 615 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileOne PyTorch 2 | 3 | Unofficial PyTorch implementation of 4 | [**An Improved One millisecond Mobile Backbone**](https://arxiv.org/pdf/2206.04040.pdf) paper. 5 | 6 | ## Quickstart 7 | 8 | Install with `pip install mobileone_pytorch` and create a MobileOne with: 9 | 10 | ```python 11 | from mobileone_pytorch import mobileone_s1 12 | model = mobileone_s1() 13 | ``` 14 | 15 | ## Overview 16 | 17 | This repository contains an implementation of MobileOne. 18 | 19 | Features: 20 | 21 | - Implementation of all MobileOne versions 22 | - Reparametrization for model deployment 23 | 24 | *Upcomining features*: 25 | 26 | - Squeeze-and-Excitation block for MobileOne S4 27 | 28 | **Help wanted**: 29 | 30 | - Training models on ImageNet 31 | 32 | ## Table of contents 33 | 34 | 1. [About MobileOne](#about-mobileone) 35 | 2. [Installation](#installation) 36 | 3. [Usage](#usage) 37 | - [Create models](#create-models) 38 | - [Deployment via reparametrization](#deployment) 39 | 40 | ### About MobileOne 41 | 42 | MobileOne is a novel architecture that with variants achieves an inference time 43 | under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. 44 | 45 | - MobileOne achieves state-of-the-art performance 46 | within the efficient architectures while being many times faster 47 | on mobile. 48 | 49 | - The best model (S4) obtains similar performance on ImageNet 50 | as [Mobile-Former](https://arxiv.org/abs/2108.05895) while being 38× faster. 51 | Moreover it obtains 2.3% better top-1 accuracy on ImageNet 52 | than [EfficientNet](https://arxiv.org/abs/1905.11946) at similar latency. 53 | 54 | ### Installation 55 | 56 | Install via pip: 57 | 58 | ```bash 59 | pip install mobileone_pytorch 60 | ``` 61 | 62 | Or install from source: 63 | 64 | ```bash 65 | git clone https://github.com/federicopozzi33/MobileOne-PyTorch.git 66 | cd mobileone_pytorch 67 | pip install -e . 68 | ``` 69 | 70 | ### Usage 71 | 72 | #### Create models 73 | 74 | Create MobileOne models: 75 | 76 | ```python 77 | from mobileone_pytorch import ( 78 | mobileone_s0, 79 | mobileone_s1, 80 | mobileone_s2, 81 | mobileone_s3, 82 | mobileone_s4 83 | ) 84 | 85 | model_s0 = mobileone_s0() 86 | model_s1 = mobileone_s1() 87 | model_s2 = mobileone_s2() 88 | model_s3 = mobileone_s3() 89 | model_s4 = mobileone_s4() 90 | ``` 91 | 92 | #### Deployment 93 | 94 | Deploy a MobileOne through reparametrization: 95 | 96 | ```python 97 | import torch 98 | from mobileone_pytorch import mobileone_s1 99 | 100 | x = torch.rand(1, 3, 224, 224) 101 | 102 | model = mobileone_s1() 103 | deployed = model.reparametrize() 104 | 105 | model.eval() 106 | deployed.eval() 107 | 108 | out1 = model(x) 109 | out2 = deployed(x) 110 | 111 | torch.testing.assert_close(out1, out2) 112 | ``` 113 | 114 | ### Contributing 115 | 116 | If you find a bug, create a GitHub issue. 117 | Similarly, if you have questions, simply post them as GitHub issues. 118 | -------------------------------------------------------------------------------- /mobileone_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from ._mobileone_getters import ( # noqa 2 | mobileone_s0, 3 | mobileone_s1, 4 | mobileone_s2, 5 | mobileone_s3, 6 | mobileone_s4, 7 | ) 8 | -------------------------------------------------------------------------------- /mobileone_pytorch/_depthwise_convolution.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | from ._reparametrizable_module import ReparametrizableModule 7 | from ._reparametrize import rep_conv2d_bn_to_conv2d 8 | 9 | 10 | class DepthwiseConvolutionBlock(ReparametrizableModule): 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | kernel_size: Tuple[int, int] = (3, 3), 15 | stride: int = 1, 16 | padding: int = 1, 17 | ): 18 | super().__init__() 19 | self._conv = nn.Conv2d( 20 | in_channels=in_channels, 21 | out_channels=in_channels, 22 | kernel_size=kernel_size, 23 | stride=(stride, stride), 24 | padding=(padding, padding), 25 | groups=in_channels, 26 | bias=False, 27 | ) 28 | self._bn = nn.BatchNorm2d(in_channels) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | return self._bn(self._conv(x)) 32 | 33 | def reparametrize(self, to_3x3: bool = False) -> nn.Conv2d: 34 | return rep_conv2d_bn_to_conv2d( 35 | self._conv, 36 | self._bn, 37 | to_3x3, 38 | ) 39 | 40 | 41 | def create_depthwise_blocks( 42 | k: int, 43 | in_channels: int, 44 | kernel_size: Tuple[int, int] = (3, 3), 45 | stride: int = 1, 46 | ) -> nn.ModuleList: 47 | return nn.ModuleList( 48 | [ 49 | DepthwiseConvolutionBlock( 50 | in_channels, 51 | kernel_size, 52 | stride, 53 | ) 54 | for i in range(k) 55 | ] 56 | ) 57 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_block.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from ._mobileone_block_down import MobileOneBlockDown 8 | from ._mobileone_block_up import MobileOneBlockUp 9 | from ._reparametrizable_module import ReparametrizableModule 10 | 11 | 12 | class MobileOneBlock(ReparametrizableModule): 13 | """ 14 | MobileOne block implementation. 15 | 16 | It is composed of two sequential blocks: 17 | - MobileOneBlockUp 18 | - MobileOneBlockDown 19 | 20 | Each block is followed by the ReLU activation. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | k: int, 26 | in_channels: int, 27 | out_channels: int, 28 | stride: int = 1, 29 | ): 30 | super().__init__() 31 | self._mobileone_block_up = MobileOneBlockUp( 32 | k=k, 33 | in_channels=in_channels, 34 | stride=stride, 35 | ) 36 | self._mobileone_block_down = MobileOneBlockDown( 37 | k=k, in_channels=in_channels, out_channels=out_channels 38 | ) 39 | 40 | @property 41 | def num_blocks(self) -> int: 42 | return self._mobileone_block_down.num_blocks 43 | 44 | def forward(self, x: Tensor) -> Tensor: 45 | x = self._mobileone_block_up(x) 46 | x = F.relu(x) 47 | x = self._mobileone_block_down(x) 48 | x = F.relu(x) 49 | 50 | return x 51 | 52 | def reparametrize(self) -> nn.Sequential: 53 | l1 = self._mobileone_block_up.reparametrize() 54 | l2 = self._mobileone_block_down.reparametrize() 55 | 56 | return nn.Sequential( 57 | OrderedDict( 58 | [ 59 | ("block_up", l1), 60 | ("relu1", nn.ReLU()), 61 | ("block_down", l2), 62 | ("relu2", nn.ReLU()), 63 | ] 64 | ) 65 | ) 66 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_block_down.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | from ._mobileone_component import MobileOneBlockComponent 6 | from ._pointwise_convolution import create_pointwise_blocks 7 | from ._reparametrizable_module import ReparametrizableParallelModule 8 | from ._reparametrize import rep_bn_to_conv2d 9 | 10 | 11 | class MobileOneBlockDown(MobileOneBlockComponent, ReparametrizableParallelModule): 12 | """ 13 | The second component of the MobileOne block. 14 | 15 | This block can expand the number of channels 16 | using the out_channels parameter. 17 | 18 | It is composed of two parallel branches. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | k: int, 24 | in_channels: int, 25 | out_channels: int, 26 | ): 27 | super().__init__( 28 | k=k, 29 | in_channels=in_channels, 30 | out_channels=out_channels, 31 | kernel_size=(1, 1), 32 | stride=1, 33 | groups=1, 34 | ) 35 | 36 | self._pointwise_blocks = create_pointwise_blocks( 37 | k=self._k, 38 | in_channels=self._in_channels, 39 | out_channels=self._out_channels, 40 | ) 41 | self._batch_norm_branch = ( 42 | nn.BatchNorm2d(self._out_channels) if self._in_channels == self._out_channels else None 43 | ) 44 | 45 | @property 46 | def num_branches(self) -> int: 47 | return self._k + 1 48 | 49 | @property 50 | def num_blocks(self) -> int: 51 | return self._k 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | outputs = [] 55 | 56 | if self._batch_norm_branch is not None: 57 | outputs.append(self._batch_norm_branch(x)) 58 | 59 | for layer in self._pointwise_blocks: 60 | outputs.append(layer(x)) 61 | 62 | return torch.stack(outputs, dim=0).sum(dim=0) 63 | 64 | def _reparametrize_layers(self) -> nn.ModuleList: 65 | layers = [] 66 | 67 | if self._batch_norm_branch is not None: 68 | layers.append( 69 | rep_bn_to_conv2d( 70 | batch_norm=self._batch_norm_branch, 71 | groups=self._groups, 72 | kernel_size=self._kernel_size, 73 | ) 74 | ) 75 | 76 | layers.extend([branch.reparametrize() for branch in self._pointwise_blocks]) 77 | 78 | return nn.ModuleList(layers) 79 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_block_up.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | from mobileone_pytorch._depthwise_convolution import ( 6 | DepthwiseConvolutionBlock, 7 | create_depthwise_blocks, 8 | ) 9 | 10 | from ._mobileone_component import MobileOneBlockComponent 11 | from ._reparametrizable_module import ReparametrizableParallelModule 12 | from ._reparametrize import rep_bn_to_conv2d 13 | 14 | 15 | class MobileOneBlockUp(MobileOneBlockComponent, ReparametrizableParallelModule): 16 | """ 17 | The first component of the MobileOne block. 18 | 19 | This block can reduce the resolution using the stride parameter. 20 | 21 | It is composed of three parallel branches. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | k: int, 27 | in_channels: int, 28 | stride: int = 1, 29 | ): 30 | super().__init__( 31 | k=k, 32 | in_channels=in_channels, 33 | out_channels=in_channels, 34 | kernel_size=(3, 3), 35 | stride=stride, 36 | groups=in_channels, 37 | ) 38 | self._depthwise_blocks = create_depthwise_blocks( 39 | k=self._k, 40 | in_channels=self._in_channels, 41 | kernel_size=self._kernel_size, 42 | stride=self._stride, 43 | ) 44 | 45 | self._batch_norm_branch = nn.BatchNorm2d(self._in_channels) if self._stride == 1 else None 46 | 47 | self._depthwise_branch = DepthwiseConvolutionBlock( 48 | in_channels=self._in_channels, 49 | kernel_size=(1, 1), 50 | stride=self._stride, 51 | padding=0, 52 | ) 53 | 54 | @property 55 | def num_branches(self) -> int: 56 | return self._k + 2 57 | 58 | @property 59 | def num_blocks(self) -> int: 60 | return self._k 61 | 62 | def forward(self, x: Tensor) -> Tensor: 63 | outputs = [] 64 | 65 | if self._batch_norm_branch is not None: 66 | outputs.append(self._batch_norm_branch(x)) 67 | 68 | outputs.append(self._depthwise_branch(x)) 69 | 70 | for branch in self._depthwise_blocks: 71 | outputs.append(branch(x)) 72 | 73 | return torch.stack(outputs, dim=0).sum(dim=0) 74 | 75 | def _reparametrize_layers(self) -> nn.ModuleList: 76 | layers = [] 77 | if self._batch_norm_branch is not None: 78 | layers.append( 79 | rep_bn_to_conv2d( 80 | batch_norm=self._batch_norm_branch, 81 | groups=self._groups, 82 | kernel_size=self._kernel_size, 83 | ) 84 | ) 85 | 86 | layers.append(self._depthwise_branch.reparametrize(to_3x3=True)) 87 | layers.extend([branch.reparametrize() for branch in self._depthwise_blocks]) 88 | 89 | return nn.ModuleList(layers) 90 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_component.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | from torch.nn import Module 5 | 6 | 7 | class MobileOneBlockComponent(Module, ABC): 8 | def __init__( 9 | self, 10 | k: int, 11 | in_channels: int, 12 | out_channels: int, 13 | kernel_size: Tuple[int, int], 14 | stride: int, 15 | groups: int, 16 | ): 17 | super().__init__() 18 | self._k = k 19 | self._in_channels = in_channels 20 | self._out_channels = out_channels 21 | self._kernel_size = kernel_size 22 | self._stride = stride 23 | self._groups = groups 24 | 25 | @property 26 | @abstractmethod 27 | def num_branches(self) -> int: 28 | """Return the number of parallel branches.""" 29 | 30 | @property 31 | @abstractmethod 32 | def num_blocks(self) -> int: 33 | """Return the number of blocks.""" 34 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_getters.py: -------------------------------------------------------------------------------- 1 | from ._mobileone_network import MobileOneConfiguration, MobileOneNetwork, MobileOneSize, get_params 2 | 3 | 4 | def mobileone_s0(num_classes: int = 1000) -> MobileOneNetwork: 5 | return _get_mobileone(get_params(MobileOneSize.S0, num_classes)) 6 | 7 | 8 | def mobileone_s1(num_classes: int = 1000) -> MobileOneNetwork: 9 | return _get_mobileone(get_params(MobileOneSize.S1, num_classes=num_classes)) 10 | 11 | 12 | def mobileone_s2(num_classes: int = 1000) -> MobileOneNetwork: 13 | return _get_mobileone(get_params(MobileOneSize.S2, num_classes=num_classes)) 14 | 15 | 16 | def mobileone_s3(num_classes: int = 1000) -> MobileOneNetwork: 17 | return _get_mobileone(get_params(MobileOneSize.S3, num_classes=num_classes)) 18 | 19 | 20 | def mobileone_s4(num_classes: int = 1000) -> MobileOneNetwork: 21 | return _get_mobileone(get_params(MobileOneSize.S4, num_classes=num_classes)) 22 | 23 | 24 | def _get_mobileone(conf: MobileOneConfiguration) -> MobileOneNetwork: 25 | return MobileOneNetwork( 26 | ks=conf.ks, 27 | out_channels=conf.out_channels, 28 | num_blocks=conf.num_blocks, 29 | strides=conf.strides, 30 | num_classes=conf.num_classes, 31 | ) 32 | -------------------------------------------------------------------------------- /mobileone_pytorch/_mobileone_network.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Dict, List 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | 10 | from ._mobileone_block import MobileOneBlock 11 | from ._reparametrizable_module import ReparametrizableModule, ReparametrizableSequential 12 | 13 | 14 | class MobileOneSize(Enum): 15 | S0 = "s0" 16 | S1 = "s1" 17 | S2 = "s2" 18 | S3 = "s3" 19 | S4 = "s4" 20 | 21 | 22 | @dataclass 23 | class MobileOneConfiguration: 24 | num_blocks: List[int] 25 | out_channels: List[int] 26 | strides: List[int] 27 | ks: List[int] 28 | num_classes: int 29 | 30 | 31 | @dataclass 32 | class NetworkConfig: 33 | ks: List[int] 34 | alphas: List[float] 35 | 36 | 37 | @dataclass 38 | class NetworkBasicConfig: 39 | num_blocks: List[int] 40 | strides: List[int] 41 | out_channels: List[int] 42 | 43 | 44 | BaseConfig = Dict[str, List[int]] 45 | NetworkConfigs = Dict[MobileOneSize, NetworkConfig] 46 | 47 | _BASE_CONFIG = NetworkBasicConfig( 48 | num_blocks=[1, 2, 8, 5, 5, 1], 49 | strides=[2, 2, 2, 2, 1, 2], 50 | out_channels=[64, 64, 128, 256, 256, 512], 51 | ) 52 | 53 | _NETWORK_CONFIGS: NetworkConfigs = { 54 | MobileOneSize.S0: NetworkConfig( 55 | alphas=[0.75, 0.75, 1.0, 1.0, 1.0, 2.0], 56 | ks=[4, 4, 4, 4, 4, 4], 57 | ), 58 | MobileOneSize.S1: NetworkConfig( 59 | alphas=[1.5, 1.5, 1.5, 2.0, 2.0, 2.5], 60 | ks=[1, 1, 1, 1, 1, 1], 61 | ), 62 | MobileOneSize.S2: NetworkConfig( 63 | alphas=[1.5, 1.5, 2.0, 2.5, 2.5, 4.0], 64 | ks=[1, 1, 1, 1, 1, 1], 65 | ), 66 | MobileOneSize.S3: NetworkConfig( 67 | alphas=[2.0, 2.0, 2.5, 3.0, 3.0, 4.0], 68 | ks=[1, 1, 1, 1, 1, 1], 69 | ), 70 | MobileOneSize.S4: NetworkConfig( 71 | alphas=[3.0, 3.0, 3.5, 3.5, 3.5, 4.0], 72 | ks=[1, 1, 1, 1, 1, 1], 73 | ), 74 | } 75 | 76 | 77 | def get_params(size: MobileOneSize, num_classes: int) -> MobileOneConfiguration: 78 | conf = _NETWORK_CONFIGS[size] 79 | 80 | out_channels = [ 81 | int(out_ch * alfa) 82 | for out_ch, alfa in zip( 83 | _BASE_CONFIG.out_channels, 84 | conf.alphas, 85 | ) 86 | ] 87 | 88 | return MobileOneConfiguration( 89 | num_blocks=_BASE_CONFIG.num_blocks, 90 | out_channels=out_channels, 91 | strides=_BASE_CONFIG.strides, 92 | ks=conf.ks, 93 | num_classes=num_classes, 94 | ) 95 | 96 | 97 | class MobileOneNetwork(ReparametrizableModule): 98 | """MobileOne network. 99 | 100 | Described in detail here: https://arxiv.org/abs/2206.04040 101 | """ 102 | 103 | def __init__( 104 | self, 105 | ks: List[int], 106 | out_channels: List[int], 107 | num_blocks: List[int], 108 | strides: List[int], 109 | num_classes: int = 1000, 110 | ): 111 | super().__init__() 112 | self._features = ReparametrizableSequential( 113 | OrderedDict( 114 | [ 115 | ( 116 | f"_stage{i+1}", 117 | _compose_stage( 118 | num_blocks=num_blocks[i], 119 | k=ks[i], 120 | in_channels=3 if i == 0 else out_channels[i - 1], 121 | out_channels=out_channels[i], 122 | stride=strides[i], 123 | ), 124 | ) 125 | for i in range(len(num_blocks)) 126 | ] 127 | ) 128 | ) 129 | self._average_pooling = nn.AdaptiveAvgPool2d((1, 1)) 130 | self._linear = nn.Linear( 131 | in_features=out_channels[-1], 132 | out_features=num_classes, 133 | ) 134 | 135 | @property 136 | def num_classes(self) -> int: 137 | return self._linear.out_features 138 | 139 | def forward(self, x: Tensor) -> Tensor: 140 | x = self._features(x) 141 | x = self._average_pooling(x) 142 | x = torch.flatten(x, 1) 143 | x = self._linear(x) 144 | 145 | return x 146 | 147 | def reparametrize(self) -> nn.Sequential: 148 | return nn.Sequential( 149 | *[ 150 | self._features.reparametrize(), 151 | self._average_pooling, 152 | nn.Flatten(), 153 | self._linear, 154 | ] 155 | ) 156 | 157 | 158 | def _compose_stage( 159 | num_blocks: int, 160 | k: int, 161 | in_channels: int, 162 | out_channels: int, 163 | stride: int, 164 | ) -> ReparametrizableSequential: 165 | return ReparametrizableSequential( 166 | *[ 167 | MobileOneBlock( 168 | k=k, 169 | in_channels=in_channels if i == 0 else out_channels, 170 | out_channels=out_channels, 171 | stride=stride if i == 0 else 1, 172 | ) 173 | for i in range(num_blocks) 174 | ] 175 | ) 176 | -------------------------------------------------------------------------------- /mobileone_pytorch/_pointwise_convolution.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | 4 | from ._reparametrizable_module import ReparametrizableModule 5 | from ._reparametrize import rep_conv2d_bn_to_conv2d 6 | 7 | 8 | class PointwiseConvolutionBlock(ReparametrizableModule): 9 | def __init__(self, in_channels: int, out_channels: int): 10 | super().__init__() 11 | self._conv = nn.Conv2d( 12 | in_channels=in_channels, 13 | out_channels=out_channels, 14 | kernel_size=(1, 1), 15 | bias=False, 16 | ) 17 | self._bn = nn.BatchNorm2d(out_channels) 18 | 19 | def forward(self, x: Tensor): 20 | return self._bn(self._conv(x)) 21 | 22 | def reparametrize(self) -> nn.Conv2d: 23 | return rep_conv2d_bn_to_conv2d(self._conv, self._bn) 24 | 25 | 26 | def create_pointwise_blocks( 27 | k: int, 28 | in_channels: int, 29 | out_channels: int, 30 | ) -> nn.ModuleList: 31 | return nn.ModuleList( 32 | [ 33 | PointwiseConvolutionBlock( 34 | in_channels, 35 | out_channels, 36 | ) 37 | for i in range(k) 38 | ] 39 | ) 40 | -------------------------------------------------------------------------------- /mobileone_pytorch/_reparametrizable_module.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch.nn as nn 4 | from torch.nn import Module 5 | 6 | from ._reparametrize import rep_parallel_convs2d 7 | 8 | 9 | class ReparametrizableModule(Module, ABC): 10 | @abstractmethod 11 | def reparametrize(self) -> Module: 12 | """Return a reparametrized module.""" 13 | 14 | 15 | class ReparametrizableSequential(nn.Sequential): 16 | def reparametrize(self) -> nn.Sequential: 17 | return nn.Sequential(*[module.reparametrize() for module in self]) 18 | 19 | 20 | class ReparametrizableParallelModule(ReparametrizableModule, ABC): 21 | def reparametrize(self) -> nn.Conv2d: 22 | return self._reparametrize_parallel_layers(self._reparametrize_layers()) 23 | 24 | @abstractmethod 25 | def _reparametrize_layers(self) -> nn.ModuleList: 26 | """Return a ModuleList of reparametrized layers.""" 27 | 28 | def _reparametrize_parallel_layers(self, layers: nn.ModuleList) -> nn.Conv2d: 29 | return rep_parallel_convs2d(*layers) 30 | -------------------------------------------------------------------------------- /mobileone_pytorch/_reparametrize.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | def pad_1x1_to_3x3_tensor(kernel1x1: Tensor) -> Tensor: 10 | return F.pad(kernel1x1, [1, 1, 1, 1]) 11 | 12 | 13 | def rep_conv2d_bn_to_conv2d( 14 | conv: nn.Conv2d, batch_norm: nn.BatchNorm2d, to_3x3: bool = False 15 | ) -> nn.Conv2d: 16 | """Reparametrize `conv` + `batch_norm` into Conv2d layer. 17 | 18 | Supported `kernel_size` values: (1, 1) and (3, 3) 19 | """ 20 | kernel, bias = _rep_conv_bn(conv.weight, batch_norm) 21 | 22 | if conv.kernel_size == (1, 1) and to_3x3: 23 | kernel = pad_1x1_to_3x3_tensor(kernel1x1=kernel) 24 | 25 | if conv.bias is not None: 26 | bias += conv.bias 27 | 28 | # formula = ((w - k + 2p) / 2) + 1 29 | assert isinstance(conv.padding, tuple) 30 | assert len(conv.padding) == 2 # mypy 31 | padding = ( 32 | conv.padding 33 | if not to_3x3 or conv.kernel_size == (3, 3) 34 | else (conv.padding[0] + 1, conv.padding[1] + 1) 35 | ) 36 | 37 | return _create_conv2d( 38 | kernel, 39 | bias, 40 | conv.in_channels, 41 | conv.out_channels, 42 | padding, # type: ignore 43 | conv.groups, 44 | conv.stride, # type: ignore 45 | ) 46 | 47 | 48 | def rep_bn_to_conv2d( 49 | batch_norm: nn.BatchNorm2d, 50 | kernel_size: Tuple[int, int], 51 | groups: int = 1, 52 | ) -> nn.Conv2d: 53 | """Reparametrize `batch_norm` layer into conv2d with given parameters (e.g. `kernel_size`). 54 | 55 | Supported `kernel_size` values: (1, 1) and (3, 3). 56 | """ 57 | in_channels = batch_norm.num_features 58 | input_dim = in_channels // groups 59 | 60 | kernel = torch.zeros( 61 | (in_channels, input_dim, kernel_size[0], kernel_size[1]), 62 | ) 63 | 64 | h = 1 if kernel_size == (3, 3) else 0 65 | for i in range(in_channels): 66 | kernel[i, i % input_dim, h, h] = 1 67 | 68 | kernel, bias = _rep_conv_bn(kernel, batch_norm) 69 | padding = (1, 1) if kernel_size == (3, 3) else (0, 0) 70 | 71 | return _create_conv2d( 72 | kernel, 73 | bias, 74 | in_channels, 75 | in_channels, 76 | padding, 77 | groups, 78 | ) 79 | 80 | 81 | def rep_parallel_convs2d(*convs: nn.Conv2d) -> nn.Conv2d: 82 | """Reparametrize identical convolutional layers in parallel. 83 | 84 | All conv layers are expected to have bias. 85 | """ 86 | 87 | def sum_biases(biases: List[Tensor]) -> Tensor: 88 | bias = torch.zeros_like(biases[0].data) 89 | 90 | for _bias in biases: 91 | bias += _bias 92 | 93 | return bias 94 | 95 | kernel = torch.stack([conv.weight.data for conv in convs], dim=0).sum(dim=0) 96 | bias = sum_biases([conv.bias.data for conv in convs]) # type: ignore 97 | 98 | return _create_conv2d( 99 | kernel, 100 | bias, 101 | convs[0].in_channels, 102 | convs[0].out_channels, 103 | convs[0].padding, # type: ignore 104 | convs[0].groups, 105 | stride=convs[0].stride, # type: ignore 106 | ) 107 | 108 | 109 | def _create_conv2d( 110 | kernel: Tensor, 111 | bias: Tensor, 112 | in_channels: int, 113 | out_channels: int, 114 | padding: Tuple[int, int], 115 | groups: int, 116 | stride: Tuple[int, int] = (1, 1), 117 | ) -> nn.Conv2d: 118 | conv = nn.Conv2d( 119 | in_channels, 120 | out_channels, 121 | kernel_size=(kernel.shape[-2], kernel.shape[-1]), 122 | padding=padding, 123 | groups=groups, 124 | stride=stride, 125 | bias=True, 126 | ) 127 | 128 | conv.weight.data = kernel 129 | assert conv.bias is not None # mypy 130 | conv.bias.data = bias 131 | 132 | return conv 133 | 134 | 135 | def _rep_conv_bn(kernel: Tensor, bn: nn.BatchNorm2d) -> Tuple[Tensor, Tensor]: 136 | mu = bn.running_mean 137 | var = bn.running_var 138 | assert var is not None 139 | gamma = bn.weight 140 | beta = bn.bias 141 | eps = bn.eps 142 | 143 | std = torch.sqrt(var + eps) 144 | bias = beta - (mu * gamma) / std 145 | kernel = kernel * (gamma / std).reshape(-1, 1, 1, 1) 146 | 147 | return kernel, bias 148 | -------------------------------------------------------------------------------- /pymarkdown.cfg: -------------------------------------------------------------------------------- 1 | {"plugins": {"line-length": {"line_length": 100}}} 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | skip = .pytype 3 | multi_line_output = 3 4 | include_trailing_comma = True 5 | force_grid_wrap = 0 6 | use_parentheses = True 7 | line_length = 100 8 | 9 | [flake8] 10 | ignore = E203, E266, E501, W503, 11 | max-line-length = 100 12 | max-complexity = 10 13 | select = B,C,E,F,W,T4 14 | 15 | [mypy] 16 | files= 17 | mobileone_pytorch, 18 | tests 19 | ignore_missing_imports=true 20 | 21 | [pytype] 22 | inputs = 23 | mobileone_pytorch 24 | tests 25 | pythonpath = . 26 | 27 | [tool:pytest] 28 | testpaths=tests/ 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | NAME = "mobileone_pytorch" 4 | DESCRIPTION = "MobileOne implemented in PyTorch." 5 | URL = "https://github.com/federicopozzi33/MobileOne-PyTorch" 6 | EMAIL = "f.pozzi33@campus.unimib.it" 7 | AUTHOR = "Federico Pozzi" 8 | REQUIRES_PYTHON = ">=3.8.0" 9 | VERSION = "0.1.0" 10 | REQUIRED = ["torch"] 11 | 12 | setup( 13 | name=NAME, 14 | version=VERSION, 15 | description=DESCRIPTION, 16 | url=URL, 17 | author=AUTHOR, 18 | author_email=EMAIL, 19 | packages=find_packages(include=["mobileone_pytorch"]), 20 | install_requires=REQUIRED, 21 | python_requires=REQUIRES_PYTHON, 22 | ) 23 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/federicopozzi33/MobileOne-PyTorch/1a2c1e10df6d4d1db441013d6db6e072c7be0f73/tests/__init__.py -------------------------------------------------------------------------------- /tests/_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from ._count_parameters import count_parameters # noqa 2 | -------------------------------------------------------------------------------- /tests/_tools/_count_parameters.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | 4 | def count_parameters(model: Module) -> int: 5 | return sum(p.numel() for p in model.parameters()) 6 | -------------------------------------------------------------------------------- /tests/test_depthwise_convolution.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | 6 | from mobileone_pytorch._depthwise_convolution import ( 7 | DepthwiseConvolutionBlock, 8 | create_depthwise_blocks, 9 | ) 10 | 11 | 12 | class TestDepthwiseConvolutionBlock: 13 | def test_creation(self): 14 | DepthwiseConvolutionBlock(in_channels=3, kernel_size=(3, 3), stride=2) 15 | 16 | def test_inference(self): 17 | module = DepthwiseConvolutionBlock(in_channels=3, kernel_size=(3, 3), stride=2) 18 | x = torch.randn(1, 3, 28, 28) 19 | 20 | out = module(x) 21 | 22 | assert out.shape == (1, 3, 14, 14) 23 | 24 | def test_get_depthwise_blocks(self): 25 | blocks = create_depthwise_blocks(k=2, in_channels=3, kernel_size=(3, 3), stride=2) 26 | assert len(blocks) == 2 27 | assert all(isinstance(block, DepthwiseConvolutionBlock) for block in blocks) 28 | 29 | def test_depthwise_blocks_inference(self): 30 | modules = create_depthwise_blocks(k=2, in_channels=3, kernel_size=(3, 3), stride=2) 31 | x = torch.randn(1, 3, 28, 28) 32 | 33 | outs = [] 34 | for layer in modules: 35 | outs.append(layer(x).shape) 36 | 37 | assert all(out == (1, 3, 14, 14) for out in outs) 38 | assert len(outs) == 2 39 | 40 | @pytest.mark.parametrize("stride", [1, 2]) 41 | @pytest.mark.parametrize("padding", [0, 1, 2]) 42 | @pytest.mark.parametrize("kernel_size", [(3, 3), (1, 1)]) 43 | def test_reparametrize( 44 | self, 45 | kernel_size: Tuple[int, int], 46 | padding: int, 47 | stride: int, 48 | ): 49 | module = DepthwiseConvolutionBlock( 50 | in_channels=3, 51 | kernel_size=kernel_size, 52 | padding=padding, 53 | stride=stride, 54 | ) 55 | module.eval() 56 | x = torch.randn(1, 3, 28, 28) 57 | 58 | out1 = module(x) 59 | 60 | rep = module.reparametrize() 61 | 62 | out2 = rep(x) 63 | assert torch.allclose(out1, out2, atol=1e-6, rtol=1e-4) 64 | -------------------------------------------------------------------------------- /tests/test_mobileone_block.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from mobileone_pytorch._mobileone_block import MobileOneBlock 5 | from tests._tools import count_parameters 6 | 7 | 8 | class TestMobileOneBlock: 9 | def test_creation(self): 10 | block = MobileOneBlock(k=2, in_channels=3, out_channels=1, stride=2) 11 | 12 | assert block.num_blocks == 2 13 | 14 | def test_inference(self): 15 | module = MobileOneBlock(k=2, in_channels=3, out_channels=1, stride=2) 16 | x = torch.randn(1, 3, 28, 28) 17 | 18 | out = module(x) 19 | 20 | assert out.shape == (1, 1, 14, 14) 21 | 22 | @pytest.mark.parametrize("out_channels", [1, 3]) 23 | @pytest.mark.parametrize("k", [1, 2, 3]) 24 | @pytest.mark.parametrize("stride", [1, 2]) 25 | def test_reparametrization(self, stride: int, k: int, out_channels: int): 26 | module = MobileOneBlock( 27 | k=k, 28 | in_channels=3, 29 | out_channels=1, 30 | stride=stride, 31 | ) 32 | module.eval() 33 | 34 | x = torch.randn(1, 3, 28, 28) 35 | 36 | out1 = module(x) 37 | 38 | rep_module = module.reparametrize() 39 | rep_module.eval() 40 | 41 | out2 = rep_module(x) 42 | 43 | assert torch.allclose(out1, out2, atol=1e-6, rtol=1e-4) 44 | 45 | def test_num_params(self): 46 | block = MobileOneBlock( 47 | k=2, 48 | in_channels=3, 49 | out_channels=64, 50 | stride=2, 51 | ) 52 | rep_block = block.reparametrize() 53 | assert count_parameters(block) > count_parameters(rep_block) 54 | -------------------------------------------------------------------------------- /tests/test_mobileone_block_down.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from mobileone_pytorch._mobileone_block_down import MobileOneBlockDown 5 | from tests._tools import count_parameters 6 | 7 | 8 | class TestMobileOneBlockDown: 9 | def test_creation(self): 10 | block = MobileOneBlockDown( 11 | k=2, 12 | in_channels=3, 13 | out_channels=1, 14 | ) 15 | 16 | assert block.num_blocks == 2 17 | assert block.num_branches == 2 + 1 18 | 19 | @pytest.mark.parametrize("out_channels", [3, 1]) 20 | def test_inference(self, out_channels: int): 21 | module = MobileOneBlockDown( 22 | k=2, 23 | in_channels=3, 24 | out_channels=out_channels, 25 | ) 26 | x = torch.randn(1, 3, 28, 28) 27 | 28 | out = module(x) 29 | 30 | assert out.shape == (1, out_channels, 28, 28) 31 | 32 | @pytest.mark.parametrize("k", [1, 2, 3]) 33 | @pytest.mark.parametrize("out_channels", [3]) 34 | def test_reparametrization(self, out_channels: int, k: int): 35 | module = MobileOneBlockDown( 36 | k=k, 37 | in_channels=3, 38 | out_channels=out_channels, 39 | ) 40 | module.eval() 41 | 42 | x = torch.randn(1, 3, 28, 28) 43 | 44 | out1 = module(x) 45 | 46 | rep_module = module.reparametrize() 47 | rep_module.eval() 48 | 49 | out2 = rep_module(x) 50 | 51 | torch.testing.assert_allclose(out1, out2) 52 | 53 | def test_num_params(self): 54 | block = MobileOneBlockDown( 55 | k=2, 56 | in_channels=3, 57 | out_channels=64, 58 | ) 59 | rep_block = block.reparametrize() 60 | 61 | assert count_parameters(block) > count_parameters(rep_block) 62 | -------------------------------------------------------------------------------- /tests/test_mobileone_block_up.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from mobileone_pytorch._mobileone_block_up import MobileOneBlockUp 5 | from tests._tools import count_parameters 6 | 7 | 8 | class TestMobileOneBlockUp: 9 | def test_creation(self): 10 | block = MobileOneBlockUp( 11 | k=2, 12 | in_channels=3, 13 | stride=2, 14 | ) 15 | 16 | assert block.num_blocks == 2 17 | assert block.num_branches == 2 + 2 18 | 19 | @pytest.mark.parametrize("stride", [2, 1]) 20 | def test_inference(self, stride: int): 21 | module = MobileOneBlockUp( 22 | k=2, 23 | in_channels=3, 24 | stride=stride, 25 | ) 26 | x = torch.randn(1, 3, 28, 28) 27 | 28 | out = module(x) 29 | 30 | output_size = 28 // stride 31 | assert out.shape == (1, 3, output_size, output_size) 32 | 33 | @pytest.mark.parametrize("k", [1, 2, 3]) 34 | @pytest.mark.parametrize("stride", [2]) 35 | def test_reparametrization( 36 | self, 37 | stride: int, 38 | k: int, 39 | ): 40 | module = MobileOneBlockUp( 41 | k=k, 42 | in_channels=3, 43 | stride=stride, 44 | ) 45 | module.eval() 46 | 47 | x = torch.randn(1, 3, 28, 28) 48 | 49 | out1 = module(x) 50 | rep_module = module.reparametrize() 51 | rep_module.eval() 52 | 53 | out2 = rep_module(x) 54 | 55 | torch.testing.assert_allclose(out1, out2) 56 | 57 | def test_num_params(self): 58 | block = MobileOneBlockUp( 59 | k=2, 60 | in_channels=64, 61 | stride=2, 62 | ) 63 | rep_block = block.reparametrize() 64 | 65 | assert count_parameters(block) > count_parameters(rep_block) 66 | -------------------------------------------------------------------------------- /tests/test_mobileone_getters.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pytest 4 | import torch 5 | from torch import Tensor 6 | 7 | from mobileone_pytorch import mobileone_s0, mobileone_s1, mobileone_s2, mobileone_s3, mobileone_s4 8 | from mobileone_pytorch._mobileone_network import MobileOneNetwork 9 | from tests._tools import count_parameters 10 | 11 | 12 | def get_random_input() -> Tensor: 13 | return torch.rand(1, 3, 224, 224) 14 | 15 | 16 | class TestMobileOneInference: 17 | num_classes = 1000 18 | 19 | @pytest.mark.parametrize( 20 | "model_getter", 21 | [ 22 | mobileone_s0, 23 | mobileone_s1, 24 | mobileone_s2, 25 | mobileone_s3, 26 | mobileone_s4, 27 | ], 28 | ) 29 | def test_model_inference( 30 | self, 31 | model_getter: Callable[[int], MobileOneNetwork], 32 | ): 33 | model = model_getter(self.num_classes) 34 | 35 | out = model(get_random_input()) 36 | 37 | assert out.shape == (1, self.num_classes) 38 | assert model.num_classes == self.num_classes 39 | 40 | @pytest.mark.parametrize( 41 | "model_getter", 42 | [ 43 | mobileone_s0, 44 | mobileone_s1, 45 | mobileone_s2, 46 | mobileone_s3, 47 | mobileone_s4, 48 | ], 49 | ) 50 | def test_model_inference_reparametrized( 51 | self, 52 | model_getter: Callable[[], MobileOneNetwork], 53 | ): 54 | x = get_random_input() 55 | 56 | model = model_getter() 57 | rep_model = model.reparametrize() 58 | 59 | model.eval() 60 | rep_model.eval() 61 | 62 | out1 = model(x) 63 | out2 = rep_model(x) 64 | torch.testing.assert_close(out1, out2, rtol=1e-3, atol=1e-5) 65 | 66 | @pytest.mark.parametrize( 67 | "model_getter, expected_num_params", 68 | [ 69 | (mobileone_s0, 5_292_741), 70 | (mobileone_s1, 4_827_250), 71 | (mobileone_s2, 7_886_706), 72 | (mobileone_s3, 10_178_450), 73 | (mobileone_s4, 13_336_018), 74 | ], 75 | ) 76 | def test_num_params( 77 | self, 78 | model_getter: Callable[[], MobileOneNetwork], 79 | expected_num_params: int, 80 | ): 81 | network = model_getter() 82 | 83 | assert count_parameters(network) == expected_num_params 84 | 85 | @pytest.mark.parametrize( 86 | "model_getter, expected_num_params", 87 | [ 88 | (mobileone_s0, 2_077_382), 89 | (mobileone_s1, 4_766_854), 90 | (mobileone_s2, 7_810_182), 91 | (mobileone_s3, 10_085_894), 92 | (mobileone_s4, 13_222_406), 93 | ], 94 | ) 95 | def test_num_params_reparametrize( 96 | self, 97 | model_getter: Callable[[], MobileOneNetwork], 98 | expected_num_params: int, 99 | ): 100 | network = model_getter().reparametrize() 101 | 102 | assert count_parameters(network) == expected_num_params 103 | -------------------------------------------------------------------------------- /tests/test_mobileone_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mobileone_pytorch._mobileone_network import MobileOneNetwork, _compose_stage 4 | from tests._tools import count_parameters 5 | 6 | 7 | class TestMobileOneNetwork: 8 | def test_compare_rep_inference(self): 9 | x = torch.rand(1, 3, 224, 224) 10 | 11 | model = MobileOneNetwork( 12 | ks=[1, 1, 1, 1, 1, 1], 13 | out_channels=[64, 64, 128, 256, 256, 512], 14 | num_blocks=[1, 2, 8, 5, 5, 1], 15 | strides=[2, 2, 2, 2, 1, 2], 16 | num_classes=1000, 17 | ) 18 | rep_model = model.reparametrize() 19 | 20 | model.eval() 21 | rep_model.eval() 22 | 23 | out1 = model(x) 24 | out2 = rep_model(x) 25 | 26 | torch.testing.assert_close(out1, out2) 27 | 28 | 29 | class TestMobileOneNetworkUtils: 30 | num_blocks, k = 2, 3 31 | stride = 2 32 | in_channels, out_channels = 3, 1 33 | 34 | def test_compose_stage(self): 35 | module = _compose_stage( 36 | num_blocks=self.num_blocks, 37 | k=self.k, 38 | in_channels=self.in_channels, 39 | out_channels=self.out_channels, 40 | stride=self.stride, 41 | ) 42 | x = torch.randn(1, self.in_channels, 28, 28) 43 | 44 | out = module(x) 45 | 46 | assert len(module) == self.num_blocks 47 | assert module[0].num_blocks == module[1].num_blocks == self.k 48 | assert out.shape == (1, self.out_channels, 28 // self.stride, 28 // self.stride) 49 | 50 | def test_reparametrize(self): 51 | x = torch.randn(1, self.in_channels, 28, 28) 52 | 53 | module = _compose_stage( 54 | num_blocks=self.num_blocks, 55 | k=self.k, 56 | in_channels=self.in_channels, 57 | out_channels=self.out_channels, 58 | stride=self.stride, 59 | ) 60 | rep_module = module.reparametrize() 61 | 62 | module.eval() 63 | rep_module.eval() 64 | 65 | out1 = module(x) 66 | out2 = rep_module(x) 67 | 68 | torch.testing.assert_close(out1, out2) 69 | 70 | def test_num_params(self): 71 | module = _compose_stage( 72 | num_blocks=self.num_blocks, 73 | k=self.k, 74 | in_channels=self.in_channels, 75 | out_channels=self.out_channels, 76 | stride=self.stride, 77 | ) 78 | rep_block = module.reparametrize() 79 | 80 | assert count_parameters(module) > count_parameters(rep_block) 81 | -------------------------------------------------------------------------------- /tests/test_pointwise_convolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mobileone_pytorch._pointwise_convolution import ( 4 | PointwiseConvolutionBlock, 5 | create_pointwise_blocks, 6 | ) 7 | 8 | 9 | class TestPointwiseConvolutionBlock: 10 | def test_creation(self): 11 | PointwiseConvolutionBlock(in_channels=3, out_channels=11) 12 | 13 | def test_inference(self): 14 | module = PointwiseConvolutionBlock(in_channels=3, out_channels=1) 15 | x = torch.randn(1, 3, 28, 28) 16 | 17 | out = module(x) 18 | 19 | assert out.shape == (1, 1, 28, 28) 20 | 21 | def test_get_pointwise_blocks(self): 22 | modules = create_pointwise_blocks(k=2, in_channels=3, out_channels=1) 23 | x = torch.randn(1, 3, 28, 28) 24 | 25 | outs = [] 26 | for layer in modules: 27 | outs.append(layer(x).shape) 28 | 29 | assert all(out == (1, 1, 28, 28) for out in outs) 30 | assert len(outs) == 2 31 | 32 | def test_reparametrize(self): 33 | module = PointwiseConvolutionBlock(in_channels=3, out_channels=1) 34 | module.eval() 35 | x = torch.randn(1, 3, 28, 28) 36 | 37 | out1 = module(x) 38 | 39 | rep = module.reparametrize() 40 | 41 | out2 = rep(x) 42 | 43 | assert torch.allclose(out1, out2, atol=1e-6, rtol=1e-4) 44 | -------------------------------------------------------------------------------- /tests/test_reparametrize.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | 7 | from mobileone_pytorch._pointwise_convolution import create_pointwise_blocks 8 | from mobileone_pytorch._reparametrize import ( 9 | pad_1x1_to_3x3_tensor, 10 | rep_bn_to_conv2d, 11 | rep_conv2d_bn_to_conv2d, 12 | rep_parallel_convs2d, 13 | ) 14 | 15 | 16 | class TestReparametrize: 17 | def test_pad_1x1_to_3x3_tensor(self): 18 | t = torch.tensor([[[[1]]]]) 19 | 20 | actual = pad_1x1_to_3x3_tensor(t) 21 | expected = torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]) 22 | 23 | torch.testing.assert_close(actual, expected) 24 | 25 | @pytest.mark.parametrize("kernel_size", [(1, 1), (3, 3)]) 26 | @pytest.mark.parametrize("num_features, groups", [(1, 1), (3, 1), (3, 3)]) 27 | def test_rep_bn_to_conv( 28 | self, 29 | num_features: int, 30 | groups: int, 31 | kernel_size: Tuple[int, int], 32 | ): 33 | batch_norm = nn.BatchNorm2d(num_features) 34 | batch_norm.eval() 35 | 36 | x = torch.rand(1, num_features, 28, 28) 37 | 38 | out1 = batch_norm(x) 39 | 40 | rep_layer = rep_bn_to_conv2d(batch_norm, kernel_size, groups) 41 | out2 = rep_layer(x) 42 | 43 | torch.testing.assert_close(out1, out2) 44 | 45 | @pytest.mark.parametrize( 46 | "kernel_size,to_3x3", 47 | [((1, 1), False), ((1, 1), True), ((3, 3), False)], 48 | ) 49 | @pytest.mark.parametrize("stride", [1, 2, 3]) 50 | @pytest.mark.parametrize("groups", [1, 3]) 51 | @pytest.mark.parametrize("padding", [3]) 52 | @pytest.mark.parametrize("bias", [True, False]) 53 | def test_rep_bn_conv_to_conv( 54 | self, 55 | bias: bool, 56 | padding: int, 57 | groups: int, 58 | stride: int, 59 | kernel_size: Tuple[int, int], 60 | to_3x3: bool, 61 | ): 62 | in_channels = 3 63 | out_channels = 3 64 | conv = nn.Conv2d( 65 | in_channels, 66 | out_channels, 67 | kernel_size=kernel_size, 68 | padding=padding, 69 | bias=bias, 70 | groups=groups, 71 | stride=stride, 72 | ) 73 | batch_norm = nn.BatchNorm2d(out_channels) 74 | batch_norm.eval() 75 | 76 | x = torch.rand(1, in_channels, 28, 28) 77 | 78 | out1 = batch_norm(conv(x)) 79 | 80 | rep_layer = rep_conv2d_bn_to_conv2d(conv, batch_norm, to_3x3) 81 | 82 | out2 = rep_layer(x) 83 | 84 | torch.testing.assert_close(out1, out2) 85 | 86 | def test_rep_parallel_conv(self): 87 | in_channels = 3 88 | out_channels = 3 89 | conv1 = nn.Conv2d( 90 | in_channels, 91 | out_channels, 92 | kernel_size=(1, 1), 93 | padding=0, 94 | bias=True, 95 | groups=1, 96 | stride=1, 97 | ) 98 | batch_norm1 = nn.BatchNorm2d(out_channels) 99 | batch_norm1.eval() 100 | 101 | conv2 = nn.Conv2d( 102 | in_channels, 103 | out_channels, 104 | kernel_size=(1, 1), 105 | padding=0, 106 | bias=True, 107 | groups=1, 108 | stride=1, 109 | ) 110 | batch_norm2 = nn.BatchNorm2d(out_channels) 111 | batch_norm2.eval() 112 | 113 | batch_norm3 = nn.BatchNorm2d(out_channels) 114 | batch_norm3.eval() 115 | 116 | l1 = rep_conv2d_bn_to_conv2d(conv1, batch_norm1, False) 117 | l2 = rep_conv2d_bn_to_conv2d(conv2, batch_norm2, False) 118 | l3 = rep_bn_to_conv2d(batch_norm3, (1, 1), 1) 119 | 120 | x = torch.rand(1, in_channels, 28, 28) 121 | out1 = l1(x) + l2(x) + l3(x) 122 | out2 = rep_parallel_convs2d(l1, l2, l3)(x) 123 | 124 | torch.testing.assert_close(out1, out2) 125 | 126 | def test_rep_pointswise_blocks_and_batch_norm_parallel(self): 127 | blocks = create_pointwise_blocks(k=2, in_channels=3, out_channels=3) 128 | blocks.eval() 129 | 130 | batch_norm = nn.BatchNorm2d(3) 131 | batch_norm.eval() 132 | 133 | x = torch.rand(1, 3, 28, 28) 134 | 135 | out11 = blocks[0](x) 136 | out12 = blocks[1](x) 137 | out13 = batch_norm(x) 138 | 139 | out1 = out11 + out12 + out13 140 | 141 | reps1 = [block.reparametrize() for block in blocks] 142 | rep2 = rep_bn_to_conv2d(batch_norm, kernel_size=(1, 1)) 143 | 144 | out21 = reps1[0](x) 145 | out22 = reps1[1](x) 146 | out23 = rep2(x) 147 | 148 | out2 = out21 + out22 + out23 149 | 150 | torch.testing.assert_close(out1, out2) 151 | --------------------------------------------------------------------------------