├── .codecov.yml ├── .coveragerc ├── .github └── workflows │ ├── publish-to-pypi.yml │ └── test.yml ├── .gitignore ├── CHANGES.md ├── LICENSE ├── README.md ├── benchmarks └── bench_context_manager_overhead.py ├── conftest.py ├── continuous_integration ├── build_test_ext.sh ├── check_no_test_skipped.py ├── install.sh ├── install_blis.sh ├── install_flexiblas.sh └── run_tests.sh ├── dev-requirements.txt ├── multiple_openmp.md ├── pyproject.toml ├── tests ├── __init__.py ├── _openmp_test_helper │ ├── __init__.py │ ├── build_utils.py │ ├── nested_prange_blas.pyx │ ├── nested_prange_blas_custom.pyx │ ├── openmp_helpers_inner.pxd │ ├── openmp_helpers_inner.pyx │ ├── openmp_helpers_outer.pyx │ ├── setup_inner.py │ ├── setup_nested_prange_blas.py │ └── setup_outer.py ├── _pyMylib │ ├── __init__.py │ └── my_threaded_lib.c ├── test_threadpoolctl.py └── utils.py └── threadpoolctl.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | comment: off 2 | 3 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=threadpoolctl 3 | 4 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish threadpoolctl 🎮 distribution 📦 to PyPI and TestPyPI 2 | # Taken from: 3 | # https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ 4 | 5 | on: push 6 | 7 | jobs: 8 | build: 9 | name: Build distribution 📦 10 | # Don't run on forked repositories 11 | if: github.event.repository.fork != true 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | with: 17 | persist-credentials: false 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.x" 23 | 24 | - name: Install pypa/build 25 | run: python -m pip install build --user 26 | 27 | - name: Build a binary wheel and a source tarball 28 | run: python -m build 29 | 30 | - name: Store the distribution packages 31 | uses: actions/upload-artifact@v4 32 | with: 33 | name: python-package-distributions 34 | path: dist/ 35 | retention-days: 1 36 | 37 | publish-to-pypi: 38 | name: >- 39 | Publish threadpoolctl 🎮 distribution 📦 to PyPI 40 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 41 | needs: 42 | - build 43 | runs-on: ubuntu-latest 44 | environment: 45 | name: pypi 46 | url: https://pypi.org/project/threadpoolctl/ 47 | permissions: 48 | id-token: write # IMPORTANT: mandatory for trusted publishing 49 | 50 | steps: 51 | - name: Download all the dists 52 | uses: actions/download-artifact@v4 53 | with: 54 | name: python-package-distributions 55 | path: dist/ 56 | - name: Publish distribution 📦 to PyPI 57 | uses: pypa/gh-action-pypi-publish@release/v1 58 | 59 | github-release: 60 | name: >- 61 | Sign the threadpoolctl 🎮 distribution 📦 with Sigstore 62 | and upload them to GitHub Release 63 | needs: 64 | - publish-to-pypi 65 | runs-on: ubuntu-latest 66 | 67 | permissions: 68 | contents: write # IMPORTANT: mandatory for making GitHub Releases 69 | id-token: write # IMPORTANT: mandatory for sigstore 70 | 71 | steps: 72 | - name: Download all the dists 73 | uses: actions/download-artifact@v4 74 | with: 75 | name: python-package-distributions 76 | path: dist/ 77 | - name: Sign the dists with Sigstore 78 | uses: sigstore/gh-action-sigstore-python@v3.0.0 79 | with: 80 | inputs: >- 81 | ./dist/*.tar.gz 82 | ./dist/*.whl 83 | - name: Create GitHub Release 84 | env: 85 | GITHUB_TOKEN: ${{ github.token }} 86 | run: >- 87 | gh release create 88 | "$GITHUB_REF_NAME" 89 | --repo "$GITHUB_REPOSITORY" 90 | --notes "" 91 | - name: Upload artifact signatures to GitHub Release 92 | env: 93 | GITHUB_TOKEN: ${{ github.token }} 94 | # Upload to GitHub Release using the `gh` CLI. 95 | # `dist/` contains the built packages, and the 96 | # sigstore-produced signatures and certificates. 97 | run: >- 98 | gh release upload 99 | "$GITHUB_REF_NAME" dist/** 100 | --repo "$GITHUB_REPOSITORY" 101 | 102 | publish-to-testpypi: 103 | name: Publish threadpoolctl 🎮 distribution 📦 to TestPyPI 104 | needs: 105 | - build 106 | runs-on: ubuntu-latest 107 | 108 | environment: 109 | name: testpypi 110 | url: https://test.pypi.org/project/threadpoolctl/ 111 | 112 | permissions: 113 | id-token: write # IMPORTANT: mandatory for trusted publishing 114 | 115 | steps: 116 | - name: Download all the dists 117 | uses: actions/download-artifact@v4 118 | with: 119 | name: python-package-distributions 120 | path: dist/ 121 | - name: Publish distribution 📦 to TestPyPI 122 | uses: pypa/gh-action-pypi-publish@release/v1 123 | with: 124 | verbose: true 125 | # skip-existing is required for .dev0 versions with fixed names but 126 | # different contents. See also: https://github.com/pypa/flit/issues/257 127 | skip-existing: true 128 | repository-url: https://test.pypi.org/legacy/ 129 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | permissions: 3 | contents: read 4 | 5 | on: 6 | push: 7 | branches: 8 | - master 9 | pull_request: 10 | branches: 11 | - master 12 | schedule: 13 | # Daily build at 1:00 AM UTC 14 | - cron: "0 1 * * *" 15 | 16 | # Cancel in-progress workflows when pushing 17 | # a new commit on the same branch 18 | concurrency: 19 | group: ${{ github.workflow }}-${{ github.ref }} 20 | cancel-in-progress: true 21 | 22 | jobs: 23 | 24 | linting: 25 | runs-on: ubuntu-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v3 29 | - name: Set up Python 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: "3.12" 33 | - name: Install black 34 | run: | 35 | pip install black==25.1.0 36 | - name: Run black 37 | run: | 38 | black --check --diff . 39 | 40 | testing: 41 | needs: linting 42 | timeout-minutes: 30 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | include: 47 | 48 | # Windows env with numpy, scipy,MKL installed through conda-forge 49 | - name: pylatest_conda_forge_mkl 50 | os: windows-latest 51 | PYTHON_VERSION: "*" 52 | PACKAGER: "conda-forge" 53 | BLAS: "mkl" 54 | # Windows env with numpy, scipy, OpenBLAS installed through conda-forge 55 | - name: py311_conda_forge_openblas 56 | os: windows-latest 57 | PYTHON_VERSION: "3.11" 58 | PACKAGER: "conda-forge" 59 | BLAS: "openblas" 60 | # Windows env with numpy, scipy installed through conda 61 | - name: py310_conda 62 | os: windows-latest 63 | PYTHON_VERSION: "3.10" 64 | PACKAGER: "conda" 65 | # Windows env with numpy, scipy installed through pip 66 | - name: py39_pip 67 | os: windows-latest 68 | PYTHON_VERSION: "3.9" 69 | PACKAGER: "pip" 70 | 71 | # MacOS env with OpenMP installed through homebrew 72 | - name: py39_conda_homebrew_libomp 73 | os: macos-latest 74 | PYTHON_VERSION: "3.9" 75 | PACKAGER: "conda" 76 | BLAS: "openblas" 77 | CC_OUTER_LOOP: "clang" 78 | CC_INNER_LOOP: "clang" 79 | INSTALL_LIBOMP: "homebrew" 80 | # MacOS env with OpenBLAS and OpenMP installed through conda-forge compilers 81 | - name: pylatest_conda_forge_clang_openblas 82 | os: macos-latest 83 | PYTHON_VERSION: "*" 84 | PACKAGER: "conda-forge" 85 | BLAS: "openblas" 86 | INSTALL_LIBOMP: "conda-forge" 87 | # MacOS env with FlexiBLAS 88 | - name: pylatest_flexiblas 89 | os: macos-latest 90 | PYTHON_VERSION: "*" 91 | INSTALL_BLAS: "flexiblas" 92 | PLATFORM_SPECIFIC_PACKAGES: "llvm-openmp" 93 | 94 | # Linux environments to test that packages that comes with Ubuntu 22.04 95 | # are correctly handled. 96 | - name: py39_ubuntu_atlas_gcc_gcc 97 | os: ubuntu-22.04 98 | PYTHON_VERSION: "3.9" 99 | PACKAGER: "ubuntu" 100 | APT_BLAS: "libatlas3-base libatlas-base-dev" 101 | CC_OUTER_LOOP: "gcc" 102 | CC_INNER_LOOP: "gcc" 103 | - name: py39_ubuntu_openblas_gcc_gcc 104 | os: ubuntu-22.04 105 | PYTHON_VERSION: "3.9" 106 | PACKAGER: "ubuntu" 107 | APT_BLAS: "libopenblas-base libopenblas-dev" 108 | CC_OUTER_LOOP: "gcc" 109 | CC_INNER_LOOP: "gcc" 110 | 111 | # Linux environment with development versions of numpy and scipy 112 | - name: pylatest_pip_dev 113 | os : ubuntu-latest 114 | PACKAGER: "pip-dev" 115 | PYTHON_VERSION: "*" 116 | CC_OUTER_LOOP: "gcc" 117 | CC_INNER_LOOP: "gcc" 118 | # Linux + Python 3.9 and homogeneous runtime nesting. 119 | - name: py39_conda_openblas_clang_clang 120 | os: ubuntu-latest 121 | PACKAGER: "conda" 122 | PYTHON_VERSION: "3.9" 123 | BLAS: "openblas" 124 | CC_OUTER_LOOP: "clang-18" 125 | CC_INNER_LOOP: "clang-18" 126 | # Linux environment with MKL and Clang (known to be unsafe for 127 | # threadpoolctl) to only test the warning from multiple OpenMP. 128 | - name: pylatest_conda_mkl_clang_gcc 129 | os: ubuntu-latest 130 | PYTHON_VERSION: "*" 131 | PACKAGER: "conda" 132 | BLAS: "mkl" 133 | CC_OUTER_LOOP: "clang-18" 134 | CC_INNER_LOOP: "gcc" 135 | TESTS: "libomp_libiomp_warning" 136 | # Linux environment with MKL, safe for threadpoolctl. 137 | - name: pylatest_conda_mkl_gcc_gcc 138 | os: ubuntu-latest 139 | PYTHON_VERSION: "*" 140 | PACKAGER: "conda" 141 | BLAS: "mkl" 142 | CC_OUTER_LOOP: "gcc" 143 | CC_INNER_LOOP: "gcc" 144 | MKL_THREADING_LAYER: "INTEL" 145 | # Linux + Python 3.11 with numpy / scipy installed with pip from PyPI 146 | # and heterogeneous OpenMP runtimes. 147 | - name: py311_pip_openblas_gcc_clang 148 | os: ubuntu-latest 149 | PACKAGER: "pip" 150 | PYTHON_VERSION: "3.11" 151 | CC_OUTER_LOOP: "gcc" 152 | CC_INNER_LOOP: "clang-18" 153 | # Linux environment with numpy from conda-forge channel and openblas-openmp 154 | - name: pylatest_conda_forge 155 | os: ubuntu-latest 156 | PACKAGER: "conda-forge" 157 | PYTHON_VERSION: "*" 158 | BLAS: "openblas" 159 | OPENBLAS_THREADING_LAYER: "openmp" 160 | CC_OUTER_LOOP: "gcc" 161 | CC_INNER_LOOP: "gcc" 162 | # Linux environment with no numpy and heterogeneous OpenMP runtimes. 163 | - name: pylatest_conda_nonumpy_gcc_clang 164 | os: ubuntu-latest 165 | PACKAGER: "conda" 166 | PYTHON_VERSION: "*" 167 | NO_NUMPY: "true" 168 | CC_OUTER_LOOP: "gcc" 169 | CC_INNER_LOOP: "clang-18" 170 | 171 | # Linux environments with numpy linked to BLIS 172 | - name: pylatest_blis_gcc_clang_openmp 173 | os: ubuntu-latest 174 | PYTHON_VERSION: "*" 175 | INSTALL_BLAS: "blis" 176 | BLIS_NUM_THREAEDS: "4" 177 | CC_OUTER_LOOP: "gcc" 178 | CC_INNER_LOOP: "gcc" 179 | BLIS_CC: "clang-18" 180 | BLIS_ENABLE_THREADING: "openmp" 181 | - name: pylatest_blis_clang_gcc_pthreads 182 | os: ubuntu-latest 183 | PYTHON_VERSION: "*" 184 | INSTALL_BLAS: "blis" 185 | BLIS_NUM_THREADS: "4" 186 | CC_OUTER_LOOP: "clang-18" 187 | CC_INNER_LOOP: "clang-18" 188 | BLIS_CC: "gcc-12" 189 | BLIS_ENABLE_THREADING: "pthreads" 190 | - name: pylatest_blis_no_threading 191 | os: ubuntu-latest 192 | PYTHON_VERSION: "*" 193 | INSTALL_BLAS: "blis" 194 | BLIS_NUM_THREADS: "1" 195 | CC_OUTER_LOOP: "gcc" 196 | CC_INNER_LOOP: "gcc" 197 | BLIS_CC: "gcc-12" 198 | BLIS_ENABLE_THREADING: "no" 199 | 200 | # Linux env with FlexiBLAS 201 | - name: pylatest_flexiblas 202 | os: ubuntu-latest 203 | PYTHON_VERSION: "*" 204 | INSTALL_BLAS: "flexiblas" 205 | PLATFORM_SPECIFIC_PACKAGES: "mkl" 206 | CC_OUTER_LOOP: "gcc" 207 | CC_INNER_LOOP: "gcc" 208 | 209 | env: ${{ matrix }} 210 | 211 | runs-on: ${{ matrix.os }} 212 | 213 | defaults: 214 | run: 215 | # Need to use this shell to get conda working properly. 216 | # See https://github.com/marketplace/actions/setup-miniconda#important 217 | shell: ${{ matrix.os == 'windows-latest' && 'cmd /C CALL {0}' || 'bash -el {0}' }} 218 | 219 | 220 | steps: 221 | - name: Checkout code 222 | uses: actions/checkout@v3 223 | 224 | - name: Setup conda 225 | uses: conda-incubator/setup-miniconda@v3 226 | with: 227 | auto-activate-base: true 228 | auto-update-conda: true 229 | miniforge-version: latest 230 | 231 | - name: Install dependencies 232 | run: | 233 | bash -el continuous_integration/install.sh 234 | 235 | - name: Test library 236 | run: | 237 | bash -el continuous_integration/run_tests.sh 238 | 239 | - name: Upload test results 240 | uses: actions/upload-artifact@v4 241 | with: 242 | # Requires a unique name for each job in the matrix of this run 243 | name: test_result_${{github.run_id}}_${{ matrix.os }}_${{ matrix.name }} 244 | path: test_result.xml 245 | retention-days: 1 246 | 247 | - name: Upload to Codecov 248 | uses: codecov/codecov-action@v5 249 | with: 250 | files: coverage.xml 251 | 252 | # Meta-test to ensure that at least one of the above CI configurations had 253 | # the necessary platform settings to execute each test without raising 254 | # skipping. 255 | meta_test: 256 | needs: testing 257 | runs-on: ubuntu-latest 258 | steps: 259 | - name: Set up Python 260 | uses: actions/setup-python@v4 261 | with: 262 | python-version: "3.12" 263 | 264 | - name: Checkout code 265 | uses: actions/checkout@v3 266 | 267 | - name: Download tests results 268 | uses: actions/download-artifact@v4 269 | with: 270 | path: test_results 271 | 272 | - name: Check no test always skipped 273 | run: | 274 | python continuous_integration/check_no_test_skipped.py test_results 275 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python generated files 2 | *.pyc 3 | __pycache__ 4 | .cache 5 | .pytest_cache 6 | 7 | # Cython/C generated files 8 | *.o 9 | *.so 10 | *.dylib 11 | tests/_openmp_test_helper/*.c 12 | 13 | # Python install files, build and release artifacts 14 | *.egg-info/ 15 | build 16 | dist 17 | 18 | # Coverage data 19 | .coverage 20 | /htmlcov 21 | 22 | # Developer tools 23 | .vscode 24 | 25 | # pytest 26 | .pytest_cache 27 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | 3.6.0 (2025-03-13) 2 | ================== 3 | 4 | - Added support for libraries with a path longer than 260 on Windows. The supported path 5 | length is now 10 times higher but not unlimited for security reasons. 6 | https://github.com/joblib/threadpoolctl/pull/189 7 | 8 | - Dropped official support for Python 3.8. 9 | https://github.com/joblib/threadpoolctl/pull/186 10 | https://github.com/joblib/threadpoolctl/pull/191 11 | 12 | 3.5.0 (2024-04-29) 13 | ================== 14 | 15 | - Added support for the Scientific Python version of OpenBLAS 16 | (https://github.com/MacPython/openblas-libs), which exposes symbols with different 17 | names than the ones of the original OpenBLAS library. 18 | https://github.com/joblib/threadpoolctl/pull/175 19 | 20 | 3.4.0 (2024-03-20) 21 | ================== 22 | 23 | - Added support for Python interpreters statically linked against libc or linked against 24 | alternative implementations of libc like musl (on Alpine Linux for instance). 25 | https://github.com/joblib/threadpoolctl/pull/171 26 | 27 | - Added support for Pyodide 28 | https://github.com/joblib/threadpoolctl/pull/169 29 | 30 | 3.3.0 (2024-02-14) 31 | ================== 32 | 33 | - Extended FlexiBLAS support to be able to switch backend at runtime. 34 | https://github.com/joblib/threadpoolctl/pull/163 35 | 36 | - Added support for FlexiBLAS 37 | https://github.com/joblib/threadpoolctl/pull/156 38 | 39 | - Fixed a bug where an unsupported library would be detected because it shares a common 40 | prefix with one of the supported libraries. Now the symbols are also checked to 41 | identify the supported libraries. 42 | https://github.com/joblib/threadpoolctl/pull/151 43 | 44 | 3.2.0 (2023-07-13) 45 | ================== 46 | 47 | - Dropped support for Python 3.6 and 3.7. 48 | 49 | - Added support for custom library controllers. Custom controllers must inherit from 50 | the `threadpoolctl.LibController` class and be registered to threadpoolctl using the 51 | `threadpoolctl.register` function. 52 | https://github.com/joblib/threadpoolctl/pull/138 53 | 54 | - A warning is raised on macOS when threadpoolctl finds both Intel OpenMP and LLVM 55 | OpenMP runtimes loaded simultaneously by the same Python program. See details and 56 | workarounds at https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md. 57 | https://github.com/joblib/threadpoolctl/pull/142 58 | 59 | 3.1.0 (2022-01-31) 60 | ================== 61 | 62 | - Fixed a detection issue of the BLAS libraires packaged by conda-forge on Windows. 63 | https://github.com/joblib/threadpoolctl/pull/112 64 | 65 | - `threadpool_limits` and `ThreadpoolController.limit` now accept the string 66 | "sequential_blas_under_openmp" for the `limits` parameter. It should only be used for 67 | the specific case when one wants to have sequential BLAS calls within an OpenMP 68 | parallel region. It takes into account the unexpected behavior of OpenBLAS with the 69 | OpenMP threading layer. 70 | https://github.com/joblib/threadpoolctl/pull/114 71 | 72 | 3.0.0 (2021-10-01) 73 | ================== 74 | 75 | - New object `threadpooctl.ThreadpoolController` which holds controllers for all the 76 | supported native libraries. The states of these libraries is accessible through the 77 | `info` method (equivalent to `threadpoolctl.threadpool_info()`) and their number of 78 | threads can be limited with the `limit` method which can be used as a context 79 | manager (equivalent to `threadpoolctl.threadpool_limits()`). This is especially useful 80 | to avoid searching through all loaded shared libraries each time. 81 | https://github.com/joblib/threadpoolctl/pull/95 82 | 83 | - Added support for OpenBLAS built for 64bit integers in Fortran. 84 | https://github.com/joblib/threadpoolctl/pull/101 85 | 86 | - Added the possibility to use `threadpoolctl.threadpool_limits` and 87 | `threadpooctl.ThreadpoolController` as decorators through their `wrap` method. 88 | https://github.com/joblib/threadpoolctl/pull/102 89 | 90 | - Fixed an attribute error when using old versions of OpenBLAS or BLIS that are 91 | missing version query functions. 92 | https://github.com/joblib/threadpoolctl/pull/88 93 | https://github.com/joblib/threadpoolctl/pull/91 94 | 95 | - Fixed an attribute error when python is run with -OO. 96 | https://github.com/joblib/threadpoolctl/pull/87 97 | 98 | 2.2.0 (2021-07-09) 99 | ================== 100 | 101 | - `threadpoolctl.threadpool_info()` now reports the architecture of the CPU 102 | cores detected by OpenBLAS (via `openblas_get_corename`) and BLIS (via 103 | `bli_arch_query_id` and `bli_arch_string`). 104 | 105 | - Fixed a bug when the version of MKL was not found. The 106 | "version" field is now set to None in that case. 107 | https://github.com/joblib/threadpoolctl/pull/82 108 | 109 | 2.1.0 (2020-05-29) 110 | ================== 111 | 112 | - New commandline interface: 113 | 114 | python -m threadpoolctl -i numpy 115 | 116 | will try to import the `numpy` package and then return the output of 117 | `threadpoolctl.threadpool_info()` on STDOUT formatted using the JSON 118 | syntax. This makes it easier to quickly introspect a Python environment. 119 | 120 | 121 | 2.0.0 (2019-12-05) 122 | ================== 123 | 124 | - Expose MKL, BLIS and OpenBLAS threading layer in information displayed by 125 | `threadpool_info`. This information is referenced in the `threading_layer` 126 | field. 127 | https://github.com/joblib/threadpoolctl/pull/48 128 | https://github.com/joblib/threadpoolctl/pull/60 129 | 130 | - When threadpoolctl finds libomp (LLVM OpenMP) and libiomp (Intel OpenMP) 131 | both loaded, a warning is raised to recall that using threadpoolctl with 132 | this mix of OpenMP libraries may cause crashes or deadlocks. 133 | https://github.com/joblib/threadpoolctl/pull/49 134 | 135 | 1.1.0 (2019-09-12) 136 | ================== 137 | 138 | - Detect libraries referenced by symlinks (e.g. BLAS libraries from 139 | conda-forge). 140 | https://github.com/joblib/threadpoolctl/pull/34 141 | 142 | - Add support for BLIS. 143 | https://github.com/joblib/threadpoolctl/pull/23 144 | 145 | - Breaking change: method `get_original_num_threads` on the `threadpool_limits` 146 | context manager to cheaply access the initial state of the runtime: 147 | - drop the `user_api` parameter; 148 | - instead return a dict `{user_api: num_threads}`; 149 | - fixed a bug when the limit parameter of `threadpool_limits` was set to 150 | `None`. 151 | 152 | https://github.com/joblib/threadpoolctl/pull/32 153 | 154 | 155 | 1.0.0 (2019-06-03) 156 | ================== 157 | 158 | Initial release. 159 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, threadpoolctl contributors 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of copyright holder nor the names of its contributors 12 | may be used to endorse or promote products derived from this software 13 | without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Thread-pool Controls [![Build Status](https://github.com/joblib/threadpoolctl/actions/workflows/test.yml/badge.svg?branch=master)](https://github.com/joblib/threadpoolctl/actions?query=branch%3Amaster) [![codecov](https://codecov.io/gh/joblib/threadpoolctl/branch/master/graph/badge.svg)](https://codecov.io/gh/joblib/threadpoolctl) 2 | 3 | Python helpers to limit the number of threads used in the 4 | threadpool-backed of common native libraries used for scientific 5 | computing and data science (e.g. BLAS and OpenMP). 6 | 7 | Fine control of the underlying thread-pool size can be useful in 8 | workloads that involve nested parallelism so as to mitigate 9 | oversubscription issues. 10 | 11 | ## Installation 12 | 13 | - For users, install the last published version from PyPI: 14 | 15 | ```bash 16 | pip install threadpoolctl 17 | ``` 18 | 19 | - For contributors, install from the source repository in developer 20 | mode: 21 | 22 | ```bash 23 | pip install -r dev-requirements.txt 24 | flit install --symlink 25 | ``` 26 | 27 | then you run the tests with pytest: 28 | 29 | ```bash 30 | pytest 31 | ``` 32 | 33 | ## Usage 34 | 35 | ### Command Line Interface 36 | 37 | Get a JSON description of thread-pools initialized when importing python 38 | packages such as numpy or scipy for instance: 39 | 40 | ``` 41 | python -m threadpoolctl -i numpy scipy.linalg 42 | [ 43 | { 44 | "filepath": "/home/ogrisel/miniconda3/envs/tmp/lib/libmkl_rt.so", 45 | "prefix": "libmkl_rt", 46 | "user_api": "blas", 47 | "internal_api": "mkl", 48 | "version": "2019.0.4", 49 | "num_threads": 2, 50 | "threading_layer": "intel" 51 | }, 52 | { 53 | "filepath": "/home/ogrisel/miniconda3/envs/tmp/lib/libiomp5.so", 54 | "prefix": "libiomp", 55 | "user_api": "openmp", 56 | "internal_api": "openmp", 57 | "version": null, 58 | "num_threads": 4 59 | } 60 | ] 61 | ``` 62 | 63 | The JSON information is written on STDOUT. If some of the packages are missing, 64 | a warning message is displayed on STDERR. 65 | 66 | ### Python Runtime Programmatic Introspection 67 | 68 | Introspect the current state of the threadpool-enabled runtime libraries 69 | that are loaded when importing Python packages: 70 | 71 | ```python 72 | >>> from threadpoolctl import threadpool_info 73 | >>> from pprint import pprint 74 | >>> pprint(threadpool_info()) 75 | [] 76 | 77 | >>> import numpy 78 | >>> pprint(threadpool_info()) 79 | [{'filepath': '/home/ogrisel/miniconda3/envs/tmp/lib/libmkl_rt.so', 80 | 'internal_api': 'mkl', 81 | 'num_threads': 2, 82 | 'prefix': 'libmkl_rt', 83 | 'threading_layer': 'intel', 84 | 'user_api': 'blas', 85 | 'version': '2019.0.4'}, 86 | {'filepath': '/home/ogrisel/miniconda3/envs/tmp/lib/libiomp5.so', 87 | 'internal_api': 'openmp', 88 | 'num_threads': 4, 89 | 'prefix': 'libiomp', 90 | 'user_api': 'openmp', 91 | 'version': None}] 92 | 93 | >>> import xgboost 94 | >>> pprint(threadpool_info()) 95 | [{'filepath': '/home/ogrisel/miniconda3/envs/tmp/lib/libmkl_rt.so', 96 | 'internal_api': 'mkl', 97 | 'num_threads': 2, 98 | 'prefix': 'libmkl_rt', 99 | 'threading_layer': 'intel', 100 | 'user_api': 'blas', 101 | 'version': '2019.0.4'}, 102 | {'filepath': '/home/ogrisel/miniconda3/envs/tmp/lib/libiomp5.so', 103 | 'internal_api': 'openmp', 104 | 'num_threads': 4, 105 | 'prefix': 'libiomp', 106 | 'user_api': 'openmp', 107 | 'version': None}, 108 | {'filepath': '/home/ogrisel/miniconda3/envs/tmp/lib/libgomp.so.1.0.0', 109 | 'internal_api': 'openmp', 110 | 'num_threads': 4, 111 | 'prefix': 'libgomp', 112 | 'user_api': 'openmp', 113 | 'version': None}] 114 | ``` 115 | 116 | In the above example, `numpy` was installed from the default anaconda channel and comes 117 | with MKL and its Intel OpenMP (`libiomp5`) implementation while `xgboost` was installed 118 | from pypi.org and links against GNU OpenMP (`libgomp`) so both OpenMP runtimes are 119 | loaded in the same Python program. 120 | 121 | The state of these libraries is also accessible through the object oriented API: 122 | 123 | ```python 124 | >>> from threadpoolctl import ThreadpoolController, threadpool_info 125 | >>> from pprint import pprint 126 | >>> import numpy 127 | >>> controller = ThreadpoolController() 128 | >>> pprint(controller.info()) 129 | [{'architecture': 'Haswell', 130 | 'filepath': '/home/jeremie/miniconda/envs/dev/lib/libopenblasp-r0.3.17.so', 131 | 'internal_api': 'openblas', 132 | 'num_threads': 4, 133 | 'prefix': 'libopenblas', 134 | 'threading_layer': 'pthreads', 135 | 'user_api': 'blas', 136 | 'version': '0.3.17'}] 137 | 138 | >>> controller.info() == threadpool_info() 139 | True 140 | ``` 141 | 142 | ### Setting the Maximum Size of Thread-Pools 143 | 144 | Control the number of threads used by the underlying runtime libraries 145 | in specific sections of your Python program: 146 | 147 | ```python 148 | >>> from threadpoolctl import threadpool_limits 149 | >>> import numpy as np 150 | 151 | >>> with threadpool_limits(limits=1, user_api='blas'): 152 | ... # In this block, calls to blas implementation (like openblas or MKL) 153 | ... # will be limited to use only one thread. They can thus be used jointly 154 | ... # with thread-parallelism. 155 | ... a = np.random.randn(1000, 1000) 156 | ... a_squared = a @ a 157 | ``` 158 | 159 | The threadpools can also be controlled via the object oriented API, which is especially 160 | useful to avoid searching through all the loaded shared libraries each time. It will 161 | however not act on libraries loaded after the instantiation of the 162 | `ThreadpoolController`: 163 | 164 | ```python 165 | >>> from threadpoolctl import ThreadpoolController 166 | >>> import numpy as np 167 | >>> controller = ThreadpoolController() 168 | 169 | >>> with controller.limit(limits=1, user_api='blas'): 170 | ... a = np.random.randn(1000, 1000) 171 | ... a_squared = a @ a 172 | ``` 173 | 174 | ### Restricting the limits to the scope of a function 175 | 176 | `threadpool_limits` and `ThreadpoolController` can also be used as decorators to set 177 | the maximum number of threads used by the supported libraries at a function level. The 178 | decorators are accessible through their `wrap` method: 179 | 180 | ```python 181 | >>> from threadpoolctl import ThreadpoolController, threadpool_limits 182 | >>> import numpy as np 183 | >>> controller = ThreadpoolController() 184 | 185 | >>> @controller.wrap(limits=1, user_api='blas') 186 | ... # or @threadpool_limits.wrap(limits=1, user_api='blas') 187 | ... def my_func(): 188 | ... # Inside this function, calls to blas implementation (like openblas or MKL) 189 | ... # will be limited to use only one thread. 190 | ... a = np.random.randn(1000, 1000) 191 | ... a_squared = a @ a 192 | ... 193 | ``` 194 | 195 | ### Switching the FlexiBLAS backend 196 | 197 | `FlexiBLAS` is a BLAS wrapper for which the BLAS backend can be switched at runtime. 198 | `threadpoolctl` exposes python bindings for this feature. Here's an example but note 199 | that this part of the API is experimental and subject to change without deprecation: 200 | 201 | ```python 202 | >>> from threadpoolctl import ThreadpoolController 203 | >>> import numpy as np 204 | >>> controller = ThreadpoolController() 205 | 206 | >>> controller.info() 207 | [{'user_api': 'blas', 208 | 'internal_api': 'flexiblas', 209 | 'num_threads': 1, 210 | 'prefix': 'libflexiblas', 211 | 'filepath': '/usr/local/lib/libflexiblas.so.3.3', 212 | 'version': '3.3.1', 213 | 'available_backends': ['NETLIB', 'OPENBLASPTHREAD', 'ATLAS'], 214 | 'loaded_backends': ['NETLIB'], 215 | 'current_backend': 'NETLIB'}] 216 | 217 | # Retrieve the flexiblas controller 218 | >>> flexiblas_ct = controller.select(internal_api="flexiblas").lib_controllers[0] 219 | 220 | # Switch the backend with one predefined at build time (listed in "available_backends") 221 | >>> flexiblas_ct.switch_backend("OPENBLASPTHREAD") 222 | >>> controller.info() 223 | [{'user_api': 'blas', 224 | 'internal_api': 'flexiblas', 225 | 'num_threads': 4, 226 | 'prefix': 'libflexiblas', 227 | 'filepath': '/usr/local/lib/libflexiblas.so.3.3', 228 | 'version': '3.3.1', 229 | 'available_backends': ['NETLIB', 'OPENBLASPTHREAD', 'ATLAS'], 230 | 'loaded_backends': ['NETLIB', 'OPENBLASPTHREAD'], 231 | 'current_backend': 'OPENBLASPTHREAD'}, 232 | {'user_api': 'blas', 233 | 'internal_api': 'openblas', 234 | 'num_threads': 4, 235 | 'prefix': 'libopenblas', 236 | 'filepath': '/usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so', 237 | 'version': '0.3.8', 238 | 'threading_layer': 'pthreads', 239 | 'architecture': 'Haswell'}] 240 | 241 | # It's also possible to directly give the path to a shared library 242 | >>> flexiblas_controller.switch_backend("/home/jeremie/miniforge/envs/flexiblas_threadpoolctl/lib/libmkl_rt.so") 243 | >>> controller.info() 244 | [{'user_api': 'blas', 245 | 'internal_api': 'flexiblas', 246 | 'num_threads': 2, 247 | 'prefix': 'libflexiblas', 248 | 'filepath': '/usr/local/lib/libflexiblas.so.3.3', 249 | 'version': '3.3.1', 250 | 'available_backends': ['NETLIB', 'OPENBLASPTHREAD', 'ATLAS'], 251 | 'loaded_backends': ['NETLIB', 252 | 'OPENBLASPTHREAD', 253 | '/home/jeremie/miniforge/envs/flexiblas_threadpoolctl/lib/libmkl_rt.so'], 254 | 'current_backend': '/home/jeremie/miniforge/envs/flexiblas_threadpoolctl/lib/libmkl_rt.so'}, 255 | {'user_api': 'openmp', 256 | 'internal_api': 'openmp', 257 | 'num_threads': 4, 258 | 'prefix': 'libomp', 259 | 'filepath': '/home/jeremie/miniforge/envs/flexiblas_threadpoolctl/lib/libomp.so', 260 | 'version': None}, 261 | {'user_api': 'blas', 262 | 'internal_api': 'openblas', 263 | 'num_threads': 4, 264 | 'prefix': 'libopenblas', 265 | 'filepath': '/usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so', 266 | 'version': '0.3.8', 267 | 'threading_layer': 'pthreads', 268 | 'architecture': 'Haswell'}, 269 | {'user_api': 'blas', 270 | 'internal_api': 'mkl', 271 | 'num_threads': 2, 272 | 'prefix': 'libmkl_rt', 273 | 'filepath': '/home/jeremie/miniforge/envs/flexiblas_threadpoolctl/lib/libmkl_rt.so.2', 274 | 'version': '2024.0-Product', 275 | 'threading_layer': 'gnu'}] 276 | ``` 277 | 278 | You can observe that the previously linked OpenBLAS shared object stays loaded by 279 | the Python program indefinitely, but FlexiBLAS itself no longer delegates BLAS calls 280 | to OpenBLAS as indicated by the `current_backend` attribute. 281 | ### Writing a custom library controller 282 | 283 | Currently, `threadpoolctl` has support for `OpenMP` and the main `BLAS` libraries. 284 | However it can also be used to control the threadpool of other native libraries, 285 | provided that they expose an API to get and set the limit on the number of threads. 286 | For that, one must implement a controller for this library and register it to 287 | `threadpoolctl`. 288 | 289 | A custom controller must be a subclass of the `LibController` class and implement 290 | the attributes and methods described in the docstring of `LibController`. Then this 291 | new controller class must be registered using the `threadpoolctl.register` function. 292 | An complete example can be found [here]( 293 | https://github.com/joblib/threadpoolctl/blob/master/tests/_pyMylib/__init__.py). 294 | 295 | ### Sequential BLAS within OpenMP parallel region 296 | 297 | When one wants to have sequential BLAS calls within an OpenMP parallel region, it's 298 | safer to set `limits="sequential_blas_under_openmp"` since setting `limits=1` and 299 | `user_api="blas"` might not lead to the expected behavior in some configurations 300 | (e.g. OpenBLAS with the OpenMP threading layer 301 | https://github.com/xianyi/OpenBLAS/issues/2985). 302 | 303 | ### Known Limitations 304 | 305 | - `threadpool_limits` can fail to limit the number of inner threads when nesting 306 | parallel loops managed by distinct OpenMP runtime implementations (for instance 307 | libgomp from GCC and libomp from clang/llvm or libiomp from ICC). 308 | 309 | See the `test_openmp_nesting` function in [tests/test_threadpoolctl.py]( 310 | https://github.com/joblib/threadpoolctl/blob/master/tests/test_threadpoolctl.py) 311 | for an example. More information can be found at: 312 | https://github.com/jeremiedbb/Nested_OpenMP 313 | 314 | Note however that this problem does not happen when `threadpool_limits` is 315 | used to limit the number of threads used internally by BLAS calls that are 316 | themselves nested under OpenMP parallel loops. `threadpool_limits` works as 317 | expected, even if the inner BLAS implementation relies on a distinct OpenMP 318 | implementation. 319 | 320 | - Using Intel OpenMP (ICC) and LLVM OpenMP (clang) in the same Python program 321 | under Linux is known to cause problems. See the following guide for more details 322 | and workarounds: 323 | https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md 324 | 325 | - Setting the maximum number of threads of the OpenMP and BLAS libraries has a global 326 | effect and impacts the whole Python process. There is no thread level isolation as 327 | these libraries do not offer thread-local APIs to configure the number of threads to 328 | use in nested parallel calls. 329 | 330 | 331 | ## Maintainers 332 | 333 | To make a release: 334 | 335 | - Create a PR to bump the version number (`__version__`) in `threadpoolctl.py` and 336 | update the release date in `CHANGES.md`. 337 | 338 | - Merge the PR and check that the `Publish threadpoolctl distribution to TestPyPI` job 339 | of the `publish-to-pypi.yml` workflow successfully uploaded the wheel and source 340 | distribution to Test PyPI. 341 | 342 | - If everything is fine create a tag for the release and push it to github: 343 | 344 | ```bash 345 | git tag -a X.Y.Z 346 | git push git@github.com:joblib/threadpoolctl.git X.Y.Z 347 | ``` 348 | 349 | - Check that the `Publish threadpoolctl distribution to PyPI` job of the 350 | `publish-to-pypi.yml` workflow successfully uploaded the wheel and source distribution 351 | to PyPI this time. 352 | 353 | - Create a PR for the release on the [conda-forge feedstock](https://github.com/conda-forge/threadpoolctl-feedstock) (or wait for the bot to make it). 354 | 355 | - Publish the release on github. 356 | 357 | If for some reason the steps above can't be achieved and a munual upload of the wheel 358 | and source distribution is needed: 359 | 360 | - Build the distribution archives: 361 | 362 | ```bash 363 | pip install flit 364 | flit build 365 | ``` 366 | 367 | - Upload the wheels and source distribution to PyPI using flit. Since PyPI doesn't 368 | allow password authentication anymore, the username needs to be changed to the 369 | generic name `__token__`: 370 | 371 | ```bash 372 | FLIT_USERNAME=__token__ flit publish 373 | ``` 374 | 375 | and a PyPI token has to be passed in place of the password. 376 | 377 | ### Credits 378 | 379 | The initial dynamic library introspection code was written by @anton-malakhov 380 | for the smp package available at https://github.com/IntelPython/smp . 381 | 382 | threadpoolctl extends this for other operating systems. Contrary to smp, 383 | threadpoolctl does not attempt to limit the size of Python multiprocessing 384 | pools (threads or processes) or set operating system-level CPU affinity 385 | constraints: threadpoolctl only interacts with native libraries via their 386 | public runtime APIs. 387 | -------------------------------------------------------------------------------- /benchmarks/bench_context_manager_overhead.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | from pprint import pprint 4 | from statistics import mean, stdev 5 | from threadpoolctl import threadpool_info, threadpool_limits 6 | 7 | parser = ArgumentParser(description="Measure threadpool_limits call overhead.") 8 | parser.add_argument( 9 | "--import", 10 | dest="packages", 11 | default=[], 12 | nargs="+", 13 | help="Python packages to import to load threadpool enabled libraries.", 14 | ) 15 | parser.add_argument("--n-calls", type=int, default=100, help="Number of iterations") 16 | 17 | args = parser.parse_args() 18 | for package_name in args.packages: 19 | __import__(package_name) 20 | 21 | pprint(threadpool_info()) 22 | 23 | timings = [] 24 | for _ in range(args.n_calls): 25 | t = time.time() 26 | with threadpool_limits(limits=1): 27 | pass 28 | timings.append(time.time() - t) 29 | 30 | print(f"Overhead per call: {mean(timings) * 1e3:.3f} +/-{stdev(timings) * 1e3:.3f} ms") 31 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | collect_ignore = ["tests/_openmp_test_helper"] 2 | -------------------------------------------------------------------------------- /continuous_integration/build_test_ext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 6 | pushd tests/_pyMylib 7 | rm -rf *.so *.o 8 | gcc -c -Wall -Werror -fpic -o my_threaded_lib.o my_threaded_lib.c 9 | gcc -shared -o my_threaded_lib.so my_threaded_lib.o 10 | popd 11 | fi 12 | 13 | pushd tests/_openmp_test_helper 14 | rm -rf *.c *.so *.dylib build/ 15 | python setup_inner.py build_ext -i 16 | python setup_outer.py build_ext -i 17 | 18 | # skip scipy required extension if no numpy 19 | if [[ "$NO_NUMPY" != "true" ]]; then 20 | python setup_nested_prange_blas.py build_ext -i 21 | fi 22 | popd 23 | -------------------------------------------------------------------------------- /continuous_integration/check_no_test_skipped.py: -------------------------------------------------------------------------------- 1 | """Check tests are not skipped in every ci job""" 2 | 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import xml.etree.ElementTree as ET 8 | 9 | base_dir = sys.argv[1] 10 | 11 | # dict {test: result} where result is False if the test was skipped in every 12 | # job and True otherwise. 13 | always_skipped = {} 14 | 15 | for name in os.listdir(base_dir): 16 | # all test result files are in base_dir/test_result_*/ dirs 17 | if name.startswith("test_result_"): 18 | print(f"> processing test result from job {name.replace('test_result_', '')}") 19 | print(" > tests skipped:") 20 | result_file = os.path.join(base_dir, name, "test_result.xml") 21 | root = ET.parse(result_file).getroot() 22 | 23 | # All tests are identified by the xml tag testcase. 24 | for test in root.iter("testcase"): 25 | test_name = test.attrib["name"] 26 | skipped = any(child.tag == "skipped" for child in test) 27 | if skipped: 28 | print(" -", test_name) 29 | if test_name in always_skipped: 30 | always_skipped[test_name] &= skipped 31 | else: 32 | always_skipped[test_name] = skipped 33 | 34 | print("\n------------------------------------------------------------------\n") 35 | 36 | # List of tests that we don't want to fail the CI if they are skipped in 37 | # every job. This is useful for tests that depend on specific versions of 38 | # numpy or scipy and we don't want to pin old versions of these libraries. 39 | SAFE_SKIPPED_TESTS = ["test_multiple_shipped_openblas"] 40 | 41 | fail = False 42 | for test, skipped in always_skipped.items(): 43 | if skipped: 44 | if test in SAFE_SKIPPED_TESTS: 45 | print(test, "was skipped in every job but it's fine to skip it") 46 | else: 47 | fail = True 48 | print(test, "was skipped in every job") 49 | 50 | if fail: 51 | sys.exit(1) 52 | -------------------------------------------------------------------------------- /continuous_integration/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # License: BSD 3-Clause 4 | 5 | set -xe 6 | 7 | UNAMESTR=`uname` 8 | 9 | 10 | # Install a recent version of clang and libomp if needed 11 | # Only applicable to linux jobs 12 | if [[ "$CC_OUTER_LOOP" == "clang-18" ]] || \ 13 | [[ "$CC_INNER_LOOP" == "clang-18" ]] || \ 14 | [[ "$BLIS_CC" == "clang-18" ]] 15 | then 16 | wget https://apt.llvm.org/llvm.sh 17 | chmod +x llvm.sh 18 | sudo ./llvm.sh 18 19 | sudo apt-get install libomp-dev 20 | fi 21 | 22 | 23 | make_conda() { 24 | CHANNEL="$1" 25 | TO_INSTALL="$2" 26 | if [[ "$UNAMESTR" == "Darwin" ]]; then 27 | if [[ "$INSTALL_LIBOMP" == "conda-forge" ]]; then 28 | # Install an OpenMP-enabled clang/llvm from conda-forge 29 | # assumes conda-forge is set on priority channel 30 | TO_INSTALL="$TO_INSTALL compilers llvm-openmp" 31 | 32 | elif [[ "$INSTALL_LIBOMP" == "homebrew" ]]; then 33 | # Install a compiler with a working openmp 34 | HOMEBREW_NO_AUTO_UPDATE=1 brew install libomp 35 | 36 | # enable OpenMP support for Apple-clang 37 | export CC=/usr/bin/clang 38 | export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" 39 | export CFLAGS="$CFLAGS -I/opt/homebrew/opt/libomp/include" 40 | export LDFLAGS="$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp" 41 | fi 42 | fi 43 | 44 | if [[ "$PYTHON_VERSION" == "*" ]]; then 45 | # Avoid installing free-threaded python 46 | TO_INSTALL="$TO_INSTALL python-gil" 47 | fi 48 | 49 | # prevent mixing conda channels 50 | conda config --set channel_priority strict 51 | conda config --add channels $CHANNEL 52 | 53 | conda update -n base conda conda-libmamba-solver -q --yes 54 | conda config --set solver libmamba 55 | 56 | conda create -n testenv -q --yes python=$PYTHON_VERSION $TO_INSTALL 57 | conda activate testenv 58 | } 59 | 60 | 61 | if [[ "$PACKAGER" == "conda" ]]; then 62 | TO_INSTALL="" 63 | if [[ "$NO_NUMPY" != "true" ]]; then 64 | TO_INSTALL="$TO_INSTALL numpy scipy" 65 | if [[ -n "$BLAS" ]]; then 66 | TO_INSTALL="$TO_INSTALL blas=*=$BLAS" 67 | fi 68 | fi 69 | make_conda "defaults" "$TO_INSTALL" 70 | 71 | elif [[ "$PACKAGER" == "conda-forge" ]]; then 72 | TO_INSTALL="numpy scipy blas=*=$BLAS" 73 | if [[ "$BLAS" == "openblas" && "$OPENBLAS_THREADING_LAYER" == "openmp" ]]; then 74 | TO_INSTALL="$TO_INSTALL libopenblas=*=*openmp*" 75 | fi 76 | make_conda "conda-forge" "$TO_INSTALL" 77 | 78 | elif [[ "$PACKAGER" == "pip" ]]; then 79 | # Use conda to build an empty python env and then use pip to install 80 | # numpy and scipy 81 | make_conda "conda-forge" "" 82 | if [[ "$NO_NUMPY" != "true" ]]; then 83 | pip install numpy scipy 84 | fi 85 | 86 | elif [[ "$PACKAGER" == "pip-dev" ]]; then 87 | # Use conda to build an empty python env and then use pip to install 88 | # numpy and scipy dev versions 89 | make_conda "conda-forge" "" 90 | 91 | dev_anaconda_url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 92 | pip install --pre --upgrade --timeout=60 --extra-index $dev_anaconda_url numpy scipy 93 | 94 | elif [[ "$PACKAGER" == "ubuntu" ]]; then 95 | # Remove the ubuntu toolchain PPA that seems to be invalid: 96 | # https://github.com/scikit-learn/scikit-learn/pull/13934 97 | sudo add-apt-repository --remove ppa:ubuntu-toolchain-r/test 98 | sudo apt-get update 99 | sudo apt-get install python3-scipy python3-virtualenv $APT_BLAS 100 | python3 -m virtualenv --system-site-packages --python=python3 testenv 101 | source testenv/bin/activate 102 | 103 | elif [[ "$INSTALL_BLAS" == "blis" ]]; then 104 | TO_INSTALL="cython meson-python pkg-config" 105 | make_conda "conda-forge" "$TO_INSTALL" 106 | source ./continuous_integration/install_blis.sh 107 | 108 | elif [[ "$INSTALL_BLAS" == "flexiblas" ]]; then 109 | TO_INSTALL="cython openblas $PLATFORM_SPECIFIC_PACKAGES meson-python pkg-config compilers" 110 | make_conda "conda-forge" "$TO_INSTALL" 111 | source ./continuous_integration/install_flexiblas.sh 112 | 113 | fi 114 | 115 | python -m pip install -v -q -r dev-requirements.txt 116 | bash ./continuous_integration/build_test_ext.sh 117 | 118 | # Check which BLAS is linked (only available on linux) 119 | if [[ "$UNAMESTR" == "Linux" && "$NO_NUMPY" != "true" ]]; then 120 | ldd tests/_openmp_test_helper/nested_prange_blas.cpython*.so 121 | fi 122 | 123 | python --version 124 | python -c "import numpy; print(f'numpy {numpy.__version__}')" || echo "no numpy" 125 | python -c "import scipy; print(f'scipy {scipy.__version__}')" || echo "no scipy" 126 | 127 | python -m flit install --symlink 128 | -------------------------------------------------------------------------------- /continuous_integration/install_blis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | 6 | # Install gcc 12 to build BLIS 7 | if [[ "$BLIS_CC" == "gcc-12" ]]; then 8 | sudo apt install gcc-12 9 | fi 10 | 11 | # step outside of threadpoolctl directory 12 | pushd .. 13 | ABS_PATH=$(pwd) 14 | 15 | # build & install blis 16 | mkdir BLIS_install 17 | git clone https://github.com/flame/blis.git 18 | pushd blis 19 | 20 | ./configure --prefix=$ABS_PATH/BLIS_install --enable-cblas --enable-threading=$BLIS_ENABLE_THREADING CC=$BLIS_CC auto 21 | make -j4 22 | make install 23 | popd 24 | 25 | # build & install numpy 26 | git clone https://github.com/numpy/numpy.git 27 | pushd numpy 28 | git submodule update --init 29 | 30 | echo "libdir=$ABS_PATH/BLIS_install/lib/ 31 | includedir=$ABS_PATH/BLIS_install/include/blis/ 32 | version=latest 33 | extralib=-lm -lpthread -lgfortran 34 | Name: blis 35 | Description: BLIS 36 | Version: \${version} 37 | Libs: -L\${libdir} -lblis 38 | Libs.private: \${extralib} 39 | Cflags: -I\${includedir}" > blis.pc 40 | 41 | PKG_CONFIG_PATH=$ABS_PATH/numpy/ pip install . -v --no-build-isolation -Csetup-args=-Dblas=blis 42 | 43 | export CFLAGS=-I$ABS_PATH/BLIS_install/include/blis 44 | export LDFLAGS="-L$ABS_PATH/BLIS_install/lib -Wl,-rpath,$ABS_PATH/BLIS_install/lib" 45 | 46 | popd 47 | 48 | # back to threadpoolctl directory 49 | popd 50 | -------------------------------------------------------------------------------- /continuous_integration/install_flexiblas.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | # step outside of threadpoolctl directory 6 | pushd .. 7 | ABS_PATH=$(pwd) 8 | 9 | # build & install FlexiBLAS 10 | mkdir flexiblas_install 11 | git clone https://github.com/mpimd-csc/flexiblas.git 12 | pushd flexiblas 13 | 14 | # Temporary ping Flexiblas commit to avoid openmp symbols not found at link time 15 | git checkout v3.4.2 16 | 17 | mkdir build 18 | pushd build 19 | 20 | EXTENSION=".so" 21 | if [[ $(uname) == "Darwin" ]]; then 22 | EXTENSION=".dylib" 23 | fi 24 | 25 | # We intentionally restrict the list of backends to make it easier to 26 | # write platform agnostic tests. In particular, we do not detect OS 27 | # provided backends such as macOS' Apple/Accelerate/vecLib nor plaftorm 28 | # specific BLAS implementations such as MKL that cannot be installed on 29 | # arm64 hardware. 30 | cmake ../ -DCMAKE_INSTALL_PREFIX=$ABS_PATH"/flexiblas_install" \ 31 | -DBLAS_AUTO_DETECT="OFF" \ 32 | -DEXTRA="OPENBLAS_CONDA" \ 33 | -DFLEXIBLAS_DEFAULT="OPENBLAS_CONDA" \ 34 | -DOPENBLAS_CONDA_LIBRARY=$CONDA_PREFIX"/lib/libopenblas"$EXTENSION \ 35 | make 36 | make install 37 | 38 | # Check that all 3 BLAS are listed in FlexiBLAS configuration 39 | $ABS_PATH/flexiblas_install/bin/flexiblas list 40 | popd 41 | popd 42 | 43 | # build & install numpy 44 | git clone https://github.com/numpy/numpy.git 45 | pushd numpy 46 | git submodule update --init 47 | 48 | echo "libdir=$ABS_PATH/flexiblas_install/lib/ 49 | includedir=$ABS_PATH/flexiblas_install/include/flexiblas/ 50 | version=3.3.1 51 | extralib=-lm -lpthread -lgfortran 52 | Name: flexiblas 53 | Description: FlexiBLAS - a BLAS wrapper 54 | Version: \${version} 55 | Libs: -L\${libdir} -lflexiblas 56 | Libs.private: \${extralib} 57 | Cflags: -I\${includedir}" > flexiblas.pc 58 | 59 | PKG_CONFIG_PATH=$ABS_PATH/numpy/ pip install . -v --no-build-isolation -Csetup-args=-Dblas=flexiblas -Csetup-args=-Dlapack=flexiblas 60 | 61 | export CFLAGS=-I$ABS_PATH/flexiblas_install/include/flexiblas \ 62 | export LDFLAGS="-L$ABS_PATH/flexiblas_install/lib -Wl,-rpath,$ABS_PATH/flexiblas_install/lib" \ 63 | 64 | popd 65 | 66 | # back to threadpoolctl directory 67 | popd 68 | -------------------------------------------------------------------------------- /continuous_integration/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xe 4 | 5 | if [[ "$PACKAGER" == conda* ]] || [[ -z "$PACKAGER" ]]; then 6 | conda activate testenv 7 | conda list 8 | elif [[ "$PACKAGER" == pip* ]]; then 9 | # we actually use conda to install the base environment: 10 | conda activate testenv 11 | pip list 12 | elif [[ "$PACKAGER" == "ubuntu" ]]; then 13 | source testenv/bin/activate 14 | pip list 15 | fi 16 | 17 | # Use the CLI to display the effective runtime environment prior to 18 | # launching the tests: 19 | python -m threadpoolctl -i numpy scipy.linalg tests._openmp_test_helper.openmp_helpers_inner 20 | 21 | pytest -vlrxXs -W error -k "$TESTS" --junitxml=test_result.xml --cov=threadpoolctl --cov-report xml 22 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | flit 2 | coverage 3 | pytest 4 | pytest-cov 5 | cython 6 | setuptools 7 | -------------------------------------------------------------------------------- /multiple_openmp.md: -------------------------------------------------------------------------------- 1 | # Multiple OpenMP Runtimes 2 | 3 | ## Context 4 | 5 | OpenMP is an API specification for parallel programming. There are many 6 | implementations of it, tied to a compiler most of the time: 7 | 8 | - `libgomp` for GCC (GNU C/C++ Compiler), 9 | - `libomp` for Clang (LLVM C/C++ Compiler), 10 | - `libiomp` for ICC (Intel C/C++ Compiler), 11 | - `vcomp` for MSVC (Microsoft Visual Studio C/C++ Compiler). 12 | 13 | In general, it is not advised to have different OpenMP runtime libraries (or 14 | even different copies of the same library) loaded at the same time in a 15 | program. It's considered an undefined behavior. Fortunately it is not as bad as 16 | it sounds in most situations. 17 | 18 | However this situation is frequent in the Python ecosystem since you can 19 | install packages compiled with different compilers (hence linked to different 20 | OpenMP implementations) and import them together in a Python program. 21 | 22 | A typical example is installing NumPy from Anaconda which is linked against MKL 23 | (Intel's math library) and another package that uses multi-threading with OpenMP 24 | directly in a compiled extension, as is the case in Scikit-learn (via Cython 25 | `prange`), LightGBM and XGBoost (via pragmas in the C++ source code). 26 | 27 | From our experience, **most OpenMP libraries can seamlessly coexist in a same 28 | program**. For instance, on Linux, we never observed any issue between 29 | `libgomp` and `libiomp`, which is the most common mix (NumPy with MKL + a 30 | package compiled with GCC, the most widely used C compiler on that platform). 31 | 32 | ## Incompatibility between Intel OpenMP and LLVM OpenMP 33 | 34 | The only unrecoverable incompatibility we encountered happens when loading a 35 | mix of compiled extensions linked with **`libomp` (LLVM/Clang) and `libiomp` 36 | (ICC), on Linux and macOS**, manifested by crashes or deadlocks. It can happen 37 | even with the simplest OpenMP calls like getting the maximum number of threads 38 | that will be used in a subsequent parallel region. A possible explanation is that 39 | `libomp` is actually a fork of `libiomp` causing name colliding for instance. 40 | Using `threadpoolctl` may crash your program in such a setting. 41 | 42 | **Fortunately this problem is very rare**: at the time of writing, all major 43 | binary distributions of Python packages for Linux use either GCC or ICC to 44 | build the Python scientific packages. Therefore this problem would only happen 45 | if some packagers decide to start shipping Python packages built with 46 | LLVM/Clang instead of GCC (this is the case for instance with conda's default channel). 47 | 48 | ## Workarounds for Intel OpenMP and LLVM OpenMP case 49 | 50 | As far as we know, the only workaround consists in making sure only of one of 51 | the two incompatible OpenMP libraries is loaded. For example: 52 | 53 | - Tell MKL (used by NumPy) to use another threading implementation instead of the Intel 54 | OpenMP runtime. It can be the GNU OpenMP runtime on Linux or TBB on Linux and MacOS 55 | for instance. This is done by setting the following environment variable: 56 | 57 | export MKL_THREADING_LAYER=GNU 58 | 59 | or, if TBB is installed: 60 | 61 | export MKL_THREADING_LAYER=TBB 62 | 63 | - Install a build of NumPy and SciPy linked against OpenBLAS instead of MKL. 64 | This can be done for instance by installing NumPy and SciPy from PyPI: 65 | 66 | pip install numpy scipy 67 | 68 | from the conda-forge conda channel: 69 | 70 | conda install -c conda-forge numpy scipy 71 | 72 | or from the default conda channel: 73 | 74 | conda install numpy scipy blas[build=openblas] 75 | 76 | - Re-build your OpenMP-enabled extensions from source with GCC (or ICC) instead 77 | of Clang if you want to keep on using NumPy/SciPy linked against MKL with the 78 | default `libiomp`-based threading layer. 79 | 80 | ## References 81 | 82 | The above incompatibility has been reported upstream to the LLVM and Intel 83 | developers on the following public issue trackers/forums along with a minimal 84 | reproducer written in C: 85 | 86 | - https://bugs.llvm.org/show_bug.cgi?id=43565 87 | - https://community.intel.com/t5/Intel-C-Compiler/Cannot-call-OpenMP-functions-from-libiomp-after-calling-from/m-p/1176406 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | # TODO: replace the following section by the standard [project] section to be able 6 | # to use flit v4. 7 | [tool.flit.metadata] 8 | module = "threadpoolctl" 9 | author = "Thomas Moreau" 10 | author-email = "thomas.moreau.2010@gmail.com" 11 | home-page = "https://github.com/joblib/threadpoolctl" 12 | description-file = "README.md" 13 | requires-python = ">=3.9" 14 | license = "BSD-3-Clause" 15 | classifiers = [ 16 | "Intended Audience :: Developers", 17 | "License :: OSI Approved :: BSD License", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: 3.13", 24 | "Topic :: Software Development :: Libraries :: Python Modules", 25 | ] 26 | 27 | [tool.black] 28 | line-length = 88 29 | target_version = ['py39', 'py310', 'py311', 'py312', 'py313'] 30 | preview = true 31 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joblib/threadpoolctl/1b35392f49e6486567165808e3a4d72d14b07940/tests/__init__.py -------------------------------------------------------------------------------- /tests/_openmp_test_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joblib/threadpoolctl/1b35392f49e6486567165808e3a4d72d14b07940/tests/_openmp_test_helper/__init__.py -------------------------------------------------------------------------------- /tests/_openmp_test_helper/build_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def set_cc_variables(var_name="CC"): 6 | cc_var = os.environ.get(var_name) 7 | if cc_var is not None: 8 | os.environ["CC"] = cc_var 9 | if sys.platform == "darwin": 10 | os.environ["LDSHARED"] = cc_var + " -bundle -undefined dynamic_lookup" 11 | else: 12 | os.environ["LDSHARED"] = cc_var + " -shared" 13 | 14 | return cc_var 15 | 16 | 17 | def get_openmp_flag(): 18 | if sys.platform == "win32": 19 | return ["/openmp"] 20 | elif sys.platform == "darwin" and "openmp" in os.getenv("CPPFLAGS", ""): 21 | return [] 22 | return ["-fopenmp"] 23 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/nested_prange_blas.pyx: -------------------------------------------------------------------------------- 1 | cimport openmp 2 | from cython.parallel import parallel, prange 3 | 4 | import numpy as np 5 | from scipy.linalg.cython_blas cimport dgemm 6 | 7 | from threadpoolctl import ThreadpoolController 8 | 9 | 10 | def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): 11 | """Run multithreaded BLAS calls within OpenMP parallel loop""" 12 | cdef: 13 | int m = A.shape[0] 14 | int n = B.shape[0] 15 | int k = A.shape[1] 16 | 17 | double[:, ::1] C = np.empty((m, n)) 18 | int n_chunks = 100 19 | int chunk_size = A.shape[0] // n_chunks 20 | 21 | char* trans = 't' 22 | char* no_trans = 'n' 23 | double alpha = 1.0 24 | double beta = 0.0 25 | 26 | int i 27 | int prange_num_threads 28 | int *prange_num_threads_ptr = &prange_num_threads 29 | 30 | inner_info = [None] 31 | 32 | with nogil, parallel(num_threads=nthreads): 33 | if openmp.omp_get_thread_num() == 0: 34 | with gil: 35 | inner_info[0] = ThreadpoolController().info() 36 | 37 | prange_num_threads_ptr[0] = openmp.omp_get_num_threads() 38 | 39 | for i in prange(n_chunks): 40 | dgemm(trans, no_trans, &n, &chunk_size, &k, 41 | &alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k, 42 | &beta, &C[i * chunk_size, 0], &n) 43 | 44 | return np.asarray(C), prange_num_threads, inner_info[0] 45 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/nested_prange_blas_custom.pyx: -------------------------------------------------------------------------------- 1 | cimport openmp 2 | from cython.parallel import parallel, prange 3 | 4 | import numpy as np 5 | 6 | cdef extern from 'cblas.h' nogil: 7 | ctypedef enum CBLAS_ORDER: 8 | CblasRowMajor=101 9 | CblasColMajor=102 10 | ctypedef enum CBLAS_TRANSPOSE: 11 | CblasNoTrans=111 12 | CblasTrans=112 13 | CblasConjTrans=113 14 | void dgemm 'cblas_dgemm' ( 15 | CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, 16 | CBLAS_TRANSPOSE TransB, int M, int N, 17 | int K, double alpha, double *A, int lda, 18 | double *B, int ldb, double beta, double *C, int ldc) 19 | 20 | from threadpoolctl import ThreadpoolController 21 | 22 | 23 | def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): 24 | """Run multithreaded BLAS calls within OpenMP parallel loop""" 25 | cdef: 26 | int m = A.shape[0] 27 | int n = B.shape[0] 28 | int k = A.shape[1] 29 | 30 | double[:, ::1] C = np.empty((m, n)) 31 | int n_chunks = 100 32 | int chunk_size = A.shape[0] // n_chunks 33 | 34 | double alpha = 1.0 35 | double beta = 0.0 36 | 37 | int i 38 | int prange_num_threads 39 | int *prange_num_threads_ptr = &prange_num_threads 40 | 41 | inner_info = [None] 42 | 43 | with nogil, parallel(num_threads=nthreads): 44 | if openmp.omp_get_thread_num() == 0: 45 | with gil: 46 | inner_info[0] = ThreadpoolController().info() 47 | 48 | prange_num_threads_ptr[0] = openmp.omp_get_num_threads() 49 | 50 | for i in prange(n_chunks): 51 | dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, 52 | chunk_size, n, k, alpha, &A[i * chunk_size, 0], k, 53 | &B[0, 0], k, beta, &C[i * chunk_size, 0], n) 54 | 55 | return np.asarray(C), prange_num_threads, inner_info[0] 56 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/openmp_helpers_inner.pxd: -------------------------------------------------------------------------------- 1 | cdef int inner_openmp_loop(int) noexcept nogil 2 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/openmp_helpers_inner.pyx: -------------------------------------------------------------------------------- 1 | cimport openmp 2 | from cython.parallel import prange 3 | 4 | 5 | def check_openmp_num_threads(int n): 6 | """Run a short parallel section with OpenMP 7 | 8 | Return the number of threads that where effectively used by the 9 | OpenMP runtime. 10 | """ 11 | cdef int num_threads = -1 12 | 13 | with nogil: 14 | num_threads = inner_openmp_loop(n) 15 | return num_threads 16 | 17 | 18 | cdef int inner_openmp_loop(int n) noexcept nogil: 19 | """Run a short parallel section with OpenMP 20 | 21 | Return the number of threads that where effectively used by the 22 | OpenMP runtime. 23 | 24 | This function is expected to run without the GIL and can be called 25 | by an outer OpenMP / prange loop written in Cython in another file. 26 | """ 27 | cdef long n_sum = 0 28 | cdef int i, num_threads 29 | 30 | for i in prange(n): 31 | num_threads = openmp.omp_get_num_threads() 32 | n_sum += i 33 | 34 | if n_sum != (n - 1) * n / 2: 35 | # error 36 | return -1 37 | 38 | return num_threads 39 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/openmp_helpers_outer.pyx: -------------------------------------------------------------------------------- 1 | cimport openmp 2 | from cython.parallel import prange 3 | from openmp_helpers_inner cimport inner_openmp_loop 4 | 5 | 6 | def check_nested_openmp_loops(int n, nthreads=None): 7 | """Run a short parallel section with OpenMP with nested calls 8 | 9 | The inner OpenMP loop has not necessarily been built/linked with the 10 | same runtime OpenMP runtime. 11 | """ 12 | cdef: 13 | int outer_num_threads = -1 14 | int inner_num_threads = -1 15 | int num_threads = nthreads or openmp.omp_get_max_threads() 16 | int i 17 | 18 | for i in prange(n, num_threads=num_threads, nogil=True): 19 | inner_num_threads = inner_openmp_loop(n) 20 | outer_num_threads = openmp.omp_get_num_threads() 21 | 22 | return outer_num_threads, inner_num_threads 23 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/setup_inner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import Extension, setup 3 | from Cython.Build import cythonize 4 | 5 | from build_utils import set_cc_variables 6 | from build_utils import get_openmp_flag 7 | 8 | original_environ = os.environ.copy() 9 | try: 10 | # Make it possible to compile the 2 OpenMP enabled Cython extensions 11 | # with different compilers and therefore different OpenMP runtimes. 12 | inner_loop_cc_var = set_cc_variables("CC_INNER_LOOP") 13 | openmp_flag = get_openmp_flag() 14 | 15 | ext_modules = [ 16 | Extension( 17 | "openmp_helpers_inner", 18 | ["openmp_helpers_inner.pyx"], 19 | extra_compile_args=openmp_flag, 20 | extra_link_args=openmp_flag, 21 | ) 22 | ] 23 | 24 | setup( 25 | name="_openmp_test_helper_inner", 26 | ext_modules=cythonize( 27 | ext_modules, 28 | compiler_directives={ 29 | "language_level": 3, 30 | "boundscheck": False, 31 | "wraparound": False, 32 | }, 33 | compile_time_env={"CC_INNER_LOOP": inner_loop_cc_var or "unknown"}, 34 | ), 35 | ) 36 | 37 | finally: 38 | os.environ.update(original_environ) 39 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/setup_nested_prange_blas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import Extension, setup 3 | from Cython.Build import cythonize 4 | 5 | from build_utils import set_cc_variables 6 | from build_utils import get_openmp_flag 7 | 8 | original_environ = os.environ.copy() 9 | try: 10 | set_cc_variables("CC_OUTER_LOOP") 11 | openmp_flag = get_openmp_flag() 12 | 13 | use_custom_blas = os.getenv("INSTALL_BLAS", False) 14 | libraries = [use_custom_blas] if use_custom_blas else [] 15 | custom_suffix = "_custom" if use_custom_blas else "" 16 | filename = f"nested_prange_blas{custom_suffix}.pyx" 17 | 18 | ext_modules = [ 19 | Extension( 20 | "nested_prange_blas", 21 | [filename], 22 | extra_compile_args=openmp_flag, 23 | extra_link_args=openmp_flag, 24 | libraries=libraries, 25 | ) 26 | ] 27 | 28 | setup( 29 | name="_openmp_test_helper_nested_prange_blas", 30 | ext_modules=cythonize( 31 | ext_modules, 32 | compiler_directives={ 33 | "language_level": 3, 34 | "boundscheck": False, 35 | "wraparound": False, 36 | }, 37 | ), 38 | ) 39 | 40 | finally: 41 | os.environ.update(original_environ) 42 | -------------------------------------------------------------------------------- /tests/_openmp_test_helper/setup_outer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import Extension, setup 3 | from Cython.Build import cythonize 4 | 5 | from build_utils import set_cc_variables 6 | from build_utils import get_openmp_flag 7 | 8 | original_environ = os.environ.copy() 9 | try: 10 | # Make it possible to compile the 2 OpenMP enabled Cython extensions 11 | # with different compilers and therefore different OpenMP runtimes. 12 | outer_loop_cc_var = set_cc_variables("CC_OUTER_LOOP") 13 | openmp_flag = get_openmp_flag() 14 | 15 | ext_modules = [ 16 | Extension( 17 | "openmp_helpers_outer", 18 | ["openmp_helpers_outer.pyx"], 19 | extra_compile_args=openmp_flag, 20 | extra_link_args=openmp_flag, 21 | ) 22 | ] 23 | 24 | setup( 25 | name="_openmp_test_helper_outer", 26 | ext_modules=cythonize( 27 | ext_modules, 28 | compiler_directives={ 29 | "language_level": 3, 30 | "boundscheck": False, 31 | "wraparound": False, 32 | }, 33 | compile_time_env={"CC_OUTER_LOOP": outer_loop_cc_var or "unknown"}, 34 | ), 35 | ) 36 | 37 | finally: 38 | os.environ.update(original_environ) 39 | -------------------------------------------------------------------------------- /tests/_pyMylib/__init__.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from pathlib import Path 3 | 4 | from threadpoolctl import LibController, register 5 | 6 | path = Path(__file__).parent / "my_threaded_lib.so" 7 | ctypes.CDLL(path) 8 | 9 | 10 | class MyThreadedLibController(LibController): 11 | # names for threadpoolctl's context filtering 12 | user_api = "my_threaded_lib" 13 | internal_api = "my_threaded_lib" 14 | 15 | # Patterns to identify the name of the linked library to load. 16 | # If a dynamic library with a matching filename is linked to the python 17 | # process, it will be loaded as the `dynlib` attribute of the LibController 18 | # instance. 19 | filename_prefixes = ("my_threaded_lib",) 20 | 21 | # (Optional) Symbols that the linked library is expected to expose. It is used along 22 | # with the `filename_prefixes` to make sure that the correct library is identified. 23 | check_symbols = ( 24 | "mylib_get_num_threads", 25 | "mylib_set_num_threads", 26 | "mylib_get_version", 27 | ) 28 | 29 | def get_num_threads(self): 30 | # This function should return the current maximum number of threads, 31 | # which is reported as "num_threads" by `ThreadpoolController.info`. 32 | return getattr(self.dynlib, "mylib_get_num_threads")() 33 | 34 | def set_num_threads(self, num_threads): 35 | # This function limits the maximum number of threads, 36 | # when `ThreadpoolController.limit` is called. 37 | getattr(self.dynlib, "mylib_set_num_threads")(num_threads) 38 | 39 | def get_version(self): 40 | # This function returns the version of the linked library if it is exposed, 41 | # which is reported as "version" by `ThreadpoolController.info`. 42 | get_version = getattr(self.dynlib, "mylib_get_version") 43 | get_version.restype = ctypes.c_char_p 44 | return get_version().decode("utf-8") 45 | 46 | def set_additional_attributes(self): 47 | # This function is called during the initialization of the LibController. 48 | # Additional information meant to be exposed by `ThreadpoolController.info` 49 | # should be set here as attributes of the LibController instance. 50 | self.some_attr = "some_value" 51 | 52 | 53 | register(MyThreadedLibController) 54 | -------------------------------------------------------------------------------- /tests/_pyMylib/my_threaded_lib.c: -------------------------------------------------------------------------------- 1 | int NUM_THREADS = 42; 2 | int* NUM_THREADS_p = &NUM_THREADS; 3 | 4 | 5 | int mylib_get_num_threads(){ 6 | return *NUM_THREADS_p; 7 | } 8 | 9 | 10 | void mylib_set_num_threads(int num_threads){ 11 | *NUM_THREADS_p = num_threads; 12 | } 13 | 14 | 15 | char* mylib_get_version(){ 16 | return "2.0"; 17 | } 18 | -------------------------------------------------------------------------------- /tests/test_threadpoolctl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pytest 4 | import re 5 | import subprocess 6 | import sys 7 | 8 | from threadpoolctl import threadpool_limits, threadpool_info 9 | from threadpoolctl import ThreadpoolController 10 | from threadpoolctl import _ALL_PREFIXES, _ALL_USER_APIS 11 | 12 | from .utils import cython_extensions_compiled 13 | from .utils import check_nested_prange_blas 14 | from .utils import libopenblas_paths 15 | from .utils import scipy 16 | from .utils import threadpool_info_from_subprocess 17 | from .utils import select 18 | 19 | 20 | def is_old_openblas(lib_controller): 21 | # Possible bug in getting maximum number of threads with OpenBLAS < 0.2.16 22 | # and OpenBLAS does not expose its version before 0.3.4. 23 | return lib_controller.internal_api == "openblas" and lib_controller.version is None 24 | 25 | 26 | def skip_if_openblas_openmp(): 27 | """Helper to skip tests with side effects when OpenBLAS has the OpenMP 28 | threading layer. 29 | """ 30 | if any( 31 | lib_controller.internal_api == "openblas" 32 | and lib_controller.threading_layer == "openmp" 33 | for lib_controller in ThreadpoolController().lib_controllers 34 | ): 35 | pytest.skip( 36 | "Setting a limit on OpenBLAS when using the OpenMP threading layer also " 37 | "impact the OpenMP library. They can't be controlled independently." 38 | ) 39 | 40 | 41 | def effective_num_threads(nthreads, max_threads): 42 | if nthreads is None or nthreads > max_threads: 43 | return max_threads 44 | return nthreads 45 | 46 | 47 | def test_threadpool_info(): 48 | # Check consistency between threadpool_info and ThreadpoolController 49 | function_info = threadpool_info() 50 | object_info = ThreadpoolController().lib_controllers 51 | 52 | for lib_info, lib_controller in zip(function_info, object_info): 53 | assert lib_info == lib_controller.info() 54 | 55 | 56 | def test_threadpool_controller_info(): 57 | # Check that all keys expected for the private api are in the dicts 58 | # returned by the `info` methods 59 | controller = ThreadpoolController() 60 | 61 | assert threadpool_info() == [ 62 | lib_controller.info() for lib_controller in controller.lib_controllers 63 | ] 64 | assert controller.info() == [ 65 | lib_controller.info() for lib_controller in controller.lib_controllers 66 | ] 67 | 68 | for lib_controller_dict in controller.info(): 69 | assert "user_api" in lib_controller_dict 70 | assert "internal_api" in lib_controller_dict 71 | assert "prefix" in lib_controller_dict 72 | assert "filepath" in lib_controller_dict 73 | assert "version" in lib_controller_dict 74 | assert "num_threads" in lib_controller_dict 75 | 76 | if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"): 77 | assert "threading_layer" in lib_controller_dict 78 | 79 | 80 | def test_controller_info_actualized(): 81 | # Check that the num_threads attribute reflects the actual state of the threadpools 82 | controller = ThreadpoolController() 83 | original_info = controller.info() 84 | 85 | with threadpool_limits(limits=1): 86 | assert all( 87 | lib_controller.num_threads == 1 88 | for lib_controller in controller.lib_controllers 89 | ) 90 | 91 | assert controller.info() == original_info 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "kwargs", 96 | [ 97 | {"user_api": "blas"}, 98 | {"prefix": "libgomp"}, 99 | {"internal_api": "openblas", "prefix": "libomp"}, 100 | {"prefix": ["libgomp", "libomp", "libiomp"]}, 101 | ], 102 | ) 103 | def test_threadpool_controller_select(kwargs): 104 | # Check the behior of the select method of ThreadpoolController 105 | controller = ThreadpoolController().select(**kwargs) 106 | if not controller: 107 | pytest.skip(f"Requires at least one of {list(kwargs.values())}.") 108 | 109 | for lib_controller in controller.lib_controllers: 110 | assert any( 111 | getattr(lib_controller, key) in (val if isinstance(val, list) else [val]) 112 | for key, val in kwargs.items() 113 | ) 114 | 115 | 116 | @pytest.mark.parametrize("prefix", _ALL_PREFIXES) 117 | @pytest.mark.parametrize("limit", [1, 3]) 118 | def test_threadpool_limits_by_prefix(prefix, limit): 119 | # Check that the maximum number of threads can be set by prefix 120 | controller = ThreadpoolController() 121 | original_info = controller.info() 122 | 123 | controller_matching_prefix = controller.select(prefix=prefix) 124 | if not controller_matching_prefix: 125 | pytest.skip(f"Requires {prefix} runtime") 126 | 127 | with threadpool_limits(limits={prefix: limit}): 128 | for lib_controller in controller_matching_prefix.lib_controllers: 129 | if is_old_openblas(lib_controller): 130 | continue 131 | # threadpool_limits only sets an upper bound on the number of 132 | # threads. 133 | assert 0 < lib_controller.num_threads <= limit 134 | assert ThreadpoolController().info() == original_info 135 | 136 | 137 | @pytest.mark.parametrize("user_api", (None, "blas", "openmp")) 138 | @pytest.mark.parametrize("limit", [1, 3]) 139 | def test_set_threadpool_limits_by_api(user_api, limit): 140 | # Check that the maximum number of threads can be set by user_api 141 | controller = ThreadpoolController() 142 | original_info = controller.info() 143 | 144 | if user_api is None: 145 | controller_matching_api = controller 146 | else: 147 | controller_matching_api = controller.select(user_api=user_api) 148 | if not controller_matching_api: 149 | user_apis = _ALL_USER_APIS if user_api is None else [user_api] 150 | pytest.skip(f"Requires a library which api is in {user_apis}") 151 | 152 | with threadpool_limits(limits=limit, user_api=user_api): 153 | for lib_controller in controller_matching_api.lib_controllers: 154 | if is_old_openblas(lib_controller): 155 | continue 156 | # threadpool_limits only sets an upper bound on the number of 157 | # threads. 158 | assert 0 < lib_controller.num_threads <= limit 159 | 160 | assert ThreadpoolController().info() == original_info 161 | 162 | 163 | def test_threadpool_limits_function_with_side_effect(): 164 | # Check that threadpool_limits can be used as a function with 165 | # side effects instead of a context manager. 166 | original_info = ThreadpoolController().info() 167 | 168 | threadpool_limits(limits=1) 169 | try: 170 | for lib_controller in ThreadpoolController().lib_controllers: 171 | if is_old_openblas(lib_controller): 172 | continue 173 | assert lib_controller.num_threads == 1 174 | finally: 175 | # Restore the original limits so that this test does not have any 176 | # side-effect. 177 | threadpool_limits(limits=original_info) 178 | 179 | assert ThreadpoolController().info() == original_info 180 | 181 | 182 | def test_set_threadpool_limits_no_limit(): 183 | # Check that limits=None does nothing. 184 | original_info = ThreadpoolController().info() 185 | 186 | with threadpool_limits(limits=None): 187 | assert ThreadpoolController().info() == original_info 188 | 189 | assert ThreadpoolController().info() == original_info 190 | 191 | 192 | def test_threadpool_limits_manual_restore(): 193 | # Check that threadpool_limits can be used as an object which holds the 194 | # original state of the threadpools and that can be restored thanks to the 195 | # dedicated restore_original_limits method 196 | original_info = ThreadpoolController().info() 197 | 198 | limits = threadpool_limits(limits=1) 199 | try: 200 | for lib_controller in ThreadpoolController().lib_controllers: 201 | if is_old_openblas(lib_controller): 202 | continue 203 | assert lib_controller.num_threads == 1 204 | finally: 205 | # Restore the original limits so that this test does not have any 206 | # side-effect. 207 | limits.restore_original_limits() 208 | 209 | assert ThreadpoolController().info() == original_info 210 | 211 | 212 | def test_threadpool_controller_limit(): 213 | # Check that using the limit method of ThreadpoolController only impact its 214 | # library controllers. 215 | 216 | # This is not True for OpenBLAS with the OpenMP threading layer. 217 | skip_if_openblas_openmp() 218 | 219 | blas_controller = ThreadpoolController().select(user_api="blas") 220 | original_openmp_info = ThreadpoolController().select(user_api="openmp").info() 221 | 222 | with blas_controller.limit(limits=1): 223 | blas_controller = ThreadpoolController().select(user_api="blas") 224 | openmp_info = ThreadpoolController().select(user_api="openmp").info() 225 | 226 | assert all( 227 | lib_controller.num_threads == 1 228 | for lib_controller in blas_controller.lib_controllers 229 | ) 230 | # original_blas_controller contains only blas libraries so no opemp library 231 | # should be impacted. 232 | assert openmp_info == original_openmp_info 233 | 234 | 235 | def test_get_params_for_sequential_blas_under_openmp(): 236 | # Test for the behavior of get_params_for_sequential_blas_under_openmp. 237 | controller = ThreadpoolController() 238 | original_info = controller.info() 239 | 240 | params = controller._get_params_for_sequential_blas_under_openmp() 241 | 242 | if controller.select( 243 | internal_api="openblas", threading_layer="openmp" 244 | ).lib_controllers: 245 | assert params["limits"] is None 246 | assert params["user_api"] is None 247 | 248 | with controller.limit(limits="sequential_blas_under_openmp"): 249 | assert controller.info() == original_info 250 | 251 | else: 252 | assert params["limits"] == 1 253 | assert params["user_api"] == "blas" 254 | 255 | with controller.limit(limits="sequential_blas_under_openmp"): 256 | assert all( 257 | lib_info["num_threads"] == 1 258 | for lib_info in controller.info() 259 | if lib_info["user_api"] == "blas" 260 | ) 261 | 262 | 263 | def test_nested_limits(): 264 | # Check that exiting the context manager properly restores the original limits even 265 | # when nested. 266 | controller = ThreadpoolController() 267 | original_info = controller.info() 268 | 269 | if any(info["num_threads"] < 2 for info in original_info): 270 | pytest.skip("Test requires at least 2 CPUs on host machine") 271 | 272 | def check_num_threads(expected_num_threads): 273 | assert all( 274 | lib_controller.num_threads == expected_num_threads 275 | for lib_controller in ThreadpoolController().lib_controllers 276 | ) 277 | 278 | with controller.limit(limits=1): 279 | check_num_threads(expected_num_threads=1) 280 | 281 | with controller.limit(limits=2): 282 | check_num_threads(expected_num_threads=2) 283 | 284 | check_num_threads(expected_num_threads=1) 285 | 286 | assert ThreadpoolController().info() == original_info 287 | 288 | 289 | def test_threadpool_limits_bad_input(): 290 | # Check that appropriate errors are raised for invalid arguments 291 | match = re.escape(f"user_api must be either in {_ALL_USER_APIS} or None.") 292 | with pytest.raises(ValueError, match=match): 293 | threadpool_limits(limits=1, user_api="wrong") 294 | 295 | with pytest.raises( 296 | TypeError, match="limits must either be an int, a list, a dict, or" 297 | ): 298 | threadpool_limits(limits=(1, 2, 3)) 299 | 300 | 301 | @pytest.mark.skipif( 302 | not cython_extensions_compiled, reason="Requires cython extensions to be compiled" 303 | ) 304 | @pytest.mark.parametrize("num_threads", [1, 2, 4]) 305 | def test_openmp_limit_num_threads(num_threads): 306 | # checks that OpenMP effectively uses the number of threads requested by 307 | # the context manager 308 | import tests._openmp_test_helper.openmp_helpers_inner as omp_inner 309 | 310 | check_openmp_num_threads = omp_inner.check_openmp_num_threads 311 | 312 | old_num_threads = check_openmp_num_threads(100) 313 | 314 | with threadpool_limits(limits=num_threads): 315 | assert check_openmp_num_threads(100) in (num_threads, old_num_threads) 316 | assert check_openmp_num_threads(100) == old_num_threads 317 | 318 | 319 | @pytest.mark.skipif( 320 | not cython_extensions_compiled, reason="Requires cython extensions to be compiled" 321 | ) 322 | @pytest.mark.parametrize("nthreads_outer", [None, 1, 2, 4]) 323 | def test_openmp_nesting(nthreads_outer): 324 | # checks that OpenMP effectively uses the number of threads requested by 325 | # the context manager when nested in an outer OpenMP loop. 326 | import tests._openmp_test_helper.openmp_helpers_outer as omp_outer 327 | 328 | check_nested_openmp_loops = omp_outer.check_nested_openmp_loops 329 | 330 | # Find which OpenMP lib is used at runtime for inner loop 331 | inner_info = threadpool_info_from_subprocess( 332 | "tests._openmp_test_helper.openmp_helpers_inner" 333 | ) 334 | assert len(inner_info) == 1 335 | inner_omp = inner_info[0]["prefix"] 336 | 337 | # Find which OpenMP lib is used at runtime for outer loop 338 | outer_info = threadpool_info_from_subprocess( 339 | "tests._openmp_test_helper.openmp_helpers_outer" 340 | ) 341 | if len(outer_info) == 1: 342 | # Only 1 openmp loaded. It has to be this one. 343 | outer_omp = outer_info[0]["prefix"] 344 | else: 345 | # There are 2 openmp, the one from inner and the one from outer. 346 | assert len(outer_info) == 2 347 | # We already know the one from inner. It has to be the other one. 348 | prefixes = {lib_info["prefix"] for lib_info in outer_info} 349 | outer_omp = prefixes - {inner_omp} 350 | 351 | outer_num_threads, inner_num_threads = check_nested_openmp_loops(10) 352 | original_info = ThreadpoolController().info() 353 | 354 | if inner_omp == outer_omp: 355 | # The OpenMP runtime should be shared by default, meaning that the 356 | # inner loop should automatically be run serially by the OpenMP runtime 357 | assert inner_num_threads == 1 358 | 359 | with threadpool_limits(limits=1) as threadpoolctx: 360 | max_threads = threadpoolctx.get_original_num_threads()["openmp"] 361 | nthreads = effective_num_threads(nthreads_outer, max_threads) 362 | 363 | # Ask outer loop to run on nthreads threads and inner loop run on 1 364 | # thread 365 | outer_num_threads, inner_num_threads = check_nested_openmp_loops(10, nthreads) 366 | 367 | # The state of the original state of all threadpools should have been 368 | # restored. 369 | assert ThreadpoolController().info() == original_info 370 | 371 | # The number of threads available in the outer loop should not have been 372 | # decreased: 373 | assert outer_num_threads == nthreads 374 | 375 | # The number of threads available in the inner loop should have been 376 | # set to 1 to avoid oversubscription and preserve performance: 377 | if inner_omp != outer_omp: 378 | if inner_num_threads != 1: 379 | # XXX: this does not always work when nesting independent openmp 380 | # implementations. See: https://github.com/jeremiedbb/Nested_OpenMP 381 | pytest.xfail( 382 | f"Inner OpenMP num threads was {inner_num_threads} instead of 1" 383 | ) 384 | assert inner_num_threads == 1 385 | 386 | 387 | def test_shipped_openblas(): 388 | # checks that OpenBLAS effectively uses the number of threads requested by 389 | # the context manager 390 | original_info = ThreadpoolController().info() 391 | openblas_controller = ThreadpoolController().select(internal_api="openblas") 392 | 393 | with threadpool_limits(1): 394 | for lib_controller in openblas_controller.lib_controllers: 395 | assert lib_controller.num_threads == 1 396 | 397 | assert ThreadpoolController().info() == original_info 398 | 399 | 400 | @pytest.mark.skipif( 401 | len(libopenblas_paths) < 2, reason="need at least 2 shipped openblas library" 402 | ) 403 | def test_multiple_shipped_openblas(): 404 | # This redundant test is meant to make it easier to see if the system 405 | # has 2 or more active openblas runtimes available just by reading the 406 | # pytest report (whether or not this test has been skipped). 407 | test_shipped_openblas() 408 | 409 | 410 | @pytest.mark.skipif( 411 | not cython_extensions_compiled, reason="Requires cython extensions to be compiled" 412 | ) 413 | @pytest.mark.skipif( 414 | check_nested_prange_blas is None, 415 | reason="Requires nested_prange_blas to be compiled", 416 | ) 417 | @pytest.mark.parametrize("nthreads_outer", [None, 1, 2, 4]) 418 | def test_nested_prange_blas(nthreads_outer): 419 | # Check that the BLAS uses the number of threads requested by the context manager 420 | # when nested in an outer OpenMP loop. 421 | # Remark: this test also works with sequential BLAS only because we limit the 422 | # number of threads for the BLAS to 1. 423 | import numpy as np 424 | 425 | skip_if_openblas_openmp() 426 | 427 | original_info = ThreadpoolController().info() 428 | 429 | blas_controller = ThreadpoolController().select(user_api="blas") 430 | blis_controller = ThreadpoolController().select(internal_api="blis") 431 | 432 | # skip if the BLAS used by numpy is an old openblas. OpenBLAS 0.3.3 and 433 | # older are known to cause an unrecoverable deadlock at process shutdown 434 | # time (after pytest has exited). 435 | # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that 436 | # case this test will run BLIS gemm so no need to skip. 437 | if not blis_controller and any( 438 | is_old_openblas(lib_controller) 439 | for lib_controller in blas_controller.lib_controllers 440 | ): 441 | pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") 442 | 443 | A = np.ones((1000, 10)) 444 | B = np.ones((100, 10)) 445 | 446 | with threadpool_limits(limits=1) as threadpoolctx: 447 | max_threads = threadpoolctx.get_original_num_threads()["openmp"] 448 | nthreads = effective_num_threads(nthreads_outer, max_threads) 449 | 450 | result = check_nested_prange_blas(A, B, nthreads) 451 | C, prange_num_threads, inner_info = result 452 | 453 | assert np.allclose(C, np.dot(A, B.T)) 454 | assert prange_num_threads == nthreads 455 | 456 | nested_blas_info = select(inner_info, user_api="blas") 457 | assert len(nested_blas_info) == len(blas_controller.lib_controllers) 458 | assert all(lib_info["num_threads"] == 1 for lib_info in nested_blas_info) 459 | 460 | assert ThreadpoolController().info() == original_info 461 | 462 | 463 | # the method `get_original_num_threads` raises a UserWarning due to different 464 | # num_threads from libraries with the same `user_api`. It will be raised only 465 | # in the CI job with 2 openblas (py38_pip_openblas_gcc_clang). It is expected 466 | # so we can safely filter it. 467 | @pytest.mark.filterwarnings("ignore::UserWarning") 468 | @pytest.mark.parametrize("limit", [1, None]) 469 | def test_get_original_num_threads(limit): 470 | # Tests the method get_original_num_threads of the context manager 471 | with threadpool_limits(limits=2, user_api="blas") as ctx: 472 | # set different blas num threads to start with (when multiple openblas) 473 | if len(ctx._controller.select(user_api="blas")) > 1: 474 | ctx._controller.lib_controllers[0].set_num_threads(1) 475 | 476 | original_info = ThreadpoolController().info() 477 | with threadpool_limits(limits=limit, user_api="blas") as threadpoolctx: 478 | original_num_threads = threadpoolctx.get_original_num_threads() 479 | 480 | assert "openmp" not in original_num_threads 481 | 482 | blas_info = select(original_info, user_api="blas") 483 | if blas_info: 484 | expected = min(lib_info["num_threads"] for lib_info in blas_info) 485 | assert original_num_threads["blas"] == expected 486 | else: 487 | assert original_num_threads["blas"] is None 488 | 489 | if len(libopenblas_paths) >= 2: 490 | with pytest.warns(None, match="Multiple value possible"): 491 | threadpoolctx.get_original_num_threads() 492 | 493 | 494 | def test_mkl_threading_layer(): 495 | # Check that threadpool_info correctly recovers the threading layer used 496 | # by mkl 497 | mkl_controller = ThreadpoolController().select(internal_api="mkl") 498 | expected_layer = os.getenv("MKL_THREADING_LAYER") 499 | 500 | if not (mkl_controller and expected_layer): 501 | pytest.skip("requires MKL and the environment variable MKL_THREADING_LAYER set") 502 | 503 | actual_layer = mkl_controller.lib_controllers[0].threading_layer 504 | assert actual_layer == expected_layer.lower() 505 | 506 | 507 | def test_blis_threading_layer(): 508 | # Check that threadpool_info correctly recovers the threading layer used 509 | # by blis 510 | blis_controller = ThreadpoolController().select(internal_api="blis") 511 | expected_layer = os.getenv("BLIS_ENABLE_THREADING") 512 | if expected_layer == "no": 513 | expected_layer = "disabled" 514 | 515 | if not (blis_controller and expected_layer): 516 | pytest.skip( 517 | "requires BLIS and the environment variable BLIS_ENABLE_THREADING set" 518 | ) 519 | 520 | actual_layer = blis_controller.lib_controllers[0].threading_layer 521 | assert actual_layer == expected_layer 522 | 523 | 524 | @pytest.mark.skipif( 525 | not cython_extensions_compiled, reason="Requires cython extensions to be compiled" 526 | ) 527 | def test_libomp_libiomp_warning(recwarn): 528 | # Trigger the import of a potentially clang-compiled extension: 529 | import tests._openmp_test_helper.openmp_helpers_outer # noqa 530 | 531 | # Trigger the import of numpy to potentially import Intel OpenMP via MKL 532 | pytest.importorskip("numpy.linalg") 533 | 534 | # Check that a warning is raised when both libomp and libiomp are loaded 535 | # It should happen in one CI job (pylatest_conda_mkl_clang_gcc). 536 | controller = ThreadpoolController() 537 | prefixes = [lib_controller.prefix for lib_controller in controller.lib_controllers] 538 | 539 | if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): 540 | pytest.skip("Requires both libomp and libiomp loaded, on Linux") 541 | 542 | assert len(recwarn) == 1 543 | wm = recwarn[0] 544 | assert wm.category == RuntimeWarning 545 | assert "Found Intel" in str(wm.message) 546 | assert "LLVM" in str(wm.message) 547 | assert "multiple_openmp.md" in str(wm.message) 548 | 549 | 550 | def test_command_line_empty_or_system_openmp(): 551 | # When the command line is called without arguments, no library should be 552 | # detected. The only exception is a system OpenMP library that can be 553 | # linked to the Python interpreter, for instance via the libb2.so BLAKE2 554 | # library that can itself be linked to an OpenMP runtime on Gentoo. 555 | output = subprocess.check_output((sys.executable + " -m threadpoolctl").split()) 556 | results = json.loads(output.decode("utf-8")) 557 | conda_prefix = os.getenv("CONDA_PREFIX") 558 | managed_by_conda = conda_prefix and sys.executable.startswith(conda_prefix) 559 | if not managed_by_conda: # pragma: no cover 560 | # When using a Python interpreter that does not come from a conda 561 | # environment, we should ignore any system OpenMP library. 562 | results = [r for r in results if r["user_api"] != "openmp"] 563 | assert results == [] 564 | 565 | 566 | def test_command_line_command_flag(): 567 | pytest.importorskip("numpy") 568 | output = subprocess.check_output( 569 | [sys.executable, "-m", "threadpoolctl", "-c", "import numpy"] 570 | ) 571 | cli_info = json.loads(output.decode("utf-8")) 572 | 573 | this_process_info = threadpool_info() 574 | for lib_info in cli_info: 575 | assert lib_info in this_process_info 576 | 577 | 578 | @pytest.mark.skipif( 579 | sys.version_info < (3, 7), reason="need recent subprocess.run options" 580 | ) 581 | def test_command_line_import_flag(): 582 | result = subprocess.run( 583 | [ 584 | sys.executable, 585 | "-m", 586 | "threadpoolctl", 587 | "-i", 588 | "numpy", 589 | "scipy.linalg", 590 | "invalid_package", 591 | "numpy.invalid_sumodule", 592 | ], 593 | capture_output=True, 594 | check=True, 595 | encoding="utf-8", 596 | ) 597 | cli_info = json.loads(result.stdout) 598 | 599 | this_process_info = threadpool_info() 600 | for lib_info in cli_info: 601 | assert lib_info in this_process_info 602 | 603 | warnings = [w.strip() for w in result.stderr.splitlines()] 604 | assert "WARNING: could not import invalid_package" in warnings 605 | assert "WARNING: could not import numpy.invalid_sumodule" in warnings 606 | if scipy is None: 607 | assert "WARNING: could not import scipy.linalg" in warnings 608 | else: 609 | assert "WARNING: could not import scipy.linalg" not in warnings 610 | 611 | 612 | def test_architecture(): 613 | expected_openblas_architectures = ( 614 | # XXX: add more as needed by CI or developer laptops 615 | "armv8", 616 | "haswell", 617 | "neoversen1", 618 | "prescott", # see: https://github.com/xianyi/OpenBLAS/pull/3485 619 | "skylakex", 620 | "sandybridge", 621 | "vortex", 622 | "zen", 623 | ) 624 | expected_blis_architectures = ( 625 | # XXX: add more as needed by CI or developer laptops 626 | "skx", 627 | "haswell", 628 | "zen3", 629 | ) 630 | for lib_info in threadpool_info(): 631 | if lib_info["internal_api"] == "openblas": 632 | assert lib_info["architecture"].lower() in expected_openblas_architectures 633 | elif lib_info["internal_api"] == "blis": 634 | assert lib_info["architecture"].lower() in expected_blis_architectures 635 | else: 636 | # Not supported for other libraries 637 | assert "architecture" not in lib_info 638 | 639 | 640 | def test_openblas_threading_layer(): 641 | # Check that threadpool_info correctly recovers the threading layer used by openblas 642 | openblas_controller = ThreadpoolController().select(internal_api="openblas") 643 | 644 | if not (openblas_controller): 645 | pytest.skip("requires OpenBLAS.") 646 | 647 | expected_openblas_threading_layers = ("openmp", "pthreads", "disabled") 648 | 649 | threading_layer = openblas_controller.lib_controllers[0].threading_layer 650 | 651 | if threading_layer == "unknown": 652 | # If we never recover an acceptable value for the threading layer, it will be 653 | # always skipped and caught by check_no_test_always_skipped. 654 | pytest.skip("Unknown OpenBLAS threading layer.") 655 | 656 | assert threading_layer in expected_openblas_threading_layers 657 | 658 | 659 | # skip test if not run in a azure pipelines job since it relies on a specific flexiblas 660 | # installation. 661 | @pytest.mark.skipif( 662 | "GITHUB_ACTIONS" not in os.environ, reason="not running in azure pipelines" 663 | ) 664 | def test_flexiblas(): 665 | # Check that threadpool_info correctly recovers the FlexiBLAS backends. 666 | flexiblas_controller = ThreadpoolController().select(internal_api="flexiblas") 667 | 668 | if not flexiblas_controller: 669 | pytest.skip("requires FlexiBLAS.") 670 | flexiblas_controller = flexiblas_controller.lib_controllers[0] 671 | 672 | expected_backends = {"NETLIB", "OPENBLAS_CONDA"} 673 | expected_backends_loaded = {"OPENBLAS_CONDA"} 674 | expected_current_backend = "OPENBLAS_CONDA" # set as default at build time 675 | 676 | flexiblas_backends = flexiblas_controller.available_backends 677 | flexiblas_backends_loaded = flexiblas_controller.loaded_backends 678 | current_backend = flexiblas_controller.current_backend 679 | 680 | assert set(flexiblas_backends) == expected_backends 681 | assert set(flexiblas_backends_loaded) == expected_backends_loaded 682 | assert current_backend == expected_current_backend 683 | 684 | 685 | def test_flexiblas_switch_error(): 686 | # Check that an error is raised when trying to switch to an invalid backend. 687 | flexiblas_controller = ThreadpoolController().select(internal_api="flexiblas") 688 | 689 | if not flexiblas_controller: 690 | pytest.skip("requires FlexiBLAS.") 691 | flexiblas_controller = flexiblas_controller.lib_controllers[0] 692 | 693 | with pytest.raises(RuntimeError, match="Failed to load backend"): 694 | flexiblas_controller.switch_backend("INVALID_BACKEND") 695 | 696 | 697 | # skip test if not run in a azure pipelines job since it relies on a specific flexiblas 698 | # installation. 699 | @pytest.mark.skipif( 700 | "GITHUB_ACTIONS" not in os.environ, reason="not running in azure pipelines" 701 | ) 702 | def test_flexiblas_switch(): 703 | # Check that the backend can be switched. 704 | controller = ThreadpoolController() 705 | fb_controller = controller.select(internal_api="flexiblas") 706 | 707 | if not fb_controller: 708 | pytest.skip("requires FlexiBLAS.") 709 | fb_controller = fb_controller.lib_controllers[0] 710 | 711 | # at first mkl is not loaded in the CI jobs where this test runs 712 | assert len(controller.select(internal_api="mkl").lib_controllers) == 0 713 | 714 | # at first, only "OPENBLAS_CONDA" is loaded 715 | assert fb_controller.current_backend == "OPENBLAS_CONDA" 716 | assert fb_controller.loaded_backends == ["OPENBLAS_CONDA"] 717 | 718 | fb_controller.switch_backend("NETLIB") 719 | assert fb_controller.current_backend == "NETLIB" 720 | assert fb_controller.loaded_backends == ["OPENBLAS_CONDA", "NETLIB"] 721 | 722 | if sys.platform == "linux": 723 | mkl_path = f"{os.getenv('CONDA_PREFIX')}/lib/libmkl_rt.so" 724 | fb_controller.switch_backend(mkl_path) 725 | assert fb_controller.current_backend == mkl_path 726 | assert fb_controller.loaded_backends == ["OPENBLAS_CONDA", "NETLIB", mkl_path] 727 | # switching the backend triggered a new search for loaded shared libs 728 | assert len(controller.select(internal_api="mkl").lib_controllers) == 1 729 | 730 | # switch back to default to avoid side effects 731 | fb_controller.switch_backend("OPENBLAS_CONDA") 732 | 733 | 734 | def test_threadpool_controller_as_decorator(): 735 | # Check that using the decorator can be nested and is restricted to the scope of 736 | # the decorated function. 737 | controller = ThreadpoolController() 738 | original_info = controller.info() 739 | 740 | if any(info["num_threads"] < 2 for info in original_info): 741 | pytest.skip("Test requires at least 2 CPUs on host machine") 742 | if not controller.select(user_api="blas"): 743 | pytest.skip("Requires a blas runtime.") 744 | 745 | def check_blas_num_threads(expected_num_threads): 746 | blas_controller = ThreadpoolController().select(user_api="blas") 747 | assert all( 748 | lib_controller.num_threads == expected_num_threads 749 | for lib_controller in blas_controller.lib_controllers 750 | ) 751 | 752 | @controller.wrap(limits=1, user_api="blas") 753 | def outer_func(): 754 | check_blas_num_threads(expected_num_threads=1) 755 | inner_func() 756 | check_blas_num_threads(expected_num_threads=1) 757 | 758 | @controller.wrap(limits=2, user_api="blas") 759 | def inner_func(): 760 | check_blas_num_threads(expected_num_threads=2) 761 | 762 | outer_func() 763 | 764 | assert ThreadpoolController().info() == original_info 765 | 766 | 767 | def test_custom_controller(): 768 | # Check that a custom controller can be used to change the number of threads 769 | # used by a library. 770 | try: 771 | import tests._pyMylib # noqa 772 | except: 773 | pytest.skip("requires my_thread_lib to be compiled") 774 | 775 | controller = ThreadpoolController() 776 | original_info = controller.info() 777 | 778 | mylib_controller = controller.select(user_api="my_threaded_lib") 779 | 780 | # my_threaded_lib has been found and there's 1 matching shared library 781 | assert len(mylib_controller.lib_controllers) == 1 782 | mylib_controller = mylib_controller.lib_controllers[0] 783 | 784 | # we linked against my_threaded_lib v2.0 and by default it uses 42 thread 785 | assert mylib_controller.version == "2.0" 786 | assert mylib_controller.num_threads == 42 787 | 788 | # my_threaded_lib exposes an additional info "some_attr": 789 | assert mylib_controller.info()["some_attr"] == "some_value" 790 | 791 | with controller.limit(limits=1, user_api="my_threaded_lib"): 792 | assert mylib_controller.num_threads == 1 793 | 794 | assert ThreadpoolController().info() == original_info 795 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import threadpoolctl 5 | from glob import glob 6 | from os.path import dirname, normpath 7 | from subprocess import check_output 8 | 9 | # Path to shipped openblas for libraries such as numpy or scipy 10 | libopenblas_patterns = [] 11 | 12 | 13 | try: 14 | # make sure the mkl/blas are loaded for test_threadpool_limits 15 | import numpy as np 16 | 17 | np.dot(np.ones(1000), np.ones(1000)) 18 | 19 | libopenblas_patterns.append(os.path.join(np.__path__[0], ".libs", "libopenblas*")) 20 | except ImportError: 21 | pass 22 | 23 | 24 | try: 25 | import scipy 26 | import scipy.linalg # noqa: F401 27 | 28 | scipy.linalg.svd([[1, 2], [3, 4]]) 29 | 30 | libopenblas_patterns.append( 31 | os.path.join(scipy.__path__[0], ".libs", "libopenblas*") 32 | ) 33 | except ImportError: 34 | scipy = None 35 | 36 | libopenblas_paths = set( 37 | path for pattern in libopenblas_patterns for path in glob(pattern) 38 | ) 39 | 40 | 41 | try: 42 | import tests._openmp_test_helper.openmp_helpers_inner # noqa: F401 43 | 44 | cython_extensions_compiled = True 45 | except ImportError: 46 | cython_extensions_compiled = False 47 | 48 | 49 | try: 50 | from tests._openmp_test_helper.nested_prange_blas import check_nested_prange_blas 51 | except ImportError: 52 | check_nested_prange_blas = None 53 | 54 | 55 | def threadpool_info_from_subprocess(module): 56 | """Utility to call threadpool_info in a subprocess 57 | 58 | `module` is imported before calling threadpool_info 59 | """ 60 | # set PYTHONPATH to import from non sub-modules 61 | path1 = normpath(dirname(threadpoolctl.__file__)) 62 | path2 = os.path.join(path1, "tests", "_openmp_test_helper") 63 | pythonpath = os.pathsep.join([path1, path2]) 64 | env = os.environ.copy() 65 | try: 66 | env["PYTHONPATH"] = os.pathsep.join([pythonpath, env["PYTHONPATH"]]) 67 | except KeyError: 68 | env["PYTHONPATH"] = pythonpath 69 | 70 | cmd = [sys.executable, "-m", "threadpoolctl", "-i", module] 71 | out = check_output(cmd, env=env).decode("utf-8") 72 | return json.loads(out) 73 | 74 | 75 | def select(info, **kwargs): 76 | """Select a subset of the list of library info matching the request""" 77 | # It's just a utility function to avoid repeating the pattern 78 | # [lib_info for lib_info in info if lib_info[""] == key] 79 | for key, vals in kwargs.items(): 80 | kwargs[key] = [vals] if not isinstance(vals, list) else vals 81 | 82 | selected_info = [ 83 | lib_info 84 | for lib_info in info 85 | if any(lib_info.get(key, None) in vals for key, vals in kwargs.items()) 86 | ] 87 | 88 | return selected_info 89 | -------------------------------------------------------------------------------- /threadpoolctl.py: -------------------------------------------------------------------------------- 1 | """threadpoolctl 2 | 3 | This module provides utilities to introspect native libraries that relies on 4 | thread pools (notably BLAS and OpenMP implementations) and dynamically set the 5 | maximal number of threads they can use. 6 | """ 7 | 8 | # License: BSD 3-Clause 9 | 10 | # The code to introspect dynamically loaded libraries on POSIX systems is 11 | # adapted from code by Intel developer @anton-malakhov available at 12 | # https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation) 13 | # and also published under the BSD 3-Clause license 14 | import os 15 | import re 16 | import sys 17 | import ctypes 18 | import itertools 19 | import textwrap 20 | from typing import final 21 | import warnings 22 | from ctypes.util import find_library 23 | from abc import ABC, abstractmethod 24 | from functools import lru_cache 25 | from contextlib import ContextDecorator 26 | 27 | __version__ = "3.7.0.dev0" 28 | __all__ = [ 29 | "threadpool_limits", 30 | "threadpool_info", 31 | "ThreadpoolController", 32 | "LibController", 33 | "register", 34 | ] 35 | 36 | 37 | # One can get runtime errors or even segfaults due to multiple OpenMP libraries 38 | # loaded simultaneously which can happen easily in Python when importing and 39 | # using compiled extensions built with different compilers and therefore 40 | # different OpenMP runtimes in the same program. In particular libiomp (used by 41 | # Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for 42 | # instance when calling BLAS inside a prange. Setting the following environment 43 | # variable allows multiple OpenMP libraries to be loaded. It should not degrade 44 | # performances since we manually take care of potential over-subscription 45 | # performance issues, in sections of the code where nested OpenMP loops can 46 | # happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily 47 | # disable it while under the scope of the outer OpenMP parallel section. 48 | os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True") 49 | 50 | # Structure to cast the info on dynamically loaded library. See 51 | # https://linux.die.net/man/3/dl_iterate_phdr for more details. 52 | _SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 53 | _SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 54 | 55 | 56 | class _dl_phdr_info(ctypes.Structure): 57 | _fields_ = [ 58 | ("dlpi_addr", _SYSTEM_UINT), # Base address of object 59 | ("dlpi_name", ctypes.c_char_p), # path to the library 60 | ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers 61 | ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr 62 | ] 63 | 64 | 65 | # The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows. 66 | try: 67 | _RTLD_NOLOAD = os.RTLD_NOLOAD 68 | except AttributeError: 69 | _RTLD_NOLOAD = ctypes.DEFAULT_MODE 70 | 71 | 72 | class LibController(ABC): 73 | """Abstract base class for the individual library controllers 74 | 75 | A library controller must expose the following class attributes: 76 | - user_api : str 77 | Usually the name of the library or generic specification the library 78 | implements, e.g. "blas" is a specification with different implementations. 79 | - internal_api : str 80 | Usually the name of the library or concrete implementation of some 81 | specification, e.g. "openblas" is an implementation of the "blas" 82 | specification. 83 | - filename_prefixes : tuple 84 | Possible prefixes of the shared library's filename that allow to 85 | identify the library. e.g. "libopenblas" for libopenblas.so. 86 | 87 | and implement the following methods: `get_num_threads`, `set_num_threads` and 88 | `get_version`. 89 | 90 | Threadpoolctl loops through all the loaded shared libraries and tries to match 91 | the filename of each library with the `filename_prefixes`. If a match is found, a 92 | controller is instantiated and a handler to the library is stored in the `dynlib` 93 | attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols 94 | of the shared library to implement the above methods. 95 | 96 | The following information will be exposed in the info dictionary: 97 | - user_api : standardized API, if any, or a copy of internal_api. 98 | - internal_api : implementation-specific API. 99 | - num_threads : the current thread limit. 100 | - prefix : prefix of the shared library's filename. 101 | - filepath : path to the loaded shared library. 102 | - version : version of the library (if available). 103 | 104 | In addition, each library controller may expose internal API specific entries. They 105 | must be set as attributes in the `set_additional_attributes` method. 106 | """ 107 | 108 | @final 109 | def __init__(self, *, filepath=None, prefix=None, parent=None): 110 | """This is not meant to be overriden by subclasses.""" 111 | self.parent = parent 112 | self.prefix = prefix 113 | self.filepath = filepath 114 | self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) 115 | self._symbol_prefix, self._symbol_suffix = self._find_affixes() 116 | self.version = self.get_version() 117 | self.set_additional_attributes() 118 | 119 | def info(self): 120 | """Return relevant info wrapped in a dict""" 121 | hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix") 122 | return { 123 | "user_api": self.user_api, 124 | "internal_api": self.internal_api, 125 | "num_threads": self.num_threads, 126 | **{k: v for k, v in vars(self).items() if k not in hidden_attrs}, 127 | } 128 | 129 | def set_additional_attributes(self): 130 | """Set additional attributes meant to be exposed in the info dict""" 131 | 132 | @property 133 | def num_threads(self): 134 | """Exposes the current thread limit as a dynamic property 135 | 136 | This is not meant to be used or overriden by subclasses. 137 | """ 138 | return self.get_num_threads() 139 | 140 | @abstractmethod 141 | def get_num_threads(self): 142 | """Return the maximum number of threads available to use""" 143 | 144 | @abstractmethod 145 | def set_num_threads(self, num_threads): 146 | """Set the maximum number of threads to use""" 147 | 148 | @abstractmethod 149 | def get_version(self): 150 | """Return the version of the shared library""" 151 | 152 | def _find_affixes(self): 153 | """Return the affixes for the symbols of the shared library""" 154 | return "", "" 155 | 156 | def _get_symbol(self, name): 157 | """Return the symbol of the shared library accounding for the affixes""" 158 | return getattr( 159 | self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None 160 | ) 161 | 162 | 163 | class OpenBLASController(LibController): 164 | """Controller class for OpenBLAS""" 165 | 166 | user_api = "blas" 167 | internal_api = "openblas" 168 | filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas") 169 | 170 | _symbol_prefixes = ("", "scipy_") 171 | _symbol_suffixes = ("", "64_", "_64") 172 | 173 | # All variations of "openblas_get_num_threads", accounting for the affixes 174 | check_symbols = tuple( 175 | f"{prefix}openblas_get_num_threads{suffix}" 176 | for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes) 177 | ) 178 | 179 | def _find_affixes(self): 180 | for prefix, suffix in itertools.product( 181 | self._symbol_prefixes, self._symbol_suffixes 182 | ): 183 | if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"): 184 | return prefix, suffix 185 | 186 | def set_additional_attributes(self): 187 | self.threading_layer = self._get_threading_layer() 188 | self.architecture = self._get_architecture() 189 | 190 | def get_num_threads(self): 191 | get_num_threads_func = self._get_symbol("openblas_get_num_threads") 192 | if get_num_threads_func is not None: 193 | return get_num_threads_func() 194 | return None 195 | 196 | def set_num_threads(self, num_threads): 197 | set_num_threads_func = self._get_symbol("openblas_set_num_threads") 198 | if set_num_threads_func is not None: 199 | return set_num_threads_func(num_threads) 200 | return None 201 | 202 | def get_version(self): 203 | # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS 204 | # did not expose its version before that. 205 | get_version_func = self._get_symbol("openblas_get_config") 206 | if get_version_func is not None: 207 | get_version_func.restype = ctypes.c_char_p 208 | config = get_version_func().split() 209 | if config[0] == b"OpenBLAS": 210 | return config[1].decode("utf-8") 211 | return None 212 | return None 213 | 214 | def _get_threading_layer(self): 215 | """Return the threading layer of OpenBLAS""" 216 | get_threading_layer_func = self._get_symbol("openblas_get_parallel") 217 | if get_threading_layer_func is not None: 218 | threading_layer = get_threading_layer_func() 219 | if threading_layer == 2: 220 | return "openmp" 221 | elif threading_layer == 1: 222 | return "pthreads" 223 | return "disabled" 224 | return "unknown" 225 | 226 | def _get_architecture(self): 227 | """Return the architecture detected by OpenBLAS""" 228 | get_architecture_func = self._get_symbol("openblas_get_corename") 229 | if get_architecture_func is not None: 230 | get_architecture_func.restype = ctypes.c_char_p 231 | return get_architecture_func().decode("utf-8") 232 | return None 233 | 234 | 235 | class BLISController(LibController): 236 | """Controller class for BLIS""" 237 | 238 | user_api = "blas" 239 | internal_api = "blis" 240 | filename_prefixes = ("libblis", "libblas") 241 | check_symbols = ( 242 | "bli_thread_get_num_threads", 243 | "bli_thread_set_num_threads", 244 | "bli_info_get_version_str", 245 | "bli_info_get_enable_openmp", 246 | "bli_info_get_enable_pthreads", 247 | "bli_arch_query_id", 248 | "bli_arch_string", 249 | ) 250 | 251 | def set_additional_attributes(self): 252 | self.threading_layer = self._get_threading_layer() 253 | self.architecture = self._get_architecture() 254 | 255 | def get_num_threads(self): 256 | get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None) 257 | num_threads = get_func() 258 | # by default BLIS is single-threaded and get_num_threads 259 | # returns -1. We map it to 1 for consistency with other libraries. 260 | return 1 if num_threads == -1 else num_threads 261 | 262 | def set_num_threads(self, num_threads): 263 | set_func = getattr( 264 | self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None 265 | ) 266 | return set_func(num_threads) 267 | 268 | def get_version(self): 269 | get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None) 270 | if get_version_ is None: 271 | return None 272 | 273 | get_version_.restype = ctypes.c_char_p 274 | return get_version_().decode("utf-8") 275 | 276 | def _get_threading_layer(self): 277 | """Return the threading layer of BLIS""" 278 | if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)(): 279 | return "openmp" 280 | elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)(): 281 | return "pthreads" 282 | return "disabled" 283 | 284 | def _get_architecture(self): 285 | """Return the architecture detected by BLIS""" 286 | bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None) 287 | bli_arch_string = getattr(self.dynlib, "bli_arch_string", None) 288 | if bli_arch_query_id is None or bli_arch_string is None: 289 | return None 290 | 291 | # the true restype should be BLIS' arch_t (enum) but int should work 292 | # for us: 293 | bli_arch_query_id.restype = ctypes.c_int 294 | bli_arch_string.restype = ctypes.c_char_p 295 | return bli_arch_string(bli_arch_query_id()).decode("utf-8") 296 | 297 | 298 | class FlexiBLASController(LibController): 299 | """Controller class for FlexiBLAS""" 300 | 301 | user_api = "blas" 302 | internal_api = "flexiblas" 303 | filename_prefixes = ("libflexiblas",) 304 | check_symbols = ( 305 | "flexiblas_get_num_threads", 306 | "flexiblas_set_num_threads", 307 | "flexiblas_get_version", 308 | "flexiblas_list", 309 | "flexiblas_list_loaded", 310 | "flexiblas_current_backend", 311 | ) 312 | 313 | @property 314 | def loaded_backends(self): 315 | return self._get_backend_list(loaded=True) 316 | 317 | @property 318 | def current_backend(self): 319 | return self._get_current_backend() 320 | 321 | def info(self): 322 | """Return relevant info wrapped in a dict""" 323 | # We override the info method because the loaded and current backends 324 | # are dynamic properties 325 | exposed_attrs = super().info() 326 | exposed_attrs["loaded_backends"] = self.loaded_backends 327 | exposed_attrs["current_backend"] = self.current_backend 328 | 329 | return exposed_attrs 330 | 331 | def set_additional_attributes(self): 332 | self.available_backends = self._get_backend_list(loaded=False) 333 | 334 | def get_num_threads(self): 335 | get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None) 336 | num_threads = get_func() 337 | # by default BLIS is single-threaded and get_num_threads 338 | # returns -1. We map it to 1 for consistency with other libraries. 339 | return 1 if num_threads == -1 else num_threads 340 | 341 | def set_num_threads(self, num_threads): 342 | set_func = getattr( 343 | self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None 344 | ) 345 | return set_func(num_threads) 346 | 347 | def get_version(self): 348 | get_version_ = getattr(self.dynlib, "flexiblas_get_version", None) 349 | if get_version_ is None: 350 | return None 351 | 352 | major = ctypes.c_int() 353 | minor = ctypes.c_int() 354 | patch = ctypes.c_int() 355 | get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)) 356 | return f"{major.value}.{minor.value}.{patch.value}" 357 | 358 | def _get_backend_list(self, loaded=False): 359 | """Return the list of available backends for FlexiBLAS. 360 | 361 | If loaded is False, return the list of available backends from the FlexiBLAS 362 | configuration. If loaded is True, return the list of actually loaded backends. 363 | """ 364 | func_name = f"flexiblas_list{'_loaded' if loaded else ''}" 365 | get_backend_list_ = getattr(self.dynlib, func_name, None) 366 | if get_backend_list_ is None: 367 | return None 368 | 369 | n_backends = get_backend_list_(None, 0, 0) 370 | 371 | backends = [] 372 | for i in range(n_backends): 373 | backend_name = ctypes.create_string_buffer(1024) 374 | get_backend_list_(backend_name, 1024, i) 375 | if backend_name.value.decode("utf-8") != "__FALLBACK__": 376 | # We don't know when to expect __FALLBACK__ but it is not a real 377 | # backend and does not show up when running flexiblas list. 378 | backends.append(backend_name.value.decode("utf-8")) 379 | return backends 380 | 381 | def _get_current_backend(self): 382 | """Return the backend of FlexiBLAS""" 383 | get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None) 384 | if get_backend_ is None: 385 | return None 386 | 387 | backend = ctypes.create_string_buffer(1024) 388 | get_backend_(backend, ctypes.sizeof(backend)) 389 | return backend.value.decode("utf-8") 390 | 391 | def switch_backend(self, backend): 392 | """Switch the backend of FlexiBLAS 393 | 394 | Parameters 395 | ---------- 396 | backend : str 397 | The name or the path to the shared library of the backend to switch to. If 398 | the backend is not already loaded, it will be loaded first. 399 | """ 400 | if backend not in self.loaded_backends: 401 | if backend in self.available_backends: 402 | load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1) 403 | else: # assume backend is a path to a shared library 404 | load_func = getattr( 405 | self.dynlib, "flexiblas_load_backend_library", lambda _: -1 406 | ) 407 | res = load_func(str(backend).encode("utf-8")) 408 | if res == -1: 409 | raise RuntimeError( 410 | f"Failed to load backend {backend!r}. It must either be the name of" 411 | " a backend available in the FlexiBLAS configuration " 412 | f"{self.available_backends} or the path to a valid shared library." 413 | ) 414 | 415 | # Trigger a new search of loaded shared libraries since loading a new 416 | # backend caused a dlopen. 417 | self.parent._load_libraries() 418 | 419 | switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1) 420 | idx = self.loaded_backends.index(backend) 421 | res = switch_func(idx) 422 | if res == -1: 423 | raise RuntimeError(f"Failed to switch to backend {backend!r}.") 424 | 425 | 426 | class MKLController(LibController): 427 | """Controller class for MKL""" 428 | 429 | user_api = "blas" 430 | internal_api = "mkl" 431 | filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas") 432 | check_symbols = ( 433 | "MKL_Get_Max_Threads", 434 | "MKL_Set_Num_Threads", 435 | "MKL_Get_Version_String", 436 | "MKL_Set_Threading_Layer", 437 | ) 438 | 439 | def set_additional_attributes(self): 440 | self.threading_layer = self._get_threading_layer() 441 | 442 | def get_num_threads(self): 443 | get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None) 444 | return get_func() 445 | 446 | def set_num_threads(self, num_threads): 447 | set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None) 448 | return set_func(num_threads) 449 | 450 | def get_version(self): 451 | if not hasattr(self.dynlib, "MKL_Get_Version_String"): 452 | return None 453 | 454 | res = ctypes.create_string_buffer(200) 455 | self.dynlib.MKL_Get_Version_String(res, 200) 456 | 457 | version = res.value.decode("utf-8") 458 | group = re.search(r"Version ([^ ]+) ", version) 459 | if group is not None: 460 | version = group.groups()[0] 461 | return version.strip() 462 | 463 | def _get_threading_layer(self): 464 | """Return the threading layer of MKL""" 465 | # The function mkl_set_threading_layer returns the current threading 466 | # layer. Calling it with an invalid threading layer allows us to safely 467 | # get the threading layer 468 | set_threading_layer = getattr( 469 | self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1 470 | ) 471 | layer_map = { 472 | 0: "intel", 473 | 1: "sequential", 474 | 2: "pgi", 475 | 3: "gnu", 476 | 4: "tbb", 477 | -1: "not specified", 478 | } 479 | return layer_map[set_threading_layer(-1)] 480 | 481 | 482 | class OpenMPController(LibController): 483 | """Controller class for OpenMP""" 484 | 485 | user_api = "openmp" 486 | internal_api = "openmp" 487 | filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp") 488 | check_symbols = ( 489 | "omp_get_max_threads", 490 | "omp_get_num_threads", 491 | ) 492 | 493 | def get_num_threads(self): 494 | get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None) 495 | return get_func() 496 | 497 | def set_num_threads(self, num_threads): 498 | set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None) 499 | return set_func(num_threads) 500 | 501 | def get_version(self): 502 | # There is no way to get the version number programmatically in OpenMP. 503 | return None 504 | 505 | 506 | # Controllers for the libraries that we'll look for in the loaded libraries. 507 | # Third party libraries can register their own controllers. 508 | _ALL_CONTROLLERS = [ 509 | OpenBLASController, 510 | BLISController, 511 | MKLController, 512 | OpenMPController, 513 | FlexiBLASController, 514 | ] 515 | 516 | # Helpers for the doc and test names 517 | _ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS)) 518 | _ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS] 519 | _ALL_PREFIXES = list( 520 | set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes) 521 | ) 522 | _ALL_BLAS_LIBRARIES = [ 523 | lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas" 524 | ] 525 | _ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes 526 | 527 | 528 | def register(controller): 529 | """Register a new controller""" 530 | _ALL_CONTROLLERS.append(controller) 531 | _ALL_USER_APIS.append(controller.user_api) 532 | _ALL_INTERNAL_APIS.append(controller.internal_api) 533 | _ALL_PREFIXES.extend(controller.filename_prefixes) 534 | 535 | 536 | def _format_docstring(*args, **kwargs): 537 | def decorator(o): 538 | if o.__doc__ is not None: 539 | o.__doc__ = o.__doc__.format(*args, **kwargs) 540 | return o 541 | 542 | return decorator 543 | 544 | 545 | @lru_cache(maxsize=10000) 546 | def _realpath(filepath): 547 | """Small caching wrapper around os.path.realpath to limit system calls""" 548 | return os.path.realpath(filepath) 549 | 550 | 551 | @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) 552 | def threadpool_info(): 553 | """Return the maximal number of threads for each detected library. 554 | 555 | Return a list with all the supported libraries that have been found. Each 556 | library is represented by a dict with the following information: 557 | 558 | - "user_api" : user API. Possible values are {USER_APIS}. 559 | - "internal_api": internal API. Possible values are {INTERNAL_APIS}. 560 | - "prefix" : filename prefix of the specific implementation. 561 | - "filepath": path to the loaded library. 562 | - "version": version of the library (if available). 563 | - "num_threads": the current thread limit. 564 | 565 | In addition, each library may contain internal_api specific entries. 566 | """ 567 | return ThreadpoolController().info() 568 | 569 | 570 | class _ThreadpoolLimiter: 571 | """The guts of ThreadpoolController.limit 572 | 573 | Refer to the docstring of ThreadpoolController.limit for more details. 574 | 575 | It will only act on the library controllers held by the provided `controller`. 576 | Using the default constructor sets the limits right away such that it can be used as 577 | a callable. Setting the limits can be delayed by using the `wrap` class method such 578 | that it can be used as a decorator. 579 | """ 580 | 581 | def __init__(self, controller, *, limits=None, user_api=None): 582 | self._controller = controller 583 | self._limits, self._user_api, self._prefixes = self._check_params( 584 | limits, user_api 585 | ) 586 | self._original_info = self._controller.info() 587 | self._set_threadpool_limits() 588 | 589 | def __enter__(self): 590 | return self 591 | 592 | def __exit__(self, type, value, traceback): 593 | self.restore_original_limits() 594 | 595 | @classmethod 596 | def wrap(cls, controller, *, limits=None, user_api=None): 597 | """Return an instance of this class that can be used as a decorator""" 598 | return _ThreadpoolLimiterDecorator( 599 | controller=controller, limits=limits, user_api=user_api 600 | ) 601 | 602 | def restore_original_limits(self): 603 | """Set the limits back to their original values""" 604 | for lib_controller, original_info in zip( 605 | self._controller.lib_controllers, self._original_info 606 | ): 607 | lib_controller.set_num_threads(original_info["num_threads"]) 608 | 609 | # Alias of `restore_original_limits` for backward compatibility 610 | unregister = restore_original_limits 611 | 612 | def get_original_num_threads(self): 613 | """Original num_threads from before calling threadpool_limits 614 | 615 | Return a dict `{user_api: num_threads}`. 616 | """ 617 | num_threads = {} 618 | warning_apis = [] 619 | 620 | for user_api in self._user_api: 621 | limits = [ 622 | lib_info["num_threads"] 623 | for lib_info in self._original_info 624 | if lib_info["user_api"] == user_api 625 | ] 626 | limits = set(limits) 627 | n_limits = len(limits) 628 | 629 | if n_limits == 1: 630 | limit = limits.pop() 631 | elif n_limits == 0: 632 | limit = None 633 | else: 634 | limit = min(limits) 635 | warning_apis.append(user_api) 636 | 637 | num_threads[user_api] = limit 638 | 639 | if warning_apis: 640 | warnings.warn( 641 | "Multiple value possible for following user apis: " 642 | + ", ".join(warning_apis) 643 | + ". Returning the minimum." 644 | ) 645 | 646 | return num_threads 647 | 648 | def _check_params(self, limits, user_api): 649 | """Suitable values for the _limits, _user_api and _prefixes attributes""" 650 | 651 | if isinstance(limits, str) and limits == "sequential_blas_under_openmp": 652 | ( 653 | limits, 654 | user_api, 655 | ) = self._controller._get_params_for_sequential_blas_under_openmp().values() 656 | 657 | if limits is None or isinstance(limits, int): 658 | if user_api is None: 659 | user_api = _ALL_USER_APIS 660 | elif user_api in _ALL_USER_APIS: 661 | user_api = [user_api] 662 | else: 663 | raise ValueError( 664 | f"user_api must be either in {_ALL_USER_APIS} or None. Got " 665 | f"{user_api} instead." 666 | ) 667 | 668 | if limits is not None: 669 | limits = {api: limits for api in user_api} 670 | prefixes = [] 671 | else: 672 | if isinstance(limits, list): 673 | # This should be a list of dicts of library info, for 674 | # compatibility with the result from threadpool_info. 675 | limits = { 676 | lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits 677 | } 678 | elif isinstance(limits, ThreadpoolController): 679 | # To set the limits from the library controllers of a 680 | # ThreadpoolController object. 681 | limits = { 682 | lib_controller.prefix: lib_controller.num_threads 683 | for lib_controller in limits.lib_controllers 684 | } 685 | 686 | if not isinstance(limits, dict): 687 | raise TypeError( 688 | "limits must either be an int, a list, a dict, or " 689 | f"'sequential_blas_under_openmp'. Got {type(limits)} instead" 690 | ) 691 | 692 | # With a dictionary, can set both specific limit for given 693 | # libraries and global limit for user_api. Fetch each separately. 694 | prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES] 695 | user_api = [api for api in limits if api in _ALL_USER_APIS] 696 | 697 | return limits, user_api, prefixes 698 | 699 | def _set_threadpool_limits(self): 700 | """Change the maximal number of threads in selected thread pools. 701 | 702 | Return a list with all the supported libraries that have been found 703 | matching `self._prefixes` and `self._user_api`. 704 | """ 705 | if self._limits is None: 706 | return 707 | 708 | for lib_controller in self._controller.lib_controllers: 709 | # self._limits is a dict {key: num_threads} where key is either 710 | # a prefix or a user_api. If a library matches both, the limit 711 | # corresponding to the prefix is chosen. 712 | if lib_controller.prefix in self._limits: 713 | num_threads = self._limits[lib_controller.prefix] 714 | elif lib_controller.user_api in self._limits: 715 | num_threads = self._limits[lib_controller.user_api] 716 | else: 717 | continue 718 | 719 | if num_threads is not None: 720 | lib_controller.set_num_threads(num_threads) 721 | 722 | 723 | class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): 724 | """Same as _ThreadpoolLimiter but to be used as a decorator""" 725 | 726 | def __init__(self, controller, *, limits=None, user_api=None): 727 | self._limits, self._user_api, self._prefixes = self._check_params( 728 | limits, user_api 729 | ) 730 | self._controller = controller 731 | 732 | def __enter__(self): 733 | # we need to set the limits here and not in the __init__ because we want the 734 | # limits to be set when calling the decorated function, not when creating the 735 | # decorator. 736 | self._original_info = self._controller.info() 737 | self._set_threadpool_limits() 738 | return self 739 | 740 | 741 | @_format_docstring( 742 | USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS), 743 | BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 744 | OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 745 | ) 746 | class threadpool_limits(_ThreadpoolLimiter): 747 | """Change the maximal number of threads that can be used in thread pools. 748 | 749 | This object can be used either as a callable (the construction of this object 750 | limits the number of threads), as a context manager in a `with` block to 751 | automatically restore the original state of the controlled libraries when exiting 752 | the block, or as a decorator through its `wrap` method. 753 | 754 | Set the maximal number of threads that can be used in thread pools used in 755 | the supported libraries to `limit`. This function works for libraries that 756 | are already loaded in the interpreter and can be changed dynamically. 757 | 758 | This effect is global and impacts the whole Python process. There is no thread level 759 | isolation as these libraries do not offer thread-local APIs to configure the number 760 | of threads to use in nested parallel calls. 761 | 762 | Parameters 763 | ---------- 764 | limits : int, dict, 'sequential_blas_under_openmp' or None (default=None) 765 | The maximal number of threads that can be used in thread pools 766 | 767 | - If int, sets the maximum number of threads to `limits` for each 768 | library selected by `user_api`. 769 | 770 | - If it is a dictionary `{{key: max_threads}}`, this function sets a 771 | custom maximum number of threads for each `key` which can be either a 772 | `user_api` or a `prefix` for a specific library. 773 | 774 | - If 'sequential_blas_under_openmp', it will chose the appropriate `limits` 775 | and `user_api` parameters for the specific use case of sequential BLAS 776 | calls within an OpenMP parallel region. The `user_api` parameter is 777 | ignored. 778 | 779 | - If None, this function does not do anything. 780 | 781 | user_api : {USER_APIS} or None (default=None) 782 | APIs of libraries to limit. Used only if `limits` is an int. 783 | 784 | - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 785 | 786 | - If "openmp", it will only limit OpenMP supported libraries 787 | ({OPENMP_LIBS}). Note that it can affect the number of threads used 788 | by the BLAS libraries if they rely on OpenMP. 789 | 790 | - If None, this function will apply to all supported libraries. 791 | """ 792 | 793 | def __init__(self, limits=None, user_api=None): 794 | super().__init__(ThreadpoolController(), limits=limits, user_api=user_api) 795 | 796 | @classmethod 797 | def wrap(cls, limits=None, user_api=None): 798 | return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api) 799 | 800 | 801 | class ThreadpoolController: 802 | """Collection of LibController objects for all loaded supported libraries 803 | 804 | Attributes 805 | ---------- 806 | lib_controllers : list of `LibController` objects 807 | The list of library controllers of all loaded supported libraries. 808 | """ 809 | 810 | # Cache for libc under POSIX and a few system libraries under Windows. 811 | # We use a class level cache instead of an instance level cache because 812 | # it's very unlikely that a shared library will be unloaded and reloaded 813 | # during the lifetime of a program. 814 | _system_libraries = dict() 815 | 816 | def __init__(self): 817 | self.lib_controllers = [] 818 | self._load_libraries() 819 | self._warn_if_incompatible_openmp() 820 | 821 | @classmethod 822 | def _from_controllers(cls, lib_controllers): 823 | new_controller = cls.__new__(cls) 824 | new_controller.lib_controllers = lib_controllers 825 | return new_controller 826 | 827 | def info(self): 828 | """Return lib_controllers info as a list of dicts""" 829 | return [lib_controller.info() for lib_controller in self.lib_controllers] 830 | 831 | def select(self, **kwargs): 832 | """Return a ThreadpoolController containing a subset of its current 833 | library controllers 834 | 835 | It will select all libraries matching at least one pair (key, value) from kwargs 836 | where key is an entry of the library info dict (like "user_api", "internal_api", 837 | "prefix", ...) and value is the value or a list of acceptable values for that 838 | entry. 839 | 840 | For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])` 841 | will select all library controllers whose internal_api is either "blis" or 842 | "openblas". 843 | """ 844 | for key, vals in kwargs.items(): 845 | kwargs[key] = [vals] if not isinstance(vals, list) else vals 846 | 847 | lib_controllers = [ 848 | lib_controller 849 | for lib_controller in self.lib_controllers 850 | if any( 851 | getattr(lib_controller, key, None) in vals 852 | for key, vals in kwargs.items() 853 | ) 854 | ] 855 | 856 | return ThreadpoolController._from_controllers(lib_controllers) 857 | 858 | def _get_params_for_sequential_blas_under_openmp(self): 859 | """Return appropriate params to use for a sequential BLAS call in an OpenMP loop 860 | 861 | This function takes into account the unexpected behavior of OpenBLAS with the 862 | OpenMP threading layer. 863 | """ 864 | if self.select( 865 | internal_api="openblas", threading_layer="openmp" 866 | ).lib_controllers: 867 | return {"limits": None, "user_api": None} 868 | return {"limits": 1, "user_api": "blas"} 869 | 870 | @_format_docstring( 871 | USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), 872 | BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 873 | OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 874 | ) 875 | def limit(self, *, limits=None, user_api=None): 876 | """Change the maximal number of threads that can be used in thread pools. 877 | 878 | This function returns an object that can be used either as a callable (the 879 | construction of this object limits the number of threads) or as a context 880 | manager, in a `with` block to automatically restore the original state of the 881 | controlled libraries when exiting the block. 882 | 883 | Set the maximal number of threads that can be used in thread pools used in 884 | the supported libraries to `limits`. This function works for libraries that 885 | are already loaded in the interpreter and can be changed dynamically. 886 | 887 | This effect is global and impacts the whole Python process. There is no thread 888 | level isolation as these libraries do not offer thread-local APIs to configure 889 | the number of threads to use in nested parallel calls. 890 | 891 | Parameters 892 | ---------- 893 | limits : int, dict, 'sequential_blas_under_openmp' or None (default=None) 894 | The maximal number of threads that can be used in thread pools 895 | 896 | - If int, sets the maximum number of threads to `limits` for each 897 | library selected by `user_api`. 898 | 899 | - If it is a dictionary `{{key: max_threads}}`, this function sets a 900 | custom maximum number of threads for each `key` which can be either a 901 | `user_api` or a `prefix` for a specific library. 902 | 903 | - If 'sequential_blas_under_openmp', it will chose the appropriate `limits` 904 | and `user_api` parameters for the specific use case of sequential BLAS 905 | calls within an OpenMP parallel region. The `user_api` parameter is 906 | ignored. 907 | 908 | - If None, this function does not do anything. 909 | 910 | user_api : {USER_APIS} or None (default=None) 911 | APIs of libraries to limit. Used only if `limits` is an int. 912 | 913 | - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 914 | 915 | - If "openmp", it will only limit OpenMP supported libraries 916 | ({OPENMP_LIBS}). Note that it can affect the number of threads used 917 | by the BLAS libraries if they rely on OpenMP. 918 | 919 | - If None, this function will apply to all supported libraries. 920 | """ 921 | return _ThreadpoolLimiter(self, limits=limits, user_api=user_api) 922 | 923 | @_format_docstring( 924 | USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), 925 | BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 926 | OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 927 | ) 928 | def wrap(self, *, limits=None, user_api=None): 929 | """Change the maximal number of threads that can be used in thread pools. 930 | 931 | This function returns an object that can be used as a decorator. 932 | 933 | Set the maximal number of threads that can be used in thread pools used in 934 | the supported libraries to `limits`. This function works for libraries that 935 | are already loaded in the interpreter and can be changed dynamically. 936 | 937 | Parameters 938 | ---------- 939 | limits : int, dict or None (default=None) 940 | The maximal number of threads that can be used in thread pools 941 | 942 | - If int, sets the maximum number of threads to `limits` for each 943 | library selected by `user_api`. 944 | 945 | - If it is a dictionary `{{key: max_threads}}`, this function sets a 946 | custom maximum number of threads for each `key` which can be either a 947 | `user_api` or a `prefix` for a specific library. 948 | 949 | - If None, this function does not do anything. 950 | 951 | user_api : {USER_APIS} or None (default=None) 952 | APIs of libraries to limit. Used only if `limits` is an int. 953 | 954 | - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 955 | 956 | - If "openmp", it will only limit OpenMP supported libraries 957 | ({OPENMP_LIBS}). Note that it can affect the number of threads used 958 | by the BLAS libraries if they rely on OpenMP. 959 | 960 | - If None, this function will apply to all supported libraries. 961 | """ 962 | return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api) 963 | 964 | def __len__(self): 965 | return len(self.lib_controllers) 966 | 967 | def _load_libraries(self): 968 | """Loop through loaded shared libraries and store the supported ones""" 969 | if sys.platform == "darwin": 970 | self._find_libraries_with_dyld() 971 | elif sys.platform == "win32": 972 | self._find_libraries_with_enum_process_module_ex() 973 | elif "pyodide" in sys.modules: 974 | self._find_libraries_pyodide() 975 | else: 976 | self._find_libraries_with_dl_iterate_phdr() 977 | 978 | def _find_libraries_with_dl_iterate_phdr(self): 979 | """Loop through loaded libraries and return binders on supported ones 980 | 981 | This function is expected to work on POSIX system only. 982 | This code is adapted from code by Intel developer @anton-malakhov 983 | available at https://github.com/IntelPython/smp 984 | 985 | Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause 986 | license 987 | """ 988 | libc = self._get_libc() 989 | if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover 990 | warnings.warn( 991 | "Could not find dl_iterate_phdr in the C standard library.", 992 | RuntimeWarning, 993 | ) 994 | return [] 995 | 996 | # Callback function for `dl_iterate_phdr` which is called for every 997 | # library loaded in the current process until it returns 1. 998 | def match_library_callback(info, size, data): 999 | # Get the path of the current library 1000 | filepath = info.contents.dlpi_name 1001 | if filepath: 1002 | filepath = filepath.decode("utf-8") 1003 | 1004 | # Store the library controller if it is supported and selected 1005 | self._make_controller_from_path(filepath) 1006 | return 0 1007 | 1008 | c_func_signature = ctypes.CFUNCTYPE( 1009 | ctypes.c_int, # Return type 1010 | ctypes.POINTER(_dl_phdr_info), 1011 | ctypes.c_size_t, 1012 | ctypes.c_char_p, 1013 | ) 1014 | c_match_library_callback = c_func_signature(match_library_callback) 1015 | 1016 | data = ctypes.c_char_p(b"") 1017 | libc.dl_iterate_phdr(c_match_library_callback, data) 1018 | 1019 | def _find_libraries_with_dyld(self): 1020 | """Loop through loaded libraries and return binders on supported ones 1021 | 1022 | This function is expected to work on OSX system only 1023 | """ 1024 | libc = self._get_libc() 1025 | if not hasattr(libc, "_dyld_image_count"): # pragma: no cover 1026 | warnings.warn( 1027 | "Could not find _dyld_image_count in the C standard library.", 1028 | RuntimeWarning, 1029 | ) 1030 | return [] 1031 | 1032 | n_dyld = libc._dyld_image_count() 1033 | libc._dyld_get_image_name.restype = ctypes.c_char_p 1034 | 1035 | for i in range(n_dyld): 1036 | filepath = ctypes.string_at(libc._dyld_get_image_name(i)) 1037 | filepath = filepath.decode("utf-8") 1038 | 1039 | # Store the library controller if it is supported and selected 1040 | self._make_controller_from_path(filepath) 1041 | 1042 | def _find_libraries_with_enum_process_module_ex(self): 1043 | """Loop through loaded libraries and return binders on supported ones 1044 | 1045 | This function is expected to work on windows system only. 1046 | This code is adapted from code by Philipp Hagemeister @phihag available 1047 | at https://stackoverflow.com/questions/17474574 1048 | """ 1049 | from ctypes.wintypes import DWORD, HMODULE, MAX_PATH 1050 | 1051 | PROCESS_QUERY_INFORMATION = 0x0400 1052 | PROCESS_VM_READ = 0x0010 1053 | 1054 | LIST_LIBRARIES_ALL = 0x03 1055 | 1056 | ps_api = self._get_windll("Psapi") 1057 | kernel_32 = self._get_windll("kernel32") 1058 | 1059 | h_process = kernel_32.OpenProcess( 1060 | PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid() 1061 | ) 1062 | if not h_process: # pragma: no cover 1063 | raise OSError(f"Could not open PID {os.getpid()}") 1064 | 1065 | try: 1066 | buf_count = 256 1067 | needed = DWORD() 1068 | # Grow the buffer until it becomes large enough to hold all the 1069 | # module headers 1070 | while True: 1071 | buf = (HMODULE * buf_count)() 1072 | buf_size = ctypes.sizeof(buf) 1073 | if not ps_api.EnumProcessModulesEx( 1074 | h_process, 1075 | ctypes.byref(buf), 1076 | buf_size, 1077 | ctypes.byref(needed), 1078 | LIST_LIBRARIES_ALL, 1079 | ): 1080 | raise OSError("EnumProcessModulesEx failed") 1081 | if buf_size >= needed.value: 1082 | break 1083 | buf_count = needed.value // (buf_size // buf_count) 1084 | 1085 | count = needed.value // (buf_size // buf_count) 1086 | h_modules = map(HMODULE, buf[:count]) 1087 | 1088 | # Loop through all the module headers and get the library path 1089 | # Allocate a buffer for the path 10 times the size of MAX_PATH to take 1090 | # into account long path names. 1091 | max_path = 10 * MAX_PATH 1092 | buf = ctypes.create_unicode_buffer(max_path) 1093 | n_size = DWORD() 1094 | for h_module in h_modules: 1095 | # Get the path of the current module 1096 | if not ps_api.GetModuleFileNameExW( 1097 | h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size) 1098 | ): 1099 | raise OSError("GetModuleFileNameEx failed") 1100 | filepath = buf.value 1101 | 1102 | if len(filepath) == max_path: # pragma: no cover 1103 | warnings.warn( 1104 | "Could not get the full path of a dynamic library (path too " 1105 | "long). This library will be ignored and threadpoolctl might " 1106 | "not be able to control or display information about all " 1107 | f"loaded libraries. Here's the truncated path: {filepath!r}", 1108 | RuntimeWarning, 1109 | ) 1110 | else: 1111 | # Store the library controller if it is supported and selected 1112 | self._make_controller_from_path(filepath) 1113 | finally: 1114 | kernel_32.CloseHandle(h_process) 1115 | 1116 | def _find_libraries_pyodide(self): 1117 | """Pyodide specific implementation for finding loaded libraries. 1118 | 1119 | Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449. 1120 | 1121 | One day, we may have a simpler solution. libc dl_iterate_phdr needs to 1122 | be implemented in Emscripten and exposed in Pyodide, see 1123 | https://github.com/emscripten-core/emscripten/issues/21354 for more 1124 | details. 1125 | """ 1126 | try: 1127 | from pyodide_js._module import LDSO 1128 | except ImportError: 1129 | warnings.warn( 1130 | "Unable to import LDSO from pyodide_js._module. This should never " 1131 | "happen." 1132 | ) 1133 | return 1134 | 1135 | for filepath in LDSO.loadedLibsByName.as_object_map(): 1136 | # Some libraries are duplicated by Pyodide and do not exist in the 1137 | # filesystem, so we first check for the existence of the file. For 1138 | # more details, see 1139 | # https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1947946728 1140 | if os.path.exists(filepath): 1141 | self._make_controller_from_path(filepath) 1142 | 1143 | def _make_controller_from_path(self, filepath): 1144 | """Store a library controller if it is supported and selected""" 1145 | # Required to resolve symlinks 1146 | filepath = _realpath(filepath) 1147 | # `lower` required to take account of OpenMP dll case on Windows 1148 | # (vcomp, VCOMP, Vcomp, ...) 1149 | filename = os.path.basename(filepath).lower() 1150 | 1151 | # Loop through supported libraries to find if this filename corresponds 1152 | # to a supported one. 1153 | for controller_class in _ALL_CONTROLLERS: 1154 | # check if filename matches a supported prefix 1155 | prefix = self._check_prefix(filename, controller_class.filename_prefixes) 1156 | 1157 | # filename does not match any of the prefixes of the candidate 1158 | # library. move to next library. 1159 | if prefix is None: 1160 | continue 1161 | 1162 | # workaround for BLAS libraries packaged by conda-forge on windows, which 1163 | # are all renamed "libblas.dll". We thus have to check to which BLAS 1164 | # implementation it actually corresponds looking for implementation 1165 | # specific symbols. 1166 | if prefix == "libblas": 1167 | if filename.endswith(".dll"): 1168 | libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD) 1169 | if not any( 1170 | hasattr(libblas, func) 1171 | for func in controller_class.check_symbols 1172 | ): 1173 | continue 1174 | else: 1175 | # We ignore libblas on other platforms than windows because there 1176 | # might be a libblas dso comming with openblas for instance that 1177 | # can't be used to instantiate a pertinent LibController (many 1178 | # symbols are missing) and would create confusion by making a 1179 | # duplicate entry in threadpool_info. 1180 | continue 1181 | 1182 | # filename matches a prefix. Now we check if the library has the symbols we 1183 | # are looking for. If none of the symbols exists, it's very likely not the 1184 | # expected library (e.g. a library having a common prefix with one of the 1185 | # our supported libraries). Otherwise, create and store the library 1186 | # controller. 1187 | lib_controller = controller_class( 1188 | filepath=filepath, prefix=prefix, parent=self 1189 | ) 1190 | 1191 | if filepath in (lib.filepath for lib in self.lib_controllers): 1192 | # We already have a controller for this library. 1193 | continue 1194 | 1195 | if not hasattr(controller_class, "check_symbols") or any( 1196 | hasattr(lib_controller.dynlib, func) 1197 | for func in controller_class.check_symbols 1198 | ): 1199 | self.lib_controllers.append(lib_controller) 1200 | 1201 | def _check_prefix(self, library_basename, filename_prefixes): 1202 | """Return the prefix library_basename starts with 1203 | 1204 | Return None if none matches. 1205 | """ 1206 | for prefix in filename_prefixes: 1207 | if library_basename.startswith(prefix): 1208 | return prefix 1209 | return None 1210 | 1211 | def _warn_if_incompatible_openmp(self): 1212 | """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded""" 1213 | prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers] 1214 | msg = textwrap.dedent( 1215 | """ 1216 | Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at 1217 | the same time. Both libraries are known to be incompatible and this 1218 | can cause random crashes or deadlocks on Linux when loaded in the 1219 | same Python program. 1220 | Using threadpoolctl may cause crashes or deadlocks. For more 1221 | information and possible workarounds, please see 1222 | https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md 1223 | """ 1224 | ) 1225 | if "libomp" in prefixes and "libiomp" in prefixes: 1226 | warnings.warn(msg, RuntimeWarning) 1227 | 1228 | @classmethod 1229 | def _get_libc(cls): 1230 | """Load the lib-C for unix systems.""" 1231 | libc = cls._system_libraries.get("libc") 1232 | if libc is None: 1233 | # Remark: If libc is statically linked or if Python is linked against an 1234 | # alternative implementation of libc like musl, find_library will return 1235 | # None and CDLL will load the main program itself which should contain the 1236 | # libc symbols. We still name it libc for convenience. 1237 | # If the main program does not contain the libc symbols, it's ok because 1238 | # we check their presence later anyway. 1239 | libc = ctypes.CDLL(find_library("c"), mode=_RTLD_NOLOAD) 1240 | cls._system_libraries["libc"] = libc 1241 | return libc 1242 | 1243 | @classmethod 1244 | def _get_windll(cls, dll_name): 1245 | """Load a windows DLL""" 1246 | dll = cls._system_libraries.get(dll_name) 1247 | if dll is None: 1248 | dll = ctypes.WinDLL(f"{dll_name}.dll") 1249 | cls._system_libraries[dll_name] = dll 1250 | return dll 1251 | 1252 | 1253 | def _main(): 1254 | """Commandline interface to display thread-pool information and exit.""" 1255 | import argparse 1256 | import importlib 1257 | import json 1258 | import sys 1259 | 1260 | parser = argparse.ArgumentParser( 1261 | usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", 1262 | description="Display thread-pool information and exit.", 1263 | ) 1264 | parser.add_argument( 1265 | "-i", 1266 | "--import", 1267 | dest="modules", 1268 | nargs="*", 1269 | default=(), 1270 | help="Python modules to import before introspecting thread-pools.", 1271 | ) 1272 | parser.add_argument( 1273 | "-c", 1274 | "--command", 1275 | help="a Python statement to execute before introspecting thread-pools.", 1276 | ) 1277 | 1278 | options = parser.parse_args(sys.argv[1:]) 1279 | for module in options.modules: 1280 | try: 1281 | importlib.import_module(module, package=None) 1282 | except ImportError: 1283 | print("WARNING: could not import", module, file=sys.stderr) 1284 | 1285 | if options.command: 1286 | exec(options.command) 1287 | 1288 | print(json.dumps(threadpool_info(), indent=2)) 1289 | 1290 | 1291 | if __name__ == "__main__": 1292 | _main() 1293 | --------------------------------------------------------------------------------