├── .flake8 ├── .github └── workflows │ └── workflow.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── flop_count.md ├── fvcore ├── __init__.py ├── common │ ├── __init__.py │ ├── benchmark.py │ ├── checkpoint.py │ ├── config.py │ ├── download.py │ ├── file_io.py │ ├── history_buffer.py │ ├── param_scheduler.py │ ├── registry.py │ └── timer.py ├── nn │ ├── __init__.py │ ├── activation_count.py │ ├── distributed.py │ ├── flop_count.py │ ├── focal_loss.py │ ├── giou_loss.py │ ├── jit_analysis.py │ ├── jit_handles.py │ ├── parameter_count.py │ ├── precise_bn.py │ ├── print_model_statistics.py │ ├── smooth_l1_loss.py │ ├── squeeze_excitation.py │ └── weight_init.py └── transforms │ ├── __init__.py │ ├── transform.py │ └── transform_util.py ├── io_tests └── test_file_io.py ├── linter.sh ├── packaging ├── build_all_conda.sh ├── build_conda.sh └── fvcore │ └── meta.yaml ├── setup.cfg ├── setup.py └── tests ├── bm_common.py ├── bm_focal_loss.py ├── bm_main.py ├── configs ├── base.yaml ├── base2.yaml ├── config.yaml └── config_multi_base.yaml ├── param_scheduler ├── test_scheduler_composite.py ├── test_scheduler_constant.py ├── test_scheduler_cosine.py ├── test_scheduler_exponential.py ├── test_scheduler_linear.py ├── test_scheduler_multi_step.py ├── test_scheduler_polynomial.py ├── test_scheduler_step.py └── test_scheduler_step_with_fixed_gamma.py ├── test_activation_count.py ├── test_checkpoint.py ├── test_common.py ├── test_flop_count.py ├── test_focal_loss.py ├── test_giou_loss.py ├── test_jit_model_analysis.py ├── test_layers_squeeze_excitation.py ├── test_param_count.py ├── test_precise_bn.py ├── test_print_model_statistics.py ├── test_smooth_l1_loss.py ├── test_transform.py ├── test_transform_util.py └── test_weight_init.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, E221, E741, B905 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | exclude = build,__init__.py 7 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [push, pull_request] 3 | 4 | # Run linter with github actions for quick feedbacks. 5 | # The unittests will be run on CircleCI instead. 6 | jobs: 7 | linter: 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: [3.8, 3.9] # importlib-metadata v5 requires 3.8+ 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install flake8==3.8.1 flake8-bugbear flake8-comprehensions isort==4.3.21 23 | pip install black==24.2.0 24 | flake8 --version 25 | - name: Lint 26 | run: | 27 | echo "Running isort" 28 | isort -c -sp . 29 | echo "Running black" 30 | black --check . 31 | echo "Running flake8" 32 | flake8 . 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *egg-info* 4 | 5 | # C extensions 6 | *.so 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to fvcore 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to fvcore, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fvcore 2 | 3 | [![Support Ukraine](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://opensource.fb.com/support-ukraine) 4 | 5 | fvcore is a light-weight core library that provides the most common and essential 6 | functionality shared in various computer vision frameworks developed in FAIR, 7 | such as [Detectron2](https://github.com/facebookresearch/detectron2/), 8 | [PySlowFast](https://github.com/facebookresearch/SlowFast), and 9 | [ClassyVision](https://github.com/facebookresearch/ClassyVision). 10 | All components in this library are type-annotated, tested, and benchmarked. 11 | 12 | The computer vision team in FAIR is responsible for maintaining this library. 13 | 14 | ## Features: 15 | 16 | Besides some basic utilities, fvcore includes the following features: 17 | * Common pytorch layers, functions and losses in [fvcore.nn](fvcore/nn/). 18 | * A hierarchical per-operator flop counting tool: see [this note for details](./docs/flop_count.md). 19 | * Recursive parameter counting: see [API doc](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.parameter_count). 20 | * Recompute BatchNorm population statistics: see its [API doc](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.update_bn_stats). 21 | * A stateless, scale-invariant hyperparameter scheduler: see its [API doc](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.ParamScheduler). 22 | 23 | ## Install: 24 | 25 | fvcore requires pytorch and python >= 3.6. 26 | 27 | Use one of the following ways to install: 28 | 29 | ### 1. Install from PyPI (updated nightly) 30 | ``` 31 | pip install -U fvcore 32 | ``` 33 | 34 | ### 2. Install from Anaconda Cloud (updated nightly) 35 | 36 | ``` 37 | conda install -c fvcore -c iopath -c conda-forge fvcore 38 | ``` 39 | 40 | ### 3. Install latest from GitHub 41 | ``` 42 | pip install -U 'git+https://github.com/facebookresearch/fvcore' 43 | ``` 44 | 45 | ### 4. Install from a local clone 46 | ``` 47 | git clone https://github.com/facebookresearch/fvcore 48 | pip install -e fvcore 49 | ``` 50 | 51 | ## License 52 | 53 | This library is released under the [Apache 2.0 license](https://github.com/facebookresearch/fvcore/blob/main/LICENSE). 54 | -------------------------------------------------------------------------------- /docs/flop_count.md: -------------------------------------------------------------------------------- 1 | # Flop Counter for PyTorch Models 2 | 3 | fvcore contains a flop-counting tool for pytorch models -- the __first__ tool that can provide both __operator-level__ and __module-level__ flop counts together. We also provide functions to display the results according to the module hierarchy. We hope this tool can help pytorch users analyze their models more easily! 4 | 5 | ## Existing Approaches: 6 | 7 | To our knowledge, a good flop counter for pytorch models that satisfy our needs do not yet exist. We review some existing solutions below: 8 | 9 | ### Count per-module flops in module-hooks 10 | 11 | There are many existing tools (in [pytorch-OpCounter](https://github.com/Lyken17/pytorch-OpCounter), [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch), [mmcv](https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py), [pytorch_model_summary](https://github.com/ceykmc/pytorch_model_summary), and our own [mobile_cv](https://github.com/facebookresearch/mobile-vision/blob/master/mobile_cv/lut/lib/pt/flops_utils.py)) that count per-module flops using Pytorch’s [module forward hooks](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=module%20hook#torch.nn.Module.register_forward_hook). They work well for many models, but suffer from the same limitation that makes it hard to get accurate results: 12 | 13 | * They are accurate only if every custom module implements a corresponding flop counter. This is too much extra effort for users. 14 | * In addition, determining the flop counter of a complicated module requires manual inspection of its forward code to see what raw ops are called. A refactor of the forward logic (replacing raw ops by submodules) might require a change in its flop counter. 15 | * When a module contains control flow internally, the counting code has to replicate some of the module’s forward logic. 16 | 17 | ### Count per-operator flops 18 | 19 | These limitations of per-module counting suggest that counting flops at operator level would be better: unlike a large set of custom modules users typically create, custom operators are much less common. Also, operators typically don’t contain control logic that needs to be replicated in its flop counter. For accurate results, operator-level counting is a more preferable approach. 20 | 21 | Pytorch’s profiler recently added [flop counting capability](https://github.com/pytorch/pytorch/pull/46506), which accumulates flops for all operators encountered during profiling. However, some features that are highly desirable for research are yet to be supported: 22 | 23 | * Module-level aggregation: `nn.Module` is the level of abstraction where users design models. To help design efficient models, providing flops per `nn.Module` in a recursive hierarchy is needed. 24 | * Customization: flops is in fact sometimes ambiguously defined due to research/production needs, or community convention. We’d like the ability to customize the counting by supplying formula for each operator. 25 | 26 | ### Count *actual* hardware instructions 27 | `perf stat` can collect actual instruction count of a command. After taking into consideration the SIMD instructions, it may be used to compute *actual* total flops that's hardware and implementation dependent. We have also noticed a [blog post](http://www.bnikolic.co.uk/blog/python/flops/2019/10/01/pytorch-count-flops.html) that uses PAPI on intel CPUs to count flops, but this tool can significantly undercount by a factor of 3~4x due to SIMD instructions. 28 | 29 | ## Our Work 30 | 31 | We create a flop counting tool in fvcore, which: 32 | 33 | * is accurate for a majority of use cases: it observes all operator calls and collects operator-level flop counts 34 | * can provide aggregated flop counts for each module, and display the flop counts in a hierarchical way 35 | * can be customized from Python to supply flop counting formulas for each operator 36 | 37 | It has an interface like this: 38 | ``` 39 | $ from fvcore.nn import FlopCountAnalysis 40 | $ flops = FlopCountAnalysis(model, input) 41 | $ flops.total() 42 | 274656 43 | $ flops.by_operator() 44 | Counter({'conv': 194616, 'addmm': 80040}) 45 | $ flops.by_module() 46 | Counter({'': 274656, 'conv1': 48600, 47 | 'conv2': 146016, 'fc1': 69120, 48 | 'fc2': 10080, 'fc3': 840}) 49 | $ flops.by_module_and_operator() 50 | {'': Counter({'conv': 194616, 'addmm': 80040}), 51 | 'conv1': Counter({'conv': 48600}), 52 | 'conv2': Counter({'conv': 146016}), 53 | 'fc1': Counter({'addmm': 69120}), 54 | 'fc2': Counter({'addmm': 10080}), 55 | 'fc3': Counter({'addmm': 840})} 56 | ``` 57 | 58 | In addition to providing the results above, the class also allows users to add/override the formula to handle certain ops or ignore certain ops. See [API documentation](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis) for details. 59 | 60 | We further supply functions to pretty-print the results in two styles demonstrated in the image below: 61 |
62 | 63 |
64 | 65 | Toy examples are not enough. Below are the pretty-print results of 3 real-world models: we hope they are complicated enough to convince you that the tool probably works for your model as well. 66 | 67 | * [Mask R-CNN in detectron2](https://gist.github.com/ppwwyyxx/1885ec8aaf5093a8d40cdde2b6559ab3#file-mask-r-cnn-from-detectron2) 68 | * [Roberta in fairseq](https://gist.github.com/ppwwyyxx/1885ec8aaf5093a8d40cdde2b6559ab3#file-roberta-from-fairseq) 69 | * [ViT in classyvision](https://gist.github.com/ppwwyyxx/1885ec8aaf5093a8d40cdde2b6559ab3#file-vit-from-classyvision) 70 | 71 | In addition, our approach is not limited to flop counting, but can collect other operator-level statistics during the execution of a model. For example, [recent research](https://arxiv.org/abs/2003.13678) shows that flop count is poorly correlated with GPU latency, and proposes to use “activation counts” or memory footprint as another metric. We have added `fvcore.nn.ActivationCountAnalysis` that is able to produce this metric as well. 72 | 73 | 74 | ## Appendix: Mechanism & Limitations 75 | 76 | Here is briefly how the tool works: 77 | 78 | 1. It uses pytorch to trace the execution of the model and obtain a graph. This graph records the input/output shapes of every operator, which allows us to compute flops. 79 | 2. During tracing, module forward hooks insert per-module information into the graph: upon entering and exiting a module’s forward method, we push/pop the jit tracing scope. After tracing, we use the scopes associated with each operator to figure out which module it belongs to. 80 | 81 | The approach still has the following limitations in corner cases, but we think none of them is going to be a deal-breaker for most users (as demonstrated by a few representative models shown above): 82 | 83 | 1. It `torch.jit.trace` the given model & inputs, which means (1) only `model.forward` is used, but not other methods (2) inputs/outputs of `model.forward` shall be (tuple of) tensors but not arbitrary classes. 84 | 85 | When the above tracing requirements do not satisfy, a simple wrapper around the model is sufficient to make it traceable. (In detectron2 we built a [universal wrapper](https://github.com/facebookresearch/detectron2/blob/543fd075e146261c2e2b0770c9b537314bdae572/detectron2/utils/analysis.py#L63-L65) that recognizes common data structures, to automatically make a model traceable). 86 | 87 | 2. Forward hooks are only triggered if a module is called by `__call__()`. When a submodule is called with an explicit `.forward()` or other methods, operators may unnaturally contribute to parent modules instead. This doesn’t affect accuracy of total flop counts though. 88 | 3. JIT tracing currently prunes away ops that are not used by results. However, as tracing does not capture control flow, it may prune away useful ops whose results only affect control flow. This may lead to under-counting. 89 | 90 | We’d like to see if there are ways to disable the pruning. Meanwhile, it should be very rare that a heavy computation only affects control flow but not directly connected to the final outputs in the computation graph, so this corner case is probably unimportant. 91 | -------------------------------------------------------------------------------- /fvcore/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | # This line will be programatically read/write by setup.py. 6 | # Leave them at the bottom of this file and don't touch them. 7 | __version__ = "0.1.6" 8 | -------------------------------------------------------------------------------- /fvcore/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | -------------------------------------------------------------------------------- /fvcore/common/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,3] 5 | import sys 6 | import time 7 | from typing import Any, Callable, Dict, List 8 | 9 | import numpy as np 10 | 11 | 12 | def timeit( 13 | num_iters: int = -1, warmup_iters: int = 0 14 | ) -> Callable[[Callable[[], Any]], Callable[[], Dict[str, float]]]: 15 | """ 16 | This is intened to be used as a decorator to time any function. 17 | 18 | Args: 19 | num_iters (int): number of iterations used to compute the average time 20 | (sec) required to run the function. If negative, the number of 21 | iterations is determined dynamically by running the function a few 22 | times to make sure the estimate is stable. 23 | warmup_iters (int): number of iterations used to warm up the function. 24 | This is useful for functions that exhibit poor performance during 25 | the first few times they run (due to caches, autotuning, etc). 26 | Returns: 27 | Dict[str, float]: dictionary of the aggregated timing estimates. 28 | "iterations": number of iterations used to compute the estimated 29 | time. 30 | "mean": averate time (sec) used to run the function. 31 | "median": median time (sec) used to run the function. 32 | "min": minimal time (sec) used to run the function. 33 | "max": maximal time (sec) used to run the function. 34 | "stddev": standard deviation of the time (sec) used to run the 35 | function. 36 | """ 37 | 38 | def decorator(func: Callable[[], Any]) -> Callable[[], Dict[str, float]]: 39 | def decorated(*args: Any, **kwargs: Any) -> Dict[str, float]: 40 | # Warmup phase. 41 | for _ in range(warmup_iters): 42 | func(*args, **kwargs) 43 | 44 | # Estimate the run time of the function. 45 | total_time: float = 0 46 | count = 0 47 | run_times: List[float] = [] 48 | max_num_iters = num_iters if num_iters > 0 else sys.maxsize 49 | for _ in range(max_num_iters): 50 | start_time = time.time() 51 | func(*args, **kwargs) 52 | run_time = time.time() - start_time 53 | 54 | run_times.append(run_time) 55 | total_time += run_time 56 | count += 1 57 | if num_iters < 0 and total_time >= 0.5: 58 | # If num_iters is negative, run the function enough times so 59 | # that we can have a more robust estimate of the average time. 60 | break 61 | assert count == len(run_times) 62 | ret: Dict[str, float] = {} 63 | ret["iterations"] = count 64 | ret["mean"] = total_time / count 65 | ret["median"] = np.median(run_times) 66 | ret["min"] = np.min(run_times) 67 | ret["max"] = np.max(run_times) 68 | ret["stddev"] = np.std(run_times) 69 | return ret 70 | 71 | return decorated 72 | 73 | return decorator 74 | 75 | 76 | def benchmark( 77 | func: Callable[[], Any], 78 | bm_name: str, 79 | kwargs_list: List[Any], 80 | *, 81 | num_iters: int = -1, 82 | warmup_iters: int = 0, 83 | ) -> None: 84 | """ 85 | Benchmark the input function and print out the results. 86 | 87 | Args: 88 | func (callable): a closure that returns a function for benchmarking, 89 | where initialization can be done before the function to benchmark. 90 | bm_name (str): name of the benchmark to print out, e.g. "BM_UPDATE". 91 | kwargs_list (list): a list of argument dict to pass to the function. The 92 | intput function will be timed separately for each argument dict. 93 | num_iters (int): number of iterations to run. Defaults to run until 0.5s. 94 | warmup_iters (int): number of iterations used to warm up the function. 95 | 96 | Outputs: 97 | For each argument dict, print out the time (in microseconds) required 98 | to run the function along with the number of iterations used to get 99 | the timing estimate. Example output: 100 | 101 | Benchmark Avg Time(μs) Peak Time(μs) Iterations 102 | ------------------------------------------------------------------- 103 | BM_UPDATE_100 820 914 610 104 | BM_UPDATE_1000 7655 8709 66 105 | BM_UPDATE_10000 78062 81748 7 106 | ------------------------------------------------------------------- 107 | """ 108 | 109 | print("") 110 | outputs = [] 111 | for kwargs in kwargs_list: 112 | func_bm = func(**kwargs) 113 | time_func = timeit(num_iters=num_iters, warmup_iters=warmup_iters)(func_bm) 114 | 115 | ret = time_func() 116 | name = bm_name 117 | if kwargs: 118 | name += "_" + "_".join(str(v) for k, v in kwargs.items()) 119 | outputs.append( 120 | [ 121 | name, 122 | str(ret["mean"] * 1000000), 123 | str(ret["max"] * 1000000), 124 | str(ret["iterations"]), 125 | ] 126 | ) 127 | outputs = np.array(outputs) 128 | # Calculate column widths for metrics table. 129 | c1 = len(max(outputs[:, 0], key=len)) 130 | c2 = len(max(outputs[:, 1], key=len)) 131 | c3 = len(max(outputs[:, 2], key=len)) 132 | c4 = len(max(outputs[:, 3], key=len)) 133 | dash = "-" * 80 134 | print( 135 | "{:{}s} {:>{}s} {:>{}s} {:>{}s}".format( 136 | "Benchmark", c1, "Avg Time(μs)", c2, "Peak Time(μs)", c3, "Iterations", c4 137 | ) 138 | ) 139 | print(dash) 140 | for output in outputs: 141 | print( 142 | "{:{}s} {:15.0f} {:15.0f} {:14d}".format( 143 | output[0], c1, float(output[1]), float(output[2]), int(output[3]) 144 | ) 145 | ) 146 | print(dash) 147 | -------------------------------------------------------------------------------- /fvcore/common/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import logging 6 | import os 7 | from typing import Any, Callable, Dict, IO, List, Union 8 | 9 | import yaml 10 | from iopath.common.file_io import g_pathmgr 11 | from yacs.config import CfgNode as _CfgNode 12 | 13 | 14 | BASE_KEY = "_BASE_" 15 | 16 | 17 | class CfgNode(_CfgNode): 18 | """ 19 | Our own extended version of :class:`yacs.config.CfgNode`. 20 | It contains the following extra features: 21 | 22 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 23 | which allows the new CfgNode to inherit all the attributes from the 24 | base configuration file(s). 25 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 26 | "computed" attributes. They can be inserted regardless of whether 27 | the CfgNode is frozen or not. 28 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 29 | expressions in config. See examples in 30 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 31 | Note that this may lead to arbitrary code execution: you must not 32 | load a config file from untrusted sources before manually inspecting 33 | the content of the file. 34 | """ 35 | 36 | @classmethod 37 | def _open_cfg(cls, filename: str) -> Union[IO[str], IO[bytes]]: 38 | """ 39 | Defines how a config file is opened. May be overridden to support 40 | different file schemas. 41 | """ 42 | return g_pathmgr.open(filename, "r") 43 | 44 | @classmethod 45 | def load_yaml_with_base( 46 | cls, filename: str, allow_unsafe: bool = False 47 | ) -> Dict[str, Any]: 48 | """ 49 | Just like `yaml.load(open(filename))`, but inherit attributes from its 50 | `_BASE_`. 51 | 52 | Args: 53 | filename (str or file-like object): the file name or file of the current config. 54 | Will be used to find the base config file. 55 | allow_unsafe (bool): whether to allow loading the config file with 56 | `yaml.unsafe_load`. 57 | 58 | Returns: 59 | (dict): the loaded yaml 60 | """ 61 | with cls._open_cfg(filename) as f: 62 | try: 63 | cfg = yaml.safe_load(f) 64 | except yaml.constructor.ConstructorError: 65 | if not allow_unsafe: 66 | raise 67 | logger = logging.getLogger(__name__) 68 | logger.warning( 69 | "Loading config {} with yaml.unsafe_load. Your machine may " 70 | "be at risk if the file contains malicious content.".format( 71 | filename 72 | ) 73 | ) 74 | f.close() 75 | with cls._open_cfg(filename) as f: 76 | cfg = yaml.unsafe_load(f) 77 | 78 | def merge_a_into_b(a: Dict[str, Any], b: Dict[str, Any]) -> None: 79 | # merge dict a into dict b. values in a will overwrite b. 80 | for k, v in a.items(): 81 | if isinstance(v, dict) and k in b: 82 | assert isinstance( 83 | b[k], dict 84 | ), "Cannot inherit key '{}' from base!".format(k) 85 | merge_a_into_b(v, b[k]) 86 | else: 87 | b[k] = v 88 | 89 | def _load_with_base(base_cfg_file: str) -> Dict[str, Any]: 90 | if base_cfg_file.startswith("~"): 91 | base_cfg_file = os.path.expanduser(base_cfg_file) 92 | if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])): 93 | # the path to base cfg is relative to the config file itself. 94 | base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file) 95 | return cls.load_yaml_with_base(base_cfg_file, allow_unsafe=allow_unsafe) 96 | 97 | if BASE_KEY in cfg: 98 | if isinstance(cfg[BASE_KEY], list): 99 | base_cfg: Dict[str, Any] = {} 100 | base_cfg_files = cfg[BASE_KEY] 101 | for base_cfg_file in base_cfg_files: 102 | merge_a_into_b(_load_with_base(base_cfg_file), base_cfg) 103 | else: 104 | base_cfg_file = cfg[BASE_KEY] 105 | base_cfg = _load_with_base(base_cfg_file) 106 | del cfg[BASE_KEY] 107 | 108 | merge_a_into_b(cfg, base_cfg) 109 | return base_cfg 110 | return cfg 111 | 112 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False) -> None: 113 | """ 114 | Merge configs from a given yaml file. 115 | 116 | Args: 117 | cfg_filename: the file name of the yaml config. 118 | allow_unsafe: whether to allow loading the config file with 119 | `yaml.unsafe_load`. 120 | """ 121 | loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) 122 | loaded_cfg = type(self)(loaded_cfg) 123 | self.merge_from_other_cfg(loaded_cfg) 124 | 125 | # Forward the following calls to base, but with a check on the BASE_KEY. 126 | def merge_from_other_cfg(self, cfg_other: "CfgNode") -> Callable[[], None]: 127 | """ 128 | Args: 129 | cfg_other (CfgNode): configs to merge from. 130 | """ 131 | assert ( 132 | BASE_KEY not in cfg_other 133 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 134 | return super().merge_from_other_cfg(cfg_other) 135 | 136 | def merge_from_list(self, cfg_list: List[str]) -> Callable[[], None]: 137 | """ 138 | Args: 139 | cfg_list (list): list of configs to merge from. 140 | """ 141 | keys = set(cfg_list[0::2]) 142 | assert ( 143 | BASE_KEY not in keys 144 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 145 | return super().merge_from_list(cfg_list) 146 | 147 | def __setattr__(self, name: str, val: Any) -> None: # pyre-ignore 148 | if name.startswith("COMPUTED_"): 149 | if name in self: 150 | old_val = self[name] 151 | if old_val == val: 152 | return 153 | raise KeyError( 154 | "Computed attributed '{}' already exists " 155 | "with a different value! old={}, new={}.".format(name, old_val, val) 156 | ) 157 | self[name] = val 158 | else: 159 | super().__setattr__(name, val) 160 | -------------------------------------------------------------------------------- /fvcore/common/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | from iopath.common.download import download 5 | 6 | 7 | __all__ = ["download"] 8 | -------------------------------------------------------------------------------- /fvcore/common/file_io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import logging 6 | import os 7 | import tempfile 8 | from typing import Optional 9 | 10 | from iopath.common.file_io import ( # noqa, unused import required by some deps 11 | file_lock, 12 | HTTPURLHandler, 13 | LazyPath, 14 | NativePathHandler, 15 | OneDrivePathHandler, 16 | PathHandler, 17 | PathManager as PathManagerBase, 18 | ) 19 | 20 | 21 | __all__ = ["LazyPath", "PathManager", "get_cache_dir", "file_lock"] 22 | 23 | 24 | def get_cache_dir(cache_dir: Optional[str] = None) -> str: 25 | """ 26 | Returns a default directory to cache static files 27 | (usually downloaded from Internet), if None is provided. 28 | 29 | Args: 30 | cache_dir (None or str): if not None, will be returned as is. 31 | If None, returns the default cache directory as: 32 | 33 | 1) $FVCORE_CACHE, if set 34 | 2) otherwise ~/.torch/fvcore_cache 35 | """ 36 | if cache_dir is None: 37 | cache_dir = os.path.expanduser( 38 | os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache") 39 | ) 40 | try: 41 | PathManager.mkdirs(cache_dir) 42 | assert os.access(cache_dir, os.W_OK) 43 | except (OSError, AssertionError): 44 | tmp_dir = os.path.join(tempfile.gettempdir(), "fvcore_cache") 45 | logger = logging.getLogger(__name__) 46 | logger.warning(f"{cache_dir} is not accessible! Using {tmp_dir} instead!") 47 | cache_dir = tmp_dir 48 | return cache_dir 49 | 50 | 51 | PathManager = PathManagerBase() 52 | """ 53 | A global PathManager. 54 | 55 | Any sufficiently complicated/important project should create their own 56 | PathManager instead of using the global PathManager, to avoid conflicts 57 | when multiple projects have conflicting PathHandlers. 58 | 59 | History: at first, PathManager is part of detectron2 *only*, and therefore 60 | does not consider cross-projects conflict issues. It is later used by more 61 | projects and moved to fvcore to faciliate more use across projects and lead 62 | to some conflicts. 63 | Now the class `PathManagerBase` is added to help create per-project path manager, 64 | and this global is still named "PathManager" to keep backward compatibility. 65 | """ 66 | 67 | PathManager.register_handler(HTTPURLHandler()) 68 | PathManager.register_handler(OneDrivePathHandler()) 69 | -------------------------------------------------------------------------------- /fvcore/common/history_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | from typing import List, Optional, Tuple 6 | 7 | import numpy as np 8 | 9 | 10 | class HistoryBuffer: 11 | """ 12 | Track a series of scalar values and provide access to smoothed values over a 13 | window or the global average of the series. 14 | """ 15 | 16 | def __init__(self, max_length: int = 1000000) -> None: 17 | """ 18 | Args: 19 | max_length: maximal number of values that can be stored in the 20 | buffer. When the capacity of the buffer is exhausted, old 21 | values will be removed. 22 | """ 23 | self._max_length: int = max_length 24 | self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs 25 | self._count: int = 0 26 | self._global_avg: float = 0 27 | 28 | def update(self, value: float, iteration: Optional[float] = None) -> None: 29 | """ 30 | Add a new scalar value produced at certain iteration. If the length 31 | of the buffer exceeds self._max_length, the oldest element will be 32 | removed from the buffer. 33 | """ 34 | if iteration is None: 35 | iteration = self._count 36 | if len(self._data) == self._max_length: 37 | self._data.pop(0) 38 | self._data.append((value, iteration)) 39 | 40 | self._count += 1 41 | self._global_avg += (value - self._global_avg) / self._count 42 | 43 | def latest(self) -> float: 44 | """ 45 | Return the latest scalar value added to the buffer. 46 | """ 47 | return self._data[-1][0] 48 | 49 | def median(self, window_size: int) -> float: 50 | """ 51 | Return the median of the latest `window_size` values in the buffer. 52 | """ 53 | return np.median([x[0] for x in self._data[-window_size:]]) 54 | 55 | def avg(self, window_size: int) -> float: 56 | """ 57 | Return the mean of the latest `window_size` values in the buffer. 58 | """ 59 | return np.mean([x[0] for x in self._data[-window_size:]]) 60 | 61 | def global_avg(self) -> float: 62 | """ 63 | Return the mean of all the elements in the buffer. Note that this 64 | includes those getting removed due to limited buffer storage. 65 | """ 66 | return self._global_avg 67 | 68 | def values(self) -> List[Tuple[float, float]]: 69 | """ 70 | Returns: 71 | list[(number, iteration)]: content of the current buffer. 72 | """ 73 | return self._data 74 | -------------------------------------------------------------------------------- /fvcore/common/param_scheduler.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | import bisect 3 | import math 4 | from typing import List, Optional, Sequence, Union 5 | 6 | 7 | # pyre-ignore-all-errors[58] # handle optional 8 | 9 | 10 | __all__ = [ 11 | "ParamScheduler", 12 | "ConstantParamScheduler", 13 | "CosineParamScheduler", 14 | "ExponentialParamScheduler", 15 | "LinearParamScheduler", 16 | "CompositeParamScheduler", 17 | "MultiStepParamScheduler", 18 | "StepParamScheduler", 19 | "StepWithFixedGammaParamScheduler", 20 | "PolynomialDecayParamScheduler", 21 | ] # ported from ClassyVision 22 | 23 | 24 | class ParamScheduler: 25 | """ 26 | Base class for parameter schedulers. 27 | A parameter scheduler defines a mapping from a progress value in [0, 1) to 28 | a number (e.g. learning rate). 29 | """ 30 | 31 | # To be used for comparisons with where 32 | WHERE_EPSILON = 1e-6 33 | 34 | def __call__(self, where: float) -> float: 35 | """ 36 | Get the value of the param for a given point at training. 37 | 38 | We update params (such as learning rate) based on the percent progress 39 | of training completed. This allows a scheduler to be agnostic to the 40 | exact length of a particular run (e.g. 120 epochs vs 90 epochs), as 41 | long as the relative progress where params should be updated is the same. 42 | However, it assumes that the total length of training is known. 43 | 44 | Args: 45 | where: A float in [0,1) that represents how far training has progressed 46 | 47 | """ 48 | raise NotImplementedError("Param schedulers must override __call__") 49 | 50 | 51 | class ConstantParamScheduler(ParamScheduler): 52 | """ 53 | Returns a constant value for a param. 54 | """ 55 | 56 | def __init__(self, value: float) -> None: 57 | self._value = value 58 | 59 | def __call__(self, where: float) -> float: 60 | if where >= 1.0: 61 | raise RuntimeError( 62 | f"where in ParamScheduler must be in [0, 1]: got {where}" 63 | ) 64 | return self._value 65 | 66 | 67 | class CosineParamScheduler(ParamScheduler): 68 | """ 69 | Cosine decay or cosine warmup schedules based on start and end values. 70 | The schedule is updated based on the fraction of training progress. 71 | The schedule was proposed in 'SGDR: Stochastic Gradient Descent with 72 | Warm Restarts' (https://arxiv.org/abs/1608.03983). Note that this class 73 | only implements the cosine annealing part of SGDR, and not the restarts. 74 | 75 | Example: 76 | 77 | .. code-block:: python 78 | 79 | CosineParamScheduler(start_value=0.1, end_value=0.0001) 80 | """ 81 | 82 | def __init__( 83 | self, 84 | start_value: float, 85 | end_value: float, 86 | ) -> None: 87 | self._start_value = start_value 88 | self._end_value = end_value 89 | 90 | def __call__(self, where: float) -> float: 91 | return self._end_value + 0.5 * (self._start_value - self._end_value) * ( 92 | 1 + math.cos(math.pi * where) 93 | ) 94 | 95 | 96 | class ExponentialParamScheduler(ParamScheduler): 97 | """ 98 | Exponetial schedule parameterized by a start value and decay. 99 | The schedule is updated based on the fraction of training 100 | progress, `where`, with the formula 101 | `param_t = start_value * (decay ** where)`. 102 | 103 | Example: 104 | 105 | .. code-block:: python 106 | ExponentialParamScheduler(start_value=2.0, decay=0.02) 107 | 108 | Corresponds to a decreasing schedule with values in [2.0, 0.04). 109 | """ 110 | 111 | def __init__( 112 | self, 113 | start_value: float, 114 | decay: float, 115 | ) -> None: 116 | self._start_value = start_value 117 | self._decay = decay 118 | 119 | def __call__(self, where: float) -> float: 120 | return self._start_value * (self._decay**where) 121 | 122 | 123 | class LinearParamScheduler(ParamScheduler): 124 | """ 125 | Linearly interpolates parameter between ``start_value`` and ``end_value``. 126 | Can be used for either warmup or decay based on start and end values. 127 | The schedule is updated after every train step by default. 128 | 129 | Example: 130 | 131 | .. code-block:: python 132 | 133 | LinearParamScheduler(start_value=0.0001, end_value=0.01) 134 | 135 | Corresponds to a linear increasing schedule with values in [0.0001, 0.01) 136 | """ 137 | 138 | def __init__( 139 | self, 140 | start_value: float, 141 | end_value: float, 142 | ) -> None: 143 | self._start_value = start_value 144 | self._end_value = end_value 145 | 146 | def __call__(self, where: float) -> float: 147 | # interpolate between start and end values 148 | return self._end_value * where + self._start_value * (1 - where) 149 | 150 | 151 | class MultiStepParamScheduler(ParamScheduler): 152 | """ 153 | Takes a predefined schedule for a param value, and a list of epochs or steps 154 | which stand for the upper boundary (excluded) of each range. 155 | 156 | Example: 157 | 158 | .. code-block:: python 159 | 160 | MultiStepParamScheduler( 161 | values=[0.1, 0.01, 0.001, 0.0001], 162 | milestones=[30, 60, 80, 120] 163 | ) 164 | 165 | Then the param value will be 0.1 for epochs 0-29, 0.01 for 166 | epochs 30-59, 0.001 for epochs 60-79, 0.0001 for epochs 80-120. 167 | Note that the length of values must be equal to the length of milestones 168 | plus one. 169 | """ 170 | 171 | def __init__( 172 | self, 173 | values: List[float], 174 | num_updates: Optional[int] = None, 175 | milestones: Optional[List[int]] = None, 176 | ) -> None: 177 | """ 178 | Args: 179 | values: param value in each range 180 | num_updates: the end of the last range. If None, will use ``milestones[-1]`` 181 | milestones: the boundary of each range. If None, will evenly split ``num_updates`` 182 | 183 | For example, all the following combinations define the same scheduler: 184 | 185 | * num_updates=90, milestones=[30, 60], values=[1, 0.1, 0.01] 186 | * num_updates=90, values=[1, 0.1, 0.01] 187 | * milestones=[30, 60, 90], values=[1, 0.1, 0.01] 188 | * milestones=[3, 6, 9], values=[1, 0.1, 0.01] (ParamScheduler is scale-invariant) 189 | """ 190 | if num_updates is None and milestones is None: 191 | raise ValueError("num_updates and milestones cannot both be None") 192 | if milestones is None: 193 | # Default equispaced drop_epochs behavior 194 | milestones = [] 195 | step_width = math.ceil(num_updates / float(len(values))) 196 | for idx in range(len(values) - 1): 197 | milestones.append(step_width * (idx + 1)) 198 | else: 199 | if not ( 200 | isinstance(milestones, Sequence) 201 | and len(milestones) == len(values) - int(num_updates is not None) 202 | ): 203 | raise ValueError( 204 | "MultiStep scheduler requires a list of %d miletones" 205 | % (len(values) - int(num_updates is not None)) 206 | ) 207 | 208 | if num_updates is None: 209 | num_updates, milestones = milestones[-1], milestones[:-1] 210 | if num_updates < len(values): 211 | raise ValueError( 212 | "Total num_updates must be greater than length of param schedule" 213 | ) 214 | 215 | self._param_schedule = values 216 | self._num_updates = num_updates 217 | self._milestones: List[int] = milestones 218 | 219 | start_epoch = 0 220 | for milestone in self._milestones: 221 | # Do not exceed the total number of epochs 222 | if milestone >= self._num_updates: 223 | raise ValueError( 224 | "Milestone must be smaller than total number of updates: " 225 | "num_updates=%d, milestone=%d" % (self._num_updates, milestone) 226 | ) 227 | # Must be in ascending order 228 | if start_epoch >= milestone: 229 | raise ValueError( 230 | "Milestone must be smaller than start epoch: start_epoch=%d, milestone=%d" 231 | % (start_epoch, milestone) 232 | ) 233 | start_epoch = milestone 234 | 235 | def __call__(self, where: float) -> float: 236 | if where > 1.0: 237 | raise RuntimeError( 238 | f"where in ParamScheduler must be in [0, 1]: got {where}" 239 | ) 240 | epoch_num = int((where + self.WHERE_EPSILON) * self._num_updates) 241 | return self._param_schedule[bisect.bisect_right(self._milestones, epoch_num)] 242 | 243 | 244 | class PolynomialDecayParamScheduler(ParamScheduler): 245 | """ 246 | Decays the param value after every epoch according to a 247 | polynomial function with a fixed power. 248 | The schedule is updated after every train step by default. 249 | 250 | Example: 251 | 252 | .. code-block:: python 253 | 254 | PolynomialDecayParamScheduler(base_value=0.1, power=0.9) 255 | 256 | Then the param value will be 0.1 for epoch 0, 0.099 for epoch 1, and 257 | so on. 258 | """ 259 | 260 | def __init__( 261 | self, 262 | base_value: float, 263 | power: float, 264 | ) -> None: 265 | self._base_value = base_value 266 | self._power = power 267 | 268 | def __call__(self, where: float) -> float: 269 | return self._base_value * (1 - where) ** self._power 270 | 271 | 272 | class StepParamScheduler(ParamScheduler): 273 | """ 274 | Takes a fixed schedule for a param value. If the length of the 275 | fixed schedule is less than the number of epochs, then the epochs 276 | are divided evenly among the param schedule. 277 | The schedule is updated after every train epoch by default. 278 | 279 | Example: 280 | 281 | .. code-block:: python 282 | 283 | StepParamScheduler(values=[0.1, 0.01, 0.001, 0.0001], num_updates=120) 284 | 285 | Then the param value will be 0.1 for epochs 0-29, 0.01 for 286 | epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. 287 | """ 288 | 289 | def __init__( 290 | self, 291 | num_updates: Union[int, float], 292 | values: List[float], 293 | ) -> None: 294 | if num_updates <= 0: 295 | raise ValueError("Number of updates must be larger than 0") 296 | if not (isinstance(values, Sequence) and len(values) > 0): 297 | raise ValueError( 298 | "Step scheduler requires a list of at least one param value" 299 | ) 300 | self._param_schedule = values 301 | 302 | def __call__(self, where: float) -> float: 303 | ind = int((where + self.WHERE_EPSILON) * len(self._param_schedule)) 304 | return self._param_schedule[ind] 305 | 306 | 307 | class StepWithFixedGammaParamScheduler(ParamScheduler): 308 | """ 309 | Decays the param value by gamma at equal number of steps so as to have the 310 | specified total number of decays. 311 | 312 | Example: 313 | 314 | .. code-block:: python 315 | 316 | StepWithFixedGammaParamScheduler( 317 | base_value=0.1, gamma=0.1, num_decays=3, num_updates=120) 318 | 319 | Then the param value will be 0.1 for epochs 0-29, 0.01 for 320 | epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. 321 | """ 322 | 323 | def __init__( 324 | self, 325 | base_value: float, 326 | num_decays: int, 327 | gamma: float, 328 | num_updates: int, 329 | ) -> None: 330 | for k in [base_value, gamma]: 331 | if not (isinstance(k, (int, float)) and k > 0): 332 | raise ValueError("base_value and gamma must be positive numbers") 333 | for k in [num_decays, num_updates]: 334 | if not (isinstance(k, int) and k > 0): 335 | raise ValueError("num_decays and num_updates must be positive integers") 336 | 337 | self.base_value = base_value 338 | self.num_decays = num_decays 339 | self.gamma = gamma 340 | self.num_updates = num_updates 341 | values = [base_value] 342 | for _ in range(num_decays): 343 | values.append(values[-1] * gamma) 344 | 345 | self._step_param_scheduler = StepParamScheduler( 346 | num_updates=num_updates, values=values 347 | ) 348 | 349 | def __call__(self, where: float) -> float: 350 | return self._step_param_scheduler(where) 351 | 352 | 353 | class CompositeParamScheduler(ParamScheduler): 354 | """ 355 | Composite parameter scheduler composed of intermediate schedulers. 356 | Takes a list of schedulers and a list of lengths corresponding to 357 | percentage of training each scheduler should run for. Schedulers 358 | are run in order. All values in lengths should sum to 1.0. 359 | 360 | Each scheduler also has a corresponding interval scale. If interval 361 | scale is 'fixed', the intermediate scheduler will be run without any rescaling 362 | of the time. If interval scale is 'rescaled', intermediate scheduler is 363 | run such that each scheduler will start and end at the same values as it 364 | would if it were the only scheduler. Default is 'rescaled' for all schedulers. 365 | 366 | Example: 367 | 368 | .. code-block:: python 369 | 370 | schedulers = [ 371 | ConstantParamScheduler(value=0.42), 372 | CosineParamScheduler(start_value=0.42, end_value=1e-4) 373 | ] 374 | CompositeParamScheduler( 375 | schedulers=schedulers, 376 | interval_scaling=['rescaled', 'rescaled'], 377 | lengths=[0.3, 0.7]) 378 | 379 | The parameter value will be 0.42 for the first [0%, 30%) of steps, 380 | and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of 381 | training. 382 | """ 383 | 384 | def __init__( 385 | self, 386 | schedulers: Sequence[ParamScheduler], 387 | lengths: List[float], 388 | interval_scaling: Sequence[str], 389 | ) -> None: 390 | if len(schedulers) != len(lengths): 391 | raise ValueError("Schedulers and lengths must be same length") 392 | if len(schedulers) == 0: 393 | raise ValueError( 394 | "There must be at least one scheduler in the composite scheduler" 395 | ) 396 | if abs(sum(lengths) - 1.0) >= 1e-3: 397 | raise ValueError("The sum of all values in lengths must be 1") 398 | if sum(lengths) != 1.0: 399 | lengths[-1] = 1.0 - sum(lengths[:-1]) 400 | for s in interval_scaling: 401 | if s not in ["rescaled", "fixed"]: 402 | raise ValueError(f"Unsupported interval_scaling: {s}") 403 | 404 | self._lengths = lengths 405 | self._schedulers = schedulers 406 | self._interval_scaling = interval_scaling 407 | 408 | def __call__(self, where: float) -> float: 409 | # Find scheduler corresponding to where 410 | i = 0 411 | running_total = self._lengths[i] 412 | while (where + self.WHERE_EPSILON) > running_total and i < len( 413 | self._schedulers 414 | ) - 1: 415 | i += 1 416 | running_total += self._lengths[i] 417 | scheduler = self._schedulers[i] 418 | scheduler_where = where 419 | interval_scale = self._interval_scaling[i] 420 | if interval_scale == "rescaled": 421 | # Calculate corresponding where % for scheduler 422 | scheduler_start = running_total - self._lengths[i] 423 | scheduler_where = (where - scheduler_start) / self._lengths[i] 424 | return scheduler(scheduler_where) 425 | -------------------------------------------------------------------------------- /fvcore/common/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,3] 5 | from typing import Any, Dict, Iterable, Iterator, Tuple 6 | 7 | from tabulate import tabulate 8 | 9 | 10 | class Registry(Iterable[Tuple[str, Any]]): 11 | """ 12 | The registry that provides name -> object mapping, to support third-party 13 | users' custom modules. 14 | 15 | To create a registry (e.g. a backbone registry): 16 | 17 | .. code-block:: python 18 | 19 | BACKBONE_REGISTRY = Registry('BACKBONE') 20 | 21 | To register an object: 22 | 23 | .. code-block:: python 24 | 25 | @BACKBONE_REGISTRY.register() 26 | class MyBackbone(): 27 | ... 28 | 29 | Or: 30 | 31 | .. code-block:: python 32 | 33 | BACKBONE_REGISTRY.register(MyBackbone) 34 | """ 35 | 36 | def __init__(self, name: str) -> None: 37 | """ 38 | Args: 39 | name (str): the name of this registry 40 | """ 41 | self._name: str = name 42 | self._obj_map: Dict[str, Any] = {} 43 | 44 | def _do_register(self, name: str, obj: Any) -> None: 45 | assert ( 46 | name not in self._obj_map 47 | ), "An object named '{}' was already registered in '{}' registry!".format( 48 | name, self._name 49 | ) 50 | self._obj_map[name] = obj 51 | 52 | def register(self, obj: Any = None) -> Any: 53 | """ 54 | Register the given object under the the name `obj.__name__`. 55 | Can be used as either a decorator or not. See docstring of this class for usage. 56 | """ 57 | if obj is None: 58 | # used as a decorator 59 | def deco(func_or_class: Any) -> Any: 60 | name = func_or_class.__name__ 61 | self._do_register(name, func_or_class) 62 | return func_or_class 63 | 64 | return deco 65 | 66 | # used as a function call 67 | name = obj.__name__ 68 | self._do_register(name, obj) 69 | 70 | def get(self, name: str) -> Any: 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError( 74 | "No object named '{}' found in '{}' registry!".format(name, self._name) 75 | ) 76 | return ret 77 | 78 | def __contains__(self, name: str) -> bool: 79 | return name in self._obj_map 80 | 81 | def __repr__(self) -> str: 82 | table_headers = ["Names", "Objects"] 83 | table = tabulate( 84 | self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid" 85 | ) 86 | return "Registry of {}:\n".format(self._name) + table 87 | 88 | def __iter__(self) -> Iterator[Tuple[str, Any]]: 89 | return iter(self._obj_map.items()) 90 | 91 | # pyre-fixme[4]: Attribute must be annotated. 92 | __str__ = __repr__ 93 | -------------------------------------------------------------------------------- /fvcore/common/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # -*- coding: utf-8 -*- 3 | 4 | # pyre-strict 5 | 6 | from time import perf_counter 7 | from typing import Optional 8 | 9 | 10 | class Timer: 11 | """ 12 | A timer which computes the time elapsed since the start/reset of the timer. 13 | """ 14 | 15 | def __init__(self) -> None: 16 | self.reset() 17 | 18 | def reset(self) -> None: 19 | """ 20 | Reset the timer. 21 | """ 22 | self._start = perf_counter() 23 | self._paused: Optional[float] = None 24 | self._total_paused = 0 25 | self._count_start = 1 26 | 27 | def pause(self) -> None: 28 | """ 29 | Pause the timer. 30 | """ 31 | if self._paused is not None: 32 | raise ValueError("Trying to pause a Timer that is already paused!") 33 | self._paused = perf_counter() 34 | 35 | def is_paused(self) -> bool: 36 | """ 37 | Returns: 38 | bool: whether the timer is currently paused 39 | """ 40 | return self._paused is not None 41 | 42 | def resume(self) -> None: 43 | """ 44 | Resume the timer. 45 | """ 46 | if self._paused is None: 47 | raise ValueError("Trying to resume a Timer that is not paused!") 48 | # pyre-fixme[58]: `-` is not supported for operand types `float` and 49 | # `Optional[float]`. 50 | self._total_paused += perf_counter() - self._paused 51 | self._paused = None 52 | self._count_start += 1 53 | 54 | def seconds(self) -> float: 55 | """ 56 | Returns: 57 | (float): the total number of seconds since the start/reset of the 58 | timer, excluding the time when the timer is paused. 59 | """ 60 | if self._paused is not None: 61 | end_time: float = self._paused # type: ignore 62 | else: 63 | end_time = perf_counter() 64 | return end_time - self._start - self._total_paused 65 | 66 | def avg_seconds(self) -> float: 67 | """ 68 | Returns: 69 | (float): the average number of seconds between every start/reset and 70 | pause. 71 | """ 72 | return self.seconds() / self._count_start 73 | -------------------------------------------------------------------------------- /fvcore/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-ignore-all-errors 4 | from .activation_count import activation_count, ActivationCountAnalysis 5 | from .flop_count import flop_count, FlopCountAnalysis 6 | from .focal_loss import ( 7 | sigmoid_focal_loss, 8 | sigmoid_focal_loss_jit, 9 | sigmoid_focal_loss_star, 10 | sigmoid_focal_loss_star_jit, 11 | ) 12 | from .giou_loss import giou_loss 13 | from .parameter_count import parameter_count, parameter_count_table 14 | from .precise_bn import get_bn_modules, update_bn_stats 15 | from .print_model_statistics import flop_count_str, flop_count_table 16 | from .smooth_l1_loss import smooth_l1_loss 17 | from .weight_init import c2_msra_fill, c2_xavier_fill 18 | 19 | 20 | # pyre-fixme[5]: Global expression must be annotated. 21 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 22 | -------------------------------------------------------------------------------- /fvcore/nn/activation_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,33] 5 | 6 | from collections import defaultdict 7 | from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | from .jit_analysis import JitModelAnalysis 13 | from .jit_handles import generic_activation_jit, Handle 14 | 15 | 16 | # A dictionary that maps supported operations to their activation count handles. 17 | _DEFAULT_SUPPORTED_OPS: Dict[str, Handle] = { 18 | "aten::_convolution": generic_activation_jit("conv"), 19 | "aten::addmm": generic_activation_jit(), 20 | "aten::bmm": generic_activation_jit(), 21 | "aten::einsum": generic_activation_jit(), 22 | "aten::matmul": generic_activation_jit(), 23 | "aten::linear": generic_activation_jit(), 24 | } 25 | 26 | 27 | class ActivationCountAnalysis(JitModelAnalysis): 28 | """ 29 | Provides access to per-submodule model activation count obtained by 30 | tracing a model with pytorch's jit tracing functionality. By default, 31 | comes with standard activation counters for convolutional and dot-product 32 | operators. 33 | 34 | Handles for additional operators may be added, or the default ones 35 | overwritten, using the ``.set_op_handle(name, func)`` method. 36 | See the method documentation for details. 37 | 38 | Activation counts can be obtained as: 39 | 40 | * ``.total(module_name="")``: total activation count for a module 41 | * ``.by_operator(module_name="")``: activation counts for the module, as a 42 | Counter over different operator types 43 | * ``.by_module()``: Counter of activation counts for all submodules 44 | * ``.by_module_and_operator()``: dictionary indexed by descendant of Counters 45 | over different operator types 46 | 47 | An operator is treated as within a module if it is executed inside the 48 | module's ``__call__`` method. Note that this does not include calls to 49 | other methods of the module or explicit calls to ``module.forward(...)``. 50 | 51 | Example usage: 52 | 53 | >>> import torch.nn as nn 54 | >>> import torch 55 | >>> class TestModel(nn.Module): 56 | ... def __init__(self): 57 | ... super().__init__() 58 | ... self.fc = nn.Linear(in_features=1000, out_features=10) 59 | ... self.conv = nn.Conv2d( 60 | ... in_channels=3, out_channels=10, kernel_size=1 61 | ... ) 62 | ... self.act = nn.ReLU() 63 | ... def forward(self, x): 64 | ... return self.fc(self.act(self.conv(x)).flatten(1)) 65 | 66 | >>> model = TestModel() 67 | >>> inputs = (torch.randn((1,3,10,10)),) 68 | >>> acts = ActivationCountAnalysis(model, inputs) 69 | >>> acts.total() 70 | 1010 71 | >>> acts.total("fc") 72 | 10 73 | >>> acts.by_operator() 74 | Counter({"conv" : 1000, "addmm" : 10}) 75 | >>> acts.by_module() 76 | Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) 77 | >>> acts.by_module_and_operator() 78 | {"" : Counter({"conv" : 1000, "addmm" : 10}), 79 | "fc" : Counter({"addmm" : 10}), 80 | "conv" : Counter({"conv" : 1000}), 81 | "act" : Counter() 82 | } 83 | """ 84 | 85 | def __init__( 86 | self, 87 | model: nn.Module, 88 | inputs: Union[Tensor, Tuple[Tensor, ...]], 89 | ) -> None: 90 | super().__init__(model=model, inputs=inputs) 91 | self.set_op_handle(**_DEFAULT_SUPPORTED_OPS) 92 | 93 | __init__.__doc__ = JitModelAnalysis.__init__.__doc__ 94 | 95 | 96 | def activation_count( 97 | model: nn.Module, 98 | inputs: Tuple[Any, ...], 99 | supported_ops: Optional[Dict[str, Handle]] = None, 100 | ) -> Tuple[DefaultDict[str, float], Counter[str]]: 101 | """ 102 | Given a model and an input to the model, compute the total number of 103 | activations of the model. 104 | 105 | Args: 106 | model (nn.Module): The model to compute activation counts. 107 | inputs (tuple): Inputs that are passed to `model` to count activations. 108 | Inputs need to be in a tuple. 109 | supported_ops (dict(str,Callable) or None) : provide additional 110 | handlers for extra ops, or overwrite the existing handlers for 111 | convolution and matmul. The key is operator name and the value 112 | is a function that takes (inputs, outputs) of the op. 113 | 114 | Returns: 115 | tuple[defaultdict, Counter]: A dictionary that records the number of 116 | activation (mega) for each operation and a Counter that records the 117 | number of unsupported operations. 118 | """ 119 | if supported_ops is None: 120 | supported_ops = {} 121 | act_counter = ActivationCountAnalysis(model, inputs).set_op_handle(**supported_ops) 122 | mega_acts = defaultdict(float) 123 | for op, act in act_counter.by_operator().items(): 124 | mega_acts[op] = act / 1e6 125 | return mega_acts, act_counter.unsupported_ops() 126 | -------------------------------------------------------------------------------- /fvcore/nn/distributed.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from typing import List, Tuple 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.autograd.function import Function 7 | 8 | 9 | # pyre-ignore-all-errors[2,14,16] 10 | 11 | 12 | class _AllReduce(Function): 13 | @staticmethod 14 | def forward(ctx, input: torch.Tensor) -> torch.Tensor: 15 | input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] 16 | # Use allgather instead of allreduce since I don't trust in-place operations .. 17 | dist.all_gather(input_list, input, async_op=False) 18 | inputs = torch.stack(input_list, dim=0) 19 | return torch.sum(inputs, dim=0) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 23 | dist.all_reduce(grad_output, async_op=False) 24 | return grad_output 25 | 26 | 27 | def differentiable_all_reduce(input: torch.Tensor) -> torch.Tensor: 28 | """ 29 | Differentiable counterpart of `dist.all_reduce`. 30 | """ 31 | if ( 32 | not dist.is_available() 33 | or not dist.is_initialized() 34 | or dist.get_world_size() == 1 35 | ): 36 | return input 37 | return _AllReduce.apply(input) 38 | 39 | 40 | class _AllGather(Function): 41 | @staticmethod 42 | def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: 43 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 44 | dist.all_gather(output, x) 45 | return tuple(output) 46 | 47 | @staticmethod 48 | def backward(ctx, *grads: torch.Tensor) -> torch.Tensor: 49 | all_gradients = torch.stack(grads) 50 | dist.all_reduce(all_gradients) 51 | return all_gradients[dist.get_rank()] 52 | 53 | 54 | def differentiable_all_gather(input: torch.Tensor) -> List[torch.Tensor]: 55 | """ 56 | Differentiable counterpart of `dist.all_gather`. 57 | """ 58 | if ( 59 | not dist.is_available() 60 | or not dist.is_initialized() 61 | or dist.get_world_size() == 1 62 | ): 63 | return [input] 64 | return list(_AllGather.apply(input)) 65 | -------------------------------------------------------------------------------- /fvcore/nn/flop_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,33] 5 | 6 | from collections import defaultdict 7 | from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | from .jit_analysis import JitModelAnalysis 13 | from .jit_handles import ( 14 | addmm_flop_jit, 15 | batchnorm_flop_jit, 16 | bmm_flop_jit, 17 | conv_flop_jit, 18 | einsum_flop_jit, 19 | elementwise_flop_counter, 20 | Handle, 21 | linear_flop_jit, 22 | matmul_flop_jit, 23 | norm_flop_counter, 24 | ) 25 | 26 | 27 | # A dictionary that maps supported operations to their flop count jit handles. 28 | _DEFAULT_SUPPORTED_OPS: Dict[str, Handle] = { 29 | "aten::addmm": addmm_flop_jit, 30 | "aten::bmm": bmm_flop_jit, 31 | "aten::_convolution": conv_flop_jit, 32 | "aten::einsum": einsum_flop_jit, 33 | "aten::matmul": matmul_flop_jit, 34 | "aten::mm": matmul_flop_jit, 35 | "aten::linear": linear_flop_jit, 36 | # You might want to ignore BN flops due to inference-time fusion. 37 | # Use `set_op_handle("aten::batch_norm", None) 38 | "aten::batch_norm": batchnorm_flop_jit, 39 | "aten::group_norm": norm_flop_counter(2), 40 | "aten::layer_norm": norm_flop_counter(2), 41 | "aten::instance_norm": norm_flop_counter(1), 42 | "aten::upsample_nearest2d": elementwise_flop_counter(0, 1), 43 | "aten::upsample_bilinear2d": elementwise_flop_counter(0, 4), 44 | "aten::adaptive_avg_pool2d": elementwise_flop_counter(1, 0), 45 | "aten::grid_sampler": elementwise_flop_counter(0, 4), # assume bilinear 46 | } 47 | 48 | 49 | class FlopCountAnalysis(JitModelAnalysis): 50 | """ 51 | Provides access to per-submodule model flop count obtained by 52 | tracing a model with pytorch's jit tracing functionality. By default, 53 | comes with standard flop counters for a few common operators. 54 | Note that: 55 | 56 | 1. Flop is not a well-defined concept. We just produce our best estimate. 57 | 2. We count one fused multiply-add as one flop. 58 | 59 | Handles for additional operators may be added, or the default ones 60 | overwritten, using the ``.set_op_handle(name, func)`` method. 61 | See the method documentation for details. 62 | 63 | Flop counts can be obtained as: 64 | 65 | * ``.total(module_name="")``: total flop count for the module 66 | * ``.by_operator(module_name="")``: flop counts for the module, as a Counter 67 | over different operator types 68 | * ``.by_module()``: Counter of flop counts for all submodules 69 | * ``.by_module_and_operator()``: dictionary indexed by descendant of Counters 70 | over different operator types 71 | 72 | An operator is treated as within a module if it is executed inside the 73 | module's ``__call__`` method. Note that this does not include calls to 74 | other methods of the module or explicit calls to ``module.forward(...)``. 75 | 76 | Example usage: 77 | 78 | >>> import torch.nn as nn 79 | >>> import torch 80 | >>> class TestModel(nn.Module): 81 | ... def __init__(self): 82 | ... super().__init__() 83 | ... self.fc = nn.Linear(in_features=1000, out_features=10) 84 | ... self.conv = nn.Conv2d( 85 | ... in_channels=3, out_channels=10, kernel_size=1 86 | ... ) 87 | ... self.act = nn.ReLU() 88 | ... def forward(self, x): 89 | ... return self.fc(self.act(self.conv(x)).flatten(1)) 90 | 91 | >>> model = TestModel() 92 | >>> inputs = (torch.randn((1,3,10,10)),) 93 | >>> flops = FlopCountAnalysis(model, inputs) 94 | >>> flops.total() 95 | 13000 96 | >>> flops.total("fc") 97 | 10000 98 | >>> flops.by_operator() 99 | Counter({"addmm" : 10000, "conv" : 3000}) 100 | >>> flops.by_module() 101 | Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) 102 | >>> flops.by_module_and_operator() 103 | {"" : Counter({"addmm" : 10000, "conv" : 3000}), 104 | "fc" : Counter({"addmm" : 10000}), 105 | "conv" : Counter({"conv" : 3000}), 106 | "act" : Counter() 107 | } 108 | """ 109 | 110 | def __init__( 111 | self, 112 | model: nn.Module, 113 | inputs: Union[Tensor, Tuple[Tensor, ...]], 114 | ) -> None: 115 | super().__init__(model=model, inputs=inputs) 116 | self.set_op_handle(**_DEFAULT_SUPPORTED_OPS) 117 | 118 | __init__.__doc__ = JitModelAnalysis.__init__.__doc__ 119 | 120 | 121 | def flop_count( 122 | model: nn.Module, 123 | inputs: Tuple[Any, ...], 124 | supported_ops: Optional[Dict[str, Handle]] = None, 125 | ) -> Tuple[DefaultDict[str, float], Counter[str]]: 126 | """ 127 | Given a model and an input to the model, compute the per-operator Gflops 128 | of the given model. 129 | 130 | Args: 131 | model (nn.Module): The model to compute flop counts. 132 | inputs (tuple): Inputs that are passed to `model` to count flops. 133 | Inputs need to be in a tuple. 134 | supported_ops (dict(str,Callable) or None) : provide additional 135 | handlers for extra ops, or overwrite the existing handlers for 136 | convolution and matmul and einsum. The key is operator name and the value 137 | is a function that takes (inputs, outputs) of the op. We count 138 | one Multiply-Add as one FLOP. 139 | 140 | Returns: 141 | tuple[defaultdict, Counter]: A dictionary that records the number of 142 | gflops for each operation and a Counter that records the number of 143 | unsupported operations. 144 | """ 145 | if supported_ops is None: 146 | supported_ops = {} 147 | flop_counter = FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops) 148 | giga_flops = defaultdict(float) 149 | for op, flop in flop_counter.by_operator().items(): 150 | giga_flops[op] = flop / 1e9 151 | return giga_flops, flop_counter.unsupported_ops() 152 | -------------------------------------------------------------------------------- /fvcore/nn/focal_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def sigmoid_focal_loss( 10 | inputs: torch.Tensor, 11 | targets: torch.Tensor, 12 | alpha: float = -1, 13 | gamma: float = 2, 14 | reduction: str = "none", 15 | ) -> torch.Tensor: 16 | """ 17 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 18 | Args: 19 | inputs: A float tensor of arbitrary shape. 20 | The predictions for each example. 21 | targets: A float tensor with the same shape as inputs. Stores the binary 22 | classification label for each element in inputs 23 | (0 for the negative class and 1 for the positive class). 24 | alpha: (optional) Weighting factor in range (0,1) to balance 25 | positive vs negative examples. Default = -1 (no weighting). 26 | gamma: Exponent of the modulating factor (1 - p_t) to 27 | balance easy vs hard examples. 28 | reduction: 'none' | 'mean' | 'sum' 29 | 'none': No reduction will be applied to the output. 30 | 'mean': The output will be averaged. 31 | 'sum': The output will be summed. 32 | Returns: 33 | Loss tensor with the reduction option applied. 34 | """ 35 | inputs = inputs.float() 36 | targets = targets.float() 37 | p = torch.sigmoid(inputs) 38 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 39 | p_t = p * targets + (1 - p) * (1 - targets) 40 | loss = ce_loss * ((1 - p_t) ** gamma) 41 | 42 | if alpha >= 0: 43 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 44 | loss = alpha_t * loss 45 | 46 | if reduction == "mean": 47 | loss = loss.mean() 48 | elif reduction == "sum": 49 | loss = loss.sum() 50 | 51 | return loss 52 | 53 | 54 | # pyre-fixme[9]: sigmoid_focal_loss_jit has type `ScriptModule`; used as 55 | # `ScriptFunction[..., typing.Any]`. 56 | sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss) 57 | 58 | 59 | def sigmoid_focal_loss_star( 60 | inputs: torch.Tensor, 61 | targets: torch.Tensor, 62 | alpha: float = -1, 63 | gamma: float = 1, 64 | reduction: str = "none", 65 | ) -> torch.Tensor: 66 | """ 67 | FL* described in RetinaNet paper Appendix: https://arxiv.org/abs/1708.02002. 68 | Args: 69 | inputs: A float tensor of arbitrary shape. 70 | The predictions for each example. 71 | targets: A float tensor with the same shape as inputs. Stores the binary 72 | classification label for each element in inputs 73 | (0 for the negative class and 1 for the positive class). 74 | alpha: (optional) Weighting factor in range (0,1) to balance 75 | positive vs negative examples. Default = -1 (no weighting). 76 | gamma: Gamma parameter described in FL*. Default = 1 (no weighting). 77 | reduction: 'none' | 'mean' | 'sum' 78 | 'none': No reduction will be applied to the output. 79 | 'mean': The output will be averaged. 80 | 'sum': The output will be summed. 81 | Returns: 82 | Loss tensor with the reduction option applied. 83 | """ 84 | inputs = inputs.float() 85 | targets = targets.float() 86 | shifted_inputs = gamma * (inputs * (2 * targets - 1)) 87 | loss = -(F.logsigmoid(shifted_inputs)) / gamma 88 | 89 | if alpha >= 0: 90 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 91 | loss *= alpha_t 92 | 93 | if reduction == "mean": 94 | loss = loss.mean() 95 | elif reduction == "sum": 96 | loss = loss.sum() 97 | 98 | return loss 99 | 100 | 101 | # pyre-fixme[9]: sigmoid_focal_loss_star_jit has type `ScriptModule`; used as 102 | # `ScriptFunction[..., typing.Any]`. 103 | sigmoid_focal_loss_star_jit: "torch.jit.ScriptModule" = torch.jit.script( 104 | sigmoid_focal_loss_star 105 | ) 106 | -------------------------------------------------------------------------------- /fvcore/nn/giou_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import torch 6 | 7 | 8 | def giou_loss( 9 | boxes1: torch.Tensor, 10 | boxes2: torch.Tensor, 11 | reduction: str = "none", 12 | eps: float = 1e-7, 13 | ) -> torch.Tensor: 14 | """ 15 | Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) 16 | https://arxiv.org/abs/1902.09630 17 | 18 | Gradient-friendly IoU loss with an additional penalty that is non-zero when the 19 | boxes do not overlap and scales with the size of their smallest enclosing box. 20 | This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. 21 | 22 | Args: 23 | boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). 24 | reduction: 'none' | 'mean' | 'sum' 25 | 'none': No reduction will be applied to the output. 26 | 'mean': The output will be averaged. 27 | 'sum': The output will be summed. 28 | eps (float): small number to prevent division by zero 29 | """ 30 | 31 | x1, y1, x2, y2 = boxes1.unbind(dim=-1) 32 | x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) 33 | 34 | assert (x2 >= x1).all(), "bad box: x1 larger than x2" 35 | assert (y2 >= y1).all(), "bad box: y1 larger than y2" 36 | 37 | # Intersection keypoints 38 | xkis1 = torch.max(x1, x1g) 39 | ykis1 = torch.max(y1, y1g) 40 | xkis2 = torch.min(x2, x2g) 41 | ykis2 = torch.min(y2, y2g) 42 | 43 | intsctk = torch.zeros_like(x1) 44 | mask = (ykis2 > ykis1) & (xkis2 > xkis1) 45 | intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) 46 | unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk 47 | iouk = intsctk / (unionk + eps) 48 | 49 | # smallest enclosing box 50 | xc1 = torch.min(x1, x1g) 51 | yc1 = torch.min(y1, y1g) 52 | xc2 = torch.max(x2, x2g) 53 | yc2 = torch.max(y2, y2g) 54 | 55 | area_c = (xc2 - xc1) * (yc2 - yc1) 56 | miouk = iouk - ((area_c - unionk) / (area_c + eps)) 57 | 58 | loss = 1 - miouk 59 | 60 | if reduction == "mean": 61 | loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() 62 | elif reduction == "sum": 63 | loss = loss.sum() 64 | 65 | return loss 66 | -------------------------------------------------------------------------------- /fvcore/nn/jit_handles.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,3,16,33,6,23] 5 | # NOTE: most Any type in this file should be torch._C.Value - which was not yet annotated. 6 | # pyre also doesn't work well with many Optional in this file 7 | 8 | import typing 9 | from collections import Counter, OrderedDict 10 | from numbers import Number 11 | from typing import Any, Callable, List, Optional, Union 12 | 13 | import numpy as np 14 | 15 | 16 | try: 17 | from math import prod 18 | except ImportError: 19 | from numpy import prod 20 | 21 | 22 | Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], Number]] 23 | 24 | 25 | def get_shape(val: Any) -> Optional[List[int]]: 26 | """ 27 | Get the shapes from a jit value object. 28 | 29 | Args: 30 | val (torch._C.Value): jit value object. 31 | 32 | Returns: 33 | list(int): return a list of ints. 34 | """ 35 | if val.isCompleteTensor(): 36 | return val.type().sizes() 37 | else: 38 | return None 39 | 40 | 41 | """ 42 | Below are flop/activation counters for various ops. Every counter has the following signature: 43 | 44 | Args: 45 | inputs (list(torch._C.Value)): The inputs of the op in the form of a list of jit object. 46 | outputs (list(torch._C.Value)): The outputs of the op in the form of a list of jit object. 47 | 48 | Returns: 49 | number: The number of flops/activations for the operation. 50 | or Counter[str] 51 | """ 52 | 53 | 54 | def generic_activation_jit(op_name: Optional[str] = None) -> Handle: 55 | """ 56 | This method return a handle that counts the number of activation from the 57 | output shape for the specified operation. 58 | 59 | Args: 60 | op_name (str): The name of the operation. If given, the handle will 61 | return a counter using this name. 62 | 63 | Returns: 64 | Callable: An activation handle for the given operation. 65 | """ 66 | 67 | def _generic_activation_jit( 68 | i: Any, outputs: List[Any] 69 | ) -> Union[typing.Counter[str], Number]: 70 | """ 71 | This is a generic jit handle that counts the number of activations for any 72 | operation given the output shape. 73 | """ 74 | out_shape = get_shape(outputs[0]) 75 | ac_count = prod(out_shape) 76 | if op_name is None: 77 | return ac_count 78 | else: 79 | return Counter({op_name: ac_count}) 80 | 81 | return _generic_activation_jit 82 | 83 | 84 | def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 85 | """ 86 | Count flops for fully connected layers. 87 | """ 88 | # Count flop for nn.Linear 89 | # inputs is a list of length 3. 90 | input_shapes = [get_shape(v) for v in inputs[1:3]] 91 | # input_shapes[0]: [batch size, input feature dimension] 92 | # input_shapes[1]: [batch size, output feature dimension] 93 | assert len(input_shapes[0]) == 2, input_shapes[0] 94 | assert len(input_shapes[1]) == 2, input_shapes[1] 95 | batch_size, input_dim = input_shapes[0] 96 | output_dim = input_shapes[1][1] 97 | flops = batch_size * input_dim * output_dim 98 | return flops 99 | 100 | 101 | def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 102 | """ 103 | Count flops for the aten::linear operator. 104 | """ 105 | # Inputs is a list of length 3; unlike aten::addmm, it is the first 106 | # two elements that are relevant. 107 | input_shapes = [get_shape(v) for v in inputs[0:2]] 108 | # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] 109 | # input_shapes[1]: [output_feature_dim, input_feature_dim] 110 | assert input_shapes[0][-1] == input_shapes[1][-1] 111 | flops = prod(input_shapes[0]) * input_shapes[1][0] 112 | return flops 113 | 114 | 115 | def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 116 | """ 117 | Count flops for the bmm operation. 118 | """ 119 | # Inputs should be a list of length 2. 120 | # Inputs contains the shapes of two tensor. 121 | assert len(inputs) == 2, len(inputs) 122 | input_shapes = [get_shape(v) for v in inputs] 123 | n, c, t = input_shapes[0] 124 | d = input_shapes[-1][-1] 125 | flop = n * c * t * d 126 | return flop 127 | 128 | 129 | def conv_flop_count( 130 | x_shape: List[int], 131 | w_shape: List[int], 132 | out_shape: List[int], 133 | transposed: bool = False, 134 | ) -> Number: 135 | """ 136 | Count flops for convolution. Note only multiplication is 137 | counted. Computation for addition and bias is ignored. 138 | 139 | Flops for a transposed convolution are calculated as 140 | flops = (x_shape[2:] * prod(w_shape) * batch_size). 141 | 142 | Args: 143 | x_shape (list(int)): The input shape before convolution. 144 | w_shape (list(int)): The filter shape. 145 | out_shape (list(int)): The output shape after convolution. 146 | transposed (bool): is the convolution transposed 147 | Returns: 148 | int: the number of flops 149 | """ 150 | batch_size = x_shape[0] 151 | conv_shape = (x_shape if transposed else out_shape)[2:] 152 | flop = batch_size * prod(w_shape) * prod(conv_shape) 153 | return flop 154 | 155 | 156 | def conv_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 157 | """ 158 | Count flops for convolution. 159 | """ 160 | # Inputs of Convolution should be a list of length 12 or 13. They represent: 161 | # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding, 162 | # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn, 163 | # 10) deterministic_cudnn and 11) user_enabled_cudnn. 164 | # starting with #40737 it will be 12) user_enabled_tf32 165 | assert len(inputs) == 12 or len(inputs) == 13, len(inputs) 166 | x, w = inputs[:2] 167 | x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) 168 | transposed = inputs[6].toIValue() 169 | 170 | # use a custom name instead of "_convolution" 171 | return Counter( 172 | {"conv": conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)} 173 | ) 174 | 175 | 176 | def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 177 | """ 178 | Count flops for the einsum operation. 179 | """ 180 | # Inputs of einsum should be a list of length 2+. 181 | # Inputs[0] stores the equation used for einsum. 182 | # Inputs[1] stores the list of input shapes. 183 | # Inputs[2] optionally stores the optimized path of contraction. 184 | assert len(inputs) >= 2, len(inputs) 185 | equation = inputs[0].toIValue() 186 | # Get rid of white space in the equation string. 187 | equation = equation.replace(" ", "") 188 | input_shapes_jit = inputs[1].node().inputs() 189 | input_shapes = [get_shape(v) for v in input_shapes_jit] 190 | 191 | # Re-map equation so that same equation with different alphabet 192 | # representations will look the same. 193 | letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() 194 | mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} 195 | equation = equation.translate(mapping) 196 | 197 | if equation == "abc,abd->acd": 198 | n, c, t = input_shapes[0] 199 | p = input_shapes[-1][-1] 200 | flop = n * c * t * p 201 | return flop 202 | 203 | elif equation == "abc,adc->adb": 204 | n, t, g = input_shapes[0] 205 | c = input_shapes[-1][1] 206 | flop = n * t * g * c 207 | return flop 208 | else: 209 | np_arrs = [np.zeros(s) for s in input_shapes] 210 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 211 | for line in optim.split("\n"): 212 | if "optimized flop" in line.lower(): 213 | # divided by 2 because we count MAC (multiply-add counted as one flop) 214 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 215 | return flop 216 | raise NotImplementedError("Unsupported einsum operation.") 217 | 218 | 219 | def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 220 | """ 221 | Count flops for matmul. 222 | """ 223 | # Inputs should be a list of length 2. 224 | # Inputs contains the shapes of two matrices. 225 | input_shapes = [get_shape(v) for v in inputs] 226 | assert len(input_shapes) == 2, input_shapes 227 | assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes 228 | flop = prod(input_shapes[0]) * input_shapes[-1][-1] 229 | return flop 230 | 231 | 232 | def norm_flop_counter(affine_arg_index: int) -> Handle: 233 | """ 234 | Args: 235 | affine_arg_index: index of the affine argument in inputs 236 | """ 237 | 238 | def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 239 | """ 240 | Count flops for norm layers. 241 | """ 242 | # Inputs[0] contains the shape of the input. 243 | input_shape = get_shape(inputs[0]) 244 | has_affine = get_shape(inputs[affine_arg_index]) is not None 245 | assert 2 <= len(input_shape) <= 5, input_shape 246 | # 5 is just a rough estimate 247 | flop = prod(input_shape) * (5 if has_affine else 4) 248 | return flop 249 | 250 | return norm_flop_jit 251 | 252 | 253 | def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: 254 | training = inputs[5].toIValue() 255 | assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" 256 | if training: 257 | return norm_flop_counter(1)(inputs, outputs) # pyre-ignore 258 | has_affine = get_shape(inputs[1]) is not None 259 | input_shape = prod(get_shape(inputs[0])) 260 | return input_shape * (2 if has_affine else 1) 261 | 262 | 263 | def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Handle: 264 | """ 265 | Count flops by 266 | input_tensor.numel() * input_scale + output_tensor.numel() * output_scale 267 | 268 | Args: 269 | input_scale: scale of the input tensor (first argument) 270 | output_scale: scale of the output tensor (first element in outputs) 271 | """ 272 | 273 | def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: 274 | ret = 0 275 | if input_scale != 0: 276 | shape = get_shape(inputs[0]) 277 | ret += input_scale * prod(shape) 278 | if output_scale != 0: 279 | shape = get_shape(outputs[0]) 280 | ret += output_scale * prod(shape) 281 | return ret 282 | 283 | return elementwise_flop 284 | -------------------------------------------------------------------------------- /fvcore/nn/parameter_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # pyre-strict 4 | 5 | import typing 6 | from collections import defaultdict 7 | 8 | import tabulate 9 | from torch import nn 10 | 11 | 12 | def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]: 13 | """ 14 | Count parameters of a model and its submodules. 15 | 16 | Args: 17 | model: a torch module 18 | 19 | Returns: 20 | dict (str-> int): the key is either a parameter name or a module name. 21 | The value is the number of elements in the parameter, or in all 22 | parameters of the module. The key "" corresponds to the total 23 | number of parameters of the model. 24 | """ 25 | r = defaultdict(int) 26 | for name, prm in model.named_parameters(): 27 | size = prm.numel() 28 | name = name.split(".") 29 | for k in range(0, len(name) + 1): 30 | prefix = ".".join(name[:k]) 31 | r[prefix] += size 32 | return r 33 | 34 | 35 | def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: 36 | """ 37 | Format the parameter count of the model (and its submodules or parameters) 38 | in a nice table. It looks like this: 39 | 40 | :: 41 | 42 | | name | #elements or shape | 43 | |:--------------------------------|:---------------------| 44 | | model | 37.9M | 45 | | backbone | 31.5M | 46 | | backbone.fpn_lateral3 | 0.1M | 47 | | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | 48 | | backbone.fpn_lateral3.bias | (256,) | 49 | | backbone.fpn_output3 | 0.6M | 50 | | backbone.fpn_output3.weight | (256, 256, 3, 3) | 51 | | backbone.fpn_output3.bias | (256,) | 52 | | backbone.fpn_lateral4 | 0.3M | 53 | | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | 54 | | backbone.fpn_lateral4.bias | (256,) | 55 | | backbone.fpn_output4 | 0.6M | 56 | | backbone.fpn_output4.weight | (256, 256, 3, 3) | 57 | | backbone.fpn_output4.bias | (256,) | 58 | | backbone.fpn_lateral5 | 0.5M | 59 | | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | 60 | | backbone.fpn_lateral5.bias | (256,) | 61 | | backbone.fpn_output5 | 0.6M | 62 | | backbone.fpn_output5.weight | (256, 256, 3, 3) | 63 | | backbone.fpn_output5.bias | (256,) | 64 | | backbone.top_block | 5.3M | 65 | | backbone.top_block.p6 | 4.7M | 66 | | backbone.top_block.p7 | 0.6M | 67 | | backbone.bottom_up | 23.5M | 68 | | backbone.bottom_up.stem | 9.4K | 69 | | backbone.bottom_up.res2 | 0.2M | 70 | | backbone.bottom_up.res3 | 1.2M | 71 | | backbone.bottom_up.res4 | 7.1M | 72 | | backbone.bottom_up.res5 | 14.9M | 73 | | ...... | ..... | 74 | 75 | Args: 76 | model: a torch module 77 | max_depth (int): maximum depth to recursively print submodules or 78 | parameters 79 | 80 | Returns: 81 | str: the table to be printed 82 | """ 83 | count: typing.DefaultDict[str, int] = parameter_count(model) 84 | # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. 85 | param_shape: typing.Dict[str, typing.Tuple] = { 86 | k: tuple(v.shape) for k, v in model.named_parameters() 87 | } 88 | 89 | # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. 90 | table: typing.List[typing.Tuple] = [] 91 | 92 | def format_size(x: int) -> str: 93 | if x > 1e8: 94 | return "{:.1f}G".format(x / 1e9) 95 | if x > 1e5: 96 | return "{:.1f}M".format(x / 1e6) 97 | if x > 1e2: 98 | return "{:.1f}K".format(x / 1e3) 99 | return str(x) 100 | 101 | def fill(lvl: int, prefix: str) -> None: 102 | if lvl >= max_depth: 103 | return 104 | for name, v in count.items(): 105 | if name.count(".") == lvl and name.startswith(prefix): 106 | indent = " " * (lvl + 1) 107 | if name in param_shape: 108 | table.append((indent + name, indent + str(param_shape[name]))) 109 | else: 110 | table.append((indent + name, indent + format_size(v))) 111 | fill(lvl + 1, name + ".") 112 | 113 | table.append(("model", format_size(count.pop("")))) 114 | fill(0, "") 115 | 116 | old_ws = tabulate.PRESERVE_WHITESPACE 117 | tabulate.PRESERVE_WHITESPACE = True 118 | tab = tabulate.tabulate( 119 | table, headers=["name", "#elements or shape"], tablefmt="pipe" 120 | ) 121 | tabulate.PRESERVE_WHITESPACE = old_ws 122 | return tab 123 | -------------------------------------------------------------------------------- /fvcore/nn/precise_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2,6,16] 5 | 6 | import itertools 7 | import logging 8 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Type 9 | 10 | import torch 11 | import tqdm 12 | from torch import nn 13 | 14 | 15 | # pyre-fixme[9]: BN_MODULE_TYPES has type `Tuple[Type[Module]]`; used as 16 | # `Tuple[Type[BatchNorm1d], Type[BatchNorm2d], Type[BatchNorm3d], 17 | # Type[SyncBatchNorm]]`. 18 | BN_MODULE_TYPES: Tuple[Type[nn.Module]] = ( 19 | torch.nn.BatchNorm1d, 20 | torch.nn.BatchNorm2d, 21 | torch.nn.BatchNorm3d, 22 | torch.nn.SyncBatchNorm, 23 | ) 24 | 25 | logger: logging.Logger = logging.getLogger(__name__) 26 | 27 | 28 | class _MeanOfBatchVarianceEstimator: 29 | """ 30 | Note that PyTorch's running_var means "running average of 31 | bessel-corrected batch variance". (PyTorch's BN normalizes by biased 32 | variance, but updates EMA by unbiased (bessel-corrected) variance). 33 | So we estimate population variance by "simple average of bessel-corrected 34 | batch variance". This is the same as in the BatchNorm paper, Sec 3.1. 35 | This estimator converges to population variance as long as batch size 36 | is not too small, and total #samples for PreciseBN is large enough. 37 | Its convergence is affected by small batch size. 38 | 39 | In this implementation, we also don't distinguish differences in batch size. 40 | We assume every batch contributes equally to the population statistics. 41 | """ 42 | 43 | def __init__(self, mean_buffer: torch.Tensor, var_buffer: torch.Tensor) -> None: 44 | self.pop_mean: torch.Tensor = torch.zeros_like(mean_buffer) 45 | self.pop_var: torch.Tensor = torch.zeros_like(var_buffer) 46 | self.ind = 0 47 | 48 | def update( 49 | self, batch_mean: torch.Tensor, batch_var: torch.Tensor, batch_size: int 50 | ) -> None: 51 | self.ind += 1 52 | self.pop_mean += (batch_mean - self.pop_mean) / self.ind 53 | self.pop_var += (batch_var - self.pop_var) / self.ind 54 | 55 | 56 | class _PopulationVarianceEstimator: 57 | """ 58 | Alternatively, one can estimate population variance by the sample variance 59 | of all batches combined. This needs to use the batch size of each batch 60 | in this function to undo the bessel-correction. 61 | This produces better estimation when each batch is small. 62 | See Appendix of the paper "Rethinking Batch in BatchNorm" for details. 63 | 64 | In this implementation, we also take into account varying batch sizes. 65 | A batch of N1 samples with a mean of M1 and a batch of N2 samples with a 66 | mean of M2 will produce a population mean of (N1M1+N2M2)/(N1+N2) instead 67 | of (M1+M2)/2. 68 | """ 69 | 70 | def __init__(self, mean_buffer: torch.Tensor, var_buffer: torch.Tensor) -> None: 71 | self.pop_mean: torch.Tensor = torch.zeros_like(mean_buffer) 72 | self.pop_square_mean: torch.Tensor = torch.zeros_like(var_buffer) 73 | self.tot = 0 74 | 75 | def update( 76 | self, batch_mean: torch.Tensor, batch_var: torch.Tensor, batch_size: int 77 | ) -> None: 78 | self.tot += batch_size 79 | batch_square_mean = batch_mean.square() + batch_var * ( 80 | (batch_size - 1) / batch_size 81 | ) 82 | self.pop_mean += (batch_mean - self.pop_mean) * (batch_size / self.tot) 83 | self.pop_square_mean += (batch_square_mean - self.pop_square_mean) * ( 84 | batch_size / self.tot 85 | ) 86 | 87 | @property 88 | def pop_var(self) -> torch.Tensor: 89 | return self.pop_square_mean - self.pop_mean.square() 90 | 91 | 92 | @torch.no_grad() 93 | def update_bn_stats( 94 | model: nn.Module, 95 | data_loader: Iterable[Any], 96 | num_iters: int = 200, 97 | progress: Optional[str] = None, 98 | ) -> None: 99 | """ 100 | Recompute and update the batch norm stats to make them more precise. During 101 | training both BN stats and the weight are changing after every iteration, so 102 | the running average can not precisely reflect the actual stats of the 103 | current model. 104 | In this function, the BN stats are recomputed with fixed weights, to make 105 | the running average more precise. Specifically, it computes the true average 106 | of per-batch mean/variance instead of the running average. 107 | See Sec. 3 of the paper "Rethinking Batch in BatchNorm" for details. 108 | 109 | Args: 110 | model (nn.Module): the model whose bn stats will be recomputed. 111 | 112 | Note that: 113 | 114 | 1. This function will not alter the training mode of the given model. 115 | Users are responsible for setting the layers that needs 116 | precise-BN to training mode, prior to calling this function. 117 | 118 | 2. Be careful if your models contain other stateful layers in 119 | addition to BN, i.e. layers whose state can change in forward 120 | iterations. This function will alter their state. If you wish 121 | them unchanged, you need to either pass in a submodule without 122 | those layers, or backup the states. 123 | data_loader (iterator): an iterator. Produce data as inputs to the model. 124 | num_iters (int): number of iterations to compute the stats. 125 | progress: None or "tqdm". If set, use tqdm to report the progress. 126 | """ 127 | bn_layers = get_bn_modules(model) 128 | 129 | if len(bn_layers) == 0: 130 | return 131 | logger.info(f"Computing precise BN statistics for {len(bn_layers)} BN layers ...") 132 | 133 | # In order to make the running stats only reflect the current batch, the 134 | # momentum is disabled. 135 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 136 | # Setting the momentum to 1.0 to compute the stats without momentum. 137 | momentum_actual = [bn.momentum for bn in bn_layers] 138 | for bn in bn_layers: 139 | bn.momentum = 1.0 140 | 141 | batch_size_per_bn_layer: Dict[nn.Module, int] = {} 142 | 143 | def get_bn_batch_size_hook( 144 | module: nn.Module, input: Tuple[torch.Tensor] 145 | ) -> Tuple[torch.Tensor]: 146 | assert ( 147 | module not in batch_size_per_bn_layer 148 | ), "Some BN layers are reused. This is not supported and probably not desired." 149 | x = input[0] 150 | assert isinstance( 151 | x, torch.Tensor 152 | ), f"BN layer should take tensor as input. Got {input}" 153 | # consider spatial dimensions as batch as well 154 | batch_size = x.numel() // x.shape[1] 155 | batch_size_per_bn_layer[module] = batch_size 156 | return (x,) 157 | 158 | hooks_to_remove = [ 159 | bn.register_forward_pre_hook(get_bn_batch_size_hook) for bn in bn_layers 160 | ] 161 | 162 | estimators = [ 163 | _PopulationVarianceEstimator(bn.running_mean, bn.running_var) 164 | for bn in bn_layers 165 | ] 166 | 167 | ind = -1 168 | for inputs in tqdm.tqdm( 169 | itertools.islice(data_loader, num_iters), 170 | total=num_iters, 171 | disable=progress != "tqdm", 172 | ): 173 | ind += 1 174 | batch_size_per_bn_layer.clear() 175 | model(inputs) 176 | 177 | for i, bn in enumerate(bn_layers): 178 | # Accumulates the bn stats. 179 | batch_size = batch_size_per_bn_layer.get(bn, None) 180 | if batch_size is None: 181 | continue # the layer was unused in this forward 182 | estimators[i].update(bn.running_mean, bn.running_var, batch_size) 183 | assert ind == num_iters - 1, ( 184 | "update_bn_stats is meant to run for {} iterations, " 185 | "but the dataloader stops at {} iterations.".format(num_iters, ind) 186 | ) 187 | 188 | for i, bn in enumerate(bn_layers): 189 | # Sets the precise bn stats. 190 | bn.running_mean = estimators[i].pop_mean 191 | bn.running_var = estimators[i].pop_var 192 | bn.momentum = momentum_actual[i] 193 | for hook in hooks_to_remove: 194 | hook.remove() 195 | 196 | 197 | def get_bn_modules(model: nn.Module) -> List[nn.Module]: 198 | """ 199 | Find all BatchNorm (BN) modules that are in training mode. See 200 | fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are 201 | included in this search. 202 | 203 | Args: 204 | model (nn.Module): a model possibly containing BN modules. 205 | 206 | Returns: 207 | list[nn.Module]: all BN modules in the model. 208 | """ 209 | # Finds all the bn layers. 210 | bn_layers = [ 211 | m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES) 212 | ] 213 | return bn_layers 214 | -------------------------------------------------------------------------------- /fvcore/nn/smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import torch 6 | 7 | 8 | def smooth_l1_loss( 9 | input: torch.Tensor, target: torch.Tensor, beta: float, reduction: str = "none" 10 | ) -> torch.Tensor: 11 | """ 12 | Smooth L1 loss defined in the Fast R-CNN paper as: 13 | :: 14 | | 0.5 * x ** 2 / beta if abs(x) < beta 15 | smoothl1(x) = | 16 | | abs(x) - 0.5 * beta otherwise, 17 | 18 | where x = input - target. 19 | 20 | Smooth L1 loss is related to Huber loss, which is defined as: 21 | :: 22 | | 0.5 * x ** 2 if abs(x) < beta 23 | huber(x) = | 24 | | beta * (abs(x) - 0.5 * beta) otherwise 25 | 26 | Smooth L1 loss is equal to huber(x) / beta. This leads to the following 27 | differences: 28 | 29 | - As beta -> 0, Smooth L1 loss converges to L1 loss, while Huber loss 30 | converges to a constant 0 loss. 31 | - As beta -> +inf, Smooth L1 converges to a constant 0 loss, while Huber loss 32 | converges to L2 loss. 33 | - For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant 34 | slope of 1. For Huber loss, the slope of the L1 segment is beta. 35 | 36 | Smooth L1 loss can be seen as exactly L1 loss, but with the abs(x) < beta 37 | portion replaced with a quadratic function such that at abs(x) = beta, its 38 | slope is 1. The quadratic segment smooths the L1 loss near x = 0. 39 | 40 | Args: 41 | input (Tensor): input tensor of any shape 42 | target (Tensor): target value tensor with the same shape as input 43 | beta (float): L1 to L2 change point. 44 | For beta values < 1e-5, L1 loss is computed. 45 | reduction: 'none' | 'mean' | 'sum' 46 | 'none': No reduction will be applied to the output. 47 | 'mean': The output will be averaged. 48 | 'sum': The output will be summed. 49 | 50 | Returns: 51 | The loss with the reduction option applied. 52 | 53 | Note: 54 | PyTorch's builtin "Smooth L1 loss" implementation does not actually 55 | implement Smooth L1 loss, nor does it implement Huber loss. It implements 56 | the special case of both in which they are equal (beta=1). 57 | See: https://pytorch.org/docs/stable/nn.html#torch.nn.SmoothL1Loss. 58 | """ 59 | if beta < 1e-5: 60 | # if beta == 0, then torch.where will result in nan gradients when 61 | # the chain rule is applied due to pytorch implementation details 62 | # (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of 63 | # zeros, rather than "no gradient"). To avoid this issue, we define 64 | # small values of beta to be exactly l1 loss. 65 | loss = torch.abs(input - target) 66 | else: 67 | n = torch.abs(input - target) 68 | cond = n < beta 69 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. 70 | loss = torch.where(cond, 0.5 * n**2 / beta, n - 0.5 * beta) 71 | 72 | if reduction == "mean": 73 | loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() 74 | elif reduction == "sum": 75 | loss = loss.sum() 76 | return loss 77 | -------------------------------------------------------------------------------- /fvcore/nn/squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SqueezeExcitation(nn.Module): 12 | """ 13 | Generic 2d/3d extension of Squeeze-and-Excitation (SE) block described in: 14 | *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* 15 | Squeezing spatially and exciting channel-wise 16 | """ 17 | 18 | block: nn.Module 19 | is_3d: bool 20 | 21 | def __init__( 22 | self, 23 | num_channels: int, 24 | num_channels_reduced: Optional[int] = None, 25 | reduction_ratio: float = 2.0, 26 | is_3d: bool = False, 27 | activation: Optional[nn.Module] = None, 28 | ) -> None: 29 | """ 30 | Args: 31 | num_channels (int): Number of input channels. 32 | num_channels_reduced (int): 33 | Number of reduced channels. If none, uses reduction_ratio to calculate. 34 | reduction_ratio (float): 35 | How much num_channels should be reduced if num_channels_reduced is not provided. 36 | is_3d (bool): Whether we're operating on 3d data (or 2d), default 2d. 37 | activation (nn.Module): Activation function used, defaults to ReLU. 38 | """ 39 | super().__init__() 40 | 41 | if num_channels_reduced is None: 42 | num_channels_reduced = int(num_channels // reduction_ratio) 43 | 44 | if activation is None: 45 | activation = nn.ReLU() 46 | 47 | if is_3d: 48 | conv1 = nn.Conv3d( 49 | num_channels, num_channels_reduced, kernel_size=1, bias=True 50 | ) 51 | conv2 = nn.Conv3d( 52 | num_channels_reduced, num_channels, kernel_size=1, bias=True 53 | ) 54 | else: 55 | conv1 = nn.Conv2d( 56 | num_channels, num_channels_reduced, kernel_size=1, bias=True 57 | ) 58 | conv2 = nn.Conv2d( 59 | num_channels_reduced, num_channels, kernel_size=1, bias=True 60 | ) 61 | 62 | self.is_3d = is_3d 63 | self.block = nn.Sequential( 64 | conv1, 65 | activation, 66 | conv2, 67 | nn.Sigmoid(), 68 | ) 69 | 70 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 71 | """ 72 | Args: 73 | input_tensor: X, shape = (batch_size, num_channels, H, W). 74 | For 3d X, shape = (batch_size, num_channels, T, H, W). 75 | output tensor 76 | """ 77 | mean_tensor = ( 78 | input_tensor.mean(dim=[2, 3, 4], keepdim=True) 79 | if self.is_3d 80 | else input_tensor.mean(dim=[2, 3], keepdim=True) 81 | ) 82 | output_tensor = torch.mul(input_tensor, self.block(mean_tensor)) 83 | 84 | return output_tensor 85 | 86 | 87 | class SpatialSqueezeExcitation(nn.Module): 88 | """ 89 | Generic 2d/3d extension of SE block 90 | squeezing channel-wise and exciting spatially described in: 91 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation 92 | in Fully Convolutional Networks, MICCAI 2018* 93 | """ 94 | 95 | block: nn.Module 96 | 97 | def __init__( 98 | self, 99 | num_channels: int, 100 | is_3d: bool = False, 101 | ) -> None: 102 | """ 103 | Args: 104 | num_channels (int): Number of input channels. 105 | is_3d (bool): Whether we're operating on 3d data. 106 | """ 107 | super().__init__() 108 | 109 | if is_3d: 110 | conv = nn.Conv3d(num_channels, 1, kernel_size=1, bias=True) 111 | else: 112 | conv = nn.Conv2d(num_channels, 1, kernel_size=1, bias=True) 113 | 114 | self.block = nn.Sequential( 115 | conv, 116 | nn.Sigmoid(), 117 | ) 118 | 119 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 120 | """ 121 | Args: 122 | input_tensor: X, shape = (batch_size, num_channels, H, W). 123 | For 3d X, shape = (batch_size, num_channels, T, H, W). 124 | output tensor 125 | """ 126 | output_tensor = torch.mul(input_tensor, self.block(input_tensor)) 127 | 128 | return output_tensor 129 | 130 | 131 | class ChannelSpatialSqueezeExcitation(nn.Module): 132 | """ 133 | Generic 2d/3d extension of concurrent spatial and channel squeeze & excitation: 134 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation 135 | in Fully Convolutional Networks, arXiv:1803.02579* 136 | """ 137 | 138 | def __init__( 139 | self, 140 | num_channels: int, 141 | num_channels_reduced: Optional[int] = None, 142 | reduction_ratio: float = 16.0, 143 | is_3d: bool = False, 144 | activation: Optional[nn.Module] = None, 145 | ) -> None: 146 | """ 147 | Args: 148 | num_channels (int): Number of input channels. 149 | num_channels_reduced (int): 150 | Number of reduced channels. If none, uses reduction_ratio to calculate. 151 | reduction_ratio (float): 152 | How much num_channels should be reduced if num_channels_reduced is not provided. 153 | is_3d (bool): Whether we're operating on 3d data (or 2d), default 2d. 154 | activation (nn.Module): Activation function used, defaults to ReLU. 155 | """ 156 | super().__init__() 157 | self.channel = SqueezeExcitation( 158 | num_channels=num_channels, 159 | num_channels_reduced=num_channels_reduced, 160 | reduction_ratio=reduction_ratio, 161 | is_3d=is_3d, 162 | activation=activation, 163 | ) 164 | self.spatial = SpatialSqueezeExcitation(num_channels=num_channels, is_3d=is_3d) 165 | 166 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 167 | """ 168 | Args: 169 | input_tensor: X, shape = (batch_size, num_channels, H, W) 170 | For 3d X, shape = (batch_size, num_channels, T, H, W) 171 | output tensor 172 | """ 173 | output_tensor = torch.max( 174 | self.channel(input_tensor), self.spatial(input_tensor) 175 | ) 176 | 177 | return output_tensor 178 | -------------------------------------------------------------------------------- /fvcore/nn/weight_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import torch.nn as nn 6 | 7 | 8 | def c2_xavier_fill(module: nn.Module) -> None: 9 | """ 10 | Initialize `module.weight` using the "XavierFill" implemented in Caffe2. 11 | Also initializes `module.bias` to 0. 12 | 13 | Args: 14 | module (torch.nn.Module): module to initialize. 15 | """ 16 | # Caffe2 implementation of XavierFill in fact 17 | # corresponds to kaiming_uniform_ in PyTorch 18 | # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Module, Tensor]`. 19 | nn.init.kaiming_uniform_(module.weight, a=1) 20 | if module.bias is not None: 21 | # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Module, 22 | # Tensor]`. 23 | nn.init.constant_(module.bias, 0) 24 | 25 | 26 | def c2_msra_fill(module: nn.Module) -> None: 27 | """ 28 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. 29 | Also initializes `module.bias` to 0. 30 | 31 | Args: 32 | module (torch.nn.Module): module to initialize. 33 | """ 34 | # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Module, Tensor]`. 35 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 36 | if module.bias is not None: 37 | # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Module, 38 | # Tensor]`. 39 | nn.init.constant_(module.bias, 0) 40 | -------------------------------------------------------------------------------- /fvcore/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | from .transform import * 6 | -------------------------------------------------------------------------------- /fvcore/transforms/transform_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # pyre-ignore-all-errors 8 | def to_float_tensor(numpy_array: np.ndarray) -> torch.Tensor: 9 | """ 10 | Convert the numpy array to torch float tensor with dimension of NxCxHxW. 11 | Pytorch is not fully supporting uint8, so convert tensor to float if the 12 | numpy_array is uint8. 13 | Args: 14 | numpy_array (ndarray): of shape NxHxWxC, or HxWxC or HxW to 15 | represent an image. The array can be of type uint8 in range 16 | [0, 255], or floating point in range [0, 1] or [0, 255]. 17 | Returns: 18 | float_tensor (tensor): converted float tensor. 19 | """ 20 | assert isinstance(numpy_array, np.ndarray) 21 | assert len(numpy_array.shape) in (2, 3, 4) 22 | 23 | # Some of the input numpy array has negative strides. Pytorch currently 24 | # does not support negative strides, perform ascontiguousarray to 25 | # resolve the issue. 26 | float_tensor = torch.from_numpy(np.ascontiguousarray(numpy_array)) 27 | if numpy_array.dtype in (np.uint8, np.int32, np.int64): 28 | float_tensor = float_tensor.float() 29 | 30 | if len(numpy_array.shape) == 2: 31 | # HxW -> 1x1xHxW. 32 | float_tensor = float_tensor[None, None, :, :] 33 | elif len(numpy_array.shape) == 3: 34 | # HxWxC -> 1xCxHxW. 35 | float_tensor = float_tensor.permute(2, 0, 1) 36 | float_tensor = float_tensor[None, :, :, :] 37 | elif len(numpy_array.shape) == 4: 38 | # NxHxWxC -> NxCxHxW 39 | float_tensor = float_tensor.permute(0, 3, 1, 2) 40 | else: 41 | raise NotImplementedError( 42 | "Unknow numpy_array dimension of {}".format(float_tensor.shape) 43 | ) 44 | return float_tensor 45 | 46 | 47 | def to_numpy( 48 | float_tensor: torch.Tensor, target_shape: list, target_dtype: np.dtype 49 | ) -> np.ndarray: 50 | """ 51 | Convert float tensor with dimension of NxCxHxW back to numpy array. 52 | Args: 53 | float_tensor (tensor): a float pytorch tensor with shape of NxCxHxW. 54 | target_shape (list): the target shape of the numpy array to represent 55 | the image as output. options include NxHxWxC, or HxWxC or HxW. 56 | target_dtype (dtype): the target dtype of the numpy array to represent 57 | the image as output. The array can be of type uint8 in range 58 | [0, 255], or floating point in range [0, 1] or [0, 255]. 59 | Returns: 60 | (ndarray): converted numpy array. 61 | """ 62 | assert len(target_shape) in (2, 3, 4) 63 | 64 | if len(target_shape) == 2: 65 | # 1x1xHxW -> HxW. 66 | assert float_tensor.shape[0] == 1 67 | assert float_tensor.shape[1] == 1 68 | float_tensor = float_tensor[0, 0, :, :] 69 | elif len(target_shape) == 3: 70 | assert float_tensor.shape[0] == 1 71 | # 1xCxHxW -> HxWxC. 72 | float_tensor = float_tensor[0].permute(1, 2, 0) 73 | elif len(target_shape) == 4: 74 | # NxCxHxW -> NxHxWxC 75 | float_tensor = float_tensor.permute(0, 2, 3, 1) 76 | else: 77 | raise NotImplementedError( 78 | "Unknow target shape dimension of {}".format(target_shape) 79 | ) 80 | if target_dtype == np.uint8: 81 | # Need to specifically call round here, notice in pytroch the round 82 | # is half to even. 83 | # https://github.com/pytorch/pytorch/issues/16498 84 | float_tensor = float_tensor.round().byte() 85 | return float_tensor.numpy() 86 | -------------------------------------------------------------------------------- /io_tests/test_file_io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import os 6 | import shutil 7 | import tempfile 8 | import unittest 9 | import uuid 10 | from typing import Optional 11 | from unittest.mock import MagicMock 12 | 13 | from fvcore.common.file_io import get_cache_dir, LazyPath, PathManager 14 | 15 | 16 | class TestNativeIO(unittest.TestCase): 17 | _tmpdir: Optional[str] = None 18 | _tmpfile: Optional[str] = None 19 | _tmpfile_contents = "Hello, World" 20 | 21 | @classmethod 22 | def setUpClass(cls) -> None: 23 | cls._tmpdir = tempfile.mkdtemp() 24 | with open(os.path.join(cls._tmpdir, "test.txt"), "w") as f: 25 | cls._tmpfile = f.name 26 | f.write(cls._tmpfile_contents) 27 | f.flush() 28 | 29 | @classmethod 30 | def tearDownClass(cls) -> None: 31 | # Cleanup temp working dir. 32 | if cls._tmpdir is not None: 33 | shutil.rmtree(cls._tmpdir) # type: ignore 34 | 35 | def test_open(self) -> None: 36 | # pyre-ignore 37 | with PathManager.open(self._tmpfile, "r") as f: 38 | self.assertEqual(f.read(), self._tmpfile_contents) 39 | 40 | def test_open_args(self) -> None: 41 | PathManager.set_strict_kwargs_checking(True) 42 | f = PathManager.open( 43 | self._tmpfile, # type: ignore 44 | mode="r", 45 | buffering=1, 46 | encoding="UTF-8", 47 | errors="ignore", 48 | newline=None, 49 | closefd=True, 50 | opener=None, 51 | ) 52 | f.close() 53 | 54 | def test_get_local_path(self) -> None: 55 | self.assertEqual( 56 | # pyre-ignore 57 | PathManager.get_local_path(self._tmpfile), 58 | self._tmpfile, 59 | ) 60 | 61 | def test_exists(self) -> None: 62 | # pyre-ignore 63 | self.assertTrue(PathManager.exists(self._tmpfile)) 64 | # pyre-fixme[6]: For 1st argument expected `Union[PathLike[str], str]` but 65 | # got `Optional[str]`. 66 | fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex) 67 | self.assertFalse(PathManager.exists(fake_path)) 68 | 69 | def test_isfile(self) -> None: 70 | self.assertTrue(PathManager.isfile(self._tmpfile)) # pyre-ignore 71 | # This is a directory, not a file, so it should fail 72 | self.assertFalse(PathManager.isfile(self._tmpdir)) # pyre-ignore 73 | # This is a non-existing path, so it should fail 74 | # pyre-fixme[6]: For 1st argument expected `Union[PathLike[str], str]` but 75 | # got `Optional[str]`. 76 | fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex) 77 | self.assertFalse(PathManager.isfile(fake_path)) 78 | 79 | def test_isdir(self) -> None: 80 | # pyre-ignore 81 | self.assertTrue(PathManager.isdir(self._tmpdir)) 82 | # This is a file, not a directory, so it should fail 83 | # pyre-ignore 84 | self.assertFalse(PathManager.isdir(self._tmpfile)) 85 | # This is a non-existing path, so it should fail 86 | # pyre-fixme[6]: For 1st argument expected `Union[PathLike[str], str]` but 87 | # got `Optional[str]`. 88 | fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex) 89 | self.assertFalse(PathManager.isdir(fake_path)) 90 | 91 | def test_ls(self) -> None: 92 | # Create some files in the tempdir to ls out. 93 | # pyre-fixme[6]: For 1st argument expected `typing_extensions.LiteralString` 94 | # but got `Optional[str]`. 95 | root_dir = os.path.join(self._tmpdir, "ls") 96 | os.makedirs(root_dir, exist_ok=True) 97 | files = sorted(["foo.txt", "bar.txt", "baz.txt"]) 98 | for f in files: 99 | open(os.path.join(root_dir, f), "a").close() 100 | 101 | children = sorted(PathManager.ls(root_dir)) 102 | self.assertListEqual(children, files) 103 | 104 | # Cleanup the tempdir 105 | shutil.rmtree(root_dir) 106 | 107 | def test_mkdirs(self) -> None: 108 | # pyre-fixme[6]: For 1st argument expected `typing_extensions.LiteralString` 109 | # but got `Optional[str]`. 110 | new_dir_path = os.path.join(self._tmpdir, "new", "tmp", "dir") 111 | self.assertFalse(PathManager.exists(new_dir_path)) 112 | PathManager.mkdirs(new_dir_path) 113 | self.assertTrue(PathManager.exists(new_dir_path)) 114 | 115 | def test_copy(self) -> None: 116 | _tmpfile_2 = self._tmpfile + "2" # pyre-ignore 117 | _tmpfile_2_contents = "something else" 118 | with open(_tmpfile_2, "w") as f: 119 | f.write(_tmpfile_2_contents) 120 | f.flush() 121 | # pyre-ignore 122 | assert PathManager.copy(self._tmpfile, _tmpfile_2, True) 123 | with PathManager.open(_tmpfile_2, "r") as f: 124 | self.assertEqual(f.read(), self._tmpfile_contents) 125 | 126 | def test_symlink(self) -> None: 127 | _symlink = self._tmpfile + "_symlink" # pyre-ignore 128 | assert PathManager.symlink(self._tmpfile, _symlink) # pyre-ignore 129 | with PathManager.open(_symlink) as f: 130 | self.assertEqual(f.read(), self._tmpfile_contents) 131 | assert os.readlink(_symlink) == self._tmpfile 132 | os.remove(_symlink) 133 | 134 | def test_rm(self) -> None: 135 | # pyre-fixme[6]: For 1st argument expected `typing_extensions.LiteralString` 136 | # but got `Optional[str]`. 137 | with open(os.path.join(self._tmpdir, "test_rm.txt"), "w") as f: 138 | rm_file = f.name 139 | f.write(self._tmpfile_contents) 140 | f.flush() 141 | self.assertTrue(PathManager.exists(rm_file)) 142 | self.assertTrue(PathManager.isfile(rm_file)) 143 | PathManager.rm(rm_file) 144 | self.assertFalse(PathManager.exists(rm_file)) 145 | self.assertFalse(PathManager.isfile(rm_file)) 146 | 147 | def test_bad_args(self) -> None: 148 | # TODO (T58240718): Replace with dynamic checks 149 | with self.assertRaises(ValueError): 150 | PathManager.copy(self._tmpfile, self._tmpfile, foo="foo") # type: ignore 151 | with self.assertRaises(ValueError): 152 | PathManager.exists(self._tmpfile, foo="foo") # type: ignore 153 | with self.assertRaises(ValueError): 154 | PathManager.get_local_path(self._tmpfile, foo="foo") # type: ignore 155 | with self.assertRaises(ValueError): 156 | PathManager.isdir(self._tmpfile, foo="foo") # type: ignore 157 | with self.assertRaises(ValueError): 158 | PathManager.isfile(self._tmpfile, foo="foo") # type: ignore 159 | with self.assertRaises(ValueError): 160 | PathManager.ls(self._tmpfile, foo="foo") # type: ignore 161 | with self.assertRaises(ValueError): 162 | PathManager.mkdirs(self._tmpfile, foo="foo") # type: ignore 163 | with self.assertRaises(ValueError): 164 | PathManager.open(self._tmpfile, foo="foo") # type: ignore 165 | with self.assertRaises(ValueError): 166 | PathManager.rm(self._tmpfile, foo="foo") # type: ignore 167 | 168 | PathManager.set_strict_kwargs_checking(False) 169 | 170 | PathManager.copy(self._tmpfile, self._tmpfile, foo="foo") # type: ignore 171 | PathManager.exists(self._tmpfile, foo="foo") # type: ignore 172 | PathManager.get_local_path(self._tmpfile, foo="foo") # type: ignore 173 | PathManager.isdir(self._tmpfile, foo="foo") # type: ignore 174 | PathManager.isfile(self._tmpfile, foo="foo") # type: ignore 175 | PathManager.ls(self._tmpdir, foo="foo") # type: ignore 176 | PathManager.mkdirs(self._tmpdir, foo="foo") # type: ignore 177 | f = PathManager.open(self._tmpfile, foo="foo") # type: ignore 178 | f.close() 179 | # pyre-fixme[6]: For 1st argument expected `typing_extensions.LiteralString` 180 | # but got `Optional[str]`. 181 | with open(os.path.join(self._tmpdir, "test_rm.txt"), "w") as f: 182 | rm_file = f.name 183 | f.write(self._tmpfile_contents) 184 | f.flush() 185 | PathManager.rm(rm_file, foo="foo") # type: ignore 186 | 187 | 188 | class TestHTTPIO(unittest.TestCase): 189 | _remote_uri = "https://www.facebook.com" 190 | _filename = "facebook.html" 191 | _cache_dir: str = os.path.join(get_cache_dir(), __name__) 192 | 193 | def setUp(self) -> None: 194 | if os.path.exists(self._cache_dir): 195 | shutil.rmtree(self._cache_dir) 196 | os.makedirs(self._cache_dir, exist_ok=True) 197 | 198 | def test_open_writes(self) -> None: 199 | # HTTPURLHandler does not support writing, only reading. 200 | with self.assertRaises(AssertionError): 201 | with PathManager.open(self._remote_uri, "w") as f: 202 | f.write("foobar") 203 | 204 | def test_bad_args(self) -> None: 205 | with self.assertRaises(NotImplementedError): 206 | PathManager.copy( 207 | self._remote_uri, 208 | self._remote_uri, 209 | foo="foo", # type: ignore 210 | ) 211 | with self.assertRaises(NotImplementedError): 212 | PathManager.exists(self._remote_uri, foo="foo") # type: ignore 213 | with self.assertRaises(ValueError): 214 | PathManager.get_local_path(self._remote_uri, foo="foo") # type: ignore 215 | with self.assertRaises(NotImplementedError): 216 | PathManager.isdir(self._remote_uri, foo="foo") # type: ignore 217 | with self.assertRaises(NotImplementedError): 218 | PathManager.isfile(self._remote_uri, foo="foo") # type: ignore 219 | with self.assertRaises(NotImplementedError): 220 | PathManager.ls(self._remote_uri, foo="foo") # type: ignore 221 | with self.assertRaises(NotImplementedError): 222 | PathManager.mkdirs(self._remote_uri, foo="foo") # type: ignore 223 | with self.assertRaises(ValueError): 224 | PathManager.open(self._remote_uri, foo="foo") # type: ignore 225 | with self.assertRaises(NotImplementedError): 226 | PathManager.rm(self._remote_uri, foo="foo") # type: ignore 227 | 228 | PathManager.set_strict_kwargs_checking(False) 229 | 230 | PathManager.get_local_path(self._remote_uri, foo="foo") # type: ignore 231 | f = PathManager.open(self._remote_uri, foo="foo") # type: ignore 232 | f.close() 233 | PathManager.set_strict_kwargs_checking(True) 234 | 235 | 236 | class TestLazyPath(unittest.TestCase): 237 | def test_materialize(self) -> None: 238 | f = MagicMock(return_value="test") 239 | x = LazyPath(f) 240 | f.assert_not_called() 241 | 242 | p = os.fspath(x) 243 | f.assert_called() 244 | self.assertEqual(p, "test") 245 | 246 | p = os.fspath(x) 247 | # should only be called once 248 | f.assert_called_once() 249 | self.assertEqual(p, "test") 250 | 251 | def test_join(self) -> None: 252 | f = MagicMock(return_value="test") 253 | x = LazyPath(f) 254 | p = os.path.join(x, "a.txt") 255 | f.assert_called_once() 256 | self.assertEqual(p, "test/a.txt") 257 | 258 | def test_getattr(self) -> None: 259 | x = LazyPath(lambda: "abc") 260 | with self.assertRaises(AttributeError): 261 | x.startswith("ab") 262 | _ = os.fspath(x) 263 | self.assertTrue(x.startswith("ab")) 264 | 265 | def test_PathManager(self) -> None: 266 | x = LazyPath(lambda: "./") 267 | output = PathManager.ls(x) # pyre-ignore 268 | output_gt = PathManager.ls("./") 269 | self.assertEqual(sorted(output), sorted(output_gt)) 270 | 271 | def test_getitem(self) -> None: 272 | x = LazyPath(lambda: "abc") 273 | with self.assertRaises(TypeError): 274 | x[0] 275 | _ = os.fspath(x) 276 | self.assertEqual(x[0], "a") 277 | 278 | 279 | class TestOneDrive(unittest.TestCase): 280 | _url = "https://1drv.ms/u/s!Aus8VCZ_C_33gQbJsUPTIj3rQu99" 281 | 282 | def test_one_drive_download(self) -> None: 283 | from fvcore.common.file_io import OneDrivePathHandler 284 | 285 | _direct_url = OneDrivePathHandler().create_one_drive_direct_download(self._url) 286 | _gt_url = ( 287 | "https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBd" 288 | + "XM4VkNaX0NfMzNnUWJKc1VQVElqM3JRdTk5/root/content" 289 | ) 290 | self.assertEqual(_direct_url, _gt_url) 291 | -------------------------------------------------------------------------------- /linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ev 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # Run this script at project root by "./linter.sh" before you commit. 5 | 6 | { 7 | black --version | grep -E "24.2.0" > /dev/null 8 | } || { 9 | echo "Linter requires 'black==24.2.0' !" 10 | exit 1 11 | } 12 | 13 | echo "Running isort..." 14 | isort -y -sp . 15 | 16 | echo "Running black..." 17 | black . 18 | 19 | echo "Running flake8..." 20 | if [ -x "$(command -v flake8)" ]; then 21 | flake8 . 22 | else 23 | python3 -m flake8 . 24 | fi 25 | 26 | command -v arc > /dev/null && { 27 | echo "Running arc lint ..." 28 | arc lint 29 | } 30 | -------------------------------------------------------------------------------- /packaging/build_all_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | set -ex 4 | 5 | for PV in 3.6 3.7 3.8 6 | do 7 | PYTHON_VERSION=$PV bash packaging/build_conda.sh 8 | done 9 | 10 | ls -Rl packaging 11 | 12 | for version in 36 37 38 13 | do 14 | (cd packaging/out && conda convert -p win-64 linux-64/fvcore-*-py$version.tar.bz2) 15 | (cd packaging/out && conda convert -p osx-64 linux-64/fvcore-*-py$version.tar.bz2) 16 | done 17 | 18 | ls -Rl packaging 19 | 20 | for dir in win-64 osx-64 linux-64 21 | do 22 | this_out_dir=packaging/output_files/$dir 23 | mkdir -p $this_out_dir 24 | cp packaging/out/$dir/*.tar.bz2 $this_out_dir 25 | done 26 | 27 | ls -Rl packaging 28 | -------------------------------------------------------------------------------- /packaging/build_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | set -ex 4 | 5 | mkdir -p packaging/out 6 | 7 | version=$(python -c "exec(open('fvcore/__init__.py').read()); print(__version__)") 8 | build_version=$version.post$(date +%Y%m%d) 9 | 10 | export BUILD_VERSION=$build_version 11 | 12 | conda build -c defaults -c conda-forge -c iopath --no-anaconda-upload --python "$PYTHON_VERSION" --output-folder packaging/out packaging/fvcore 13 | -------------------------------------------------------------------------------- /packaging/fvcore/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: fvcore 3 | version: "{{ environ.get('BUILD_VERSION') }}" 4 | 5 | source: 6 | path: ../.. 7 | 8 | requirements: 9 | 10 | host: 11 | - python 12 | - setuptools 13 | 14 | run: 15 | - python 16 | - yacs 17 | - pyyaml 18 | - tqdm 19 | - iopath 20 | - termcolor 21 | - pillow 22 | - tabulate 23 | 24 | build: 25 | string: py{{py}} 26 | script: BUILD_NIGHTLY=1 python setup.py install --single-version-externally-managed --record=record.txt # [not win] 27 | 28 | about: 29 | home: https://github.com/facebookresearch/fvcore 30 | license: Apache 2.0 31 | license_file: LICENSE 32 | summary: "Collection of common code that's shared among different research projects in FAIR computer vision team." 33 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 88 3 | multi_line_output = 3 4 | include_trailing_comma = True 5 | force_grid_wrap = 0 6 | default_section = THIRDPARTY 7 | lines_after_imports = 2 8 | combine_as_imports = True 9 | # Using force_alphabetical_sort_within_sections to match other Meta codebase 10 | # convention. 11 | force_alphabetical_sort_within_sections = True 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import os 5 | from os import path 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def get_version(): 11 | init_py_path = path.join( 12 | path.abspath(path.dirname(__file__)), "fvcore", "__init__.py" 13 | ) 14 | init_py = open(init_py_path, "r").readlines() 15 | version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] 16 | version = version_line.split("=")[-1].strip().strip("'\"") 17 | 18 | # Used by CI to build nightly packages. Users should never use it. 19 | # To build a nightly wheel, run: 20 | # BUILD_NIGHTLY=1 python setup.py sdist 21 | if os.getenv("BUILD_NIGHTLY", "0") == "1": 22 | from datetime import datetime 23 | 24 | date_str = datetime.today().strftime("%Y%m%d") 25 | # pip can perform proper comparison for ".post" suffix, 26 | # i.e., "1.1.post1234" >= "1.1" 27 | version = version + ".post" + date_str 28 | 29 | new_init_py = [l for l in init_py if not l.startswith("__version__")] 30 | new_init_py.append('__version__ = "{}"\n'.format(version)) 31 | with open(init_py_path, "w") as f: 32 | f.write("".join(new_init_py)) 33 | return version 34 | 35 | 36 | setup( 37 | name="fvcore", 38 | version=get_version(), 39 | author="FAIR", 40 | license="Apache 2.0", 41 | url="https://github.com/facebookresearch/fvcore", 42 | description="Collection of common code shared among different research " 43 | "projects in FAIR computer vision team", 44 | python_requires=">=3.6", 45 | install_requires=[ 46 | "numpy", 47 | "yacs>=0.1.6", 48 | "pyyaml>=5.1", 49 | "tqdm", 50 | "termcolor>=1.1", 51 | "Pillow", 52 | "tabulate", 53 | "iopath>=0.1.7", 54 | "dataclasses; python_version<'3.7'", 55 | ], 56 | extras_require={"all": ["shapely"]}, 57 | packages=find_packages(exclude=("tests",)), 58 | ) 59 | -------------------------------------------------------------------------------- /tests/bm_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | from fvcore.common.benchmark import benchmark 6 | from test_common import TestHistoryBuffer 7 | 8 | 9 | def bm_history_buffer_update() -> None: 10 | kwargs_list = [ 11 | {"num_values": 100}, 12 | {"num_values": 1000}, 13 | {"num_values": 10000}, 14 | {"num_values": 100000}, 15 | {"num_values": 1000000}, 16 | ] 17 | benchmark( 18 | TestHistoryBuffer.create_buffer_with_init, 19 | "BM_UPDATE", 20 | kwargs_list, 21 | warmup_iters=1, 22 | ) 23 | -------------------------------------------------------------------------------- /tests/bm_focal_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import torch 6 | from fvcore.common.benchmark import benchmark 7 | from test_focal_loss import TestFocalLoss, TestFocalLossStar 8 | 9 | 10 | def bm_focal_loss() -> None: 11 | if not torch.cuda.is_available(): 12 | print("Skipped: CUDA unavailable") 13 | return 14 | 15 | kwargs_list = [ 16 | {"N": 100}, 17 | {"N": 100, "alpha": 0}, 18 | {"N": 1000}, 19 | {"N": 1000, "alpha": 0}, 20 | {"N": 10000}, 21 | {"N": 10000, "alpha": 0}, 22 | ] 23 | benchmark( 24 | TestFocalLoss.focal_loss_with_init, "Focal_loss", kwargs_list, warmup_iters=1 25 | ) 26 | benchmark( 27 | TestFocalLoss.focal_loss_jit_with_init, 28 | "Focal_loss_JIT", 29 | kwargs_list, 30 | warmup_iters=1, 31 | ) 32 | 33 | 34 | def bm_focal_loss_star() -> None: 35 | if not torch.cuda.is_available(): 36 | print("Skipped: CUDA unavailable") 37 | return 38 | 39 | kwargs_list = [ 40 | {"N": 100}, 41 | {"N": 100, "alpha": 0}, 42 | {"N": 1000}, 43 | {"N": 1000, "alpha": 0}, 44 | {"N": 10000}, 45 | {"N": 10000, "alpha": 0}, 46 | ] 47 | benchmark( 48 | TestFocalLossStar.focal_loss_star_with_init, 49 | "Focal_loss_star", 50 | kwargs_list, 51 | warmup_iters=1, 52 | ) 53 | benchmark( 54 | TestFocalLossStar.focal_loss_star_jit_with_init, 55 | "Focal_loss_star_JIT", 56 | kwargs_list, 57 | warmup_iters=1, 58 | ) 59 | -------------------------------------------------------------------------------- /tests/bm_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | # pyre-strict 5 | 6 | import glob 7 | import importlib 8 | import sys 9 | from os.path import basename, dirname, isfile, join 10 | 11 | 12 | def main() -> None: 13 | if len(sys.argv) > 1: 14 | # Parse from flags. 15 | module_names = [n for n in sys.argv if n.startswith("bm_")] 16 | else: 17 | # Get all the benchmark files (starting with "bm_"). 18 | bm_files = glob.glob(join(dirname(__file__), "bm_*.py")) 19 | module_names = [ 20 | basename(f)[:-3] 21 | for f in bm_files 22 | if isfile(f) and not f.endswith("bm_main.py") 23 | ] 24 | 25 | for module_name in module_names: 26 | module = importlib.import_module(module_name) 27 | for attr in dir(module): 28 | # Run all the functions with names "bm_*" in the module. 29 | if attr.startswith("bm_"): 30 | print("Running benchmarks for " + module_name + "/" + attr + "...") 31 | getattr(module, attr)() 32 | 33 | 34 | if __name__ == "__main__": 35 | main() # pragma: no cover 36 | -------------------------------------------------------------------------------- /tests/configs/base.yaml: -------------------------------------------------------------------------------- 1 | KEY1: "base" 2 | KEY2: "base" 3 | -------------------------------------------------------------------------------- /tests/configs/base2.yaml: -------------------------------------------------------------------------------- 1 | KEY1: "base2" 2 | -------------------------------------------------------------------------------- /tests/configs/config.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | KEY2: "config" 3 | EXPRESSION: !!python/object/apply:eval ["[x ** 2 for x in [1, 2, 3]]"] 4 | -------------------------------------------------------------------------------- /tests/configs/config_multi_base.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: ["base.yaml", "base2.yaml"] 2 | KEY2: "config" 3 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_composite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import copy 6 | import unittest 7 | 8 | from fvcore.common.param_scheduler import ( 9 | CompositeParamScheduler, 10 | ConstantParamScheduler, 11 | CosineParamScheduler, 12 | LinearParamScheduler, 13 | StepParamScheduler, 14 | ) 15 | 16 | 17 | class TestCompositeScheduler(unittest.TestCase): 18 | _num_updates = 10 19 | 20 | def _get_valid_long_config(self): 21 | return { 22 | "schedulers": [ 23 | ConstantParamScheduler(0.1), 24 | ConstantParamScheduler(0.2), 25 | ConstantParamScheduler(0.3), 26 | ConstantParamScheduler(0.4), 27 | ], 28 | "lengths": [0.2, 0.4, 0.1, 0.3], 29 | "interval_scaling": ["rescaled"] * 4, 30 | } 31 | 32 | def _get_lengths_sum_less_one_config(self): 33 | return { 34 | "schedulers": [ 35 | ConstantParamScheduler(0.1), 36 | ConstantParamScheduler(0.2), 37 | ], 38 | "lengths": [0.7, 0.2999], 39 | "interval_scaling": ["rescaled", "rescaled"], 40 | } 41 | 42 | def _get_valid_mixed_config(self): 43 | return { 44 | "schedulers": [ 45 | StepParamScheduler(values=[0.1, 0.2, 0.3, 0.4, 0.5], num_updates=10), 46 | CosineParamScheduler(start_value=0.42, end_value=0.0001), 47 | ], 48 | "lengths": [0.5, 0.5], 49 | "interval_scaling": ["rescaled", "rescaled"], 50 | } 51 | 52 | def _get_valid_linear_config(self): 53 | return { 54 | "schedulers": [ 55 | LinearParamScheduler(start_value=0.0, end_value=0.5), 56 | LinearParamScheduler(start_value=0.5, end_value=1.0), 57 | ], 58 | "lengths": [0.5, 0.5], 59 | "interval_scaling": ["rescaled", "rescaled"], 60 | } 61 | 62 | def test_invalid_config(self): 63 | config = self._get_valid_mixed_config() 64 | bad_config = copy.deepcopy(config) 65 | 66 | # Size of schedulers and lengths doesn't match 67 | bad_config["schedulers"] = copy.deepcopy(config["schedulers"]) 68 | bad_config["lengths"] = copy.deepcopy(config["lengths"]) 69 | bad_config["schedulers"].append(bad_config["schedulers"][-1]) 70 | with self.assertRaises(ValueError): 71 | CompositeParamScheduler(**bad_config) 72 | 73 | # Sum of lengths < 1 74 | bad_config["schedulers"] = copy.deepcopy(config["schedulers"]) 75 | bad_config["lengths"][-1] -= 0.1 76 | with self.assertRaises(ValueError): 77 | CompositeParamScheduler(**bad_config) 78 | 79 | # Sum of lengths > 1 80 | bad_config["lengths"] = copy.deepcopy(config["lengths"]) 81 | bad_config["lengths"][-1] += 0.1 82 | with self.assertRaises(ValueError): 83 | CompositeParamScheduler(**bad_config) 84 | 85 | # Bad value for composition_mode 86 | bad_config["interval_scaling"] = ["rescaled", "rescaleds"] 87 | with self.assertRaises(ValueError): 88 | CompositeParamScheduler(**bad_config) 89 | 90 | # Wrong number composition modes 91 | bad_config["interval_scaling"] = ["rescaled"] 92 | with self.assertRaises(ValueError): 93 | CompositeParamScheduler(**bad_config) 94 | 95 | def test_long_scheduler(self): 96 | config = self._get_valid_long_config() 97 | 98 | scheduler = CompositeParamScheduler(**config) 99 | schedule = [ 100 | scheduler(epoch_num / self._num_updates) 101 | for epoch_num in range(self._num_updates) 102 | ] 103 | expected_schedule = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4] 104 | 105 | self.assertEqual(schedule, expected_schedule) 106 | 107 | def test_scheduler_lengths_within_epsilon_of_one(self): 108 | config = self._get_lengths_sum_less_one_config() 109 | scheduler = CompositeParamScheduler(**config) 110 | schedule = [ 111 | scheduler(epoch_num / self._num_updates) 112 | for epoch_num in range(self._num_updates) 113 | ] 114 | expected_schedule = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2] 115 | self.assertEqual(schedule, expected_schedule) 116 | 117 | def test_scheduler_with_mixed_types(self): 118 | config = self._get_valid_mixed_config() 119 | scheduler_0 = config["schedulers"][0] 120 | scheduler_1 = config["schedulers"][1] 121 | 122 | # Check scaled 123 | config["interval_scaling"] = ["rescaled", "rescaled"] 124 | scheduler = CompositeParamScheduler(**config) 125 | scaled_schedule = [ 126 | round(scheduler(epoch_num / self._num_updates), 4) 127 | for epoch_num in range(self._num_updates) 128 | ] 129 | expected_schedule = [ 130 | round(scheduler_0(epoch_num / self._num_updates), 4) 131 | for epoch_num in range(0, self._num_updates, 2) 132 | ] + [ 133 | round(scheduler_1(epoch_num / self._num_updates), 4) 134 | for epoch_num in range(0, self._num_updates, 2) 135 | ] 136 | self.assertEqual(scaled_schedule, expected_schedule) 137 | 138 | # Check fixed 139 | config["interval_scaling"] = ["fixed", "fixed"] 140 | scheduler = CompositeParamScheduler(**config) 141 | fixed_schedule = [ 142 | round(scheduler(epoch_num / self._num_updates), 4) 143 | for epoch_num in range(self._num_updates) 144 | ] 145 | expected_schedule = [ 146 | round(scheduler_0(epoch_num / self._num_updates), 4) 147 | for epoch_num in range(0, int(self._num_updates / 2)) 148 | ] + [ 149 | round(scheduler_1(epoch_num / self._num_updates), 4) 150 | for epoch_num in range(int(self._num_updates / 2), self._num_updates) 151 | ] 152 | self.assertEqual(fixed_schedule, expected_schedule) 153 | 154 | # Check warmup of rescaled then fixed 155 | config["interval_scaling"] = ["rescaled", "fixed"] 156 | scheduler = CompositeParamScheduler(**config) 157 | fixed_schedule = [ 158 | round(scheduler(epoch_num / self._num_updates), 4) 159 | for epoch_num in range(self._num_updates) 160 | ] 161 | expected_schedule = [ 162 | round(scheduler_0(epoch_num / self._num_updates), 4) 163 | for epoch_num in range(0, int(self._num_updates), 2) 164 | ] + [ 165 | round(scheduler_1(epoch_num / self._num_updates), 4) 166 | for epoch_num in range(int(self._num_updates / 2), self._num_updates) 167 | ] 168 | self.assertEqual(fixed_schedule, expected_schedule) 169 | 170 | def test_linear_scheduler_no_gaps(self): 171 | config = self._get_valid_linear_config() 172 | 173 | # Check rescaled 174 | scheduler = CompositeParamScheduler(**config) 175 | schedule = [ 176 | scheduler(epoch_num / self._num_updates) 177 | for epoch_num in range(self._num_updates) 178 | ] 179 | expected_schedule = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 180 | self.assertEqual(expected_schedule, schedule) 181 | 182 | # Check fixed composition gives same result as only 1 scheduler 183 | config["schedulers"][1] = config["schedulers"][0] 184 | config["interval_scaling"] = ["fixed", "fixed"] 185 | scheduler = CompositeParamScheduler(**config) 186 | linear_scheduler = config["schedulers"][0] 187 | schedule = [ 188 | scheduler(epoch_num / self._num_updates) 189 | for epoch_num in range(self._num_updates) 190 | ] 191 | expected_schedule = [ 192 | linear_scheduler(epoch_num / self._num_updates) 193 | for epoch_num in range(self._num_updates) 194 | ] 195 | self.assertEqual(expected_schedule, schedule) 196 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_constant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import unittest 6 | 7 | from fvcore.common.param_scheduler import ConstantParamScheduler 8 | 9 | 10 | class TestConstantScheduler(unittest.TestCase): 11 | _num_epochs = 12 12 | 13 | def test_scheduler(self): 14 | scheduler = ConstantParamScheduler(0.1) 15 | schedule = [ 16 | scheduler(epoch_num / self._num_epochs) 17 | for epoch_num in range(self._num_epochs) 18 | ] 19 | expected_schedule = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] 20 | 21 | self.assertEqual(schedule, expected_schedule) 22 | # The input for the scheduler should be in the interval [0;1), open 23 | with self.assertRaises(RuntimeError): 24 | scheduler(1) 25 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_cosine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import copy 6 | import unittest 7 | 8 | from fvcore.common.param_scheduler import CosineParamScheduler 9 | 10 | 11 | class TestCosineScheduler(unittest.TestCase): 12 | _num_epochs = 10 13 | 14 | def _get_valid_decay_config(self): 15 | return {"start_value": 0.1, "end_value": 0} 16 | 17 | def _get_valid_decay_config_intermediate_values(self): 18 | return [0.0976, 0.0905, 0.0794, 0.0655, 0.05, 0.0345, 0.0206, 0.0095, 0.0024] 19 | 20 | def test_scheduler_as_decay(self): 21 | config = self._get_valid_decay_config() 22 | 23 | scheduler = CosineParamScheduler(**config) 24 | schedule = [ 25 | round(scheduler(epoch_num / self._num_epochs), 4) 26 | for epoch_num in range(self._num_epochs) 27 | ] 28 | expected_schedule = [ 29 | config["start_value"] 30 | ] + self._get_valid_decay_config_intermediate_values() 31 | 32 | self.assertEqual(schedule, expected_schedule) 33 | 34 | def test_scheduler_as_warmup(self): 35 | config = self._get_valid_decay_config() 36 | # Swap start and end lr to change to warmup 37 | tmp = config["start_value"] 38 | config["start_value"] = config["end_value"] 39 | config["end_value"] = tmp 40 | 41 | scheduler = CosineParamScheduler(**config) 42 | schedule = [ 43 | round(scheduler(epoch_num / self._num_epochs), 4) 44 | for epoch_num in range(self._num_epochs) 45 | ] 46 | # Schedule should be decay reversed 47 | expected_schedule = [config["start_value"]] + list( 48 | reversed(self._get_valid_decay_config_intermediate_values()) 49 | ) 50 | 51 | self.assertEqual(schedule, expected_schedule) 52 | 53 | def test_scheduler_warmup_decay_match(self): 54 | decay_config = self._get_valid_decay_config() 55 | decay_scheduler = CosineParamScheduler(**decay_config) 56 | 57 | warmup_config = copy.deepcopy(decay_config) 58 | # Swap start and end lr to change to warmup 59 | tmp = warmup_config["start_value"] 60 | warmup_config["start_value"] = warmup_config["end_value"] 61 | warmup_config["end_value"] = tmp 62 | warmup_scheduler = CosineParamScheduler(**warmup_config) 63 | 64 | decay_schedule = [ 65 | round(decay_scheduler(epoch_num / 1000), 8) for epoch_num in range(1, 1000) 66 | ] 67 | warmup_schedule = [ 68 | round(warmup_scheduler(epoch_num / 1000), 8) for epoch_num in range(1, 1000) 69 | ] 70 | 71 | self.assertEqual(decay_schedule, list(reversed(warmup_schedule))) 72 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_exponential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import unittest 6 | 7 | from fvcore.common.param_scheduler import ExponentialParamScheduler 8 | 9 | 10 | class TestExponentialScheduler(unittest.TestCase): 11 | _num_epochs = 10 12 | 13 | def _get_valid_config(self): 14 | return {"start_value": 2.0, "decay": 0.1} 15 | 16 | def _get_valid_intermediate_values(self): 17 | return [1.5887, 1.2619, 1.0024, 0.7962, 0.6325, 0.5024, 0.3991, 0.3170, 0.2518] 18 | 19 | def test_scheduler(self): 20 | config = self._get_valid_config() 21 | 22 | scheduler = ExponentialParamScheduler(**config) 23 | schedule = [ 24 | round(scheduler(epoch_num / self._num_epochs), 4) 25 | for epoch_num in range(self._num_epochs) 26 | ] 27 | expected_schedule = [ 28 | config["start_value"] 29 | ] + self._get_valid_intermediate_values() 30 | 31 | self.assertEqual(schedule, expected_schedule) 32 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import unittest 6 | 7 | from fvcore.common.param_scheduler import LinearParamScheduler 8 | 9 | 10 | class TestLienarScheduler(unittest.TestCase): 11 | _num_epochs = 10 12 | 13 | def _get_valid_intermediate(self): 14 | return [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09] 15 | 16 | def _get_valid_config(self): 17 | return {"start_value": 0.0, "end_value": 0.1} 18 | 19 | def test_scheduler(self): 20 | config = self._get_valid_config() 21 | 22 | # Check as warmup 23 | scheduler = LinearParamScheduler(**config) 24 | schedule = [ 25 | round(scheduler(epoch_num / self._num_epochs), 4) 26 | for epoch_num in range(self._num_epochs) 27 | ] 28 | expected_schedule = [config["start_value"]] + self._get_valid_intermediate() 29 | self.assertEqual(schedule, expected_schedule) 30 | 31 | # Check as decay 32 | tmp = config["start_value"] 33 | config["start_value"] = config["end_value"] 34 | config["end_value"] = tmp 35 | scheduler = LinearParamScheduler(**config) 36 | schedule = [ 37 | round(scheduler(epoch_num / self._num_epochs), 4) 38 | for epoch_num in range(self._num_epochs) 39 | ] 40 | expected_schedule = [config["start_value"]] + list( 41 | reversed(self._get_valid_intermediate()) 42 | ) 43 | self.assertEqual(schedule, expected_schedule) 44 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_multi_step.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import copy 6 | import unittest 7 | 8 | from fvcore.common.param_scheduler import MultiStepParamScheduler 9 | 10 | 11 | class TestMultiStepParamScheduler(unittest.TestCase): 12 | _num_updates = 12 13 | 14 | def _get_valid_config(self): 15 | return { 16 | "num_updates": self._num_updates, 17 | "values": [0.1, 0.01, 0.001, 0.0001], 18 | "milestones": [4, 6, 8], 19 | } 20 | 21 | def test_invalid_config(self): 22 | # Invalid num epochs 23 | config = self._get_valid_config() 24 | 25 | bad_config = copy.deepcopy(config) 26 | bad_config["num_updates"] = -1 27 | with self.assertRaises(ValueError): 28 | MultiStepParamScheduler(**bad_config) 29 | 30 | bad_config["values"] = {"a": "b"} 31 | with self.assertRaises(ValueError): 32 | MultiStepParamScheduler(**bad_config) 33 | 34 | bad_config["values"] = [] 35 | with self.assertRaises(ValueError): 36 | MultiStepParamScheduler(**bad_config) 37 | 38 | # Invalid drop epochs 39 | bad_config["values"] = config["values"] 40 | bad_config["milestones"] = {"a": "b"} 41 | with self.assertRaises(ValueError): 42 | MultiStepParamScheduler(**bad_config) 43 | 44 | # Too many 45 | bad_config["milestones"] = [3, 6, 8, 12] 46 | with self.assertRaises(ValueError): 47 | MultiStepParamScheduler(**bad_config) 48 | 49 | # Too few 50 | bad_config["milestones"] = [3, 6] 51 | with self.assertRaises(ValueError): 52 | MultiStepParamScheduler(**bad_config) 53 | 54 | # Exceeds num_updates 55 | bad_config["milestones"] = [3, 6, 12] 56 | with self.assertRaises(ValueError): 57 | MultiStepParamScheduler(**bad_config) 58 | 59 | # Out of order 60 | bad_config["milestones"] = [3, 8, 6] 61 | with self.assertRaises(ValueError): 62 | MultiStepParamScheduler(**bad_config) 63 | 64 | def _test_config_scheduler(self, config, expected_schedule): 65 | scheduler = MultiStepParamScheduler(**config) 66 | schedule = [ 67 | scheduler(epoch_num / self._num_updates) 68 | for epoch_num in range(self._num_updates) 69 | ] 70 | self.assertEqual(schedule, expected_schedule) 71 | 72 | def test_scheduler(self): 73 | config = self._get_valid_config() 74 | expected_schedule = [ 75 | 0.1, 76 | 0.1, 77 | 0.1, 78 | 0.1, 79 | 0.01, 80 | 0.01, 81 | 0.001, 82 | 0.001, 83 | 0.0001, 84 | 0.0001, 85 | 0.0001, 86 | 0.0001, 87 | ] 88 | self._test_config_scheduler(config, expected_schedule) 89 | 90 | def test_default_config(self): 91 | config = self._get_valid_config() 92 | default_config = copy.deepcopy(config) 93 | # Default equispaced drop_epochs behavior 94 | del default_config["milestones"] 95 | expected_schedule = [ 96 | 0.1, 97 | 0.1, 98 | 0.1, 99 | 0.01, 100 | 0.01, 101 | 0.01, 102 | 0.001, 103 | 0.001, 104 | 0.001, 105 | 0.0001, 106 | 0.0001, 107 | 0.0001, 108 | ] 109 | self._test_config_scheduler(default_config, expected_schedule) 110 | 111 | def test_optional_args(self): 112 | v = [1, 0.1, 0.01] 113 | s1 = MultiStepParamScheduler(v, num_updates=90, milestones=[30, 60]) 114 | s2 = MultiStepParamScheduler(v, num_updates=90) 115 | s3 = MultiStepParamScheduler(v, milestones=[30, 60, 90]) 116 | for i in range(10): 117 | k = i / 10 118 | self.assertEqual(s1(k), s2(k)) 119 | self.assertEqual(s1(k), s3(k)) 120 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_polynomial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import unittest 6 | 7 | from fvcore.common.param_scheduler import PolynomialDecayParamScheduler 8 | 9 | 10 | class TestPolynomialScheduler(unittest.TestCase): 11 | _num_epochs = 10 12 | 13 | def test_scheduler(self): 14 | scheduler = PolynomialDecayParamScheduler(base_value=0.1, power=1) 15 | schedule = [ 16 | round(scheduler(epoch_num / self._num_epochs), 2) 17 | for epoch_num in range(self._num_epochs) 18 | ] 19 | expected_schedule = [0.1, 0.09, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01] 20 | 21 | self.assertEqual(schedule, expected_schedule) 22 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_step.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import copy 6 | import unittest 7 | from typing import Any, Dict 8 | 9 | from fvcore.common.param_scheduler import StepParamScheduler 10 | 11 | 12 | class TestStepScheduler(unittest.TestCase): 13 | _num_updates = 12 14 | 15 | def _get_valid_config(self) -> Dict[str, Any]: 16 | return { 17 | "num_updates": self._num_updates, 18 | "values": [0.1, 0.01, 0.001, 0.0001], 19 | } 20 | 21 | def test_invalid_config(self): 22 | # Invalid num epochs 23 | config = self._get_valid_config() 24 | 25 | bad_config = copy.deepcopy(config) 26 | bad_config["num_updates"] = -1 27 | with self.assertRaises(ValueError): 28 | StepParamScheduler(**bad_config) 29 | 30 | bad_config["values"] = {"a": "b"} 31 | with self.assertRaises(ValueError): 32 | StepParamScheduler(**bad_config) 33 | 34 | bad_config["values"] = [] 35 | with self.assertRaises(ValueError): 36 | StepParamScheduler(**bad_config) 37 | 38 | def test_scheduler(self): 39 | config = self._get_valid_config() 40 | 41 | scheduler = StepParamScheduler(**config) 42 | schedule = [ 43 | scheduler(epoch_num / self._num_updates) 44 | for epoch_num in range(self._num_updates) 45 | ] 46 | expected_schedule = [ 47 | 0.1, 48 | 0.1, 49 | 0.1, 50 | 0.01, 51 | 0.01, 52 | 0.01, 53 | 0.001, 54 | 0.001, 55 | 0.001, 56 | 0.0001, 57 | 0.0001, 58 | 0.0001, 59 | ] 60 | 61 | self.assertEqual(schedule, expected_schedule) 62 | -------------------------------------------------------------------------------- /tests/param_scheduler/test_scheduler_step_with_fixed_gamma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # pyre-ignore-all-errors 4 | 5 | import copy 6 | import unittest 7 | 8 | from fvcore.common.param_scheduler import StepWithFixedGammaParamScheduler 9 | 10 | 11 | class TestStepWithFixedGammaScheduler(unittest.TestCase): 12 | _num_updates = 12 13 | 14 | def _get_valid_config(self): 15 | return { 16 | "base_value": 1, 17 | "gamma": 0.1, 18 | "num_decays": 3, 19 | "num_updates": self._num_updates, 20 | } 21 | 22 | def test_invalid_config(self): 23 | config = self._get_valid_config() 24 | 25 | # Invalid num epochs 26 | bad_config = copy.deepcopy(config) 27 | bad_config["num_updates"] = -1 28 | with self.assertRaises(ValueError): 29 | StepWithFixedGammaParamScheduler(**bad_config) 30 | 31 | # Invalid num_decays 32 | bad_config["num_decays"] = 0 33 | with self.assertRaises(ValueError): 34 | StepWithFixedGammaParamScheduler(**bad_config) 35 | 36 | # Invalid base_value 37 | bad_config = copy.deepcopy(config) 38 | bad_config["base_value"] = -0.01 39 | with self.assertRaises(ValueError): 40 | StepWithFixedGammaParamScheduler(**bad_config) 41 | 42 | # Invalid gamma 43 | bad_config = copy.deepcopy(config) 44 | bad_config["gamma"] = [2] 45 | with self.assertRaises(ValueError): 46 | StepWithFixedGammaParamScheduler(**bad_config) 47 | 48 | def test_scheduler(self): 49 | config = self._get_valid_config() 50 | 51 | scheduler = StepWithFixedGammaParamScheduler(**config) 52 | schedule = [ 53 | scheduler(epoch_num / self._num_updates) 54 | for epoch_num in range(self._num_updates) 55 | ] 56 | expected_schedule = [ 57 | 1, 58 | 1, 59 | 1, 60 | 0.1, 61 | 0.1, 62 | 0.1, 63 | 0.01, 64 | 0.01, 65 | 0.01, 66 | 0.001, 67 | 0.001, 68 | 0.001, 69 | ] 70 | 71 | for param, expected_param in zip(schedule, expected_schedule): 72 | self.assertAlmostEqual(param, expected_param) 73 | -------------------------------------------------------------------------------- /tests/test_activation_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | # pyre-ignore-all-errors[2] 5 | 6 | import typing 7 | import unittest 8 | from collections import Counter, defaultdict 9 | from typing import Any, Dict, List, Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from fvcore.nn.activation_count import activation_count, ActivationCountAnalysis 15 | from fvcore.nn.jit_handles import Handle 16 | from numpy import prod 17 | 18 | 19 | class SmallConvNet(nn.Module): 20 | """ 21 | A network with three conv layers. This is used for testing convolution 22 | layers for activation count. 23 | """ 24 | 25 | def __init__(self, input_dim: int) -> None: 26 | super(SmallConvNet, self).__init__() 27 | conv_dim1 = 8 28 | conv_dim2 = 4 29 | conv_dim3 = 2 30 | self.conv1 = nn.Conv2d(input_dim, conv_dim1, 1, 1) 31 | self.conv2 = nn.Conv2d(conv_dim1, conv_dim2, 1, 2) 32 | self.conv3 = nn.Conv2d(conv_dim2, conv_dim3, 1, 2) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | x = self.conv1(x) 36 | x = self.conv2(x) 37 | x = self.conv3(x) 38 | return x 39 | 40 | def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]: 41 | x = self.conv1(x) 42 | count1 = prod(list(x.size())) 43 | x = self.conv2(x) 44 | count2 = prod(list(x.size())) 45 | x = self.conv3(x) 46 | count3 = prod(list(x.size())) 47 | return (count1, count2, count3) 48 | 49 | 50 | class TestActivationCountAnalysis(unittest.TestCase): 51 | """ 52 | Unittest for activation_count. 53 | """ 54 | 55 | def setUp(self) -> None: 56 | # nn.Linear uses a different operator based on version, so make sure 57 | # we are testing the right thing. 58 | lin = nn.Linear(10, 10) 59 | lin_x: torch.Tensor = torch.randn(10, 10) 60 | trace = torch.jit.trace(lin, (lin_x,)) 61 | node_kinds = [node.kind() for node in trace.graph.nodes()] 62 | assert "aten::addmm" in node_kinds or "aten::linear" in node_kinds 63 | if "aten::addmm" in node_kinds: 64 | self.lin_op = "addmm" 65 | else: 66 | self.lin_op = "linear" 67 | 68 | def test_conv2d(self) -> None: 69 | """ 70 | Test the activation count for convolutions. 71 | """ 72 | batch_size = 1 73 | input_dim = 3 74 | spatial_dim = 32 75 | x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) 76 | convNet = SmallConvNet(input_dim) 77 | ac_dict, _ = activation_count(convNet, (x,)) 78 | gt_count = sum(convNet.get_gt_activation(x)) 79 | 80 | gt_dict = defaultdict(float) 81 | gt_dict["conv"] = gt_count / 1e6 82 | self.assertDictEqual( 83 | gt_dict, 84 | ac_dict, 85 | "ConvNet with 3 layers failed to pass the activation count test.", 86 | ) 87 | 88 | def test_linear(self) -> None: 89 | """ 90 | Test the activation count for fully connected layer. 91 | """ 92 | batch_size = 1 93 | input_dim = 10 94 | output_dim = 20 95 | netLinear = nn.Linear(input_dim, output_dim) 96 | x = torch.randn(batch_size, input_dim) 97 | ac_dict, _ = activation_count(netLinear, (x,)) 98 | gt_count = batch_size * output_dim 99 | gt_dict = defaultdict(float) 100 | gt_dict[self.lin_op] = gt_count / 1e6 101 | self.assertEqual( 102 | gt_dict, ac_dict, "FC layer failed to pass the activation count test." 103 | ) 104 | 105 | def test_supported_ops(self) -> None: 106 | """ 107 | Test the activation count for user provided handles. 108 | """ 109 | 110 | def dummy_handle(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 111 | return Counter({"conv": 100}) 112 | 113 | batch_size = 1 114 | input_dim = 3 115 | spatial_dim = 32 116 | x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) 117 | convNet = SmallConvNet(input_dim) 118 | sp_ops: Dict[str, Handle] = {"aten::_convolution": dummy_handle} 119 | ac_dict, _ = activation_count(convNet, (x,), sp_ops) 120 | gt_dict = defaultdict(float) 121 | conv_layers = 3 122 | gt_dict["conv"] = 100 * conv_layers / 1e6 123 | self.assertDictEqual( 124 | gt_dict, 125 | ac_dict, 126 | "ConvNet with 3 layers failed to pass the activation count test.", 127 | ) 128 | 129 | def test_activation_count_class(self) -> None: 130 | """ 131 | Tests ActivationCountAnalysis. 132 | """ 133 | batch_size = 1 134 | input_dim = 10 135 | output_dim = 20 136 | netLinear = nn.Linear(input_dim, output_dim) 137 | x = torch.randn(batch_size, input_dim) 138 | gt_count = batch_size * output_dim 139 | gt_dict = Counter( 140 | { 141 | "": gt_count, 142 | } 143 | ) 144 | acts_counter = ActivationCountAnalysis(netLinear, (x,)) 145 | self.assertEqual(acts_counter.by_module(), gt_dict) 146 | 147 | batch_size = 1 148 | input_dim = 3 149 | spatial_dim = 32 150 | x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) 151 | convNet = SmallConvNet(input_dim) 152 | acts_counter = ActivationCountAnalysis(convNet, (x,)) 153 | gt_counts = convNet.get_gt_activation(x) 154 | gt_dict = Counter( 155 | { 156 | "": sum(gt_counts), 157 | "conv1": gt_counts[0], 158 | "conv2": gt_counts[1], 159 | "conv3": gt_counts[2], 160 | } 161 | ) 162 | 163 | self.assertDictEqual(gt_dict, acts_counter.by_module()) 164 | -------------------------------------------------------------------------------- /tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import copy 6 | import os 7 | import random 8 | import string 9 | import typing 10 | import unittest 11 | from collections import OrderedDict 12 | from tempfile import TemporaryDirectory 13 | from typing import Tuple 14 | from unittest.mock import MagicMock 15 | 16 | import torch 17 | from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer 18 | from torch import nn 19 | 20 | 21 | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) 22 | if TORCH_VERSION >= (1, 11): 23 | from torch.ao import quantization 24 | from torch.ao.quantization import ( 25 | disable_observer, 26 | enable_fake_quant, 27 | get_default_qat_qconfig, 28 | prepare_qat, 29 | ) 30 | elif ( 31 | TORCH_VERSION >= (1, 8) 32 | and hasattr(torch.quantization, "FakeQuantizeBase") 33 | and hasattr(torch.quantization, "ObserverBase") 34 | ): 35 | from torch import quantization 36 | from torch.quantization import ( 37 | disable_observer, 38 | enable_fake_quant, 39 | get_default_qat_qconfig, 40 | prepare_qat, 41 | ) 42 | 43 | 44 | class TestCheckpointer(unittest.TestCase): 45 | def _create_model(self) -> nn.Module: 46 | """ 47 | Create a simple model. 48 | """ 49 | return nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 1)) 50 | 51 | def _create_complex_model( 52 | self, 53 | ) -> typing.Tuple[nn.Module, typing.Dict[str, torch.Tensor]]: 54 | """ 55 | Create a complex model. 56 | """ 57 | m = nn.Module() 58 | # pyre-fixme[16]: `Module` has no attribute `block1`. 59 | m.block1 = nn.Module() 60 | # pyre-fixme[16]: `Module` has no attribute `layer1`. 61 | # pyre-fixme[16]: `Tensor` has no attribute `layer1`. 62 | m.block1.layer1 = nn.Linear(2, 3) 63 | # pyre-fixme[16]: `Module` has no attribute `layer2`. 64 | m.layer2 = nn.Linear(3, 2) 65 | # pyre-fixme[16]: `Module` has no attribute `res`. 66 | m.res = nn.Module() 67 | # pyre-fixme[16]: `Tensor` has no attribute `layer2`. 68 | m.res.layer2 = nn.Linear(3, 2) 69 | 70 | state_dict = OrderedDict() 71 | state_dict["layer1.weight"] = torch.rand(3, 2) 72 | state_dict["layer1.bias"] = torch.rand(3) 73 | state_dict["layer2.weight"] = torch.rand(2, 3) 74 | state_dict["layer2.bias"] = torch.rand(2) 75 | state_dict["res.layer2.weight"] = torch.rand(2, 3) 76 | state_dict["res.layer2.bias"] = torch.rand(2) 77 | 78 | return m, state_dict 79 | 80 | @unittest.skipIf( # pyre-fixme[56] 81 | (not hasattr(quantization, "ObserverBase")) 82 | or (not hasattr(quantization, "FakeQuantizeBase")), 83 | "quantization per-channel observer base classes not supported", 84 | ) 85 | def test_loading_objects_with_expected_shape_mismatches(self) -> None: 86 | def _get_model() -> torch.nn.Module: 87 | m = nn.Sequential(nn.Conv2d(2, 2, 1)) 88 | # pyre-fixme[16]: `Sequential` has no attribute `qconfig`. 89 | m.qconfig = get_default_qat_qconfig("fbgemm") 90 | m = prepare_qat(m) 91 | return m 92 | 93 | m1, m2 = _get_model(), _get_model() 94 | # Calibrate m1 with data to populate the observer stats 95 | m1(torch.randn(4, 2, 4, 4)) 96 | # Load m1's checkpoint into m2. This should work without errors even 97 | # though the shapes of per-channel observer buffers do not match. 98 | with TemporaryDirectory() as f: 99 | checkpointer = Checkpointer(m1, save_dir=f) 100 | checkpointer.save("checkpoint_file") 101 | 102 | # in the same folder 103 | fresh_checkpointer = Checkpointer(m2, save_dir=f) 104 | self.assertTrue(fresh_checkpointer.has_checkpoint()) 105 | self.assertEqual( 106 | fresh_checkpointer.get_checkpoint_file(), 107 | os.path.join(f, "checkpoint_file.pth"), 108 | ) 109 | fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file()) 110 | # Run the expected input through the network with observers 111 | # disabled and fake_quant enabled. If buffers were loaded correctly 112 | # into per-channel observers, this line will not crash. 113 | m2.apply(disable_observer) 114 | m2.apply(enable_fake_quant) 115 | m2(torch.randn(4, 2, 4, 4)) 116 | 117 | def test_from_last_checkpoint_model(self) -> None: 118 | """ 119 | test that loading works even if they differ by a prefix. 120 | """ 121 | for trained_model, fresh_model in [ 122 | (self._create_model(), self._create_model()), 123 | (nn.DataParallel(self._create_model()), self._create_model()), 124 | (self._create_model(), nn.DataParallel(self._create_model())), 125 | ( 126 | nn.DataParallel(self._create_model()), 127 | nn.DataParallel(self._create_model()), 128 | ), 129 | ]: 130 | with TemporaryDirectory() as f: 131 | checkpointer = Checkpointer(trained_model, save_dir=f) 132 | checkpointer.save("checkpoint_file") 133 | 134 | # in the same folder 135 | fresh_checkpointer = Checkpointer(fresh_model, save_dir=f) 136 | self.assertTrue(fresh_checkpointer.has_checkpoint()) 137 | self.assertEqual( 138 | fresh_checkpointer.get_checkpoint_file(), 139 | os.path.join(f, "checkpoint_file.pth"), 140 | ) 141 | fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file()) 142 | 143 | for trained_p, loaded_p in zip( 144 | trained_model.parameters(), fresh_model.parameters() 145 | ): 146 | # different tensor references 147 | self.assertFalse(id(trained_p) == id(loaded_p)) 148 | # same content 149 | self.assertTrue(trained_p.cpu().equal(loaded_p.cpu())) 150 | 151 | def test_from_name_file_model(self) -> None: 152 | """ 153 | test that loading works even if they differ by a prefix. 154 | """ 155 | for trained_model, fresh_model in [ 156 | (self._create_model(), self._create_model()), 157 | (nn.DataParallel(self._create_model()), self._create_model()), 158 | (self._create_model(), nn.DataParallel(self._create_model())), 159 | ( 160 | nn.DataParallel(self._create_model()), 161 | nn.DataParallel(self._create_model()), 162 | ), 163 | ]: 164 | with TemporaryDirectory() as f: 165 | checkpointer = Checkpointer( 166 | trained_model, save_dir=f, save_to_disk=True 167 | ) 168 | checkpointer.save("checkpoint_file") 169 | 170 | # on different folders. 171 | with TemporaryDirectory() as g: 172 | fresh_checkpointer = Checkpointer(fresh_model, save_dir=g) 173 | self.assertFalse(fresh_checkpointer.has_checkpoint()) 174 | self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "") 175 | fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth")) 176 | 177 | for trained_p, loaded_p in zip( 178 | trained_model.parameters(), fresh_model.parameters() 179 | ): 180 | # different tensor references. 181 | self.assertFalse(id(trained_p) == id(loaded_p)) 182 | # same content. 183 | self.assertTrue(trained_p.cpu().equal(loaded_p.cpu())) 184 | 185 | def test_checkpointables(self) -> None: 186 | """ 187 | Test saving and loading checkpointables. 188 | """ 189 | 190 | class CheckpointableObj: 191 | """ 192 | A dummy checkpointableObj class with state_dict and load_state_dict 193 | methods. 194 | """ 195 | 196 | def __init__(self): 197 | self.state = { 198 | self.random_handle(): self.random_handle() for i in range(10) 199 | } 200 | 201 | def random_handle(self, str_len=100) -> str: 202 | """ 203 | Generate a random string of fixed length. 204 | Args: 205 | str_len (str): length of the output string. 206 | Returns: 207 | (str): random generated handle. 208 | """ 209 | letters = string.ascii_uppercase 210 | return "".join(random.choice(letters) for i in range(str_len)) 211 | 212 | def state_dict(self): 213 | """ 214 | Return the state. 215 | Returns: 216 | (dict): return the state. 217 | """ 218 | return self.state 219 | 220 | def load_state_dict(self, state) -> None: 221 | """ 222 | Load the state from a given state. 223 | Args: 224 | state (dict): a key value dictionary. 225 | """ 226 | self.state = copy.deepcopy(state) 227 | 228 | trained_model, fresh_model = self._create_model(), self._create_model() 229 | with TemporaryDirectory() as f: 230 | checkpointables = CheckpointableObj() 231 | checkpointer = Checkpointer( 232 | trained_model, 233 | save_dir=f, 234 | save_to_disk=True, 235 | checkpointables=checkpointables, 236 | ) 237 | checkpointer.save("checkpoint_file") 238 | # in the same folder 239 | fresh_checkpointer = Checkpointer(fresh_model, save_dir=f) 240 | self.assertTrue(fresh_checkpointer.has_checkpoint()) 241 | self.assertEqual( 242 | fresh_checkpointer.get_checkpoint_file(), 243 | os.path.join(f, "checkpoint_file.pth"), 244 | ) 245 | checkpoint = fresh_checkpointer.load( 246 | fresh_checkpointer.get_checkpoint_file() 247 | ) 248 | state_dict = checkpointables.state_dict() 249 | for key, _ in state_dict.items(): 250 | self.assertTrue(checkpoint["checkpointables"].get(key) is not None) 251 | self.assertTrue(checkpoint["checkpointables"][key] == state_dict[key]) 252 | 253 | def test_load_reused_params(self) -> None: 254 | class Model(nn.Module): 255 | def __init__(self, has_y: bool) -> None: 256 | super().__init__() 257 | self.x = nn.Linear(10, 10) 258 | if has_y: 259 | self.y = self.x 260 | 261 | model = Model(has_y=False) 262 | model.x.bias.data.fill_(5.0) 263 | data = {"model": model.state_dict()} 264 | new_model = Model(has_y=True) 265 | chkpt = Checkpointer(new_model) 266 | chkpt.logger = logger = MagicMock() 267 | incompatible = chkpt._load_model(data) 268 | chkpt._log_incompatible_keys(incompatible) 269 | self.assertTrue( 270 | # pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`. 271 | torch.allclose(new_model.y.bias - 5.0, torch.zeros_like(new_model.y.bias)) 272 | ) 273 | logger.info.assert_not_called() 274 | 275 | @unittest.skipIf( # pyre-fixme[56] 276 | not hasattr(nn, "LazyLinear"), "LazyModule not supported" 277 | ) 278 | def test_load_lazy_module(self) -> None: 279 | def _get_model() -> nn.Sequential: 280 | return nn.Sequential(nn.LazyLinear(10)) 281 | 282 | m1, m2 = _get_model(), _get_model() 283 | m1(torch.randn(4, 2, 4, 4)) # initialize m1, but not m2 284 | # Load m1's checkpoint into m2. 285 | with TemporaryDirectory() as f: 286 | checkpointer = Checkpointer(m1, save_dir=f) 287 | checkpointer.save("checkpoint_file") 288 | 289 | fresh_checkpointer = Checkpointer(m2, save_dir=f) 290 | self.assertTrue(fresh_checkpointer.has_checkpoint()) 291 | self.assertEqual( 292 | fresh_checkpointer.get_checkpoint_file(), 293 | os.path.join(f, "checkpoint_file.pth"), 294 | ) 295 | fresh_checkpointer.load(fresh_checkpointer.get_checkpoint_file()) 296 | # pyre-fixme[6]: Incompatible parameter type: In call `torch._C._VariableFunctions.eq... 297 | self.assertTrue(torch.equal(m1[0].weight, m2[0].weight)) 298 | 299 | 300 | class TestPeriodicCheckpointer(unittest.TestCase): 301 | def _create_model(self) -> nn.Module: 302 | """ 303 | Create a simple model. 304 | """ 305 | return nn.Sequential(nn.Linear(2, 3), nn.Linear(3, 1)) 306 | 307 | def test_periodic_checkpointer(self) -> None: 308 | """ 309 | test that loading works even if they differ by a prefix. 310 | """ 311 | _period = 10 312 | _max_iter = 100 313 | for trained_model in [ 314 | self._create_model(), 315 | nn.DataParallel(self._create_model()), 316 | ]: 317 | with TemporaryDirectory() as f: 318 | checkpointer = Checkpointer( 319 | trained_model, save_dir=f, save_to_disk=True 320 | ) 321 | periodic_checkpointer = PeriodicCheckpointer(checkpointer, _period, 99) 322 | for iteration in range(_max_iter): 323 | periodic_checkpointer.step(iteration) 324 | path = os.path.join(f, "model_{:07d}.pth".format(iteration)) 325 | if (iteration + 1) % _period == 0: 326 | self.assertTrue(os.path.exists(path)) 327 | else: 328 | self.assertFalse(os.path.exists(path)) 329 | 330 | def test_periodic_checkpointer_max_to_keep(self) -> None: 331 | """ 332 | Test parameter: max_to_keep 333 | """ 334 | _period = 10 335 | _max_iter = 100 336 | _max_to_keep = 3 337 | for trained_model in [ 338 | self._create_model(), 339 | nn.DataParallel(self._create_model()), 340 | ]: 341 | with TemporaryDirectory() as f: 342 | checkpointer = Checkpointer( 343 | trained_model, save_dir=f, save_to_disk=True 344 | ) 345 | periodic_checkpointer = PeriodicCheckpointer( 346 | checkpointer, _period, 99, max_to_keep=_max_to_keep 347 | ) 348 | for _ in range(2): 349 | checkpoint_paths = [] 350 | 351 | for iteration in range(_max_iter): 352 | periodic_checkpointer.step(iteration) 353 | if (iteration + 1) % _period == 0: 354 | path = os.path.join(f, "model_{:07d}.pth".format(iteration)) 355 | checkpoint_paths.append(path) 356 | 357 | for path in checkpoint_paths[:-_max_to_keep]: 358 | self.assertFalse(os.path.exists(path)) 359 | 360 | for path in checkpoint_paths[-_max_to_keep:]: 361 | self.assertTrue(os.path.exists(path)) 362 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import math 6 | import time 7 | import typing 8 | import unittest 9 | 10 | import numpy as np 11 | from fvcore.common.config import CfgNode 12 | from fvcore.common.history_buffer import HistoryBuffer 13 | from fvcore.common.registry import Registry 14 | from fvcore.common.timer import Timer 15 | from yaml.constructor import ConstructorError 16 | 17 | 18 | class TestHistoryBuffer(unittest.TestCase): 19 | def setUp(self) -> None: 20 | super().setUp() 21 | np.random.seed(42) 22 | 23 | @staticmethod 24 | def create_buffer_with_init( 25 | num_values: int, 26 | buffer_len: int = 1000000, 27 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 28 | ) -> typing.Callable[[], typing.Union[object, np.ndarray]]: 29 | """ 30 | Return a HistoryBuffer of the given length filled with random numbers. 31 | 32 | Args: 33 | buffer_len: length of the created history buffer. 34 | num_values: number of random numbers added to the history buffer. 35 | """ 36 | 37 | max_value = 1000 38 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 39 | values: np.ndarray = np.random.randint(max_value, size=num_values) 40 | 41 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 42 | def create_buffer() -> typing.Union[object, np.ndarray]: 43 | buf = HistoryBuffer(buffer_len) 44 | for v in values: 45 | buf.update(v) 46 | return buf, values 47 | 48 | return create_buffer 49 | 50 | def test_buffer(self) -> None: 51 | """ 52 | Test creation of HistoryBuffer and the methods provided in the class. 53 | """ 54 | 55 | num_iters = 100 56 | for _ in range(num_iters): 57 | gt_len = 1000 58 | buffer_len = np.random.randint(1, gt_len) 59 | create_buffer = TestHistoryBuffer.create_buffer_with_init( 60 | gt_len, buffer_len 61 | ) 62 | # pyre-fixme[23]: Unable to unpack `Union[ndarray[typing.Any, 63 | # typing.Any], object]` into 2 values. 64 | buf, gt = create_buffer() 65 | 66 | values, iterations = zip(*buf.values()) 67 | self.assertEqual(len(values), buffer_len) 68 | self.assertEqual(len(iterations), buffer_len) 69 | self.assertTrue((values == gt[-buffer_len:]).all()) 70 | iterations_gt = np.arange(gt_len - buffer_len, gt_len) 71 | self.assertTrue( 72 | (iterations == iterations_gt).all(), 73 | ", ".join(str(x) for x in iterations), 74 | ) 75 | self.assertAlmostEqual(buf.global_avg(), gt.mean()) 76 | w = 100 77 | effective_w = min(w, buffer_len) 78 | self.assertAlmostEqual( 79 | buf.median(w), 80 | np.median(gt[-effective_w:]), 81 | None, 82 | " ".join(str(x) for x in gt[-effective_w:]), 83 | ) 84 | self.assertAlmostEqual( 85 | buf.avg(w), 86 | np.mean(gt[-effective_w:]), 87 | None, 88 | " ".join(str(x) for x in gt[-effective_w:]), 89 | ) 90 | 91 | 92 | class TestTimer(unittest.TestCase): 93 | def test_timer(self) -> None: 94 | """ 95 | Test basic timer functions (pause, resume, and reset). 96 | """ 97 | timer = Timer() 98 | time.sleep(0.5) 99 | self.assertTrue(0.99 > timer.seconds() >= 0.5) 100 | 101 | timer.pause() 102 | time.sleep(0.5) 103 | 104 | self.assertTrue(0.99 > timer.seconds() >= 0.5) 105 | 106 | timer.resume() 107 | time.sleep(0.5) 108 | self.assertTrue(1.49 > timer.seconds() >= 1.0) 109 | 110 | timer.reset() 111 | self.assertTrue(0.49 > timer.seconds() >= 0) 112 | 113 | def test_avg_second(self) -> None: 114 | """ 115 | Test avg_seconds that counts the average time. 116 | """ 117 | for pause_second in (0.1, 0.15): 118 | timer = Timer() 119 | for t in (pause_second,) * 10: 120 | if timer.is_paused(): 121 | timer.resume() 122 | time.sleep(t) 123 | timer.pause() 124 | self.assertTrue( 125 | math.isclose(pause_second, timer.avg_seconds(), rel_tol=1e-1), 126 | msg="{}: {}".format(pause_second, timer.avg_seconds()), 127 | ) 128 | 129 | 130 | class TestCfgNode(unittest.TestCase): 131 | @staticmethod 132 | def gen_default_cfg() -> CfgNode: 133 | cfg = CfgNode() 134 | cfg.KEY1 = "default" 135 | cfg.KEY2 = "default" 136 | cfg.EXPRESSION = [3.0] 137 | 138 | return cfg 139 | 140 | def test_merge_from_file(self) -> None: 141 | """ 142 | Test merge_from_file function provided in the class. 143 | """ 144 | import pkg_resources 145 | 146 | base_yaml = pkg_resources.resource_filename(__name__, "configs/base.yaml") 147 | config_yaml = pkg_resources.resource_filename(__name__, "configs/config.yaml") 148 | config_multi_base_yaml = pkg_resources.resource_filename( 149 | __name__, "configs/config_multi_base.yaml" 150 | ) 151 | 152 | cfg = TestCfgNode.gen_default_cfg() 153 | cfg.merge_from_file(base_yaml) 154 | self.assertEqual(cfg.KEY1, "base") 155 | self.assertEqual(cfg.KEY2, "base") 156 | 157 | cfg = TestCfgNode.gen_default_cfg() 158 | 159 | with self.assertRaisesRegex(ConstructorError, "python/object/apply:eval"): 160 | # config.yaml contains unsafe yaml tags, 161 | # test if an exception is thrown 162 | cfg.merge_from_file(config_yaml) 163 | 164 | cfg.merge_from_file(config_yaml, allow_unsafe=True) 165 | self.assertEqual(cfg.KEY1, "base") 166 | self.assertEqual(cfg.KEY2, "config") 167 | self.assertEqual(cfg.EXPRESSION, [1, 4, 9]) 168 | 169 | cfg = TestCfgNode.gen_default_cfg() 170 | cfg.merge_from_file(config_multi_base_yaml, allow_unsafe=True) 171 | self.assertEqual(cfg.KEY1, "base2") 172 | self.assertEqual(cfg.KEY2, "config") 173 | 174 | def test_merge_from_list(self) -> None: 175 | """ 176 | Test merge_from_list function provided in the class. 177 | """ 178 | cfg = TestCfgNode.gen_default_cfg() 179 | cfg.merge_from_list(["KEY1", "list1", "KEY2", "list2"]) 180 | self.assertEqual(cfg.KEY1, "list1") 181 | self.assertEqual(cfg.KEY2, "list2") 182 | 183 | def test_setattr(self) -> None: 184 | """ 185 | Test __setattr__ function provided in the class. 186 | """ 187 | cfg = TestCfgNode.gen_default_cfg() 188 | cfg.KEY1 = "new1" 189 | cfg.KEY3 = "new3" 190 | self.assertEqual(cfg.KEY1, "new1") 191 | self.assertEqual(cfg.KEY3, "new3") 192 | 193 | # Test computed attributes, which can be inserted regardless of whether 194 | # the CfgNode is frozen or not. 195 | cfg = TestCfgNode.gen_default_cfg() 196 | cfg.COMPUTED_1 = "computed1" 197 | self.assertEqual(cfg.COMPUTED_1, "computed1") 198 | cfg.freeze() 199 | cfg.COMPUTED_2 = "computed2" 200 | self.assertEqual(cfg.COMPUTED_2, "computed2") 201 | 202 | # Test computed attributes, which should be 'insert only' (could not be 203 | # updated). 204 | cfg = TestCfgNode.gen_default_cfg() 205 | cfg.COMPUTED_1 = "computed1" 206 | with self.assertRaises(KeyError) as err: 207 | cfg.COMPUTED_1 = "update_computed1" 208 | self.assertTrue( 209 | "Computed attributed 'COMPUTED_1' already exists" in str(err.exception) 210 | ) 211 | 212 | # Resetting the same value should be safe: 213 | cfg.COMPUTED_1 = "computed1" 214 | 215 | 216 | class TestRegistry(unittest.TestCase): 217 | def test_registry(self) -> None: 218 | """ 219 | Test registering and accessing objects in the Registry. 220 | """ 221 | OBJECT_REGISTRY = Registry("OBJECT") 222 | 223 | @OBJECT_REGISTRY.register() 224 | class Object1: 225 | pass 226 | 227 | with self.assertRaises(AssertionError) as err: 228 | OBJECT_REGISTRY.register(Object1) 229 | self.assertTrue( 230 | "An object named 'Object1' was already registered in 'OBJECT' registry!" 231 | in str(err.exception) 232 | ) 233 | 234 | self.assertEqual(OBJECT_REGISTRY.get("Object1"), Object1) 235 | 236 | with self.assertRaises(KeyError) as err: 237 | OBJECT_REGISTRY.get("Object2") 238 | self.assertTrue( 239 | "No object named 'Object2' found in 'OBJECT' registry!" 240 | in str(err.exception) 241 | ) 242 | 243 | items = list(OBJECT_REGISTRY) 244 | self.assertListEqual( 245 | items, [("Object1", Object1)], "Registry iterable contains valid item" 246 | ) 247 | -------------------------------------------------------------------------------- /tests/test_giou_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | import torch 9 | from fvcore.nn import giou_loss 10 | 11 | 12 | class TestGIoULoss(unittest.TestCase): 13 | def setUp(self) -> None: 14 | super().setUp() 15 | np.random.seed(42) 16 | 17 | def test_giou_loss(self) -> None: 18 | # Identical boxes should have loss of 0 19 | box = torch.tensor([-1, -1, 1, 1], dtype=torch.float32) 20 | loss = giou_loss(box, box) 21 | self.assertTrue(np.allclose(loss, [0.0])) 22 | 23 | # quarter size box inside other box = IoU of 0.25 24 | box2 = torch.tensor([0, 0, 1, 1], dtype=torch.float32) 25 | loss = giou_loss(box, box2) 26 | self.assertTrue(np.allclose(loss, [0.75])) 27 | 28 | # Two side by side boxes, area=union 29 | # IoU=0 and GIoU=0 (loss 1.0) 30 | box3 = torch.tensor([0, 1, 1, 2], dtype=torch.float32) 31 | loss = giou_loss(box2, box3) 32 | self.assertTrue(np.allclose(loss, [1.0])) 33 | 34 | # Two diagonally adjacent boxes, area=2*union 35 | # IoU=0 and GIoU=-0.5 (loss 1.5) 36 | box4 = torch.tensor([1, 1, 2, 2], dtype=torch.float32) 37 | loss = giou_loss(box2, box4) 38 | self.assertTrue(np.allclose(loss, [1.5])) 39 | 40 | # Test batched loss and reductions 41 | box1s = torch.stack([box2, box2], dim=0) 42 | box2s = torch.stack([box3, box4], dim=0) 43 | 44 | loss = giou_loss(box1s, box2s, reduction="sum") 45 | self.assertTrue(np.allclose(loss, [2.5])) 46 | 47 | loss = giou_loss(box1s, box2s, reduction="mean") 48 | self.assertTrue(np.allclose(loss, [1.25])) 49 | 50 | def test_empty_inputs(self) -> None: 51 | box1 = torch.randn([0, 4], dtype=torch.float32).requires_grad_() 52 | box2 = torch.randn([0, 4], dtype=torch.float32).requires_grad_() 53 | loss = giou_loss(box1, box2, reduction="mean") 54 | loss.backward() 55 | 56 | self.assertEqual(loss.detach().numpy(), 0.0) 57 | self.assertIsNotNone(box1.grad) 58 | self.assertIsNotNone(box2.grad) 59 | 60 | loss = giou_loss(box1, box2, reduction="none") 61 | self.assertEqual(loss.numel(), 0) 62 | -------------------------------------------------------------------------------- /tests/test_layers_squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | import itertools 5 | import unittest 6 | from typing import Iterable 7 | 8 | import torch 9 | from fvcore.nn.squeeze_excitation import ( 10 | ChannelSpatialSqueezeExcitation, 11 | SpatialSqueezeExcitation, 12 | SqueezeExcitation, 13 | ) 14 | 15 | 16 | class TestSqueezeExcitation(unittest.TestCase): 17 | def setUp(self) -> None: 18 | super().setUp() 19 | torch.manual_seed(42) 20 | 21 | def test_build_se(self) -> None: 22 | """ 23 | Test SE model builder. 24 | """ 25 | for layer, num_channels, is_3d in itertools.product( 26 | ( 27 | SqueezeExcitation, 28 | SpatialSqueezeExcitation, 29 | ChannelSpatialSqueezeExcitation, 30 | ), 31 | (16, 32), 32 | (True, False), 33 | ): 34 | model = layer( 35 | num_channels=num_channels, 36 | is_3d=is_3d, 37 | ) 38 | 39 | # Test forwarding. 40 | for input_tensor in TestSqueezeExcitation._get_inputs( 41 | num_channels=num_channels, is_3d=is_3d 42 | ): 43 | if input_tensor.shape[1] != num_channels: 44 | with self.assertRaises(RuntimeError): 45 | output_tensor = model(input_tensor) 46 | continue 47 | else: 48 | output_tensor = model(input_tensor) 49 | 50 | input_shape = input_tensor.shape 51 | output_shape = output_tensor.shape 52 | 53 | self.assertEqual( 54 | input_shape, 55 | output_shape, 56 | "Input shape {} is different from output shape {}".format( 57 | input_shape, output_shape 58 | ), 59 | ) 60 | 61 | @staticmethod 62 | def _get_inputs3d(num_channels: int = 8) -> Iterable[torch.Tensor]: 63 | """ 64 | Provide different tensors as test cases. 65 | 66 | Yield: 67 | (torch.tensor): tensor as test case input. 68 | """ 69 | # Prepare random tensor as test cases. 70 | shapes = ( 71 | # Forward succeeded. 72 | (1, num_channels, 5, 7, 7), 73 | (2, num_channels, 5, 7, 7), 74 | (4, num_channels, 5, 7, 7), 75 | (4, num_channels, 5, 7, 7), 76 | (4, num_channels, 7, 7, 7), 77 | (4, num_channels, 7, 7, 14), 78 | (4, num_channels, 7, 14, 7), 79 | (4, num_channels, 7, 14, 14), 80 | # Forward failed. 81 | (8, num_channels * 2, 3, 7, 7), 82 | (8, num_channels * 4, 5, 7, 7), 83 | ) 84 | for shape in shapes: 85 | yield torch.rand(shape) 86 | 87 | @staticmethod 88 | def _get_inputs2d(num_channels: int = 8) -> Iterable[torch.Tensor]: 89 | """ 90 | Provide different tensors as test cases. 91 | 92 | Yield: 93 | (torch.tensor): tensor as test case input. 94 | """ 95 | # Prepare random tensor as test cases. 96 | shapes = ( 97 | # Forward succeeded. 98 | (1, num_channels, 7, 7), 99 | (2, num_channels, 7, 7), 100 | (4, num_channels, 7, 7), 101 | (4, num_channels, 7, 14), 102 | (4, num_channels, 14, 7), 103 | (4, num_channels, 14, 14), 104 | # Forward failed. 105 | (8, num_channels * 2, 7, 7), 106 | (8, num_channels * 4, 7, 7), 107 | ) 108 | for shape in shapes: 109 | yield torch.rand(shape) 110 | 111 | @staticmethod 112 | def _get_inputs( 113 | num_channels: int = 8, 114 | is_3d: bool = False, 115 | ) -> Iterable[torch.Tensor]: 116 | """ 117 | Provide different tensors as test cases. 118 | 119 | Yield: 120 | (torch.tensor): tensor as test case input. 121 | """ 122 | if is_3d: 123 | return TestSqueezeExcitation._get_inputs3d(num_channels=num_channels) 124 | else: 125 | return TestSqueezeExcitation._get_inputs2d(num_channels=num_channels) 126 | -------------------------------------------------------------------------------- /tests/test_param_count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | 6 | import unittest 7 | 8 | from fvcore.nn.parameter_count import parameter_count, parameter_count_table 9 | from torch import nn 10 | 11 | 12 | class NetWithReuse(nn.Module): 13 | def __init__(self, reuse: bool = False) -> None: 14 | super().__init__() 15 | self.conv1 = nn.Conv2d(100, 100, 3) 16 | self.conv2 = nn.Conv2d(100, 100, 3) 17 | if reuse: 18 | self.conv2.weight = self.conv1.weight 19 | 20 | 21 | class NetWithDupPrefix(nn.Module): 22 | def __init__(self) -> None: 23 | super().__init__() 24 | self.conv1 = nn.Conv2d(100, 100, 3) 25 | self.conv111 = nn.Conv2d(100, 100, 3) 26 | 27 | 28 | class TestParamCount(unittest.TestCase): 29 | def test_param(self) -> None: 30 | net = NetWithReuse() 31 | count = parameter_count(net) 32 | self.assertTrue(count[""], 180200) 33 | self.assertTrue(count["conv2"], 90100) 34 | 35 | def test_param_with_reuse(self) -> None: 36 | net = NetWithReuse(reuse=True) 37 | count = parameter_count(net) 38 | self.assertTrue(count[""], 90200) 39 | self.assertTrue(count["conv2"], 100) 40 | 41 | def test_param_with_same_prefix(self) -> None: 42 | net = NetWithDupPrefix() 43 | table = parameter_count_table(net) 44 | c = ["conv111.weight" in line for line in table.split("\n")] 45 | self.assertEqual( 46 | sum(c), 1 47 | ) # it only appears once, despite being a prefix of conv1 48 | -------------------------------------------------------------------------------- /tests/test_precise_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # -*- coding: utf-8 -*- 3 | 4 | # pyre-strict 5 | 6 | import itertools 7 | import unittest 8 | from typing import List, Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from fvcore.nn import update_bn_stats 13 | from torch import nn 14 | 15 | 16 | class TestPreciseBN(unittest.TestCase): 17 | def setUp(self) -> None: 18 | torch.set_rng_state(torch.manual_seed(42).get_state()) 19 | 20 | @staticmethod 21 | def compute_bn_stats( 22 | tensors: List[torch.Tensor], 23 | dims: List[int], 24 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 25 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 26 | """ 27 | Given a list of random initialized tensors, compute the mean and 28 | variance. 29 | Args: 30 | tensors (list): list of randomly initialized tensors. 31 | dims (list): list of dimensions to compute the mean and variance. 32 | """ 33 | mean = ( 34 | torch.stack([tensor.mean(dim=dims) for tensor in tensors]) 35 | .mean(dim=0) 36 | .numpy() 37 | ) 38 | mean_of_batch_var = ( 39 | torch.stack([tensor.var(dim=dims, unbiased=True) for tensor in tensors]) 40 | .mean(dim=0) 41 | .numpy() 42 | ) 43 | var = torch.cat(tensors, dim=0).var(dim=dims, unbiased=False).numpy() 44 | return mean, mean_of_batch_var, var 45 | 46 | def test_precise_bn(self) -> None: 47 | # Number of batches to test. 48 | NB = 8 49 | _bn_types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] 50 | _stats_dims = [[0, 2], [0, 2, 3], [0, 2, 3, 4]] 51 | _input_dims = [(16, 8, 24), (16, 8, 24, 8), (16, 8, 4, 12, 6)] 52 | assert len({len(_bn_types), len(_stats_dims), len(_input_dims)}) == 1 53 | 54 | for bn, stats_dim, input_dim in zip(_bn_types, _stats_dims, _input_dims): 55 | model = bn(input_dim[1]) 56 | model.train() 57 | tensors = [torch.randn(input_dim) for _ in range(NB)] 58 | mean, mean_of_batch_var, var = TestPreciseBN.compute_bn_stats( 59 | tensors, stats_dim 60 | ) 61 | 62 | old_weight = model.weight.detach().numpy() 63 | 64 | update_bn_stats( 65 | model, 66 | itertools.cycle(tensors), 67 | len(tensors), 68 | ) 69 | # pyre-fixme[16]: Optional type has no attribute `numpy`. 70 | self.assertTrue(np.allclose(model.running_mean.numpy(), mean)) 71 | self.assertTrue(np.allclose(model.running_var.numpy(), var)) 72 | 73 | # Test that the new estimator can handle varying batch size 74 | # It should obtain same results as earlier if the same input data are 75 | # split into different batch sizes. 76 | tensors = torch.split(torch.cat(tensors, dim=0), [2, 2, 4, 8, 16, 32, 64]) 77 | update_bn_stats( 78 | model, 79 | itertools.cycle(tensors), 80 | len(tensors), 81 | ) 82 | self.assertTrue(np.allclose(model.running_mean.numpy(), mean)) 83 | self.assertTrue(np.allclose(model.running_var.numpy(), var)) 84 | self.assertTrue(np.allclose(model.weight.detach().numpy(), old_weight)) 85 | 86 | def test_precise_bn_insufficient_data(self) -> None: 87 | input_dim = (16, 32, 24, 24) 88 | model = nn.BatchNorm2d(input_dim[1]) 89 | model.train() 90 | tensor = torch.randn(input_dim) 91 | with self.assertRaises(AssertionError): 92 | update_bn_stats(model, itertools.repeat(tensor, 10), 20) 93 | -------------------------------------------------------------------------------- /tests/test_smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | import torch 9 | from fvcore.nn import smooth_l1_loss 10 | 11 | 12 | class TestSmoothL1Loss(unittest.TestCase): 13 | def setUp(self) -> None: 14 | super().setUp() 15 | np.random.seed(42) 16 | 17 | def test_smooth_l1_loss(self) -> None: 18 | inputs = torch.tensor([1, 2, 3], dtype=torch.float32) 19 | targets = torch.tensor([1.1, 2, 4.5], dtype=torch.float32) 20 | beta = 0.5 21 | loss = smooth_l1_loss(inputs, targets, beta=beta, reduction="none").numpy() 22 | self.assertTrue(np.allclose(loss, [0.5 * 0.1**2 / beta, 0, 1.5 - 0.5 * beta])) 23 | 24 | beta = 0.05 25 | loss = smooth_l1_loss(inputs, targets, beta=beta, reduction="none").numpy() 26 | self.assertTrue(np.allclose(loss, [0.1 - 0.5 * beta, 0, 1.5 - 0.5 * beta])) 27 | 28 | def test_empty_inputs(self) -> None: 29 | inputs = torch.empty([0, 10], dtype=torch.float32).requires_grad_() 30 | targets = torch.empty([0, 10], dtype=torch.float32) 31 | loss = smooth_l1_loss(inputs, targets, beta=0.5, reduction="mean") 32 | loss.backward() 33 | 34 | self.assertEqual(loss.detach().numpy(), 0.0) 35 | self.assertIsNotNone(inputs.grad) 36 | -------------------------------------------------------------------------------- /tests/test_transform_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | from fvcore.transforms.transform_util import to_float_tensor, to_numpy 9 | 10 | 11 | class TestTransformUtil(unittest.TestCase): 12 | def test_convert(self) -> None: 13 | N, C, H, W = 4, 64, 14, 14 14 | np.random.seed(0) 15 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 16 | array_HW: np.ndarray = np.random.rand(H, W) 17 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 18 | array_HWC: np.ndarray = np.random.rand(H, W, C) 19 | # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 20 | array_NHWC: np.ndarray = np.random.rand(N, H, W, C) 21 | arrays = [ 22 | array_HW, 23 | (array_HW * 255).astype(np.uint8), 24 | array_HWC, 25 | (array_HWC * 255).astype(np.uint8), 26 | array_NHWC, 27 | (array_NHWC * 255).astype(np.uint8), 28 | ] 29 | 30 | for array in arrays: 31 | converted_tensor = to_float_tensor(array) 32 | # pyre-fixme[6]: For 2nd argument expected `List[Any]` but got 33 | # `tuple[int, ...]`. 34 | converted_array = to_numpy(converted_tensor, array.shape, array.dtype) 35 | self.assertTrue(np.allclose(array, converted_array)) 36 | -------------------------------------------------------------------------------- /tests/test_weight_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # pyre-strict 4 | 5 | import itertools 6 | import math 7 | import unittest 8 | 9 | import torch 10 | import torch.nn as nn 11 | from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill 12 | 13 | 14 | class TestWeightInit(unittest.TestCase): 15 | """ 16 | Test creation of WeightInit. 17 | """ 18 | 19 | def setUp(self) -> None: 20 | torch.set_rng_state(torch.manual_seed(42).get_state()) 21 | 22 | @staticmethod 23 | def msra_fill_std(fan_out: int) -> float: 24 | # Given the fan_out, calculate the expected standard deviation for msra 25 | # fill. 26 | # pyre-fixme[7]: Expected `float` but got `Tensor`. 27 | return torch.as_tensor(math.sqrt(2.0 / fan_out)) 28 | 29 | @staticmethod 30 | def xavier_fill_std(fan_in: int) -> float: 31 | # Given the fan_in, calculate the expected standard deviation for 32 | # xavier fill. 33 | # pyre-fixme[7]: Expected `float` but got `Tensor`. 34 | return torch.as_tensor(math.sqrt(1.0 / fan_in)) 35 | 36 | @staticmethod 37 | def weight_and_bias_dist_match( 38 | weight: torch.Tensor, 39 | bias: torch.Tensor, 40 | target_std: torch.Tensor, 41 | ) -> bool: 42 | # When the size of the weight is relative small, sampling on a small 43 | # number of elements would not give us a standard deviation that close 44 | # enough to the expected distribution. So the default rtol of 1e-8 will 45 | # break some test cases. Therefore a larger rtol is used. 46 | weight_dist_match = torch.allclose( 47 | target_std, torch.std(weight), rtol=1e-2, atol=0 48 | ) 49 | bias_dist_match = torch.nonzero(bias).nelement() == 0 50 | return weight_dist_match and bias_dist_match 51 | 52 | def test_conv_weight_init(self) -> None: 53 | # Test weight initialization for convolutional layers. 54 | kernel_sizes = [1, 3] 55 | channel_in_dims = [128, 256, 512] 56 | channel_out_dims = [256, 512, 1024] 57 | 58 | for layer in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: 59 | for k_size, c_in_dim, c_out_dim in itertools.product( 60 | kernel_sizes, channel_in_dims, channel_out_dims 61 | ): 62 | p = { 63 | "kernel_size": k_size, 64 | "in_channels": c_in_dim, 65 | "out_channels": c_out_dim, 66 | } 67 | # pyre-fixme[6]: For 1st argument expected `bool` but got `int`. 68 | # pyre-fixme[6]: For 1st argument expected `str` but got `int`. 69 | model = layer(**p) 70 | 71 | if layer is nn.Conv1d: 72 | spatial_dim = k_size 73 | elif layer is nn.Conv2d: 74 | spatial_dim = k_size**2 75 | elif layer is nn.Conv3d: 76 | spatial_dim = k_size**3 77 | 78 | # Calculate fan_in and fan_out. 79 | # pyre-fixme[61]: `spatial_dim` is undefined, or not always defined. 80 | fan_in = c_in_dim * spatial_dim 81 | # pyre-fixme[61]: `spatial_dim` is undefined, or not always defined. 82 | fan_out = c_out_dim * spatial_dim 83 | 84 | # Msra weight init check. 85 | c2_msra_fill(model) 86 | self.assertTrue( 87 | TestWeightInit.weight_and_bias_dist_match( 88 | model.weight, 89 | # pyre-fixme[6]: For 2nd argument expected `Tensor` but got 90 | # `Optional[Tensor]`. 91 | model.bias, 92 | # pyre-fixme[6]: For 3rd argument expected `Tensor` but got 93 | # `float`. 94 | TestWeightInit.msra_fill_std(fan_out), 95 | ) 96 | ) 97 | 98 | # Xavier weight init check. 99 | c2_xavier_fill(model) 100 | self.assertTrue( 101 | TestWeightInit.weight_and_bias_dist_match( 102 | model.weight, 103 | # pyre-fixme[6]: For 2nd argument expected `Tensor` but got 104 | # `Optional[Tensor]`. 105 | model.bias, 106 | # pyre-fixme[6]: For 3rd argument expected `Tensor` but got 107 | # `float`. 108 | TestWeightInit.xavier_fill_std(fan_in), 109 | ) 110 | ) 111 | 112 | def test_linear_weight_init(self) -> None: 113 | # Test weight initialization for linear layer. 114 | channel_in_dims = [128, 256, 512, 1024] 115 | channel_out_dims = [256, 512, 1024, 2048] 116 | 117 | for layer in [nn.Linear]: 118 | for c_in_dim, c_out_dim in itertools.product( 119 | channel_in_dims, channel_out_dims 120 | ): 121 | p = {"in_features": c_in_dim, "out_features": c_out_dim} 122 | # pyre-fixme[6]: For 1st argument expected `bool` but got `int`. 123 | model = layer(**p) 124 | 125 | # Calculate fan_in and fan_out. 126 | fan_in = c_in_dim 127 | fan_out = c_out_dim 128 | 129 | # Msra weight init check. 130 | c2_msra_fill(model) 131 | self.assertTrue( 132 | TestWeightInit.weight_and_bias_dist_match( 133 | model.weight, 134 | model.bias, 135 | # pyre-fixme[6]: For 3rd argument expected `Tensor` but got 136 | # `float`. 137 | TestWeightInit.msra_fill_std(fan_out), 138 | ) 139 | ) 140 | 141 | # Xavier weight init check. 142 | c2_xavier_fill(model) 143 | self.assertTrue( 144 | TestWeightInit.weight_and_bias_dist_match( 145 | model.weight, 146 | model.bias, 147 | # pyre-fixme[6]: For 3rd argument expected `Tensor` but got 148 | # `float`. 149 | TestWeightInit.xavier_fill_std(fan_in), 150 | ) 151 | ) 152 | --------------------------------------------------------------------------------