├── .github └── workflows │ └── publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── requirements.txt └── source │ ├── api.rst │ ├── conf.py │ ├── examples.ipynb │ ├── index.rst │ └── readme_link.rst ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── test_extractor.py └── test_naming.py └── torchextractor ├── __init__.py ├── extractor.py ├── naming.py └── version.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Package 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | tests: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | include: 14 | # Python versions 15 | - python-version: 3.6 16 | torch-version: 1.7.1 17 | - python-version: 3.7 18 | torch-version: 1.7.1 19 | - python-version: 3.8 20 | torch-version: 1.7.1 21 | - python-version: 3.9 22 | torch-version: 1.7.1 23 | 24 | # PyTorch versions 25 | - python-version: 3.7 26 | torch-version: 1.4.0 27 | - python-version: 3.7 28 | torch-version: 1.5.0 29 | - python-version: 3.7 30 | torch-version: 1.5.1 31 | - python-version: 3.7 32 | torch-version: 1.6.0 33 | - python-version: 3.7 34 | torch-version: 1.7.0 35 | - python-version: 3.7 36 | torch-version: 1.7.1 37 | 38 | steps: 39 | - uses: actions/checkout@v2 40 | - name: Set up Python ${{ matrix.python-version }} 41 | uses: actions/setup-python@v2 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | 45 | - name: Install dependencies 46 | run: | 47 | sed -i 's/^torch>=.*/torch==${{ matrix.torch-version }}/g' requirements.txt 48 | pip install -r requirements.txt 49 | pip install -r requirements-dev.txt 50 | 51 | - name: Tests 52 | run: | 53 | python -m unittest discover -vs ./tests/ 54 | 55 | deploy: 56 | runs-on: ubuntu-latest 57 | needs: tests 58 | 59 | steps: 60 | - uses: actions/checkout@v2 61 | - name: Set up Python 62 | uses: actions/setup-python@v2 63 | with: 64 | python-version: 3.6 65 | 66 | - name: Install dependencies 67 | run: | 68 | python -m pip install --upgrade pip 69 | pip install setuptools wheel twine 70 | 71 | - name: Build and Publish 72 | env: 73 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 74 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 75 | run: | 76 | python setup.py sdist bdist_wheel 77 | twine upload dist/* --skip-existing 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # IDE 141 | .idea 142 | 143 | # ONNX Models 144 | *.onnx 145 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | args: ["--line-length", "119", "torchextractor/", "tests/"] 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.7.0 10 | hooks: 11 | - id: isort 12 | args: ["--line-length", "119", "--profile", "black", "--gitignore"] 13 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | configuration: docs/source/conf.py 8 | 9 | python: 10 | version: 3.6 11 | install: 12 | - requirements: requirements.txt 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `torchextractor`: PyTorch Intermediate Feature Extraction 2 | 3 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchextractor)](https://pypi.org/project/torchextractor/) 4 | [![PyPI](https://img.shields.io/pypi/v/torchextractor)](https://pypi.org/project/torchextractor/) 5 | [![Read the Docs](https://img.shields.io/readthedocs/torchextractor)](https://torchextractor.readthedocs.io/en/latest/) 6 | [![Upload Python Package](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml/badge.svg)](https://github.com/antoinebrl/torchextractor/actions/workflows/publish.yml) 7 | [![GitHub](https://img.shields.io/github/license/antoinebrl/torchextractor)](https://github.com/antoinebrl/torchextractor/blob/main/LICENSE) 8 | 9 | 10 | ## Introduction 11 | 12 | Too many times some model definitions get remorselessly copy-pasted just because the 13 | `forward` function does not return what the person expects. You provide module names 14 | and `torchextractor` takes care of the extraction for you.It's never been easier to 15 | extract feature, add an extra loss or plug another head to a network. 16 | Ler us know what amazing things you build with `torchextractor`! 17 | 18 | ## Installation 19 | 20 | ```shell 21 | pip install torchextractor # stable 22 | pip install git+https://github.com/antoinebrl/torchextractor.git # latest 23 | ``` 24 | 25 | Requirements: 26 | - Python >= 3.6+ 27 | - torch >= 1.4.0 28 | 29 | ## Usage 30 | 31 | ```python 32 | import torch 33 | import torchvision 34 | import torchextractor as tx 35 | 36 | model = torchvision.models.resnet18(pretrained=True) 37 | model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"]) 38 | dummy_input = torch.rand(7, 3, 224, 224) 39 | model_output, features = model(dummy_input) 40 | feature_shapes = {name: f.shape for name, f in features.items()} 41 | print(feature_shapes) 42 | 43 | # { 44 | # 'layer1': torch.Size([1, 64, 56, 56]), 45 | # 'layer2': torch.Size([1, 128, 28, 28]), 46 | # 'layer3': torch.Size([1, 256, 14, 14]), 47 | # 'layer4': torch.Size([1, 512, 7, 7]), 48 | # } 49 | ``` 50 | 51 | [See more examples](docs/source/examples.ipynb) 52 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/antoinebrl/torchextractor/HEAD?filepath=docs/source/examples.ipynb) 53 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/antoinebrl/torchextractor/blob/master/docs/source/examples.ipynb) 54 | 55 | [Read the documentation](https://torchextractor.readthedocs.io/en/latest/) 56 | 57 | ## FAQ 58 | 59 | **• How do I know the names of the modules?** 60 | 61 | You can print all module names like this: 62 | ```python 63 | tx.list_module_names(model) 64 | 65 | # OR 66 | 67 | for name, module in model.named_modules(): 68 | print(name) 69 | ``` 70 | 71 | **• Why do some operations not get listed?** 72 | 73 | It is not possible to add hooks if operations are not defined as modules. 74 | Therefore, `F.relu` cannot be captured but `nn.Relu()` can. 75 | 76 | **• How can I avoid listing all relevant modules?** 77 | 78 | You can specify a custom filtering function to hook the relevant modules: 79 | ```python 80 | # Hook everything ! 81 | module_filter_fn = lambda module, name: True 82 | 83 | # Capture of all modules inside first layer 84 | module_filter_fn = lambda module, name: name.startswith("layer1") 85 | 86 | # Focus on all convolutions 87 | module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d) 88 | 89 | model = tx.Extractor(model, module_filter_fn=module_filter_fn) 90 | ``` 91 | 92 | **• Is it compatible with ONNX?** 93 | 94 | `tx.Extractor` is compatible with ONNX! This means you can also access intermediate features maps after the export. 95 | 96 | Pro-tip: name the output nodes by using `output_names` when calling `torch.onnx.export`. 97 | 98 | **• Is it compatible with TorchScript?** 99 | 100 | Not yet, but we are working on it. Compiling registered hook of a module 101 | [was just recently added in PyTorch v1.8.0](https://github.com/pytorch/pytorch/pull/49544). 102 | 103 | **• "One more thing!" :wink:** 104 | 105 | By default we capture the latest output of the relevant modules, 106 | but you can specify your own custom operations. 107 | 108 | For example, to accumulate features over 10 forward passes you 109 | can do the following: 110 | ```python 111 | import torch 112 | import torchvision 113 | import torchextractor as tx 114 | 115 | model = torchvision.models.resnet18(pretrained=True) 116 | 117 | def capture_fn(module, input, output, module_name, feature_maps): 118 | if module_name not in feature_maps: 119 | feature_maps[module_name] = [] 120 | feature_maps[module_name].append(output) 121 | 122 | extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn) 123 | 124 | for i in range(20): 125 | for i in range(10): 126 | x = torch.rand(7, 3, 224, 224) 127 | model(x) 128 | feature_maps = extractor.collect() 129 | 130 | # Do your stuffs here 131 | 132 | # Discard collected elements 133 | extractor.clear_placeholder() 134 | ``` 135 | 136 | ## Contributing 137 | 138 | All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request! 139 | 140 | If you want to get hands-on: 141 | 1. (Fork and) clone the repo. 142 | 2. Create a virtual environment: `virtualenv -p python3 .venv && source .venv/bin/activate` 143 | 2. Install dependencies: `pip install -r requirements.txt && pip install -r requirements-dev.txt` 144 | 4. Hook auto-formatting tools: `pre-commit install` 145 | 5. Hack as much as you want! 146 | 6. Run tests: `python -m unittest discover -vs ./tests/` 147 | 7. Share your work and create a pull request. 148 | 149 | To Build documentation: 150 | ```shell 151 | cd docs 152 | pip install requirements.txt 153 | make html 154 | ``` -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | 10 | BUILDDIR = build 11 | 12 | # Put it first so that "make" without argument is like "make help". 13 | help: 14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 15 | 16 | .PHONY: help Makefile 17 | 18 | # Catch-all target: route all unknown targets to Sphinx using the new 19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 20 | %: Makefile 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | nbsphinx 4 | numpydoc 5 | m2r2 6 | jupyter -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | Extractor 5 | --------- 6 | .. autoclass:: torchextractor.Extractor 7 | :show-inheritance: 8 | 9 | Utils 10 | ----- 11 | 12 | .. autofunction:: torchextractor.list_module_names 13 | 14 | .. autofunction:: torchextractor.find_modules_by_names 15 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "torchextractor" 22 | copyright = "2021, Antoine Broyelle" 23 | author = "Antoine Broyelle" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | import torchextractor 27 | 28 | release = torchextractor.__version__ 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "numpydoc", 39 | "nbsphinx", 40 | "m2r2", 41 | ] 42 | 43 | source_suffix = [".rst", ".md", ".ipynb"] 44 | 45 | # List of patterns, relative to source directory, that match files and 46 | # directories to ignore when looking for source files. 47 | # This pattern also affects html_static_path and html_extra_path. 48 | exclude_patterns = [] 49 | 50 | 51 | # -- Options for HTML output ------------------------------------------------- 52 | 53 | # The theme to use for HTML and HTML Help pages. 54 | html_theme = "sphinx_rtd_theme" 55 | 56 | # Run the notebooks manually 57 | nbsphinx_execute = "never" 58 | 59 | autodoc_default_flags = [] 60 | autoclass_content = "init" 61 | autodoc_mock_imports = ["torch"] 62 | -------------------------------------------------------------------------------- /docs/source/examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "invisible-australian", 6 | "metadata": {}, 7 | "source": [ 8 | "# Examples" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "continental-galaxy", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "cathedral-washington", 26 | "metadata": { 27 | "scrolled": true 28 | }, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "Requirement already satisfied: torch in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (1.8.0)\n", 35 | "Requirement already satisfied: typing-extensions in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (3.7.4.3)\n", 36 | "Requirement already satisfied: dataclasses in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (0.8)\n", 37 | "Requirement already satisfied: numpy in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (1.19.5)\n", 38 | "Requirement already satisfied: torchvision in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (0.9.0)\n", 39 | "Requirement already satisfied: numpy in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (1.19.5)\n", 40 | "Requirement already satisfied: torch==1.8.0 in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (1.8.0)\n", 41 | "Requirement already satisfied: pillow>=4.1.1 in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (8.1.2)\n", 42 | "Requirement already satisfied: typing-extensions in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch==1.8.0->torchvision) (3.7.4.3)\n", 43 | "Requirement already satisfied: dataclasses in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch==1.8.0->torchvision) (0.8)\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "!pip install torch\n", 49 | "!pip install torchvision\n", 50 | "\n", 51 | "# Uncomment one of the following\n", 52 | "# !pip install torchextractor # stable\n", 53 | "!pip install git+https://github.com/antoinebrl/torchextractor.git # latest\n", 54 | "# import sys, os; sys.path.insert(0, os.path.abspath(\"../..\")) # current code\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "primary-grounds", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import torch\n", 65 | "import torchvision\n", 66 | "import torchextractor as tx" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "id": "accessory-pipeline", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "model = torchvision.models.resnet18()\n", 77 | "dummy_input = torch.rand(7, 3, 224, 224)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "whole-mystery", 83 | "metadata": {}, 84 | "source": [ 85 | "### List module names" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "id": "fossil-knitting", 92 | "metadata": { 93 | "scrolled": true 94 | }, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "\n", 101 | "conv1\n", 102 | "bn1\n", 103 | "relu\n", 104 | "maxpool\n", 105 | "layer1\n", 106 | "layer1.0\n", 107 | "layer1.0.conv1\n", 108 | "layer1.0.bn1\n", 109 | "layer1.0.relu\n", 110 | "layer1.0.conv2\n", 111 | "layer1.0.bn2\n", 112 | "layer1.1\n", 113 | "layer1.1.conv1\n", 114 | "layer1.1.bn1\n", 115 | "layer1.1.relu\n", 116 | "layer1.1.conv2\n", 117 | "layer1.1.bn2\n", 118 | "layer2\n", 119 | "layer2.0\n", 120 | "layer2.0.conv1\n", 121 | "layer2.0.bn1\n", 122 | "layer2.0.relu\n", 123 | "layer2.0.conv2\n", 124 | "layer2.0.bn2\n", 125 | "layer2.0.downsample\n", 126 | "layer2.0.downsample.0\n", 127 | "layer2.0.downsample.1\n", 128 | "layer2.1\n", 129 | "layer2.1.conv1\n", 130 | "layer2.1.bn1\n", 131 | "layer2.1.relu\n", 132 | "layer2.1.conv2\n", 133 | "layer2.1.bn2\n", 134 | "layer3\n", 135 | "layer3.0\n", 136 | "layer3.0.conv1\n", 137 | "layer3.0.bn1\n", 138 | "layer3.0.relu\n", 139 | "layer3.0.conv2\n", 140 | "layer3.0.bn2\n", 141 | "layer3.0.downsample\n", 142 | "layer3.0.downsample.0\n", 143 | "layer3.0.downsample.1\n", 144 | "layer3.1\n", 145 | "layer3.1.conv1\n", 146 | "layer3.1.bn1\n", 147 | "layer3.1.relu\n", 148 | "layer3.1.conv2\n", 149 | "layer3.1.bn2\n", 150 | "layer4\n", 151 | "layer4.0\n", 152 | "layer4.0.conv1\n", 153 | "layer4.0.bn1\n", 154 | "layer4.0.relu\n", 155 | "layer4.0.conv2\n", 156 | "layer4.0.bn2\n", 157 | "layer4.0.downsample\n", 158 | "layer4.0.downsample.0\n", 159 | "layer4.0.downsample.1\n", 160 | "layer4.1\n", 161 | "layer4.1.conv1\n", 162 | "layer4.1.bn1\n", 163 | "layer4.1.relu\n", 164 | "layer4.1.conv2\n", 165 | "layer4.1.bn2\n", 166 | "avgpool\n", 167 | "fc\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "for name, module in model.named_modules():\n", 173 | " print(name)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "boolean-purpose", 179 | "metadata": {}, 180 | "source": [ 181 | "### Extract features" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 6, 187 | "id": "southeast-worship", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "{'layer1': torch.Size([7, 64, 56, 56]),\n", 194 | " 'layer2': torch.Size([7, 128, 28, 28]),\n", 195 | " 'layer3': torch.Size([7, 256, 14, 14]),\n", 196 | " 'layer4': torch.Size([7, 512, 7, 7])}" 197 | ] 198 | }, 199 | "execution_count": 6, 200 | "metadata": {}, 201 | "output_type": "execute_result" 202 | } 203 | ], 204 | "source": [ 205 | "model = torchvision.models.resnet18()\n", 206 | "model = tx.Extractor(model, [\"layer1\", \"layer2\", \"layer3\", \"layer4\"])\n", 207 | "\n", 208 | "model_output, features = model(dummy_input)\n", 209 | "{name: f.shape for name, f in features.items()}" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "amino-taiwan", 215 | "metadata": {}, 216 | "source": [ 217 | "### Extract features from nested modules" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 7, 223 | "id": "committed-costs", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "data": { 228 | "text/plain": [ 229 | "{'layer1': torch.Size([7, 64, 56, 56]),\n", 230 | " 'layer2.1.conv1': torch.Size([7, 128, 28, 28]),\n", 231 | " 'layer3.0.downsample.0': torch.Size([7, 256, 14, 14]),\n", 232 | " 'layer4.0': torch.Size([7, 512, 7, 7])}" 233 | ] 234 | }, 235 | "execution_count": 7, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "model = torchvision.models.resnet18()\n", 242 | "model = tx.Extractor(model, [\"layer1\", \"layer2.1.conv1\", \"layer3.0.downsample.0\", \"layer4.0\"])\n", 243 | "\n", 244 | "model_output, features = model(dummy_input)\n", 245 | "{name: f.shape for name, f in features.items()}" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "id": "invalid-albany", 251 | "metadata": {}, 252 | "source": [ 253 | "### Filter modules" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "id": "independent-energy", 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "text/plain": [ 265 | "{'conv1': torch.Size([7, 64, 112, 112]),\n", 266 | " 'layer1.0.conv1': torch.Size([7, 64, 56, 56]),\n", 267 | " 'layer1.0.conv2': torch.Size([7, 64, 56, 56]),\n", 268 | " 'layer1.1.conv1': torch.Size([7, 64, 56, 56]),\n", 269 | " 'layer1.1.conv2': torch.Size([7, 64, 56, 56]),\n", 270 | " 'layer2.0.conv1': torch.Size([7, 128, 28, 28]),\n", 271 | " 'layer2.0.conv2': torch.Size([7, 128, 28, 28]),\n", 272 | " 'layer2.0.downsample.0': torch.Size([7, 128, 28, 28]),\n", 273 | " 'layer2.1.conv1': torch.Size([7, 128, 28, 28]),\n", 274 | " 'layer2.1.conv2': torch.Size([7, 128, 28, 28]),\n", 275 | " 'layer3.0.conv1': torch.Size([7, 256, 14, 14]),\n", 276 | " 'layer3.0.conv2': torch.Size([7, 256, 14, 14]),\n", 277 | " 'layer3.0.downsample.0': torch.Size([7, 256, 14, 14]),\n", 278 | " 'layer3.1.conv1': torch.Size([7, 256, 14, 14]),\n", 279 | " 'layer3.1.conv2': torch.Size([7, 256, 14, 14]),\n", 280 | " 'layer4.0.conv1': torch.Size([7, 512, 7, 7]),\n", 281 | " 'layer4.0.conv2': torch.Size([7, 512, 7, 7]),\n", 282 | " 'layer4.0.downsample.0': torch.Size([7, 512, 7, 7]),\n", 283 | " 'layer4.1.conv1': torch.Size([7, 512, 7, 7]),\n", 284 | " 'layer4.1.conv2': torch.Size([7, 512, 7, 7])}" 285 | ] 286 | }, 287 | "execution_count": 8, 288 | "metadata": {}, 289 | "output_type": "execute_result" 290 | } 291 | ], 292 | "source": [ 293 | "model = torchvision.models.resnet18()\n", 294 | "module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)\n", 295 | "model = tx.Extractor(model, module_filter_fn=module_filter_fn)\n", 296 | "\n", 297 | "model_output, features = model(dummy_input)\n", 298 | "{name: f.shape for name, f in features.items()}" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "id": "registered-raising", 304 | "metadata": {}, 305 | "source": [ 306 | "### ONNX export with named output nodes" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 9, 312 | "id": "amazing-demand", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "model = torchvision.models.resnet18()\n", 317 | "model = tx.Extractor(model, [\"layer3\", \"layer4\"])\n", 318 | "\n", 319 | "torch.onnx.export(model, dummy_input, \"resnet.onnx\", output_names=[\"classifier\", \"layer3\", \"layer4\"])" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "raised-causing", 325 | "metadata": {}, 326 | "source": [ 327 | "### Custom Operation" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 10, 333 | "id": "parliamentary-swedish", 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "name": "stdout", 338 | "output_type": "stream", 339 | "text": [ 340 | "layer3: 10 items\n", 341 | " 1 - torch.Size([7, 256, 14, 14])\n", 342 | " 2 - torch.Size([7, 256, 14, 14])\n", 343 | " 3 - torch.Size([7, 256, 14, 14])\n", 344 | " 4 - torch.Size([7, 256, 14, 14])\n", 345 | " 5 - torch.Size([7, 256, 14, 14])\n", 346 | " 6 - torch.Size([7, 256, 14, 14])\n", 347 | " 7 - torch.Size([7, 256, 14, 14])\n", 348 | " 8 - torch.Size([7, 256, 14, 14])\n", 349 | " 9 - torch.Size([7, 256, 14, 14])\n", 350 | " 10 - torch.Size([7, 256, 14, 14])\n", 351 | "layer4: 10 items\n", 352 | " 1 - torch.Size([7, 512, 7, 7])\n", 353 | " 2 - torch.Size([7, 512, 7, 7])\n", 354 | " 3 - torch.Size([7, 512, 7, 7])\n", 355 | " 4 - torch.Size([7, 512, 7, 7])\n", 356 | " 5 - torch.Size([7, 512, 7, 7])\n", 357 | " 6 - torch.Size([7, 512, 7, 7])\n", 358 | " 7 - torch.Size([7, 512, 7, 7])\n", 359 | " 8 - torch.Size([7, 512, 7, 7])\n", 360 | " 9 - torch.Size([7, 512, 7, 7])\n", 361 | " 10 - torch.Size([7, 512, 7, 7])\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "model = torchvision.models.resnet18()\n", 367 | "\n", 368 | "# Concatenate outputs of every runs\n", 369 | "def capture_fn(module, input, output, module_name, feature_maps):\n", 370 | " if module_name not in feature_maps:\n", 371 | " feature_maps[module_name] = []\n", 372 | " feature_maps[module_name].append(output)\n", 373 | " \n", 374 | "\n", 375 | "extractor = tx.Extractor(model, [\"layer3\", \"layer4\"], capture_fn=capture_fn)\n", 376 | "\n", 377 | "for i in range(10):\n", 378 | " x = torch.rand(7, 3, 224, 224)\n", 379 | " model(x)\n", 380 | "\n", 381 | "feature_maps = extractor.collect()\n", 382 | "for name, features in feature_maps.items():\n", 383 | " print(f\"{name}: {len(features)} items\")\n", 384 | " for i, f in enumerate(features):\n", 385 | " print(f\" {i+1} - {f.shape}\")" 386 | ] 387 | } 388 | ], 389 | "metadata": { 390 | "kernelspec": { 391 | "display_name": "Python 3", 392 | "language": "python", 393 | "name": "python3" 394 | }, 395 | "language_info": { 396 | "codemirror_mode": { 397 | "name": "ipython", 398 | "version": 3 399 | }, 400 | "file_extension": ".py", 401 | "mimetype": "text/x-python", 402 | "name": "python", 403 | "nbconvert_exporter": "python", 404 | "pygments_lexer": "ipython3", 405 | "version": "3.6.9" 406 | } 407 | }, 408 | "nbformat": 4, 409 | "nbformat_minor": 5 410 | } 411 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | torchextractor 2 | -------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: Contents: 7 | 8 | torchextractor 9 | api.rst 10 | Examples 11 | -------------------------------------------------------------------------------- /docs/source/readme_link.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../README.md 2 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit==2.10.1 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | numpy 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r", encoding="utf-8") as f: 4 | long_description = f.read() 5 | 6 | exec(open("torchextractor/version.py").read()) 7 | 8 | setup( 9 | name="torchextractor", # Replace with your own username 10 | version=__version__, 11 | author="Antoine Broyelle", 12 | author_email="antoine.broyelle@pm.me", 13 | description="Pytorch feature extraction made simple", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/antoinebrl/torchextractor", 17 | project_urls={ 18 | "Bug Tracker": "https://github.com/antoinebrl/torchextractor/issues", 19 | }, 20 | classifiers=[ 21 | "Development Status :: 4 - Beta", 22 | "Natural Language :: English", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Operating System :: OS Independent", 25 | "Intended Audience :: Education", 26 | "Intended Audience :: Developers", 27 | "Intended Audience :: Science/Research", 28 | "Topic :: Scientific/Engineering", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | "Topic :: Software Development", 31 | "Topic :: Software Development :: Libraries", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3 :: Only", 35 | "Programming Language :: Python :: 3.6", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | "Programming Language :: Python :: 3.9", 39 | ], 40 | keywords="pytorch torch feature extraction", 41 | packages=find_packages(exclude=["tests"]), 42 | python_requires=">=3.6", 43 | install_requires=["numpy", "torch>=1.4.0"], 44 | ) 45 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antoinebrl/torchextractor/b48ea752dd5089bc360be9b375d7c3fd01b2040a/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_extractor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from torchextractor.extractor import Extractor 8 | 9 | 10 | class MyTinyVGG(nn.Module): 11 | def __init__(self): 12 | super(MyTinyVGG, self).__init__() 13 | in_channels = 3 14 | nb_channels = 12 15 | nb_classes = 17 16 | 17 | self.block1 = self._make_layer(in_channels, nb_channels) 18 | in_channels, nb_channels = nb_channels, 2 * nb_channels 19 | self.block2 = self._make_layer(in_channels, nb_channels) 20 | in_channels, nb_channels = nb_channels, 2 * nb_channels 21 | self.block3 = self._make_layer(in_channels, nb_channels) 22 | 23 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 24 | self.classifier = nn.Sequential( 25 | nn.Linear(nb_channels, nb_channels), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(nb_channels, nb_channels), 29 | nn.ReLU(True), 30 | nn.Dropout(), 31 | nn.Linear(nb_channels, nb_classes), 32 | ) 33 | 34 | def _make_layer(self, in_channels, nb_channels): 35 | layer1 = nn.Sequential( 36 | nn.Conv2d(in_channels, nb_channels, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(nb_channels), 38 | nn.ReLU(inplace=True), 39 | ) 40 | layer2 = nn.Sequential( 41 | nn.Conv2d(nb_channels, nb_channels, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(nb_channels), 43 | nn.ReLU(inplace=True), 44 | ) 45 | return nn.Sequential( 46 | OrderedDict([("layer1", layer1), ("layer2", layer2), ("pool", nn.MaxPool2d(kernel_size=2, stride=2))]) 47 | ) 48 | 49 | def forward(self, x): 50 | x = self.block1(x) 51 | x = self.block2(x) 52 | x = self.block3(x) 53 | x = self.avgpool(x) 54 | x = x.squeeze(3).squeeze(2) 55 | x = self.classifier(x) 56 | return x 57 | 58 | 59 | class TestExtractor(unittest.TestCase): 60 | def test_model_output(self): 61 | model = MyTinyVGG() 62 | extractor = Extractor(model, ["block1", "block2"]).eval() 63 | input = torch.rand((5, 3, 32, 32)) 64 | 65 | output_model = model(input) 66 | output_extractor, _ = extractor(input) 67 | 68 | self.assertTrue(torch.allclose(output_model, output_extractor)) 69 | 70 | def test_forward_capture_feature_maps(self): 71 | model = MyTinyVGG() 72 | names = ["block1", "block2"] 73 | model = Extractor(model, names).eval() 74 | 75 | input = torch.rand((5, 3, 32, 32)) 76 | output, feature_maps = model(input) 77 | 78 | self.assertTrue(all(True if name in feature_maps else False for name in names)) 79 | self.assertEqual(list(feature_maps["block1"].shape), [5, 12, 16, 16]) 80 | self.assertEqual(list(feature_maps["block2"].shape), [5, 24, 8, 8]) 81 | 82 | def test_capture_latest_feature_map(self): 83 | model = MyTinyVGG() 84 | names = ["block1", "block2"] 85 | extractor = Extractor(model, names).eval() 86 | 87 | input1 = torch.rand((5, 3, 32, 32)) 88 | model(input1) 89 | feature_maps1 = extractor.collect() 90 | shapes = {name: list(output.shape) for name, output in feature_maps1.items()} 91 | 92 | input2 = torch.rand((5, 3, 64, 64)) 93 | model(input2) 94 | feature_maps2 = extractor.collect() 95 | 96 | self.assertTrue(all(True if name in feature_maps1 else False for name in names)) 97 | for name in names: 98 | self.assertNotEqual(shapes[name], list(feature_maps2[name].shape)) 99 | self.assertEqual(list(feature_maps1[name].shape), list(feature_maps2[name].shape)) 100 | 101 | def test_destroy_extractor(self): 102 | model = MyTinyVGG() 103 | names = ["block1", "block2"] 104 | extractor = Extractor(model, names).eval() 105 | 106 | input1 = torch.rand((5, 3, 32, 32)) 107 | model(input1) 108 | feature_maps = extractor.collect() 109 | shapes1 = {name: list(output.shape) for name, output in feature_maps.items()} 110 | 111 | del extractor 112 | 113 | input2 = torch.rand((5, 3, 64, 64)) 114 | model(input2) 115 | shapes2 = {name: list(output.shape) for name, output in feature_maps.items()} 116 | 117 | # captured content should not change if hooks are no longer operating 118 | for name in names: 119 | self.assertEqual(shapes1[name], shapes2[name]) 120 | 121 | def test_onnx_export(self): 122 | model = MyTinyVGG() 123 | names = ["block1", "block2"] 124 | model = Extractor(model, names).eval() 125 | input = torch.rand((5, 3, 32, 32)) 126 | output, feature_maps = model(input) 127 | 128 | torch.onnx.export(model, input, "/tmp/model.onnx", output_names=["classifier"] + list(feature_maps.keys())) 129 | 130 | try: 131 | import onnx 132 | 133 | model = onnx.load("/tmp/model.onnx") 134 | output_names = [node.name for node in model.graph.output] 135 | for name in names: 136 | self.assertTrue(name in output_names) 137 | except ImportError as e: 138 | print(e) 139 | 140 | 141 | if __name__ == "__main__": 142 | unittest.main() 143 | -------------------------------------------------------------------------------- /tests/test_naming.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import OrderedDict 3 | 4 | from torch import nn 5 | 6 | from torchextractor.naming import attach_name_to_modules, find_modules_by_names 7 | 8 | 9 | class TestNaming(unittest.TestCase): 10 | def test_sequential_model(self): 11 | model = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()) 12 | 13 | attach_name_to_modules(model) 14 | 15 | for i in range(4): 16 | self.assertEqual(model[i]._extractor_fullname, str(i)) 17 | 18 | def test_sequential_model_with_dict(self): 19 | names = ["conv1", "relu1", "conv2", "relu2"] 20 | ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()] 21 | model = nn.Sequential(OrderedDict(zip(names, ops))) 22 | 23 | attach_name_to_modules(model) 24 | 25 | for i, name in enumerate(names): 26 | self.assertEqual(model[i]._extractor_fullname, name) 27 | 28 | def test_sequential_module_inheritance(self): 29 | class MyModel(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.conv1 = nn.Conv2d(1, 20, 5) 33 | self.conv2 = nn.Conv2d(20, 20, 5) 34 | 35 | model = MyModel() 36 | attach_name_to_modules(model) 37 | 38 | module_iter = model.children() 39 | module1 = next(module_iter) 40 | self.assertEqual(module1._extractor_fullname, "conv1") 41 | module2 = next(module_iter) 42 | self.assertEqual(module2._extractor_fullname, "conv2") 43 | # Only two modules 44 | self.assertIsNone(next(module_iter, None)) 45 | 46 | def test_nested_modules(self): 47 | class MyModel(nn.Module): 48 | def __init__(self): 49 | super().__init__() 50 | self.block1 = nn.Sequential( 51 | nn.Sequential(nn.Linear(4, 4), nn.Sigmoid(), nn.Linear(4, 1), nn.Sigmoid()), 52 | nn.Sigmoid(), 53 | ) 54 | 55 | model = MyModel() 56 | attach_name_to_modules(model) 57 | 58 | self.assertEqual(model.block1[0][2]._extractor_fullname, "block1.0.2") 59 | 60 | 61 | class TestModuleSearch(unittest.TestCase): 62 | def test_sequential_model(self): 63 | names = ["conv1", "relu1", "conv2", "relu2"] 64 | ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()] 65 | model = nn.Sequential(OrderedDict(zip(names, ops))) 66 | attach_name_to_modules(model) 67 | 68 | search_names = ["conv1", "relu2"] 69 | modules = find_modules_by_names(model, search_names) 70 | 71 | # All names have a match 72 | self.assertTrue(all([name in modules for name in search_names])) 73 | 74 | # Each name links to the right module 75 | self.assertEqual(id(modules["conv1"]), id(ops[0])) 76 | self.assertEqual(id(modules["relu2"]), id(ops[3])) 77 | 78 | def test_module_not_found(self): 79 | names = ["conv1", "relu1", "conv2", "relu2"] 80 | ops = [nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()] 81 | model = nn.Sequential(OrderedDict(zip(names, ops))) 82 | attach_name_to_modules(model) 83 | 84 | search_names = ["conv1", "azertyuiop"] 85 | modules = find_modules_by_names(model, search_names) 86 | 87 | # Each name links to the right module 88 | self.assertFalse("azertyuiop" in modules) 89 | 90 | def test_nested_modules(self): 91 | class MyModel(nn.Module): 92 | def __init__(self): 93 | super().__init__() 94 | self.block1 = nn.Sequential( 95 | nn.Sequential(nn.Linear(4, 4), nn.Sigmoid(), nn.Linear(4, 1), nn.Sigmoid()), 96 | nn.Sigmoid(), 97 | ) 98 | 99 | model = MyModel() 100 | attach_name_to_modules(model) 101 | modules = find_modules_by_names(model, ["block1.0.2"]) 102 | 103 | self.assertEqual(id(model.block1[0][2]), id(modules["block1.0.2"])) 104 | 105 | 106 | if __name__ == "__main__": 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /torchextractor/__init__.py: -------------------------------------------------------------------------------- 1 | from .extractor import Extractor 2 | from .naming import find_modules_by_names, list_module_names 3 | from .version import __version__ 4 | -------------------------------------------------------------------------------- /torchextractor/extractor.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Dict 3 | from typing import Iterable as IterableType 4 | from typing import List, Tuple 5 | 6 | from torch import nn 7 | 8 | from .naming import attach_name_to_modules 9 | 10 | 11 | def hook_wrapper(module: nn.Module, input: Any, output: Any, capture_fn: Callable, feature_maps: Dict[str, Any]): 12 | """ 13 | Hook wrapper to expose module name to hook 14 | """ 15 | capture_fn(module, input, output, module._extractor_fullname, feature_maps) 16 | 17 | 18 | def hook_capture_module_output( 19 | module: nn.Module, input: Any, output: Any, module_name: str, feature_maps: Dict[str, Any] 20 | ): 21 | """ 22 | Hook function to capture the output of the module. 23 | 24 | Parameters 25 | ---------- 26 | module: nn.Module 27 | The module doing the computations. 28 | input: 29 | Whatever is provided to the module. 30 | output: 31 | Whatever is computed by the module. 32 | module_name: str 33 | Fully qualifying name of the module 34 | feature_maps: dictionary - keys: fully qualifying module names 35 | Placeholder to store the output of the modules so it can be used later on 36 | """ 37 | feature_maps[module_name] = output 38 | 39 | 40 | def register_hook(module_filter_fn: Callable, hook: Callable, hook_handles: List) -> Callable: 41 | """ 42 | Attach a hook to some relevant modules. 43 | 44 | Parameters 45 | ---------- 46 | module_filter_fn: callable 47 | A filtering function called for each module. When evaluated to `True` a hook is registered. 48 | hook: callable 49 | The hook to register. See documentation about PyTorch hooks. 50 | https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook 51 | hook_handles: list 52 | A placeholders containing all newly registered hooks 53 | 54 | Returns 55 | ------- 56 | callable function to apply on each module 57 | 58 | """ 59 | 60 | def init_hook(module: nn.Module): 61 | if module_filter_fn(module, module._extractor_fullname): 62 | handle = module.register_forward_hook(hook) 63 | hook_handles.append(handle) 64 | 65 | return init_hook 66 | 67 | 68 | class Extractor(nn.Module): 69 | def __init__( 70 | self, 71 | model: nn.Module, 72 | module_names: IterableType[str] = None, 73 | module_filter_fn: Callable = None, 74 | capture_fn: Callable = None, 75 | ): 76 | """ 77 | Capture the intermediate feature maps of of model. 78 | 79 | Parameters 80 | ---------- 81 | model: nn.Module, 82 | The model to extract features from. 83 | 84 | module_names: list of str, default None 85 | The fully qualified names of the modules producing the relevant feature maps. 86 | 87 | module_filter_fn: callable, default None 88 | A filtering function. Takes a module and module name as input and returns True for modules 89 | producing the relevant features. Either `module_names` or `module_filter_fn` should be 90 | provided but not both at the same time. 91 | 92 | Example:: 93 | 94 | def module_filter_fn(module, name): 95 | return isinstance(module, torch.nn.Conv2d) 96 | 97 | capture_fn: callable, default None 98 | Operation to carry at each forward pass. The function should comply to the following interface. 99 | 100 | Example:: 101 | 102 | def capture_fn( 103 | module: nn.Module, 104 | input: Any, 105 | output: Any, 106 | module_name:str, 107 | feature_maps: Dict[str, Any] 108 | ): 109 | feature_maps[module_name] = output 110 | """ 111 | assert ( 112 | module_names is not None or module_filter_fn is not None 113 | ), "Module names or a filtering function must be provided" 114 | assert not (module_names is not None and module_filter_fn is not None), ( 115 | "You should either specify the fully qualifying names of the modules or a filtering function " 116 | "but not both at the same time" 117 | ) 118 | 119 | super(Extractor, self).__init__() 120 | self.model = attach_name_to_modules(model) 121 | 122 | self.feature_maps = {} 123 | self.hook_handles = [] 124 | 125 | module_filter_fn = module_filter_fn or (lambda module, name: name in module_names) 126 | capture_fn = capture_fn or hook_capture_module_output 127 | hook_fn = partial(hook_wrapper, capture_fn=capture_fn, feature_maps=self.feature_maps) 128 | self.model.apply(register_hook(module_filter_fn, hook_fn, self.hook_handles)) 129 | 130 | def collect(self) -> Dict[str, nn.Module]: 131 | """ 132 | Returns the structure holding the most recent feature maps. 133 | 134 | Notes 135 | _____ 136 | The return structure is mutated at each forward pass of the model. 137 | It is the caller responsibility to duplicate the structure content if needed. 138 | """ 139 | return self.feature_maps 140 | 141 | def clear_placeholder(self): 142 | """ 143 | Resets the structure holding captured feature maps. 144 | """ 145 | self.feature_maps.clear() 146 | 147 | def forward(self, *args, **kwargs) -> Tuple[Any, Dict[str, nn.Module]]: 148 | """ 149 | Performs model computations and collects feature maps 150 | 151 | Returns 152 | ------- 153 | Model output and intermediate feature maps 154 | """ 155 | output = self.model(*args, **kwargs) 156 | return output, self.feature_maps 157 | 158 | def __del__(self): 159 | # Unregister hooks 160 | for handle in self.hook_handles: 161 | handle.remove() 162 | -------------------------------------------------------------------------------- /torchextractor/naming.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Iterable 3 | from typing import Dict 4 | from typing import Iterable as IterableType 5 | from typing import List 6 | 7 | from torch import nn 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def list_module_names(model: nn.Module) -> List[str]: 13 | """ 14 | List names of modules and submodules. 15 | 16 | Parameters 17 | ---------- 18 | model: nn.Module 19 | PyTorch model to examine. 20 | 21 | Returns 22 | ------- 23 | list[str]: 24 | List of names 25 | """ 26 | return [name for name, module in model.named_modules()] 27 | 28 | 29 | def attach_name_to_modules(model: nn.Module) -> nn.Module: 30 | """ 31 | Assign a unique name to each module based on the nested structure of the model. 32 | 33 | Parameters 34 | ---------- 35 | model: nn.Module 36 | PyTorch model to decorate with fully qualifying names for each module. 37 | 38 | Returns 39 | ------- 40 | model: nn.Module. 41 | The provided model as input. 42 | 43 | """ 44 | for name, module in model.named_modules(): 45 | module._extractor_fullname = name 46 | return model 47 | 48 | 49 | def find_modules_by_names(model: nn.Module, names: IterableType[str]) -> Dict[str, nn.Module]: 50 | """ 51 | Find some modules given their fully qualifying names. 52 | 53 | Parameters 54 | ---------- 55 | model: nn.Module 56 | PyTorch model to examine. 57 | names: list of str 58 | List of fully qualifying names. 59 | 60 | Returns 61 | ------- 62 | dict: name -> module 63 | If no match is found for a name, it is not added to the returned structure 64 | 65 | """ 66 | assert isinstance(names, (list, tuple)) 67 | found_modules = {name: module for name, module in model.named_modules() if name in names} 68 | return found_modules 69 | -------------------------------------------------------------------------------- /torchextractor/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.0" 2 | --------------------------------------------------------------------------------