├── .editorconfig ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── dev.yml │ ├── pre-commit-autoupdate.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── api.md ├── contributing.md ├── index.md ├── installation.md ├── logo.png └── usage.md ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── requests_oauth2client ├── __init__.py ├── api_client.py ├── auth.py ├── authorization_request.py ├── backchannel_authentication.py ├── client.py ├── client_authentication.py ├── device_authorization.py ├── discovery.py ├── dpop.py ├── exceptions.py ├── flask │ ├── __init__.py │ └── auth.py ├── pooling.py ├── py.typed ├── tokens.py ├── utils.py └── vendor_specific │ ├── __init__.py │ ├── auth0.py │ └── ping.py ├── tests ├── .coveragerc ├── __init__.py ├── conftest.py ├── test_authorization_code.py ├── test_client_credentials.py ├── test_device_authorization.py ├── test_examples.py ├── test_oidc.py ├── test_refresh_token.py ├── test_token_exchange.py └── unit_tests │ ├── __init__.py │ ├── conftest.py │ ├── test_api_client.py │ ├── test_auth.py │ ├── test_authorization_request.py │ ├── test_backchannel_authentication.py │ ├── test_client.py │ ├── test_client_authentication.py │ ├── test_device_authorization.py │ ├── test_discovery.py │ ├── test_dpop.py │ ├── test_flask.py │ ├── test_oidc.py │ ├── test_pkce.py │ ├── test_tokens.py │ ├── test_utils.py │ └── vendor_specific │ ├── __init__.py │ ├── test_auth0.py │ └── test_ping.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | - `requests_oauth2client` version: 2 | - Python version: 3 | - Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the code you ran and the output. 14 | If there was a crash, please include the full traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | # Check for updates to GitHub Actions every week 8 | interval: "weekly" 9 | 10 | - package-ecosystem: "pip" 11 | directory: "/" 12 | schedule: 13 | # Check for updates to GitHub Actions every week 14 | interval: "weekly" 15 | -------------------------------------------------------------------------------- /.github/workflows/dev.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: dev workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the master branch 8 | push: 9 | branches: [ main,release ] 10 | pull_request: 11 | branches: [ main,release ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 17 | jobs: 18 | # This workflow contains a single job called "build" 19 | test: 20 | # The type of runner that the job will run on 21 | strategy: 22 | matrix: 23 | python-versions: ['3.9', '3.10', '3.11', '3.12', '3.13'] 24 | os: [ubuntu-latest] 25 | runs-on: ${{ matrix.os }} 26 | 27 | # Steps represent a sequence of tasks that will be executed as part of the job 28 | steps: 29 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 30 | - uses: actions/checkout@v4 31 | - uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python-versions }} 34 | allow-prereleases: true 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install poetry tox tox-gh-actions 40 | 41 | - name: test with tox 42 | run: 43 | tox 44 | 45 | - name: Build documentation 46 | run: | 47 | poetry install 48 | poetry run mkdocs build 49 | 50 | - name: list files 51 | run: ls -l . 52 | 53 | publish_dev_build: 54 | # if test failed, we should not publish 55 | needs: test 56 | runs-on: ubuntu-latest 57 | steps: 58 | - uses: actions/checkout@v4 59 | - uses: actions/setup-python@v5 60 | with: 61 | python-version: '3.10' 62 | 63 | - name: Install dependencies 64 | run: | 65 | python -m pip install --upgrade pip 66 | pip install poetry tox tox-gh-actions 67 | 68 | - name: test with tox 69 | run: 70 | tox 71 | 72 | - name: list files 73 | run: ls -l . 74 | 75 | - uses: codecov/codecov-action@v5 76 | with: 77 | token: ${{ secrets.CODECOV_TOKEN }} 78 | fail_ci_if_error: true 79 | files: coverage.xml 80 | verbose: true 81 | 82 | - name: Build wheels and source tarball 83 | run: | 84 | poetry version $(poetry version --short)-dev.$GITHUB_RUN_NUMBER 85 | poetry version --short 86 | poetry build 87 | 88 | - name: publish to Test PyPI 89 | uses: pypa/gh-action-pypi-publish@release/v1 90 | with: 91 | user: __token__ 92 | password: ${{ secrets.TEST_PYPI_API_TOKEN}} 93 | repository-url: https://test.pypi.org/legacy/ 94 | skip-existing: true 95 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-autoupdate.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit auto-update 2 | 3 | on: 4 | schedule: 5 | - cron: '0 2 * * 1' 6 | 7 | permissions: 8 | contents: write 9 | pull-requests: write 10 | 11 | jobs: 12 | auto-update: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: 3.12 21 | 22 | - name: Install pre-commit 23 | run: pip install pre-commit 24 | 25 | - name: Run pre-commit autoupdate 26 | run: pre-commit autoupdate 27 | 28 | - name: Create Pull Request 29 | uses: peter-evans/create-pull-request@v7 30 | with: 31 | token: ${{ secrets.CPR_GITHUB_TOKEN }} 32 | branch: update/pre-commit-autoupdate 33 | title: Auto-update pre-commit hooks 34 | commit-message: Auto-update pre-commit hooks 35 | body: | 36 | Update versions of tools in pre-commit 37 | configs to latest version 38 | labels: dependencies 39 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # Publish package on release branch if it's tagged with 'v*' 2 | 3 | name: release & publish workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the master branch 8 | push: 9 | branches: [ release ] 10 | tags: 11 | - 'v*' 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 17 | jobs: 18 | # This workflow contains a single job called "build" 19 | release: 20 | name: Create Release 21 | runs-on: ubuntu-latest 22 | 23 | strategy: 24 | matrix: 25 | python-versions: ['3.12'] 26 | 27 | # Steps represent a sequence of tasks that will be executed as part of the job 28 | steps: 29 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 30 | - uses: actions/checkout@v4 31 | 32 | - name: Generate Changelog 33 | if: ${{ false }} 34 | uses: heinrichreimer/github-changelog-generator-action@v2.4 35 | with: 36 | token: ${{ secrets.GITHUB_TOKEN }} 37 | issues: true 38 | issuesWoLabels: true 39 | pullRequests: true 40 | prWoLabels: true 41 | unreleased: true 42 | addSections: '{"documentation":{"prefix":"**Documentation:**","labels":["documentation"]}}' 43 | output: CHANGELOG.md 44 | 45 | - uses: actions/setup-python@v5 46 | with: 47 | python-version: ${{ matrix.python-versions }} 48 | 49 | - name: Install dependencies 50 | run: | 51 | python -m pip install --upgrade pip 52 | pip install tox-gh-actions poetry 53 | 54 | - name: Build documentation 55 | run: | 56 | poetry install 57 | poetry run mkdocs build 58 | 59 | - name: Publish documentation 60 | uses: peaceiris/actions-gh-pages@v4 61 | with: 62 | personal_token: ${{ secrets.PERSONAL_TOKEN }} 63 | publish_dir: ./site 64 | 65 | - name: Build wheels and source tarball 66 | run: >- 67 | poetry build 68 | 69 | - name: Show temporary files 70 | run: >- 71 | ls -l 72 | 73 | - name: Create GitHub release 74 | id: create_release 75 | uses: softprops/action-gh-release@v2 76 | env: 77 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 78 | with: 79 | files: dist/*.whl 80 | draft: false 81 | prerelease: false 82 | 83 | - name: Publish on PyPi 84 | uses: pypa/gh-action-pypi-publish@release/v1 85 | with: 86 | user: __token__ 87 | password: ${{ secrets.PYPI_API_TOKEN }} 88 | skip_existing: true 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | .idea/ 107 | 108 | # mkdocs build dir 109 | site/ 110 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-yaml 7 | args: [--unsafe] 8 | - id: no-commit-to-branch 9 | - id: end-of-file-fixer 10 | - repo: https://github.com/pre-commit/pygrep-hooks 11 | rev: v1.10.0 12 | hooks: 13 | - id: python-use-type-annotations 14 | - id: text-unicode-replacement-char 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | rev: v0.11.7 17 | hooks: 18 | - id: ruff-format 19 | - id: ruff 20 | args: [ --fix ] 21 | - id: ruff-format 22 | - repo: https://github.com/pre-commit/mirrors-mypy 23 | rev: v1.15.0 24 | hooks: 25 | - id: mypy 26 | args: 27 | - --strict 28 | - --show-error-codes 29 | - --show-error-context 30 | - --show-column-numbers 31 | additional_dependencies: 32 | - attrs 33 | - pytest_examples 34 | - pytest-mock 35 | - pytest-mypy 36 | - pytest-freezer 37 | - jwskate 38 | - types-requests 39 | - requests_mock 40 | - flask 41 | - furl 42 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome, and they are greatly appreciated! Every little bit 4 | helps, and credit will always be given. 5 | 6 | You can contribute in many ways: 7 | 8 | ## Types of Contributions 9 | 10 | ### Report Bugs 11 | 12 | Report bugs at https://github.com/guillp/requests_oauth2client/issues. 13 | 14 | If you are reporting a bug, please include: 15 | 16 | - Detailed steps to reproduce the bug. 17 | - _Full_ error message whenever there is one 18 | - Your Python version, operating system name and version. 19 | - Any details about your local setup that might be helpful in troubleshooting. 20 | 21 | ### Fix Bugs 22 | 23 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 24 | wanted" is open to whoever wants to implement it. 25 | 26 | ### Implement Features 27 | 28 | Look through the GitHub issues for features. Anything tagged with "enhancement" 29 | and "help wanted" is open to whoever wants to implement it. 30 | 31 | ### Write Documentation 32 | 33 | `requests_oauth2client` could always use more documentation, whether as part of the 34 | official requests_oauth2client docs, in docstrings, or even on the web in blog posts, 35 | articles, and such. 36 | 37 | ### Submit Feedback 38 | 39 | The best way to send feedback is to file an issue at https://github.com/guillp/requests_oauth2client/issues. 40 | 41 | If you are proposing a feature: 42 | 43 | - Explain in detail how it would work. 44 | - Keep the scope as narrow as possible, to make it easier to implement. 45 | - Remember that this is a volunteer-driven project, and that contributions 46 | are welcome :) 47 | 48 | ## Get Started! 49 | 50 | Ready to contribute? Here's how to set up `requests_oauth2client` for local development. 51 | 52 | 1. Fork the `requests_oauth2client` repo on GitHub. 53 | 2. Clone your fork locally 54 | 55 | ``` 56 | $ git clone git@github.com:your_name_here/requests_oauth2client.git 57 | ``` 58 | 59 | 3. Ensure [poetry](https://python-poetry.org/docs/) is installed. 60 | 4. Install dependencies and start your virtualenv: 61 | 62 | ``` 63 | $ poetry install -E test -E doc -E dev 64 | ``` 65 | 66 | 5. Create a branch for local development: 67 | 68 | ``` 69 | $ git checkout -b name-of-your-bugfix-or-feature 70 | ``` 71 | 72 | Now you can make your changes locally. 73 | 74 | 6. When you're done making changes, check that your changes pass the 75 | tests, including testing other Python versions, with tox: 76 | 77 | ``` 78 | $ tox 79 | ``` 80 | 81 | 7. Commit your changes and push your branch to GitHub: 82 | 83 | ``` 84 | $ git add . 85 | $ git commit -m "Your detailed description of your changes." 86 | $ git push origin name-of-your-bugfix-or-feature 87 | ``` 88 | 89 | 8. Submit a pull request through the GitHub website. 90 | 91 | ## Pull Request Guidelines 92 | 93 | Before you submit a pull request, check that it meets these guidelines: 94 | 95 | 1. The pull request should include tests. 96 | 2. If the pull request adds functionality, the docs should be updated. Put 97 | your new functionality into a function with a docstring, and add the 98 | feature to the list in README.md. 99 | 3. The pull request should work for Python 3.8+ and for PyPy. Check 100 | https://github.com/guillp/requests_oauth2client/actions 101 | and make sure that the tests pass for all supported Python versions. 102 | 103 | ## Tips 104 | 105 | ``` 106 | $ pytest tests.test_client_credentials 107 | ``` 108 | 109 | To run a subset of tests. 110 | 111 | ## Deploying 112 | 113 | A reminder for the maintainers on how to deploy. 114 | Make sure all your changes are committed (including an entry in HISTORY.md). 115 | Then run: 116 | 117 | ``` 118 | $ poetry patch # possible: major / minor / patch 119 | $ git push 120 | $ git push --tags 121 | ``` 122 | 123 | Travis will then deploy to PyPI if tests pass. 124 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache Software License 2.0 2 | 3 | Copyright (c) 2021, Guillaume Pujol 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include VERSION 3 | include requirements.txt 4 | include requirements-dev.txt 5 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ::: requests_oauth2client 2 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | {% 2 | include-markdown "../CONTRIBUTING.md" 3 | %} 4 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | {% 2 | include-markdown "../README.md" 3 | %} 4 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Stable release 4 | 5 | To install requests_oauth2client, run this command in your 6 | terminal: 7 | 8 | ```console 9 | $ pip install requests_oauth2client 10 | ``` 11 | 12 | This is the preferred method to install requests_oauth2client, as it will always install the most recent stable release. 13 | 14 | If you don't have [pip] installed, this [Python installation guide] 15 | can guide you through the process. 16 | 17 | ## From source 18 | 19 | The source for requests_oauth2client can be downloaded from 20 | the [Github repo]. 21 | 22 | You can either clone the public repository: 23 | 24 | ```console 25 | $ git clone git://github.com/guillp/requests_oauth2client 26 | ``` 27 | 28 | Or download the [tarball]: 29 | 30 | ```console 31 | $ curl -OJL https://github.com/guillp/requests_oauth2client/tarball/master 32 | ``` 33 | 34 | Once you have a copy of the source, you can install it with: 35 | 36 | ```console 37 | $ pip install . 38 | ``` 39 | 40 | [github repo]: https://github.com/%7B%7B%20cookiecutter.github_username%20%7D%7D/%7B%7B%20cookiecutter.project_slug%20%7D%7D 41 | [pip]: https://pip.pypa.io 42 | [python installation guide]: http://docs.python-guide.org/en/latest/starting/installation/ 43 | [tarball]: https://github.com/%7B%7B%20cookiecutter.github_username%20%7D%7D/%7B%7B%20cookiecutter.project_slug%20%7D%7D/tarball/master 44 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guillp/requests_oauth2client/810e7b6099ac89742adbf7877d7a8b0f785c4016/docs/logo.png -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | To use requests_oauth2client in a project 4 | 5 | ``` 6 | from requests_oauth2client import * 7 | ``` 8 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: requests_oauth2client 2 | repo_url: https://github.com/guillp/requests_oauth2client 3 | repo_name: requests_oauth2client 4 | strict: true 5 | watch: 6 | - requests_oauth2client 7 | - README.md 8 | nav: 9 | - Home: index.md 10 | - Installation: installation.md 11 | - API: api.md 12 | - Contributing: contributing.md 13 | theme: 14 | name: material 15 | language: en 16 | #logo: assets/logo.png 17 | palette: 18 | primary: light blue 19 | features: 20 | - navigation.indexes 21 | - navigation.tabs 22 | - navigation.instant 23 | - navigation.tabs.sticky 24 | - navigation.footer 25 | - content.code.copy 26 | - content.action.view 27 | markdown_extensions: 28 | - pymdownx.emoji: 29 | emoji_index: !!python/name:material.extensions.emoji.twemoji 30 | emoji_generator: !!python/name:materialx.emoji.to_svg 31 | - pymdownx.critic 32 | - pymdownx.caret 33 | - pymdownx.mark 34 | - pymdownx.tilde 35 | - pymdownx.tabbed 36 | - attr_list 37 | - pymdownx.arithmatex: 38 | generic: true 39 | - pymdownx.highlight: 40 | linenums: true 41 | anchor_linenums: true 42 | line_spans: __span 43 | pygments_lang_class: true 44 | - pymdownx.inlinehilite 45 | - pymdownx.superfences 46 | - pymdownx.details 47 | - admonition 48 | - toc: 49 | baselevel: 2 50 | permalink: true 51 | slugify: !!python/object/apply:pymdownx.slugs.slugify {kwds: {case: lower}} 52 | - meta 53 | plugins: 54 | - include-markdown 55 | - search: 56 | lang: en 57 | - autorefs 58 | - mkdocstrings: 59 | default_handler: python 60 | handlers: 61 | python: 62 | options: 63 | #extensions: 64 | #- griffe_fieldz: {include_inherited: true} 65 | filters: 66 | - "!^_" 67 | - "^__init__" 68 | - "!^utils" 69 | members_order: source 70 | show_root_heading: true 71 | show_submodules: true 72 | import: 73 | - https://requests.readthedocs.io/en/master/objects.inv 74 | - https://guillp.github.io/jwskate/objects.inv 75 | extra: 76 | social: 77 | - icon: fontawesome/brands/github 78 | link: https://github.com/guillp/requests_oauth2client 79 | name: Github 80 | - icon: material/email 81 | link: "mailto:guill.p.linux@gmail.com" 82 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool] 2 | [tool.poetry] 3 | name = "requests_oauth2client" 4 | version = "1.7.0" 5 | homepage = "https://github.com/guillp/requests_oauth2client" 6 | description = "An OAuth2.x client based on `requests`." 7 | authors = ["Guillaume Pujol "] 8 | readme = "README.md" 9 | license = "Apache-2.0" 10 | classifiers = [ 11 | 'Development Status :: 4 - Beta', 12 | 'Intended Audience :: Developers', 13 | 'Topic :: Security', 14 | 'License :: OSI Approved :: Apache Software License', 15 | 'Programming Language :: Python :: 3', 16 | 'Programming Language :: Python :: 3.9', 17 | 'Programming Language :: Python :: 3.10', 18 | 'Programming Language :: Python :: 3.11', 19 | 'Programming Language :: Python :: 3.12', 20 | 'Programming Language :: Python :: 3.13', 21 | ] 22 | packages = [ 23 | { include = "requests_oauth2client" }, 24 | { include = "tests", format = "sdist" }, 25 | ] 26 | 27 | [tool.poetry.dependencies] 28 | python = ">=3.9" 29 | 30 | requests = ">=2.19.0" 31 | binapy = ">=0.8" 32 | furl = ">=2.1.2" 33 | jwskate = ">=0.11.1" 34 | attrs = ">=23.2.0" 35 | 36 | 37 | [tool.poetry.group.dev.dependencies] 38 | coverage = ">=7.8.0" 39 | flask = ">=3.0.3" 40 | livereload = ">=2.6.3" 41 | mypy = ">=1.8" 42 | mkdocs = ">=1.3.1" 43 | mkdocs-autorefs = ">=0.3.0" 44 | mkdocs-include-markdown-plugin = ">=6" 45 | mkdocs-material = ">=9.6.11" 46 | mkdocs-material-extensions = ">=1.0.1" 47 | mkdocstrings = { version = ">=0.29.1", extras = ["python"] } 48 | pre-commit = ">=3.5.0" 49 | pytest = ">=7.0.1" 50 | pytest-cov = ">=5.0.0" 51 | pytest-freezer = ">=0.4.8" 52 | pytest-mock = "^3.14.0" 53 | pytest-mypy = ">=1.0.0" 54 | requests-mock = ">=1.9.3" 55 | toml = ">=0.10.2" 56 | tox = ">=4" 57 | types-requests = ">=2.25.10" 58 | types-cryptography = ">=3.3.15" 59 | virtualenv = ">=20.30.0" 60 | pytest-examples = ">=0.0.17" 61 | 62 | 63 | [tool.poetry.extras] 64 | test = ["pytest", "pytest-cov"] 65 | doc = [ 66 | "mdformat", 67 | "mkdocs", 68 | "mkdocs-autorefs", 69 | "mkdocs-include-markdown-plugin", 70 | "mkdocs-material", 71 | "mkdocs-material-extensions", 72 | "mkdocstrings" 73 | ] 74 | 75 | [build-system] 76 | requires = ["poetry-core>=1.0.0"] 77 | build-backend = "poetry.core.masonry.api" 78 | 79 | [tool.coverage.run] 80 | source = ["requests_oauth2client"] 81 | 82 | [tool.coverage.report] 83 | exclude_also = [ 84 | "def __repr__", 85 | "if self.debug:", 86 | "if settings.DEBUG", 87 | "raise AssertionError", 88 | "raise NotImplementedError", 89 | "if 0:", 90 | "if __name__ == .__main__.:", 91 | "def main", 92 | "if TYPE_CHECKING:", 93 | ] 94 | 95 | [tool.docformatter] 96 | black = true 97 | recursive = true 98 | wrap-summaries = 120 99 | wrap-descriptions = 120 100 | blank = true 101 | 102 | [tool.ruff] 103 | target-version = "py39" 104 | line-length = 120 105 | 106 | 107 | [tool.ruff.format] 108 | docstring-code-format = true 109 | line-ending = "lf" 110 | 111 | [tool.ruff.lint] 112 | select = ["ALL"] 113 | ignore = [ 114 | "ANN401", # any-type in function args 115 | "N818", # Exception names should be named with an Error suffix 116 | "PLR0912", # Too many branches 117 | "D105", # Undocumented magic method 118 | "D107", # Missing docstring in `__init__` 119 | "S105", # Possible hardcoded password 120 | "COM812", 121 | "ISC001", 122 | ] 123 | 124 | [tool.ruff.lint.per-file-ignores] 125 | "tests/**.py" = ["ARG001", "B018", "D100", "D101", "D102", "D103", "D104", "F821", "PGH005", "PLR0913", "PLR0915", "PLR2004", "S101", "S106", "S113", 126 | "PT011", "E501"] 127 | 128 | [tool.ruff.lint.pylint] 129 | max-args = 10 130 | 131 | [tool.ruff.lint.pydocstyle] 132 | convention = "google" 133 | ignore-decorators = ['override'] 134 | 135 | [tool.mypy] 136 | strict = true 137 | show_error_context = true 138 | show_column_numbers = true 139 | show_error_codes = true 140 | pretty = true 141 | warn_unused_configs = true 142 | warn_unused_ignores = true 143 | warn_redundant_casts = true 144 | 145 | 146 | [tool.pytest.ini_options] 147 | requests_mock_case_sensitive = true 148 | markers = [ 149 | "slow: marks tests as slow" 150 | ] 151 | filterwarnings = [ 152 | "ignore::DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead." 153 | ] 154 | -------------------------------------------------------------------------------- /requests_oauth2client/__init__.py: -------------------------------------------------------------------------------- 1 | """Main module for `requests_oauth2client`. 2 | 3 | You can import any class from any submodule directly from this main module. 4 | 5 | """ 6 | 7 | import requests 8 | from jwskate import EncryptionAlgs, KeyManagementAlgs, SignatureAlgs 9 | 10 | from .api_client import ApiClient, InvalidBoolFieldsParam, InvalidPathParam 11 | from .auth import ( 12 | NonRenewableTokenError, 13 | OAuth2AccessTokenAuth, 14 | OAuth2AuthorizationCodeAuth, 15 | OAuth2ClientCredentialsAuth, 16 | OAuth2DeviceCodeAuth, 17 | OAuth2ResourceOwnerPasswordAuth, 18 | ) 19 | from .authorization_request import ( 20 | AuthorizationRequest, 21 | AuthorizationRequestSerializer, 22 | AuthorizationResponse, 23 | CodeChallengeMethods, 24 | InvalidCodeVerifierParam, 25 | InvalidMaxAgeParam, 26 | MissingIssuerParam, 27 | PkceUtils, 28 | RequestParameterAuthorizationRequest, 29 | RequestUriParameterAuthorizationRequest, 30 | ResponseTypes, 31 | UnsupportedCodeChallengeMethod, 32 | UnsupportedResponseTypeParam, 33 | ) 34 | from .backchannel_authentication import ( 35 | BackChannelAuthenticationPoolingJob, 36 | BackChannelAuthenticationResponse, 37 | ) 38 | from .client import ( 39 | Endpoints, 40 | GrantTypes, 41 | InvalidAcrValuesParam, 42 | InvalidBackchannelAuthenticationRequestHintParam, 43 | InvalidDiscoveryDocument, 44 | InvalidEndpointUri, 45 | InvalidIssuer, 46 | InvalidParam, 47 | InvalidScopeParam, 48 | MissingAuthRequestId, 49 | MissingDeviceCode, 50 | MissingEndpointUri, 51 | MissingIdTokenEncryptedResponseAlgParam, 52 | MissingRefreshToken, 53 | OAuth2Client, 54 | UnknownActorTokenType, 55 | UnknownSubjectTokenType, 56 | UnknownTokenType, 57 | ) 58 | from .client_authentication import ( 59 | BaseClientAssertionAuthenticationMethod, 60 | BaseClientAuthenticationMethod, 61 | ClientSecretBasic, 62 | ClientSecretJwt, 63 | ClientSecretPost, 64 | InvalidClientAssertionSigningKeyOrAlg, 65 | InvalidRequestForClientAuthentication, 66 | PrivateKeyJwt, 67 | PublicApp, 68 | UnsupportedClientCredentials, 69 | ) 70 | from .device_authorization import ( 71 | DeviceAuthorizationPoolingJob, 72 | DeviceAuthorizationResponse, 73 | ) 74 | from .discovery import ( 75 | oauth2_discovery_document_url, 76 | oidc_discovery_document_url, 77 | well_known_uri, 78 | ) 79 | from .dpop import ( 80 | DPoPKey, 81 | DPoPToken, 82 | InvalidDPoPAccessToken, 83 | InvalidDPoPAlg, 84 | InvalidDPoPKey, 85 | InvalidDPoPProof, 86 | InvalidUseDPoPNonceResponse, 87 | MissingDPoPNonce, 88 | RepeatedDPoPNonce, 89 | validate_dpop_proof, 90 | ) 91 | from .exceptions import ( 92 | AccessDenied, 93 | AccountSelectionRequired, 94 | AuthorizationPending, 95 | AuthorizationResponseError, 96 | BackChannelAuthenticationError, 97 | ConsentRequired, 98 | DeviceAuthorizationError, 99 | EndpointError, 100 | ExpiredToken, 101 | InteractionRequired, 102 | IntrospectionError, 103 | InvalidAuthResponse, 104 | InvalidBackChannelAuthenticationResponse, 105 | InvalidClient, 106 | InvalidDeviceAuthorizationResponse, 107 | InvalidGrant, 108 | InvalidPushedAuthorizationResponse, 109 | InvalidRequest, 110 | InvalidScope, 111 | InvalidTarget, 112 | InvalidTokenResponse, 113 | LoginRequired, 114 | MismatchingIssuer, 115 | MismatchingState, 116 | MissingAuthCode, 117 | MissingIssuer, 118 | OAuth2Error, 119 | RevocationError, 120 | ServerError, 121 | SessionSelectionRequired, 122 | SlowDown, 123 | TokenEndpointError, 124 | UnauthorizedClient, 125 | UnknownIntrospectionError, 126 | UnknownTokenEndpointError, 127 | UnsupportedTokenType, 128 | UseDPoPNonce, 129 | ) 130 | from .pooling import ( 131 | BaseTokenEndpointPoolingJob, 132 | ) 133 | from .tokens import ( 134 | BearerToken, 135 | BearerTokenSerializer, 136 | ExpiredAccessToken, 137 | ExpiredIdToken, 138 | IdToken, 139 | InvalidIdToken, 140 | MismatchingIdTokenAcr, 141 | MismatchingIdTokenAlg, 142 | MismatchingIdTokenAudience, 143 | MismatchingIdTokenAzp, 144 | MismatchingIdTokenIssuer, 145 | MismatchingIdTokenNonce, 146 | MissingIdToken, 147 | ) 148 | from .utils import ( 149 | InvalidUri, 150 | validate_endpoint_uri, 151 | validate_issuer_uri, 152 | ) 153 | 154 | __all__ = [ 155 | "AccessDenied", 156 | "AccountSelectionRequired", 157 | "ApiClient", 158 | "AuthorizationPending", 159 | "AuthorizationRequest", 160 | "AuthorizationRequestSerializer", 161 | "AuthorizationResponse", 162 | "AuthorizationResponseError", 163 | "BackChannelAuthenticationError", 164 | "BackChannelAuthenticationPoolingJob", 165 | "BackChannelAuthenticationResponse", 166 | "BaseClientAssertionAuthenticationMethod", 167 | "BaseClientAuthenticationMethod", 168 | "BaseTokenEndpointPoolingJob", 169 | "BearerToken", 170 | "BearerTokenSerializer", 171 | "ClientSecretBasic", 172 | "ClientSecretJwt", 173 | "ClientSecretPost", 174 | "CodeChallengeMethods", 175 | "ConsentRequired", 176 | "DPoPKey", 177 | "DPoPToken", 178 | "DeviceAuthorizationError", 179 | "DeviceAuthorizationPoolingJob", 180 | "DeviceAuthorizationResponse", 181 | "EncryptionAlgs", 182 | "EndpointError", 183 | "Endpoints", 184 | "ExpiredAccessToken", 185 | "ExpiredIdToken", 186 | "ExpiredToken", 187 | "GrantTypes", 188 | "IdToken", 189 | "InteractionRequired", 190 | "IntrospectionError", 191 | "InvalidAcrValuesParam", 192 | "InvalidAuthResponse", 193 | "InvalidBackChannelAuthenticationResponse", 194 | "InvalidBackchannelAuthenticationRequestHintParam", 195 | "InvalidBoolFieldsParam", 196 | "InvalidClient", 197 | "InvalidClientAssertionSigningKeyOrAlg", 198 | "InvalidCodeVerifierParam", 199 | "InvalidDPoPAccessToken", 200 | "InvalidDPoPAlg", 201 | "InvalidDPoPKey", 202 | "InvalidDPoPProof", 203 | "InvalidDeviceAuthorizationResponse", 204 | "InvalidDiscoveryDocument", 205 | "InvalidEndpointUri", 206 | "InvalidGrant", 207 | "InvalidIdToken", 208 | "InvalidIssuer", 209 | "InvalidMaxAgeParam", 210 | "InvalidParam", 211 | "InvalidPathParam", 212 | "InvalidPushedAuthorizationResponse", 213 | "InvalidRequest", 214 | "InvalidRequestForClientAuthentication", 215 | "InvalidScope", 216 | "InvalidScopeParam", 217 | "InvalidTarget", 218 | "InvalidTokenResponse", 219 | "InvalidUri", 220 | "InvalidUseDPoPNonceResponse", 221 | "KeyManagementAlgs", 222 | "LoginRequired", 223 | "MismatchingIdTokenAcr", 224 | "MismatchingIdTokenAlg", 225 | "MismatchingIdTokenAudience", 226 | "MismatchingIdTokenAzp", 227 | "MismatchingIdTokenIssuer", 228 | "MismatchingIdTokenNonce", 229 | "MismatchingIssuer", 230 | "MismatchingState", 231 | "MissingAuthCode", 232 | "MissingAuthRequestId", 233 | "MissingDPoPNonce", 234 | "MissingDeviceCode", 235 | "MissingEndpointUri", 236 | "MissingIdToken", 237 | "MissingIdTokenEncryptedResponseAlgParam", 238 | "MissingIssuer", 239 | "MissingIssuerParam", 240 | "MissingRefreshToken", 241 | "NonRenewableTokenError", 242 | "OAuth2AccessTokenAuth", 243 | "OAuth2AuthorizationCodeAuth", 244 | "OAuth2Client", 245 | "OAuth2ClientCredentialsAuth", 246 | "OAuth2DeviceCodeAuth", 247 | "OAuth2Error", 248 | "OAuth2ResourceOwnerPasswordAuth", 249 | "PkceUtils", 250 | "PrivateKeyJwt", 251 | "PublicApp", 252 | "RepeatedDPoPNonce", 253 | "RequestParameterAuthorizationRequest", 254 | "RequestUriParameterAuthorizationRequest", 255 | "ResponseTypes", 256 | "RevocationError", 257 | "ServerError", 258 | "SessionSelectionRequired", 259 | "SignatureAlgs", 260 | "SlowDown", 261 | "TokenEndpointError", 262 | "UnauthorizedClient", 263 | "UnknownActorTokenType", 264 | "UnknownIntrospectionError", 265 | "UnknownSubjectTokenType", 266 | "UnknownTokenEndpointError", 267 | "UnknownTokenType", 268 | "UnsupportedClientCredentials", 269 | "UnsupportedCodeChallengeMethod", 270 | "UnsupportedResponseTypeParam", 271 | "UnsupportedTokenType", 272 | "UseDPoPNonce", 273 | "oauth2_discovery_document_url", 274 | "oidc_discovery_document_url", 275 | "requests", 276 | "validate_dpop_proof", 277 | "validate_endpoint_uri", 278 | "validate_issuer_uri", 279 | "well_known_uri", 280 | ] 281 | -------------------------------------------------------------------------------- /requests_oauth2client/auth.py: -------------------------------------------------------------------------------- 1 | """This module contains `requests`-compatible Auth Handlers that implement OAuth 2.0.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | import requests 8 | from attrs import define, field, setters 9 | from typing_extensions import override 10 | 11 | from .tokens import BearerToken 12 | 13 | if TYPE_CHECKING: 14 | from .authorization_request import AuthorizationResponse 15 | from .client import OAuth2Client 16 | from .device_authorization import DeviceAuthorizationResponse 17 | 18 | 19 | class NonRenewableTokenError(Exception): 20 | """Raised when attempting to renew a token non-interactively when missing renewing material.""" 21 | 22 | 23 | @define(init=False) 24 | class OAuth2AccessTokenAuth(requests.auth.AuthBase): 25 | """Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens. 26 | 27 | This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as 28 | Bearer or DPoP token, and can automatically refresh it when expired, if a refresh token is available. 29 | 30 | Token can be a simple `str` containing a raw access token value, or a 31 | [BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a `refresh_token`. 32 | 33 | In addition to adding a properly formatted `Authorization` header, this will obtain a new token 34 | once the current token is expired. Expiration is detected based on the `expires_in` hint 35 | returned by the AS. A configurable `leeway`, in number of seconds, will make sure that a new 36 | token is obtained some seconds before the actual expiration is reached. This may help in 37 | situations where the client, AS and RS have slightly offset clocks. 38 | 39 | Args: 40 | client: the client to use to refresh tokens. 41 | token: an initial Access Token, if you have one already. In most cases, leave `None`. 42 | leeway: expiration leeway, in number of seconds. 43 | **token_kwargs: additional kwargs to pass to the token endpoint. 44 | 45 | Example: 46 | ```python 47 | from requests_oauth2client import BearerToken, OAuth2Client, OAuth2AccessTokenAuth, requests 48 | 49 | client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) 50 | # obtain a BearerToken any way you see fit, optionally including a refresh token 51 | # for this example, the token value is hardcoded 52 | token = BearerToken(access_token="access_token", expires_in=600, refresh_token="refresh_token") 53 | auth = OAuth2AccessTokenAuth(client, token, scope="my_scope") 54 | resp = requests.post("https://my.api.local/resource", auth=auth) 55 | ``` 56 | 57 | """ 58 | 59 | client: OAuth2Client = field(on_setattr=setters.frozen) 60 | token: BearerToken | None 61 | leeway: int = field(on_setattr=setters.frozen) 62 | token_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) 63 | 64 | def __init__( 65 | self, client: OAuth2Client, token: str | BearerToken, *, leeway: int = 20, **token_kwargs: Any 66 | ) -> None: 67 | if isinstance(token, str): 68 | token = BearerToken(token) 69 | self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs) 70 | 71 | def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 72 | """Add the Access Token to the request. 73 | 74 | If Access Token is not specified or expired, obtain a new one first. 75 | 76 | Raises: 77 | NonRenewableTokenError: if the token is not renewable 78 | 79 | """ 80 | if self.token is None or self.token.is_expired(self.leeway): 81 | self.renew_token() 82 | if self.token is None: 83 | raise NonRenewableTokenError # pragma: no cover 84 | return self.token(request) 85 | 86 | def renew_token(self) -> None: 87 | """Obtain a new Bearer Token. 88 | 89 | This will try to use the `refresh_token`, if there is one. 90 | 91 | """ 92 | if self.token is not None and self.token.refresh_token is not None: 93 | self.token = self.client.refresh_token(refresh_token=self.token, **self.token_kwargs) 94 | 95 | def forget_token(self) -> None: 96 | """Forget the current token, forcing a renewal on the next HTTP request.""" 97 | self.token = None 98 | 99 | 100 | @define(init=False) 101 | class OAuth2ClientCredentialsAuth(OAuth2AccessTokenAuth): 102 | """An Auth Handler for the [Client Credentials grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.4). 103 | 104 | This [requests AuthBase][requests.auth.AuthBase] automatically gets Access Tokens from an OAuth 105 | 2.0 Token Endpoint with the Client Credentials grant, and will get a new one once the current 106 | one is expired. 107 | 108 | Args: 109 | client: the [OAuth2Client][requests_oauth2client.client.OAuth2Client] to use to obtain Access Tokens. 110 | token: an initial Access Token, if you have one already. In most cases, leave `None`. 111 | leeway: expiration leeway, in number of seconds 112 | **token_kwargs: extra kw parameters to pass to the Token Endpoint. May include `scope`, `resource`, etc. 113 | 114 | Example: 115 | ```python 116 | from requests_oauth2client import OAuth2Client, OAuth2ClientCredentialsAuth, requests 117 | 118 | client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) 119 | oauth2cc = OAuth2ClientCredentialsAuth(client, scope="my_scope") 120 | resp = requests.post("https://my.api.local/resource", auth=oauth2cc) 121 | ``` 122 | 123 | """ 124 | 125 | def __init__( 126 | self, client: OAuth2Client, *, leeway: int = 20, token: str | BearerToken | None = None, **token_kwargs: Any 127 | ) -> None: 128 | if isinstance(token, str): 129 | token = BearerToken(token) 130 | self.__attrs_init__(client=client, token=token, leeway=leeway, token_kwargs=token_kwargs) 131 | 132 | @override 133 | def renew_token(self) -> None: 134 | """Obtain a new token for use within this Auth Handler.""" 135 | self.token = self.client.client_credentials(**self.token_kwargs) 136 | 137 | 138 | @define(init=False) 139 | class OAuth2AuthorizationCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override] 140 | """Authentication handler for the [Authorization Code grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.1). 141 | 142 | This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges an Authorization 143 | Code for an access token, then automatically refreshes it once it is expired. 144 | 145 | Args: 146 | client: the client to use to obtain Access Tokens. 147 | code: an Authorization Code that has been obtained from the AS. 148 | token: an initial Access Token, if you have one already. In most cases, leave `None`. 149 | leeway: expiration leeway, in number of seconds. 150 | **token_kwargs: additional kwargs to pass to the token endpoint. 151 | 152 | Example: 153 | ```python 154 | from requests_oauth2client import ApiClient, OAuth2Client, OAuth2AuthorizationCodeAuth 155 | 156 | client = OAuth2Client(token_endpoint="https://myas.local/token", auth=("client_id", "client_secret")) 157 | code = "my_code" # you must obtain this code yourself 158 | api = ApiClient("https://my.api.local/resource", auth=OAuth2AuthorizationCodeAuth(client, code)) 159 | ``` 160 | 161 | """ 162 | 163 | code: str | AuthorizationResponse | None 164 | 165 | def __init__( 166 | self, 167 | client: OAuth2Client, 168 | code: str | AuthorizationResponse | None, 169 | *, 170 | leeway: int = 20, 171 | token: str | BearerToken | None = None, 172 | **token_kwargs: Any, 173 | ) -> None: 174 | if isinstance(token, str): 175 | token = BearerToken(token) 176 | self.__attrs_init__( 177 | client=client, 178 | token=token, 179 | code=code, 180 | leeway=leeway, 181 | token_kwargs=token_kwargs, 182 | ) 183 | 184 | def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 185 | """Implement the Authorization Code grant as an Authentication Handler. 186 | 187 | This exchanges an Authorization Code for an access token and adds it in the request. 188 | 189 | Args: 190 | request: the request 191 | 192 | Returns: 193 | the request, with an Access Token added in Authorization Header 194 | 195 | """ 196 | if self.token is None or self.token.is_expired(): 197 | self.exchange_code_for_token() 198 | return super().__call__(request) 199 | 200 | def exchange_code_for_token(self) -> None: 201 | """Exchange the authorization code for an access token.""" 202 | if self.code: # pragma: no branch 203 | self.token = self.client.authorization_code(code=self.code, **self.token_kwargs) 204 | self.code = None 205 | 206 | 207 | @define(init=False) 208 | class OAuth2ResourceOwnerPasswordAuth(OAuth2AccessTokenAuth): # type: ignore[override] 209 | """Authentication Handler for the [Resource Owner Password Credentials Flow](https://www.rfc-editor.org/rfc/rfc6749#section-4.3). 210 | 211 | This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges the user 212 | credentials for an Access Token, then automatically repeats the process to get a new one 213 | once the current one is expired. 214 | 215 | Note that this flow is considered *deprecated*, and the Authorization Code flow should be 216 | used whenever possible. 217 | Among other bad things, ROPC: 218 | 219 | - does not support SSO between multiple apps, 220 | - does not support MFA or risk-based adaptative authentication, 221 | - depends on the user typing its credentials directly inside the application, instead of on a 222 | dedicated, centralized login page managed by the AS, which makes it totally insecure for 3rd party apps. 223 | 224 | It needs the username and password and an 225 | [OAuth2Client][requests_oauth2client.client.OAuth2Client] to be able to get a token from 226 | the AS Token Endpoint just before the first request using this Auth Handler is being sent. 227 | 228 | Args: 229 | client: the client to use to obtain Access Tokens 230 | username: the username 231 | password: the user password 232 | leeway: an amount of time, in seconds 233 | token: an initial Access Token, if you have one already. In most cases, leave `None`. 234 | **token_kwargs: additional kwargs to pass to the token endpoint 235 | 236 | Example: 237 | ```python 238 | from requests_oauth2client import ApiClient, OAuth2Client, OAuth2ResourceOwnerPasswordAuth 239 | 240 | client = OAuth2Client( 241 | token_endpoint="https://myas.local/token", 242 | auth=("client_id", "client_secret"), 243 | ) 244 | username = "my_username" 245 | password = "my_password" # you must obtain those credentials from the user 246 | auth = OAuth2ResourceOwnerPasswordAuth(client, username=username, password=password) 247 | api = ApiClient("https://myapi.local", auth=auth) 248 | ``` 249 | """ 250 | 251 | username: str 252 | password: str 253 | 254 | def __init__( 255 | self, 256 | client: OAuth2Client, 257 | *, 258 | username: str, 259 | password: str, 260 | leeway: int = 20, 261 | token: str | BearerToken | None = None, 262 | **token_kwargs: Any, 263 | ) -> None: 264 | if isinstance(token, str): 265 | token = BearerToken(token) 266 | self.__attrs_init__( 267 | client=client, 268 | token=token, 269 | leeway=leeway, 270 | token_kwargs=token_kwargs, 271 | username=username, 272 | password=password, 273 | ) 274 | 275 | @override 276 | def renew_token(self) -> None: 277 | """Exchange the user credentials for an Access Token.""" 278 | self.token = self.client.resource_owner_password( 279 | username=self.username, 280 | password=self.password, 281 | **self.token_kwargs, 282 | ) 283 | 284 | 285 | @define(init=False) 286 | class OAuth2DeviceCodeAuth(OAuth2AccessTokenAuth): # type: ignore[override] 287 | """Authentication Handler for the [Device Code Flow](https://www.rfc-editor.org/rfc/rfc8628). 288 | 289 | This [Requests Auth handler][requests.auth.AuthBase] implementation exchanges a Device Code for 290 | an Access Token, then automatically refreshes it once it is expired. 291 | 292 | It needs a Device Code and an [OAuth2Client][requests_oauth2client.client.OAuth2Client] to be 293 | able to get a token from the AS Token Endpoint just before the first request using this Auth 294 | Handler is being sent. 295 | 296 | Args: 297 | client: the [OAuth2Client][requests_oauth2client.client.OAuth2Client] to use to obtain Access Tokens. 298 | device_code: a Device Code obtained from the AS. 299 | interval: the interval to use to pool the Token Endpoint, in seconds. 300 | expires_in: the lifetime of the token, in seconds. 301 | token: an initial Access Token, if you have one already. In most cases, leave `None`. 302 | leeway: expiration leeway, in number of seconds. 303 | **token_kwargs: additional kwargs to pass to the token endpoint. 304 | 305 | Example: 306 | ```python 307 | from requests_oauth2client import OAuth2Client, OAuth2DeviceCodeAuth, requests 308 | 309 | client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) 310 | device_code = client.device_authorization() 311 | auth = OAuth2DeviceCodeAuth(client, device_code) 312 | resp = requests.post("https://my.api.local/resource", auth=auth) 313 | ``` 314 | 315 | """ 316 | 317 | device_code: str | DeviceAuthorizationResponse | None 318 | interval: int 319 | expires_in: int 320 | 321 | def __init__( 322 | self, 323 | client: OAuth2Client, 324 | *, 325 | device_code: str | DeviceAuthorizationResponse, 326 | leeway: int = 20, 327 | interval: int = 5, 328 | expires_in: int = 360, 329 | token: str | BearerToken | None = None, 330 | **token_kwargs: Any, 331 | ) -> None: 332 | if isinstance(token, str): 333 | token = BearerToken(token) 334 | self.__attrs_init__( 335 | client=client, 336 | token=token, 337 | leeway=leeway, 338 | token_kwargs=token_kwargs, 339 | device_code=device_code, 340 | interval=interval, 341 | expires_in=expires_in, 342 | ) 343 | 344 | @override 345 | def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 346 | """Implement the Device Code grant as a request Authentication Handler. 347 | 348 | This exchanges a Device Code for an access token and adds it in HTTP requests. 349 | 350 | Args: 351 | request: a [requests.PreparedRequest][] 352 | 353 | Returns: 354 | a [requests.PreparedRequest][] with an Access Token added in Authorization Header 355 | 356 | """ 357 | if self.token is None: 358 | self.exchange_device_code_for_token() 359 | return super().__call__(request) 360 | 361 | def exchange_device_code_for_token(self) -> None: 362 | """Exchange the Device Code for an access token. 363 | 364 | This will poll the Token Endpoint until the user finishes the authorization process. 365 | 366 | """ 367 | from .device_authorization import DeviceAuthorizationPoolingJob 368 | 369 | if self.device_code: # pragma: no branch 370 | pooling_job = DeviceAuthorizationPoolingJob( 371 | client=self.client, 372 | device_code=self.device_code, 373 | interval=self.interval, 374 | ) 375 | token = None 376 | while token is None: 377 | token = pooling_job() 378 | self.token = token 379 | self.device_code = None 380 | -------------------------------------------------------------------------------- /requests_oauth2client/backchannel_authentication.py: -------------------------------------------------------------------------------- 1 | """Implementation of CIBA. 2 | 3 | CIBA stands for Client Initiated BackChannel Authentication and is standardised by the OpenID Fundation. 4 | https://openid.net/specs/openid-client-initiated-backchannel- 5 | authentication-core-1_0.html. 6 | 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from datetime import datetime, timedelta, timezone 12 | from math import ceil 13 | from typing import TYPE_CHECKING, Any 14 | 15 | from attrs import define 16 | 17 | from .pooling import BaseTokenEndpointPoolingJob 18 | from .utils import accepts_expires_in 19 | 20 | if TYPE_CHECKING: 21 | from .client import OAuth2Client 22 | from .tokens import BearerToken 23 | 24 | 25 | class BackChannelAuthenticationResponse: 26 | """Represent a BackChannel Authentication Response. 27 | 28 | This contains all the parameters that are returned by the AS as a result of a BackChannel 29 | Authentication Request, such as `auth_req_id` (required), and the optional `expires_at`, 30 | `interval`, and/or any custom parameters. 31 | 32 | Args: 33 | auth_req_id: the `auth_req_id` as returned by the AS. 34 | expires_at: the date when the `auth_req_id` expires. 35 | Note that this request also accepts an `expires_in` parameter, in seconds. 36 | interval: the Token Endpoint pooling interval, in seconds, as returned by the AS. 37 | **kwargs: any additional custom parameters as returned by the AS. 38 | 39 | """ 40 | 41 | @accepts_expires_in 42 | def __init__( 43 | self, 44 | auth_req_id: str, 45 | expires_at: datetime | None = None, 46 | interval: int | None = 20, 47 | **kwargs: Any, 48 | ) -> None: 49 | self.auth_req_id = auth_req_id 50 | self.expires_at = expires_at 51 | self.interval = interval 52 | self.other = kwargs 53 | 54 | def is_expired(self, leeway: int = 0) -> bool | None: 55 | """Return `True` if the `auth_req_id` within this response is expired. 56 | 57 | Expiration is evaluated at the time of the call. If there is no "expires_at" hint (which is 58 | derived from the `expires_in` hint returned by the AS BackChannel Authentication endpoint), 59 | this will return `None`. 60 | 61 | Returns: 62 | `True` if the auth_req_id is expired, `False` if it is still valid, `None` if there is 63 | no `expires_in` hint. 64 | 65 | """ 66 | if self.expires_at: 67 | return datetime.now(tz=timezone.utc) - timedelta(seconds=leeway) > self.expires_at 68 | return None 69 | 70 | @property 71 | def expires_in(self) -> int | None: 72 | """Number of seconds until expiration.""" 73 | if self.expires_at: 74 | return ceil((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds()) 75 | return None 76 | 77 | def __getattr__(self, key: str) -> Any: 78 | """Return attributes from this `BackChannelAuthenticationResponse`. 79 | 80 | Allows accessing response parameters with `token_response.expires_in` or 81 | `token_response.any_custom_attribute`. 82 | 83 | Args: 84 | key: a key 85 | 86 | Returns: 87 | the associated value in this token response 88 | 89 | Raises: 90 | AttributeError: if the attribute is not present in the response 91 | 92 | """ 93 | return self.other.get(key) or super().__getattribute__(key) 94 | 95 | 96 | @define(init=False) 97 | class BackChannelAuthenticationPoolingJob(BaseTokenEndpointPoolingJob): 98 | """A pooling job for the BackChannel Authentication flow. 99 | 100 | This will poll the Token Endpoint until the user finishes with its authentication. 101 | 102 | Args: 103 | client: an OAuth2Client that will be used to pool the token endpoint. 104 | auth_req_id: an `auth_req_id` as `str` or a `BackChannelAuthenticationResponse`. 105 | interval: The pooling interval, in seconds, to use. This overrides 106 | the one in `auth_req_id` if it is a `BackChannelAuthenticationResponse`. 107 | Defaults to 5 seconds. 108 | slow_down_interval: Number of seconds to add to the pooling interval when the AS returns 109 | a slow down request. 110 | requests_kwargs: Additional parameters for the underlying calls to [requests.request][]. 111 | **token_kwargs: Additional parameters for the token request. 112 | 113 | Example: 114 | ```python 115 | client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) 116 | pool_job = BackChannelAuthenticationPoolingJob( 117 | client=client, 118 | auth_req_id="my_auth_req_id", 119 | ) 120 | 121 | token = None 122 | while token is None: 123 | token = pool_job() 124 | ``` 125 | 126 | """ 127 | 128 | auth_req_id: str 129 | 130 | def __init__( 131 | self, 132 | client: OAuth2Client, 133 | auth_req_id: str | BackChannelAuthenticationResponse, 134 | *, 135 | interval: int | None = None, 136 | slow_down_interval: int = 5, 137 | requests_kwargs: dict[str, Any] | None = None, 138 | **token_kwargs: Any, 139 | ) -> None: 140 | if isinstance(auth_req_id, BackChannelAuthenticationResponse): 141 | interval = interval or auth_req_id.interval 142 | auth_req_id = auth_req_id.auth_req_id 143 | 144 | self.__attrs_init__( 145 | client=client, 146 | auth_req_id=auth_req_id, 147 | interval=interval or 5, 148 | slow_down_interval=slow_down_interval, 149 | requests_kwargs=requests_kwargs or {}, 150 | token_kwargs=token_kwargs, 151 | ) 152 | 153 | def token_request(self) -> BearerToken: 154 | """Implement the CIBA token request. 155 | 156 | This actually calls [OAuth2Client.ciba(auth_req_id)] on `client`. 157 | 158 | Returns: 159 | a [BearerToken][requests_oauth2client.tokens.BearerToken] 160 | 161 | """ 162 | return self.client.ciba(self.auth_req_id, requests_kwargs=self.requests_kwargs, **self.token_kwargs) 163 | -------------------------------------------------------------------------------- /requests_oauth2client/device_authorization.py: -------------------------------------------------------------------------------- 1 | """Implements the Device Authorization Flow as defined in RFC8628. 2 | 3 | See [RFC8628](https://datatracker.ietf.org/doc/html/rfc8628). 4 | 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from datetime import datetime, timedelta, timezone 10 | from typing import TYPE_CHECKING, Any 11 | 12 | from attrs import define 13 | 14 | from .pooling import BaseTokenEndpointPoolingJob 15 | from .utils import accepts_expires_in 16 | 17 | if TYPE_CHECKING: 18 | from .client import OAuth2Client 19 | from .tokens import BearerToken 20 | 21 | 22 | class DeviceAuthorizationResponse: 23 | """Represent a response returned by the device Authorization Endpoint. 24 | 25 | All parameters are those returned by the AS as response to a Device Authorization Request. 26 | 27 | Args: 28 | device_code: the `device_code` as returned by the AS. 29 | user_code: the `device_code` as returned by the AS. 30 | verification_uri: the `device_code` as returned by the AS. 31 | verification_uri_complete: the `device_code` as returned by the AS. 32 | expires_at: the expiration date for the device_code. 33 | Also accepts an `expires_in` parameter, as a number of seconds in the future. 34 | interval: the pooling `interval` as returned by the AS. 35 | **kwargs: additional parameters as returned by the AS. 36 | 37 | """ 38 | 39 | @accepts_expires_in 40 | def __init__( 41 | self, 42 | device_code: str, 43 | user_code: str, 44 | verification_uri: str, 45 | verification_uri_complete: str | None = None, 46 | expires_at: datetime | None = None, 47 | interval: int | None = None, 48 | **kwargs: Any, 49 | ) -> None: 50 | self.device_code = device_code 51 | self.user_code = user_code 52 | self.verification_uri = verification_uri 53 | self.verification_uri_complete = verification_uri_complete 54 | self.expires_at = expires_at 55 | self.interval = interval 56 | self.other = kwargs 57 | 58 | def is_expired(self, leeway: int = 0) -> bool | None: 59 | """Check if the `device_code` within this response is expired. 60 | 61 | Returns: 62 | `True` if the device_code is expired, `False` if it is still valid, `None` if there is 63 | no `expires_in` hint. 64 | 65 | """ 66 | if self.expires_at: 67 | return datetime.now(tz=timezone.utc) - timedelta(seconds=leeway) > self.expires_at 68 | return None 69 | 70 | 71 | @define(init=False) 72 | class DeviceAuthorizationPoolingJob(BaseTokenEndpointPoolingJob): 73 | """A Token Endpoint pooling job for the Device Authorization Flow. 74 | 75 | This periodically checks if the user has finished with his authorization in a Device 76 | Authorization flow. 77 | 78 | Args: 79 | client: an OAuth2Client that will be used to pool the token endpoint. 80 | device_code: a `device_code` as `str` or a `DeviceAuthorizationResponse`. 81 | interval: The pooling interval to use. This overrides the one in `auth_req_id` if it is 82 | a `BackChannelAuthenticationResponse`. 83 | slow_down_interval: Number of seconds to add to the pooling interval when the AS returns 84 | a slow-down request. 85 | requests_kwargs: Additional parameters for the underlying calls to [requests.request][]. 86 | **token_kwargs: Additional parameters for the token request. 87 | 88 | Example: 89 | ```python 90 | from requests_oauth2client import DeviceAuthorizationPoolingJob, OAuth2Client 91 | 92 | client = OAuth2Client(token_endpoint="https://my.as.local/token", auth=("client_id", "client_secret")) 93 | pooler = DeviceAuthorizationPoolingJob(client=client, device_code="my_device_code") 94 | 95 | token = None 96 | while token is None: 97 | token = pooler() 98 | ``` 99 | 100 | """ 101 | 102 | device_code: str 103 | 104 | def __init__( 105 | self, 106 | client: OAuth2Client, 107 | device_code: str | DeviceAuthorizationResponse, 108 | interval: int | None = None, 109 | slow_down_interval: int = 5, 110 | requests_kwargs: dict[str, Any] | None = None, 111 | **token_kwargs: Any, 112 | ) -> None: 113 | if isinstance(device_code, DeviceAuthorizationResponse): 114 | interval = interval or device_code.interval 115 | device_code = device_code.device_code 116 | 117 | self.__attrs_init__( 118 | client=client, 119 | device_code=device_code, 120 | interval=interval or 5, 121 | slow_down_interval=slow_down_interval, 122 | requests_kwargs=requests_kwargs or {}, 123 | token_kwargs=token_kwargs, 124 | ) 125 | 126 | def token_request(self) -> BearerToken: 127 | """Implement the Device Code token request. 128 | 129 | This actually calls [OAuth2Client.device_code(device_code)][requests_oauth2client.OAuth2Client.device_code] 130 | on `self.client`. 131 | 132 | Returns: 133 | a [BearerToken][requests_oauth2client.tokens.BearerToken] 134 | 135 | """ 136 | return self.client.device_code(self.device_code, requests_kwargs=self.requests_kwargs, **self.token_kwargs) 137 | -------------------------------------------------------------------------------- /requests_oauth2client/discovery.py: -------------------------------------------------------------------------------- 1 | """Implements Metadata discovery documents URLS. 2 | 3 | This is as defined in [RFC8615](https://datatracker.ietf.org/doc/html/rfc8615) and [OpenID Connect 4 | Discovery 1.0](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). 5 | 6 | """ 7 | 8 | from furl import Path, furl # type: ignore[import-untyped] 9 | 10 | 11 | def well_known_uri(origin: str, name: str, *, at_root: bool = True) -> str: 12 | """Return the location of a well-known document on an origin url. 13 | 14 | See [RFC8615](https://datatracker.ietf.org/doc/html/rfc8615) and [OIDC 15 | Discovery](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). 16 | 17 | Args: 18 | origin: origin to use to build the well-known uri. 19 | name: document name to use to build the well-known uri. 20 | at_root: if `True`, assume the well-known document is at root level (as defined in [RFC8615](https://datatracker.ietf.org/doc/html/rfc8615)). 21 | If `False`, assume the well-known location is per-directory, as defined in [OpenID 22 | Connect Discovery 23 | 1.0](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). 24 | 25 | Returns: 26 | the well-know uri, relative to origin, where the well-known document named `name` should be 27 | found. 28 | 29 | """ 30 | url = furl(origin) 31 | if at_root: 32 | url.path = Path(".well-known") / url.path / name 33 | else: 34 | url.path.add(Path(".well-known") / name) 35 | return str(url) 36 | 37 | 38 | def oidc_discovery_document_url(issuer: str) -> str: 39 | """Construct the OIDC discovery document url for a given `issuer`. 40 | 41 | Given an `issuer` identifier, return the standardised URL where the OIDC discovery document can 42 | be retrieved. 43 | 44 | The returned URL is biuilt as specified in [OpenID Connect Discovery 45 | 1.0](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). 46 | 47 | Args: 48 | issuer: an OIDC Authentication Server `issuer` 49 | 50 | Returns: 51 | the standardised discovery document URL. Note that no attempt to fetch this document is 52 | made. 53 | 54 | """ 55 | return well_known_uri(issuer, "openid-configuration", at_root=False) 56 | 57 | 58 | def oauth2_discovery_document_url(issuer: str) -> str: 59 | """Construct the standardised OAuth 2.0 discovery document url for a given `issuer`. 60 | 61 | Based an `issuer` identifier, returns the standardised URL where the OAuth20 server metadata can 62 | be retrieved. 63 | 64 | The returned URL is built as specified in 65 | [RFC8414](https://datatracker.ietf.org/doc/html/rfc8414). 66 | 67 | Args: 68 | issuer: an OAuth20 Authentication Server `issuer` 69 | 70 | Returns: 71 | the standardised discovery document URL. Note that no attempt to fetch this document is 72 | made. 73 | 74 | """ 75 | return well_known_uri(issuer, "oauth-authorization-server", at_root=True) 76 | -------------------------------------------------------------------------------- /requests_oauth2client/exceptions.py: -------------------------------------------------------------------------------- 1 | """This module contains all exception classes from `requests_oauth2client`.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | import requests 9 | 10 | from requests_oauth2client.authorization_request import AuthorizationRequest 11 | from requests_oauth2client.client import OAuth2Client 12 | 13 | 14 | class OAuth2Error(Exception): 15 | """Base class for Exceptions raised when a backend endpoint returns an error. 16 | 17 | Args: 18 | response: the HTTP response containing the error 19 | client : the OAuth2Client used to send the request 20 | description: description of the error 21 | 22 | """ 23 | 24 | def __init__(self, response: requests.Response, client: OAuth2Client, description: str | None = None) -> None: 25 | super().__init__(f"The remote endpoint returned an error: {description or 'no description provided'}") 26 | self.response = response 27 | self.client = client 28 | self.description = description 29 | 30 | @property 31 | def request(self) -> requests.PreparedRequest: 32 | """The request leading to the error.""" 33 | return self.response.request 34 | 35 | 36 | class EndpointError(OAuth2Error): 37 | """Base class for exceptions raised from backend endpoint errors. 38 | 39 | This contains the error message, description and uri that are returned 40 | by the AS in the OAuth 2.0 standardised way. 41 | 42 | Args: 43 | response: the raw response containing the error. 44 | error: the `error` identifier as returned by the AS. 45 | description: the `error_description` as returned by the AS. 46 | uri: the `error_uri` as returned by the AS. 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | response: requests.Response, 53 | client: OAuth2Client, 54 | error: str, 55 | description: str | None = None, 56 | uri: str | None = None, 57 | ) -> None: 58 | super().__init__(response=response, client=client, description=description) 59 | self.error = error 60 | self.uri = uri 61 | 62 | 63 | class InvalidTokenResponse(OAuth2Error): 64 | """Raised when the Token Endpoint returns a non-standard response.""" 65 | 66 | 67 | class UnknownTokenEndpointError(EndpointError): 68 | """Raised when the token endpoint returns an otherwise unknown error.""" 69 | 70 | 71 | class ServerError(EndpointError): 72 | """Raised when the token endpoint returns `error = server_error`.""" 73 | 74 | 75 | class TokenEndpointError(EndpointError): 76 | """Base class for errors that are specific to the token endpoint.""" 77 | 78 | 79 | class InvalidRequest(TokenEndpointError): 80 | """Raised when the Token Endpoint returns `error = invalid_request`.""" 81 | 82 | 83 | class InvalidClient(TokenEndpointError): 84 | """Raised when the Token Endpoint returns `error = invalid_client`.""" 85 | 86 | 87 | class InvalidScope(TokenEndpointError): 88 | """Raised when the Token Endpoint returns `error = invalid_scope`.""" 89 | 90 | 91 | class InvalidTarget(TokenEndpointError): 92 | """Raised when the Token Endpoint returns `error = invalid_target`.""" 93 | 94 | 95 | class InvalidGrant(TokenEndpointError): 96 | """Raised when the Token Endpoint returns `error = invalid_grant`.""" 97 | 98 | 99 | class UseDPoPNonce(TokenEndpointError): 100 | """Raised when the Token Endpoint raises error = use_dpop_nonce`.""" 101 | 102 | 103 | class AccessDenied(EndpointError): 104 | """Raised when the Authorization Server returns `error = access_denied`.""" 105 | 106 | 107 | class UnauthorizedClient(EndpointError): 108 | """Raised when the Authorization Server returns `error = unauthorized_client`.""" 109 | 110 | 111 | class RevocationError(EndpointError): 112 | """Base class for Revocation Endpoint errors.""" 113 | 114 | 115 | class UnsupportedTokenType(RevocationError): 116 | """Raised when the Revocation endpoint returns `error = unsupported_token_type`.""" 117 | 118 | 119 | class IntrospectionError(EndpointError): 120 | """Base class for Introspection Endpoint errors.""" 121 | 122 | 123 | class UnknownIntrospectionError(OAuth2Error): 124 | """Raised when the Introspection Endpoint returns a non-standard error.""" 125 | 126 | 127 | class DeviceAuthorizationError(EndpointError): 128 | """Base class for Device Authorization Endpoint errors.""" 129 | 130 | 131 | class AuthorizationPending(TokenEndpointError): 132 | """Raised when the Token Endpoint returns `error = authorization_pending`.""" 133 | 134 | 135 | class SlowDown(TokenEndpointError): 136 | """Raised when the Token Endpoint returns `error = slow_down`.""" 137 | 138 | 139 | class ExpiredToken(TokenEndpointError): 140 | """Raised when the Token Endpoint returns `error = expired_token`.""" 141 | 142 | 143 | class InvalidDeviceAuthorizationResponse(OAuth2Error): 144 | """Raised when the Device Authorization Endpoint returns a non-standard error response.""" 145 | 146 | 147 | class AuthorizationResponseError(Exception): 148 | """Base class for error responses returned by the Authorization endpoint. 149 | 150 | An `AuthorizationResponseError` contains the error message, description and uri that are 151 | returned by the AS. 152 | 153 | Args: 154 | error: the `error` identifier as returned by the AS 155 | description: the `error_description` as returned by the AS 156 | uri: the `error_uri` as returned by the AS 157 | 158 | """ 159 | 160 | def __init__( 161 | self, 162 | request: AuthorizationRequest, 163 | response: str, 164 | error: str, 165 | description: str | None = None, 166 | uri: str | None = None, 167 | ) -> None: 168 | self.error = error 169 | self.description = description 170 | self.uri = uri 171 | self.request = request 172 | self.response = response 173 | 174 | 175 | class InteractionRequired(AuthorizationResponseError): 176 | """Raised when the Authorization Endpoint returns `error = interaction_required`.""" 177 | 178 | 179 | class LoginRequired(InteractionRequired): 180 | """Raised when the Authorization Endpoint returns `error = login_required`.""" 181 | 182 | 183 | class AccountSelectionRequired(InteractionRequired): 184 | """Raised when the Authorization Endpoint returns `error = account_selection_required`.""" 185 | 186 | 187 | class SessionSelectionRequired(InteractionRequired): 188 | """Raised when the Authorization Endpoint returns `error = session_selection_required`.""" 189 | 190 | 191 | class ConsentRequired(InteractionRequired): 192 | """Raised when the Authorization Endpoint returns `error = consent_required`.""" 193 | 194 | 195 | class InvalidAuthResponse(ValueError): 196 | """Raised when the Authorization Endpoint returns an invalid response.""" 197 | 198 | def __init__(self, message: str, request: AuthorizationRequest, response: str) -> None: 199 | super().__init__(f"The Authorization Response is invalid: {message}") 200 | self.request = request 201 | self.response = response 202 | 203 | 204 | class MissingAuthCode(InvalidAuthResponse): 205 | """Raised when the Authorization Endpoint does not return the mandatory `code`. 206 | 207 | This happens when the Authorization Endpoint does not return an error, but does not return an 208 | authorization `code` either. 209 | 210 | """ 211 | 212 | def __init__(self, request: AuthorizationRequest, response: str) -> None: 213 | super().__init__("missing `code` query parameter in response", request, response) 214 | 215 | 216 | class MissingIssuer(InvalidAuthResponse): 217 | """Raised when the Authorization Endpoint does not return an `iss` parameter as expected. 218 | 219 | The Authorization Server advertises its support with a flag 220 | `authorization_response_iss_parameter_supported` in its discovery document. If it is set to 221 | `true`, it must include an `iss` parameter in its authorization responses, containing its issuer 222 | identifier. 223 | 224 | """ 225 | 226 | def __init__(self, request: AuthorizationRequest, response: str) -> None: 227 | super().__init__("missing `iss` query parameter in response", request, response) 228 | 229 | 230 | class MismatchingState(InvalidAuthResponse): 231 | """Raised on mismatching `state` value. 232 | 233 | This happens when the Authorization Endpoints returns a 'state' parameter that doesn't match the value passed in the 234 | Authorization Request. 235 | 236 | """ 237 | 238 | def __init__(self, received: str, expected: str, request: AuthorizationRequest, response: str) -> None: 239 | super().__init__(f"mismatching `state` (received '{received}', expected '{expected}')", request, response) 240 | self.received = received 241 | self.expected = expected 242 | 243 | 244 | class MismatchingIssuer(InvalidAuthResponse): 245 | """Raised on mismatching `iss` value. 246 | 247 | This happens when the Authorization Endpoints returns an 'iss' that doesn't match the expected value. 248 | 249 | """ 250 | 251 | def __init__(self, received: str, expected: str, request: AuthorizationRequest, response: str) -> None: 252 | super().__init__(f"mismatching `iss` (received '{received}', expected '{expected}')", request, response) 253 | self.received = received 254 | self.expected = expected 255 | 256 | 257 | class BackChannelAuthenticationError(EndpointError): 258 | """Base class for errors returned by the BackChannel Authentication endpoint.""" 259 | 260 | 261 | class InvalidBackChannelAuthenticationResponse(OAuth2Error): 262 | """Raised when the BackChannel Authentication endpoint returns a non-standard response.""" 263 | 264 | 265 | class InvalidPushedAuthorizationResponse(OAuth2Error): 266 | """Raised when the Pushed Authorization Endpoint returns an error.""" 267 | -------------------------------------------------------------------------------- /requests_oauth2client/flask/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains helper classes for the Flask Framework. 2 | 3 | See [Flask framework](https://flask.palletsprojects.com). 4 | 5 | """ 6 | 7 | from .auth import FlaskOAuth2ClientCredentialsAuth 8 | 9 | __all__ = ["FlaskOAuth2ClientCredentialsAuth"] 10 | -------------------------------------------------------------------------------- /requests_oauth2client/flask/auth.py: -------------------------------------------------------------------------------- 1 | """Helper classes for the [Flask](https://flask.palletsprojects.com) framework.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from flask import session 8 | 9 | from requests_oauth2client.auth import OAuth2ClientCredentialsAuth 10 | from requests_oauth2client.tokens import BearerToken, BearerTokenSerializer 11 | 12 | 13 | class FlaskSessionAuthMixin: 14 | """A Mixin for auth handlers to store their tokens in Flask session. 15 | 16 | Storing tokens in Flask session does ensure that each user of a Flask application has a 17 | different access token, and that tokens used for backend API access will be persisted between 18 | multiple requests to the front-end Flask app. 19 | 20 | Args: 21 | session_key: the key that will be used to store the access token in session. 22 | serializer: the serializer that will be used to store the access token in session. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | session_key: str, 29 | serializer: BearerTokenSerializer | None = None, 30 | *args: Any, 31 | **token_kwargs: Any, 32 | ) -> None: 33 | self.serializer = serializer or BearerTokenSerializer() 34 | self.session_key = session_key 35 | super().__init__(*args, **token_kwargs) 36 | 37 | @property 38 | def token(self) -> BearerToken | None: 39 | """Return the Access Token stored in session. 40 | 41 | Returns: 42 | The current `BearerToken` for this session, if any. 43 | 44 | """ 45 | serialized_token = session.get(self.session_key) 46 | if serialized_token is None: 47 | return None 48 | return self.serializer.loads(serialized_token) 49 | 50 | @token.setter 51 | def token(self, token: BearerToken | str | None) -> None: 52 | """Store an Access Token in session. 53 | 54 | Args: 55 | token: the token to store 56 | 57 | """ 58 | if isinstance(token, str): 59 | token = BearerToken(token) # pragma: no cover 60 | if token: 61 | serialized_token = self.serializer.dumps(token) 62 | session[self.session_key] = serialized_token 63 | elif session and self.session_key in session: 64 | session.pop(self.session_key, None) 65 | 66 | 67 | class FlaskOAuth2ClientCredentialsAuth(FlaskSessionAuthMixin, OAuth2ClientCredentialsAuth): # type: ignore[misc] 68 | """A `requests` Auth handler for CC grant that stores its token in Flask session. 69 | 70 | It will automatically get Access Tokens from an OAuth 2.x AS with the Client Credentials grant 71 | (and can get a new one once the first one is expired), and stores the retrieved token, 72 | serialized in Flask `session`, so that each user has a different access token. 73 | 74 | """ 75 | -------------------------------------------------------------------------------- /requests_oauth2client/pooling.py: -------------------------------------------------------------------------------- 1 | """Contains base classes for pooling jobs.""" 2 | 3 | from __future__ import annotations 4 | 5 | import time 6 | from typing import TYPE_CHECKING, Any 7 | 8 | from attrs import define, field, setters 9 | 10 | from .exceptions import AuthorizationPending, SlowDown 11 | 12 | if TYPE_CHECKING: 13 | from .client import OAuth2Client 14 | from .tokens import BearerToken 15 | 16 | 17 | @define 18 | class BaseTokenEndpointPoolingJob: 19 | """Base class for Token Endpoint pooling jobs. 20 | 21 | This is used for decoupled flows like CIBA or Device Authorization. 22 | 23 | This class must be subclassed to implement actual BackChannel flows. This needs an 24 | [OAuth2Client][requests_oauth2client.client.OAuth2Client] that will be used to pool the token 25 | endpoint. The initial pooling `interval` is configurable. 26 | 27 | """ 28 | 29 | client: OAuth2Client = field(on_setattr=setters.frozen) 30 | requests_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) 31 | token_kwargs: dict[str, Any] = field(on_setattr=setters.frozen) 32 | slow_down_interval: int = field(on_setattr=setters.frozen) 33 | interval: int 34 | 35 | def __call__(self) -> BearerToken | None: 36 | """Wrap the actual Token Endpoint call with a pooling interval. 37 | 38 | Everytime this method is called, it will wait for the entire duration of the pooling 39 | interval before calling 40 | [token_request()][requests_oauth2client.pooling.TokenEndpointPoolingJob.token_request]. So 41 | you can call it immediately after initiating the BackChannel flow, and it will wait before 42 | initiating the first call. 43 | 44 | This implements the logic to handle 45 | [AuthorizationPending][requests_oauth2client.exceptions.AuthorizationPending] or 46 | [SlowDown][requests_oauth2client.exceptions.SlowDown] requests by the AS. 47 | 48 | Returns: 49 | a `BearerToken` if the AS returns one, or `None` if the Authorization is still pending. 50 | 51 | """ 52 | self.sleep() 53 | try: 54 | return self.token_request() 55 | except SlowDown: 56 | self.slow_down() 57 | except AuthorizationPending: 58 | self.authorization_pending() 59 | return None 60 | 61 | def sleep(self) -> None: 62 | """Implement the wait between two requests of the token endpoint. 63 | 64 | By default, relies on time.sleep(). 65 | 66 | """ 67 | time.sleep(self.interval) 68 | 69 | def slow_down(self) -> None: 70 | """Implement the behavior when receiving a 'slow_down' response from the AS. 71 | 72 | By default, it increases the pooling interval by the slow down interval. 73 | 74 | """ 75 | self.interval += self.slow_down_interval 76 | 77 | def authorization_pending(self) -> None: 78 | """Implement the behavior when receiving an 'authorization_pending' response from the AS. 79 | 80 | By default, it does nothing. 81 | 82 | """ 83 | 84 | def token_request(self) -> BearerToken: 85 | """Abstract method for the token endpoint call. 86 | 87 | Subclasses must implement this. This method must raise 88 | [AuthorizationPending][requests_oauth2client.exceptions.AuthorizationPending] to retry after 89 | the pooling interval, or [SlowDown][requests_oauth2client.exceptions.SlowDown] to increase 90 | the pooling interval by `slow_down_interval` seconds. 91 | 92 | Returns: 93 | a [BearerToken][requests_oauth2client.tokens.BearerToken] 94 | 95 | """ 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /requests_oauth2client/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guillp/requests_oauth2client/810e7b6099ac89742adbf7877d7a8b0f785c4016/requests_oauth2client/py.typed -------------------------------------------------------------------------------- /requests_oauth2client/utils.py: -------------------------------------------------------------------------------- 1 | """Various utilities used in multiple places. 2 | 3 | This module contains helper methods that are used in multiple places. 4 | 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from contextlib import suppress 10 | from datetime import datetime, timedelta, timezone 11 | from functools import wraps 12 | from typing import TYPE_CHECKING, Any, Callable 13 | 14 | from furl import furl # type: ignore[import-untyped] 15 | 16 | if TYPE_CHECKING: 17 | from collections.abc import Iterator 18 | 19 | 20 | class InvalidUri(ValueError): 21 | """Raised when a URI does not pass validation by `validate_endpoint_uri()`.""" 22 | 23 | def __init__( 24 | self, url: str, *, https: bool, no_credentials: bool, no_port: bool, no_fragment: bool, path: bool 25 | ) -> None: 26 | super().__init__("Invalid endpoint uri.") 27 | self.url = url 28 | self.https = https 29 | self.no_credentials = no_credentials 30 | self.no_port = no_port 31 | self.no_fragment = no_fragment 32 | self.path = path 33 | 34 | def errors(self) -> Iterator[str]: 35 | """Iterate over all error descriptions, as str.""" 36 | if self.https: 37 | yield "must use https" 38 | if self.no_credentials: 39 | yield "must not contain basic credentials" 40 | if self.no_port: 41 | yield "no custom port number allowed" 42 | if self.no_fragment: 43 | yield "must not contain a uri fragment" 44 | if self.path: 45 | yield "must include a path other than /" 46 | 47 | def __str__(self) -> str: 48 | all_errors = ", ".join(self.errors()) 49 | return f"Invalid URI: {all_errors}" 50 | 51 | 52 | def validate_endpoint_uri( 53 | uri: str, 54 | *, 55 | https: bool = True, 56 | no_credentials: bool = True, 57 | no_port: bool = False, 58 | no_fragment: bool = True, 59 | path: bool = True, 60 | ) -> str: 61 | """Validate that a URI is suitable as an endpoint URI. 62 | 63 | It checks: 64 | 65 | - that the scheme is `https` 66 | - that no custom port number is being used 67 | - that no username or password are included 68 | - that no fragment is included 69 | - that a path is present 70 | 71 | Those checks can be individually disabled by using the parameters. 72 | 73 | Args: 74 | uri: the uri 75 | https: if `True`, check that the uri is https 76 | no_port: if `True`, check that no custom port number is included 77 | no_credentials: if ` True`, check that no username/password are included 78 | no_fragment: if `True`, check that the uri contains no fragment 79 | path: if `True`, check that the uri contains a path component 80 | 81 | Raises: 82 | ValueError: if the supplied url is not suitable 83 | 84 | Returns: 85 | the endpoint URI, if all checks passed 86 | 87 | """ 88 | url = furl(uri) 89 | if https and url.scheme == "https": 90 | https = False 91 | if no_port and url.port == 443: # noqa: PLR2004 92 | no_port = False 93 | if no_credentials and not url.username and not url.password: 94 | no_credentials = False 95 | if no_fragment and not url.fragment: 96 | no_fragment = False 97 | if path and url.path and url.path != "/": 98 | path = False 99 | 100 | if https or no_port or no_credentials or no_fragment or path: 101 | raise InvalidUri( 102 | uri, https=https, no_port=no_port, no_credentials=no_credentials, no_fragment=no_fragment, path=path 103 | ) 104 | 105 | return uri 106 | 107 | 108 | def validate_issuer_uri(uri: str) -> str: 109 | """Validate that an Issuer Identifier URI is valid. 110 | 111 | This is almost the same as a valid endpoint URI, but a path is not mandatory. 112 | 113 | """ 114 | return validate_endpoint_uri(uri, path=False) 115 | 116 | 117 | def accepts_expires_in(f: Callable[..., Any]) -> Callable[..., Any]: 118 | """Decorate methods to handle both `expires_at` and `expires_in`. 119 | 120 | This decorates methods that accept an `expires_at` datetime parameter, to also allow an 121 | `expires_in` parameter in seconds. 122 | 123 | If supplied, `expires_in` will be converted to a datetime `expires_in` seconds in the future, 124 | and passed as `expires_at` in the decorated method. 125 | 126 | Args: 127 | f: the method to decorate, with an `expires_at` parameter 128 | 129 | Returns: 130 | a decorated method that accepts either `expires_in` or `expires_at`. 131 | 132 | """ 133 | 134 | @wraps(f) 135 | def decorator( 136 | *args: Any, 137 | expires_in: int | str | None = None, 138 | expires_at: datetime | None = None, 139 | **kwargs: Any, 140 | ) -> Any: 141 | if expires_in is None and expires_at is None: 142 | return f(*args, **kwargs) 143 | if expires_in and isinstance(expires_in, str): 144 | with suppress(ValueError): 145 | expires_at = datetime.now(tz=timezone.utc).replace(microsecond=0) + timedelta(seconds=int(expires_in)) 146 | elif expires_in and isinstance(expires_in, int): 147 | expires_at = datetime.now(tz=timezone.utc).replace(microsecond=0) + timedelta(seconds=expires_in) 148 | return f(*args, expires_at=expires_at, **kwargs) 149 | 150 | return decorator 151 | -------------------------------------------------------------------------------- /requests_oauth2client/vendor_specific/__init__.py: -------------------------------------------------------------------------------- 1 | """Vendor-specific utilities. 2 | 3 | This module contains vendor-specific subclasses of [requests_oauth2client] classes, that make it easier to work with 4 | specific OAuth 2.x providers and/or fix compatibility issues. 5 | 6 | """ 7 | 8 | from .auth0 import Auth0 9 | from .ping import Ping 10 | 11 | __all__ = ["Auth0", "Ping"] 12 | -------------------------------------------------------------------------------- /requests_oauth2client/vendor_specific/auth0.py: -------------------------------------------------------------------------------- 1 | """Implements subclasses for [Auth0](https://auth0.com).""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING, Any 6 | 7 | from requests_oauth2client import ApiClient, OAuth2Client, OAuth2ClientCredentialsAuth 8 | 9 | if TYPE_CHECKING: 10 | import requests 11 | from jwskate import Jwk 12 | 13 | 14 | class Auth0: 15 | """Auth0-related utilities.""" 16 | 17 | @classmethod 18 | def tenant(cls, tenant: str) -> str: 19 | """Given a short tenant name, returns the full tenant FQDN.""" 20 | if not tenant: 21 | msg = "You must specify a tenant name." 22 | raise ValueError(msg) 23 | if "." not in tenant or tenant.endswith((".eu", ".us", ".au", ".jp")): 24 | tenant = f"{tenant}.auth0.com" 25 | if "://" in tenant: 26 | if tenant.startswith("https://"): 27 | return tenant[8:] 28 | msg = ( 29 | "Invalid tenant name. " 30 | "It must be a tenant name like 'mytenant.myregion' " 31 | "or a full FQDN like 'mytenant.myregion.auth0.com'." 32 | "or an issuer like 'https://mytenant.myregion.auth0.com'" 33 | ) 34 | raise ValueError(msg) 35 | return tenant 36 | 37 | @classmethod 38 | def client( 39 | cls, 40 | tenant: str, 41 | auth: ( 42 | requests.auth.AuthBase | tuple[str, str] | tuple[str, Jwk] | tuple[str, dict[str, Any]] | str | None 43 | ) = None, 44 | *, 45 | client_id: str | None = None, 46 | client_secret: str | None = None, 47 | private_jwk: Any | None = None, 48 | session: requests.Session | None = None, 49 | **kwargs: Any, 50 | ) -> OAuth2Client: 51 | """Initialise an OAuth2Client for an Auth0 tenant.""" 52 | tenant = cls.tenant(tenant) 53 | issuer = f"https://{tenant}" 54 | token_endpoint = f"{issuer}/oauth/token" 55 | authorization_endpoint = f"{issuer}/authorize" 56 | revocation_endpoint = f"{issuer}/oauth/revoke" 57 | userinfo_endpoint = f"{issuer}/userinfo" 58 | jwks_uri = f"{issuer}/.well-known/jwks.json" 59 | 60 | return OAuth2Client( 61 | auth=auth, 62 | client_id=client_id, 63 | client_secret=client_secret, 64 | private_jwk=private_jwk, 65 | session=session, 66 | token_endpoint=token_endpoint, 67 | authorization_endpoint=authorization_endpoint, 68 | revocation_endpoint=revocation_endpoint, 69 | userinfo_endpoint=userinfo_endpoint, 70 | issuer=issuer, 71 | jwks_uri=jwks_uri, 72 | **kwargs, 73 | ) 74 | 75 | @classmethod 76 | def management_api_client( 77 | cls, 78 | tenant: str, 79 | auth: ( 80 | requests.auth.AuthBase | tuple[str, str] | tuple[str, Jwk] | tuple[str, dict[str, Any]] | str | None 81 | ) = None, 82 | *, 83 | client_id: str | None = None, 84 | client_secret: str | None = None, 85 | private_jwk: Any | None = None, 86 | session: requests.Session | None = None, 87 | **kwargs: Any, 88 | ) -> ApiClient: 89 | """Initialize a client for the Auth0 Management API. 90 | 91 | See [Auth0 Management API v2](https://auth0.com/docs/api/management/v2). You must provide the 92 | target tenant name and the credentials for a client that is allowed access to the Management 93 | API. 94 | 95 | Args: 96 | tenant: the tenant name. 97 | Same definition as for [Auth0.client][requests_oauth2client.vendor_specific.auth0.Auth0.client] 98 | auth: client credentials. 99 | Same definition as for [OAuth2Client][requests_oauth2client.client.OAuth2Client] 100 | client_id: the Client ID. 101 | Same definition as for [OAuth2Client][requests_oauth2client.client.OAuth2Client] 102 | client_secret: the Client Secret. 103 | Same definition as for [OAuth2Client][requests_oauth2client.client.OAuth2Client] 104 | private_jwk: the private key to use for client authentication. 105 | Same definition as for [OAuth2Client][requests_oauth2client.client.OAuth2Client] 106 | session: requests session. 107 | Same definition as for [OAuth2Client][requests_oauth2client.client.OAuth2Client] 108 | **kwargs: additional kwargs to pass to the ApiClient base class 109 | 110 | Example: 111 | ```python 112 | from requests_oauth2client.vendor_specific import Auth0 113 | 114 | a0mgmt = Auth0.management_api_client("mytenant.eu", client_id=client_id, client_secret=client_secret) 115 | users = a0mgmt.get("users", params={"page": 0, "per_page": 100}) 116 | ``` 117 | 118 | """ 119 | tenant = cls.tenant(tenant) 120 | client = cls.client( 121 | tenant, 122 | auth=auth, 123 | client_id=client_id, 124 | client_secret=client_secret, 125 | private_jwk=private_jwk, 126 | session=session, 127 | ) 128 | audience = f"https://{tenant}/api/v2/" 129 | api_auth = OAuth2ClientCredentialsAuth(client, audience=audience) 130 | return ApiClient( 131 | base_url=audience, 132 | auth=api_auth, 133 | session=session, 134 | **kwargs, 135 | ) 136 | -------------------------------------------------------------------------------- /requests_oauth2client/vendor_specific/ping.py: -------------------------------------------------------------------------------- 1 | """PingID specific client.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | import requests # noqa: TC002 8 | 9 | from requests_oauth2client import OAuth2Client 10 | 11 | 12 | class Ping: 13 | """Ping Identity related utilities.""" 14 | 15 | @classmethod 16 | def client( 17 | cls, 18 | issuer: str, 19 | auth: requests.auth.AuthBase | tuple[str, str] | str | None = None, 20 | client_id: str | None = None, 21 | client_secret: str | None = None, 22 | private_jwk: Any = None, 23 | session: requests.Session | None = None, 24 | ) -> OAuth2Client: 25 | """Initialize an OAuth2Client for PingFederate. 26 | 27 | This will configure all endpoints with PingID specific urls, without using the metadata. 28 | Excepted for avoiding a round-trip to get the metadata url, this does not provide any advantage 29 | over using `OAuth2Client.from_discovery_endpoint(issuer="https://myissuer.domain.tld")`. 30 | 31 | """ 32 | if not issuer.startswith("https://"): 33 | if "://" in issuer: 34 | msg = "Invalid issuer. It must be an https:// url or a domain name without a scheme." 35 | raise ValueError(msg) 36 | issuer = f"https://{issuer}" 37 | if "." not in issuer: 38 | msg = "Invalid issuer. It must contain at least a dot in the domain name." 39 | raise ValueError(msg) 40 | 41 | return OAuth2Client( 42 | authorization_endpoint=f"{issuer}/as/authorization.oauth2", 43 | token_endpoint=f"{issuer}/as/token.oauth2", 44 | revocation_endpoint=f"{issuer}/as/revoke_token.oauth2", 45 | userinfo_endpoint=f"{issuer}/idp/userinfo.openid", 46 | introspection_endpoint=f"{issuer}/as/introspect.oauth2", 47 | jwks_uri=f"{issuer}/pf/JWKS", 48 | registration_endpoint=f"{issuer}/as/clients.oauth2", 49 | ping_revoked_sris_endpoint=f"{issuer}/pf-ws/rest/sessionMgmt/revokedSris", 50 | ping_session_management_sris_endpoint=f"{issuer}/pf-ws/rest/sessionMgmt/sessions", 51 | ping_session_management_users_endpoint=f"{issuer}/pf-ws/rest/sessionMgmt/users", 52 | ping_end_session_endpoint=f"{issuer}/idp/startSLO.ping", 53 | device_authorization_endpoint=f"{issuer}/as/device_authz.oauth2", 54 | auth=auth, 55 | client_id=client_id, 56 | client_secret=client_secret, 57 | private_jwk=private_jwk, 58 | session=session, 59 | ) 60 | -------------------------------------------------------------------------------- /tests/.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = requests_oauth2client 4 | omit = 5 | 6 | 7 | [report] 8 | # Regexes for lines to exclude from consideration 9 | exclude_lines = 10 | # Have to re-enable the standard pragma 11 | pragma: no cover 12 | 13 | # Don't complain about missing debug-only code: 14 | def __repr__ 15 | if self\.debug 16 | 17 | # Don't complain if tests don't hit defensive assertion code: 18 | raise AssertionError 19 | raise NotImplementedError 20 | 21 | # Don't complain if non-runnable code isn't run: 22 | if 0: 23 | if __name__ == .__main__.: 24 | if TYPE_CHECKING: 25 | 26 | ignore_errors = True 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guillp/requests_oauth2client/810e7b6099ac89742adbf7877d7a8b0f785c4016/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_authorization_code.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import secrets 4 | from datetime import datetime, timedelta, timezone 5 | 6 | import requests 7 | from freezegun import freeze_time 8 | from furl import Query, furl # type: ignore[import-untyped] 9 | from jwskate import Jwk 10 | from requests_mock import Mocker 11 | 12 | from requests_oauth2client import ( 13 | AuthorizationRequest, 14 | BearerToken, 15 | ClientSecretPost, 16 | IdToken, 17 | OAuth2Client, 18 | oidc_discovery_document_url, 19 | ) 20 | 21 | 22 | @freeze_time() 23 | def test_authorization_code( 24 | session: requests.Session, 25 | requests_mock: Mocker, 26 | issuer: str, 27 | token_endpoint: str, 28 | authorization_endpoint: str, 29 | jwks_uri: str, 30 | discovery_document: str, 31 | client_id: str, 32 | client_secret: str, 33 | redirect_uri: str, 34 | scope: str, 35 | audience: str, 36 | ) -> None: 37 | id_token_sig_alg = "ES256" 38 | id_token_signing_key = Jwk.generate(alg=id_token_sig_alg).with_kid_thumbprint() 39 | 40 | requests_mock.get(issuer + "/.well-known/openid-configuration", json=discovery_document) 41 | requests_mock.get(jwks_uri, json={"keys": [id_token_signing_key.public_jwk().to_dict()]}) 42 | client = OAuth2Client.from_discovery_endpoint( 43 | issuer=issuer, 44 | client_id=client_id, 45 | client_secret=client_secret, 46 | redirect_uri=redirect_uri, 47 | id_token_signed_response_alg=id_token_sig_alg, 48 | ) 49 | authorization_request = client.authorization_request(scope=scope, audience=audience) 50 | assert authorization_request.authorization_endpoint == authorization_endpoint 51 | assert authorization_request.client_id == client_id 52 | assert authorization_request.response_type == "code" 53 | assert authorization_request.redirect_uri == redirect_uri 54 | assert authorization_request.scope is not None 55 | assert " ".join(authorization_request.scope) == scope 56 | assert authorization_request.state is not None 57 | assert authorization_request.nonce is not None 58 | assert authorization_request.audience == audience 59 | assert authorization_request.code_challenge_method == "S256" 60 | assert authorization_request.code_challenge is not None 61 | 62 | authorization_code = secrets.token_urlsafe() 63 | state = authorization_request.state 64 | 65 | authorization_response = furl(redirect_uri, query={"code": authorization_code, "state": state}).url 66 | 67 | access_token = secrets.token_urlsafe() 68 | 69 | c_hash = IdToken.hash_method(id_token_signing_key)(authorization_code) 70 | at_hash = IdToken.hash_method(id_token_signing_key)(access_token) 71 | s_hash = IdToken.hash_method(id_token_signing_key)(state) 72 | 73 | id_token = IdToken.sign( 74 | { 75 | "iss": issuer, 76 | "sub": "248289761001", 77 | "aud": client_id, 78 | "nonce": authorization_request.nonce, 79 | "iat": IdToken.timestamp(), 80 | "exp": IdToken.timestamp(60), 81 | "c_hash": c_hash, 82 | "at_hash": at_hash, 83 | "s_hash": s_hash, 84 | "auth_time": IdToken.timestamp(), 85 | }, 86 | key=id_token_signing_key, 87 | ) 88 | code_verifier = authorization_request.code_verifier 89 | assert code_verifier is not None 90 | assert ( 91 | base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=") 92 | == authorization_request.code_challenge.encode() 93 | ) 94 | 95 | auth_response = authorization_request.validate_callback(authorization_response) 96 | 97 | requests_mock.post( 98 | token_endpoint, 99 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600, "id_token": str(id_token)}, 100 | ) 101 | token = client.authorization_code(code=auth_response) 102 | 103 | assert isinstance(token, BearerToken) 104 | assert token.access_token == access_token 105 | assert not token.is_expired() 106 | assert token.expires_at is not None 107 | assert token.expires_at == datetime.now(tz=timezone.utc).replace(microsecond=0) + timedelta(seconds=3600) 108 | 109 | assert requests_mock.last_request is not None 110 | params = Query(requests_mock.last_request.text).params 111 | assert params.get("client_id") == client_id 112 | assert params.get("client_secret") == client_secret 113 | assert params.get("grant_type") == "authorization_code" 114 | assert params.get("code") == authorization_code 115 | 116 | 117 | @freeze_time() 118 | def test_authorization_code_legacy( 119 | session: requests.Session, 120 | requests_mock: Mocker, 121 | issuer: str, 122 | discovery_document: str, 123 | client_id: str, 124 | client_secret: str, 125 | redirect_uri: str, 126 | scope: str, 127 | audience: str, 128 | ) -> None: 129 | discovery_url = oidc_discovery_document_url(issuer) 130 | requests_mock.get(discovery_url, json=discovery_document) 131 | discovery = session.get(discovery_url).json() 132 | authorization_endpoint = discovery.get("authorization_endpoint") 133 | assert authorization_endpoint 134 | token_endpoint = discovery.get("token_endpoint") 135 | assert token_endpoint 136 | 137 | authorization_request = AuthorizationRequest( 138 | authorization_endpoint, 139 | client_id=client_id, 140 | redirect_uri=redirect_uri, 141 | scope=scope, 142 | audience=audience, 143 | ) 144 | 145 | authorization_code = secrets.token_urlsafe() 146 | 147 | state = authorization_request.state 148 | 149 | authorization_response = furl(redirect_uri, query={"code": authorization_code, "state": state}).url 150 | requests_mock.get( 151 | authorization_request.uri, 152 | status_code=302, 153 | headers={"Location": authorization_response}, 154 | ) 155 | resp = requests.get(authorization_request.uri, allow_redirects=False) 156 | assert resp.status_code == 302 157 | location = resp.headers.get("Location") 158 | assert location == authorization_response 159 | assert requests_mock.last_request is not None 160 | qs = Query(requests_mock.last_request.qs).params 161 | assert qs.get("client_id") == client_id 162 | assert qs.get("response_type") == "code" 163 | assert qs.get("redirect_uri") == redirect_uri 164 | assert qs.get("state") == state 165 | code_challenge = qs.get("code_challenge") 166 | assert code_challenge 167 | code_challenge_method = requests_mock.last_request.qs.get("code_challenge_method") 168 | code_verifier = authorization_request.code_verifier 169 | assert code_verifier is not None 170 | assert code_challenge_method == ["S256"] 171 | assert ( 172 | base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=") 173 | == code_challenge.encode() 174 | ) 175 | 176 | auth_response = authorization_request.validate_callback(location) 177 | 178 | client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 179 | 180 | access_token = secrets.token_urlsafe() 181 | 182 | requests_mock.post( 183 | token_endpoint, 184 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 185 | ) 186 | token = client.authorization_code(code=auth_response, redirect_uri=redirect_uri, validate=False) 187 | 188 | assert isinstance(token, BearerToken) 189 | assert token.access_token == access_token 190 | assert not token.is_expired() 191 | assert token.expires_at is not None 192 | assert token.expires_at == datetime.now(tz=timezone.utc).replace(microsecond=0) + timedelta(seconds=3600) 193 | 194 | assert requests_mock.last_request is not None 195 | params = Query(requests_mock.last_request.text).params 196 | assert params.get("client_id") == client_id 197 | assert params.get("client_secret") == client_secret 198 | assert params.get("grant_type") == "authorization_code" 199 | assert params.get("code") == authorization_code 200 | -------------------------------------------------------------------------------- /tests/test_client_credentials.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import parse_qs 2 | 3 | import requests 4 | from requests_mock import Mocker 5 | 6 | from requests_oauth2client import ( 7 | ClientSecretPost, 8 | OAuth2Client, 9 | OAuth2ClientCredentialsAuth, 10 | ) 11 | 12 | 13 | def test_client_credentials_get_token( 14 | requests_mock: Mocker, 15 | client_id: str, 16 | client_secret: str, 17 | token_endpoint: str, 18 | target_api: str, 19 | access_token: str, 20 | ) -> None: 21 | requests_mock.post( 22 | token_endpoint, 23 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 24 | ) 25 | client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 26 | token_response = client.client_credentials() 27 | assert token_response.access_token == access_token 28 | 29 | assert requests_mock.last_request is not None 30 | params = parse_qs(requests_mock.last_request.text) 31 | assert params["client_id"][0] == client_id 32 | assert params["client_secret"][0] == client_secret 33 | assert params["grant_type"][0] == "client_credentials" 34 | 35 | 36 | def test_client_credentials_api( 37 | requests_mock: Mocker, 38 | access_token: str, 39 | token_endpoint: str, 40 | client_id: str, 41 | client_secret: str, 42 | target_api: str, 43 | ) -> None: 44 | client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 45 | auth = OAuth2ClientCredentialsAuth(client) 46 | 47 | requests_mock.post( 48 | token_endpoint, 49 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 50 | ) 51 | requests_mock.get(target_api, request_headers={"Authorization": f"Bearer {access_token}"}) 52 | response = requests.get(target_api, auth=auth) 53 | assert response.ok 54 | assert len(requests_mock.request_history) == 2 55 | token_request = requests_mock.request_history[0] 56 | api_request = requests_mock.request_history[1] 57 | params = parse_qs(token_request.text) 58 | assert params.get("client_id") == [client_id] 59 | assert params.get("client_secret") == [client_secret] 60 | assert params.get("grant_type") == ["client_credentials"] 61 | 62 | assert api_request.headers.get("Authorization") == f"Bearer {access_token}" 63 | -------------------------------------------------------------------------------- /tests/test_device_authorization.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | import pytest 4 | from furl import Query # type: ignore[import-untyped] 5 | from requests_mock import Mocker 6 | 7 | from requests_oauth2client import ( 8 | BearerToken, 9 | ClientSecretBasic, 10 | DeviceAuthorizationError, 11 | DeviceAuthorizationPoolingJob, 12 | InvalidDeviceAuthorizationResponse, 13 | OAuth2Client, 14 | PublicApp, 15 | ) 16 | from tests.conftest import FixtureRequest, join_url 17 | 18 | 19 | @pytest.fixture(params=["device", "oauth/device"]) 20 | def device_authorization_endpoint(request: FixtureRequest, issuer: str) -> str: 21 | return join_url(issuer, request.param) 22 | 23 | 24 | @pytest.mark.slow 25 | def test_device_authorization( 26 | requests_mock: Mocker, 27 | device_authorization_endpoint: str, 28 | token_endpoint: str, 29 | client_id: str, 30 | client_secret: str, 31 | ) -> None: 32 | device_code = secrets.token_urlsafe() 33 | user_code = secrets.token_urlsafe(6) 34 | verification_uri = "https://test.com/verify_device" 35 | 36 | client = OAuth2Client( 37 | token_endpoint=token_endpoint, 38 | device_authorization_endpoint=device_authorization_endpoint, 39 | auth=(client_id, client_secret), 40 | ) 41 | 42 | requests_mock.post( 43 | device_authorization_endpoint, 44 | json={ 45 | "device_code": device_code, 46 | "user_code": user_code, 47 | "verification_uri": verification_uri, 48 | "expires_in": 3600, 49 | "interval": 1, 50 | }, 51 | ) 52 | device_auth_resp = client.authorize_device() 53 | assert device_auth_resp.device_code 54 | assert device_auth_resp.user_code 55 | assert device_auth_resp.verification_uri 56 | assert not device_auth_resp.is_expired() 57 | 58 | assert requests_mock.last_request is not None 59 | params = Query(requests_mock.last_request.text).params 60 | assert params.get("client_id") == client_id 61 | assert params.get("client_secret") == client_secret 62 | 63 | access_token = secrets.token_urlsafe() 64 | 65 | requests_mock.post( 66 | token_endpoint, 67 | [ 68 | {"json": {"error": "authorization_pending"}, "status_code": 400}, 69 | {"json": {"error": "slow_down"}, "status_code": 400}, 70 | { 71 | "json": { 72 | "access_token": access_token, 73 | "token_type": "Bearer", 74 | "expires_in": 3600, 75 | } 76 | }, 77 | ], 78 | ) 79 | 80 | pool_job = DeviceAuthorizationPoolingJob( 81 | client, 82 | device_auth_resp, 83 | interval=1, 84 | slow_down_interval=2, 85 | ) 86 | 87 | # 1st attempt: authorization_pending 88 | resp = pool_job() 89 | assert requests_mock.last_request is not None 90 | params = Query(requests_mock.last_request.text).params 91 | assert params.get("client_id") == client_id 92 | assert params.get("client_secret") == client_secret 93 | 94 | assert pool_job.interval == 1 95 | assert resp is None 96 | 97 | # 2nd attempt: slow down 98 | resp = pool_job() 99 | assert requests_mock.last_request is not None 100 | params = Query(requests_mock.last_request.text).params 101 | assert params.get("client_id") == client_id 102 | assert params.get("client_secret") == client_secret 103 | 104 | assert pool_job.interval == 3 105 | assert resp is None 106 | 107 | # 3rd attempt: access token delivered 108 | resp = pool_job() 109 | assert isinstance(resp, BearerToken) 110 | assert requests_mock.last_request is not None 111 | params = Query(requests_mock.last_request.text).params 112 | assert params.get("client_id") == client_id 113 | assert params.get("client_secret") == client_secret 114 | 115 | assert not resp.is_expired() 116 | 117 | 118 | def test_auth_handler( 119 | token_endpoint: str, 120 | device_authorization_endpoint: str, 121 | client_id: str, 122 | client_secret: str, 123 | ) -> None: 124 | auth = ClientSecretBasic(client_id, client_secret) 125 | da_client = OAuth2Client( 126 | token_endpoint=token_endpoint, 127 | device_authorization_endpoint=device_authorization_endpoint, 128 | auth=auth, 129 | ) 130 | 131 | assert da_client.auth == auth 132 | 133 | da_client = OAuth2Client( 134 | token_endpoint=token_endpoint, 135 | device_authorization_endpoint=device_authorization_endpoint, 136 | auth=client_id, 137 | ) 138 | 139 | assert isinstance(da_client.auth, PublicApp) 140 | assert da_client.auth.client_id == client_id 141 | 142 | 143 | def test_invalid_response( 144 | requests_mock: Mocker, 145 | token_endpoint: str, 146 | device_authorization_endpoint: str, 147 | client_id: str, 148 | client_secret: str, 149 | ) -> None: 150 | da_client = OAuth2Client( 151 | token_endpoint=token_endpoint, 152 | device_authorization_endpoint=device_authorization_endpoint, 153 | auth=(client_id, client_secret), 154 | ) 155 | 156 | requests_mock.post( 157 | device_authorization_endpoint, 158 | status_code=500, 159 | json={"error": "unknown_error"}, 160 | ) 161 | with pytest.raises(DeviceAuthorizationError): 162 | da_client.authorize_device() 163 | 164 | requests_mock.post( 165 | device_authorization_endpoint, 166 | status_code=500, 167 | json={"foo": "bar"}, 168 | ) 169 | with pytest.raises(InvalidDeviceAuthorizationResponse): 170 | da_client.authorize_device() 171 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_examples import CodeExample, EvalExample, find_examples 3 | 4 | 5 | @pytest.mark.parametrize("example", find_examples("README.md"), ids=str) 6 | def test_readme(example: CodeExample, eval_example: EvalExample) -> None: 7 | eval_example.set_config(line_length=120, ruff_ignore=["D", "E402", "ERA001", "F", "S", "T"]) 8 | if eval_example.update_examples: 9 | eval_example.format(example) 10 | else: 11 | eval_example.lint(example) 12 | -------------------------------------------------------------------------------- /tests/test_oidc.py: -------------------------------------------------------------------------------- 1 | from freezegun import freeze_time 2 | from furl import furl # type: ignore[import-untyped] 3 | from jwskate import EncryptionAlgs, Jwk, Jwt 4 | 5 | from requests_oauth2client import IdToken, OAuth2Client 6 | from tests.conftest import RequestsMocker 7 | 8 | 9 | @freeze_time("2024-01-01 00:00:00") 10 | def test_encrypted_id_token(requests_mock: RequestsMocker) -> None: 11 | id_token_decryption_key = Jwk( 12 | { 13 | "kty": "EC", 14 | "crv": "P-256", 15 | "x": "GNWWCtwaKIdNjsz_ypPKEX1If_yL5w_mJeAepqEDNdk", 16 | "y": "qjfk0Og-Ov9cWxtuR3-Oxcr4MqW9LB4FLkQuo-ryUWE", 17 | "d": "y-ndvYzmafoeY9AlnUkoXIiNe5xf_h_23NEEATYKoY4", 18 | "alg": "ECDH-ES+A256KW", 19 | "kid": "RvIJrxavhz4CLxA9woSdt4szQkvBIxJtR_s8huPIfIQ", 20 | } 21 | ) 22 | id_token_encryption_key = id_token_decryption_key.public_jwk() 23 | 24 | id_token_signature_key = Jwk( 25 | { 26 | "kty": "EC", 27 | "crv": "P-256", 28 | "x": "Q9nRvw5sxTnl93FWc3oHvvbfREUt_1on0WVucVqSPvw", 29 | "y": "2dNrVWA0LHTwC8vOChVR29HbesoLCwbvaHwHcqKQSG4", 30 | "d": "pfpik5SEnMh6NcegGPrI0XOlf2YIx4wB7hws6-kO1fE", 31 | "alg": "ES256", 32 | "kid": "uiSjaT2_mswJWSBQ6Oj78RjpPnAQVz0iDkyLZHEkFvc", 33 | } 34 | ) 35 | id_token_verification_key = id_token_signature_key.public_jwk() 36 | 37 | subject = "user1" 38 | nonce = "mynonce" 39 | 40 | client_id = "myclientid" 41 | private_key = Jwk( 42 | { 43 | "kty": "EC", 44 | "crv": "P-256", 45 | "x": "mKV-T7IbQJwt6sakGn9kN3dCyMWIa3XqA_EyIUs_jzc", 46 | "y": "8sy4p5BzWwDjAULMokrgkCJwaPWNICTozriOUUA_KQ8", 47 | "d": "xitL_m0Y1lxjoOQINYcynNTJU-EopW4NiBeiMWE-3O8", 48 | "alg": "ES256", 49 | "kid": "Vs6sw5LGsEYfeiAs3rwiOwXKJpw4S926IaOpefvm-Ec", 50 | } 51 | ) 52 | token_endpoint = "https://as.local/token" 53 | authorization_endpoint = "https://as.local/authorize" 54 | issuer = "https://issuer" 55 | 56 | claims = {"iss": issuer, "iat": Jwt.timestamp(), "exp": Jwt.timestamp(60), "sub": subject, "nonce": nonce} 57 | id_token = Jwt.sign_and_encrypt( 58 | claims, sign_key=id_token_signature_key, enc_key=id_token_encryption_key, enc=EncryptionAlgs.A256CBC_HS512 59 | ) 60 | 61 | redirect_uri = "http://localhost:12345/callback" 62 | client = OAuth2Client( 63 | client_id=client_id, 64 | private_key=private_key, 65 | issuer=issuer, 66 | token_endpoint=token_endpoint, 67 | authorization_endpoint=authorization_endpoint, 68 | redirect_uri=redirect_uri, 69 | id_token_signed_response_alg="ES256", 70 | id_token_decryption_key=id_token_decryption_key, 71 | authorization_server_jwks=id_token_verification_key.as_jwks(), 72 | ) 73 | 74 | state = "mystate" 75 | 76 | authorization_code = "authorization_code" 77 | authorization_request = client.authorization_request(scope="openid", state=state, nonce=nonce) 78 | 79 | authorization_response = authorization_request.validate_callback( 80 | furl(redirect_uri).add(args={"code": authorization_code, "state": state, "iss": issuer}) 81 | ) 82 | 83 | access_token = "my_access_token" 84 | 85 | requests_mock.post( 86 | token_endpoint, 87 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600, "id_token": str(id_token)}, 88 | ) 89 | token_resp = client.authorization_code(authorization_response, validate=True) 90 | assert isinstance(token_resp.id_token, IdToken) 91 | assert token_resp.id_token.claims == claims 92 | -------------------------------------------------------------------------------- /tests/test_refresh_token.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | from requests_oauth2client import OAuth2Client 4 | 5 | from .conftest import RequestsMocker, RequestValidatorType 6 | 7 | 8 | def test_refresh_token( 9 | requests_mock: RequestsMocker, 10 | token_endpoint: str, 11 | revocation_endpoint: str, 12 | refresh_token: str, 13 | client_secret_post_auth_validator: RequestValidatorType, 14 | client_id: str, 15 | client_secret: str, 16 | refresh_token_grant_validator: RequestValidatorType, 17 | revocation_request_validator: RequestValidatorType, 18 | ) -> None: 19 | client = OAuth2Client( 20 | token_endpoint, 21 | revocation_endpoint=revocation_endpoint, 22 | auth=(client_id, client_secret), 23 | ) 24 | 25 | new_access_token = secrets.token_urlsafe() 26 | new_refresh_token = secrets.token_urlsafe() 27 | requests_mock.post( 28 | token_endpoint, 29 | json={ 30 | "access_token": new_access_token, 31 | "refresh_token": new_refresh_token, 32 | "token_type": "Bearer", 33 | "expires_in": 3600, 34 | }, 35 | ) 36 | token_resp = client.refresh_token(refresh_token) 37 | assert not token_resp.is_expired() 38 | assert token_resp.access_token == new_access_token 39 | assert token_resp.refresh_token == new_refresh_token 40 | 41 | refresh_token_grant_validator(requests_mock.last_request, refresh_token=refresh_token) 42 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 43 | 44 | requests_mock.post(revocation_endpoint) 45 | 46 | assert client.revoke_access_token(token_resp.access_token) is True 47 | 48 | revocation_request_validator(requests_mock.last_request, new_access_token, "access_token") 49 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 50 | 51 | assert client.revoke_refresh_token(token_resp.refresh_token) is True 52 | 53 | revocation_request_validator(requests_mock.last_request, new_refresh_token, "refresh_token") 54 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 55 | -------------------------------------------------------------------------------- /tests/test_token_exchange.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | 3 | import pytest 4 | from freezegun import freeze_time 5 | from furl import Query # type: ignore[import-untyped] 6 | 7 | from requests_oauth2client import BearerToken, ClientSecretPost, IdToken, OAuth2Client, UnknownTokenType 8 | from tests.conftest import RequestsMocker 9 | 10 | 11 | @freeze_time() 12 | def test_token_exchange( 13 | requests_mock: RequestsMocker, 14 | client_id: str, 15 | client_secret: str, 16 | token_endpoint: str, 17 | ) -> None: 18 | access_token = secrets.token_urlsafe() 19 | 20 | client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 21 | 22 | requests_mock.post( 23 | token_endpoint, 24 | json={ 25 | "access_token": access_token, 26 | "token_type": "Bearer", 27 | "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", 28 | "expires_in": 60, 29 | }, 30 | ) 31 | 32 | subject_token = "accVkjcJyb4BWCxGsndESCJQbdFMogUC5PbRDqceLTC" 33 | resource = "https://backend.example.com/api" 34 | token_response = client.token_exchange(subject_token=BearerToken(subject_token), resource=resource) 35 | 36 | assert token_response.access_token == access_token 37 | assert token_response.issued_token_type == "urn:ietf:params:oauth:token-type:access_token" 38 | assert token_response.token_type == "Bearer" 39 | assert token_response.expires_in == 60 40 | 41 | assert requests_mock.last_request is not None 42 | params = Query(requests_mock.last_request.text).params 43 | assert params.pop("client_id") == client_id 44 | assert params.pop("client_secret") == client_secret 45 | assert params.pop("grant_type") == "urn:ietf:params:oauth:grant-type:token-exchange" 46 | assert params.pop("subject_token") == subject_token 47 | assert params.pop("subject_token_type") == "urn:ietf:params:oauth:token-type:access_token" 48 | assert params.pop("resource") == resource 49 | assert not params 50 | 51 | 52 | def test_token_type() -> None: 53 | assert ( 54 | OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:access_token") 55 | == "urn:ietf:params:oauth:token-type:access_token" 56 | ) 57 | assert ( 58 | OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:refresh_token") 59 | == "urn:ietf:params:oauth:token-type:refresh_token" 60 | ) 61 | assert ( 62 | OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:id_token") 63 | == "urn:ietf:params:oauth:token-type:id_token" 64 | ) 65 | assert ( 66 | OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:saml1") 67 | == "urn:ietf:params:oauth:token-type:saml1" 68 | ) 69 | assert ( 70 | OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:saml2") 71 | == "urn:ietf:params:oauth:token-type:saml2" 72 | ) 73 | assert OAuth2Client.get_token_type("urn:ietf:params:oauth:token-type:jwt") == "urn:ietf:params:oauth:token-type:jwt" 74 | 75 | assert OAuth2Client.get_token_type("access_token") == "urn:ietf:params:oauth:token-type:access_token" 76 | assert OAuth2Client.get_token_type("refresh_token") == "urn:ietf:params:oauth:token-type:refresh_token" 77 | assert OAuth2Client.get_token_type("id_token") == "urn:ietf:params:oauth:token-type:id_token" 78 | assert OAuth2Client.get_token_type("saml1") == "urn:ietf:params:oauth:token-type:saml1" 79 | assert OAuth2Client.get_token_type("saml2") == "urn:ietf:params:oauth:token-type:saml2" 80 | assert OAuth2Client.get_token_type("jwt") == "urn:ietf:params:oauth:token-type:jwt" 81 | 82 | assert OAuth2Client.get_token_type("foobar") == "foobar" 83 | 84 | assert OAuth2Client.get_token_type(token=BearerToken("mytoken")) == "urn:ietf:params:oauth:token-type:access_token" 85 | assert ( 86 | OAuth2Client.get_token_type( 87 | token_type="refresh_token", 88 | token=BearerToken("mytoken", refresh_token="myrefreshtoken"), 89 | ) 90 | == "urn:ietf:params:oauth:token-type:refresh_token" 91 | ) 92 | assert OAuth2Client.get_token_type("id_token", token="foo") == "urn:ietf:params:oauth:token-type:id_token" 93 | assert OAuth2Client.get_token_type("saml1") == "urn:ietf:params:oauth:token-type:saml1" 94 | assert OAuth2Client.get_token_type("saml2") == "urn:ietf:params:oauth:token-type:saml2" 95 | assert OAuth2Client.get_token_type("jwt") == "urn:ietf:params:oauth:token-type:jwt" 96 | 97 | with pytest.raises(TypeError, match="token is of type ''") as exc: 98 | OAuth2Client.get_token_type( 99 | token_type="access_token", 100 | token=IdToken( 101 | "eyJraWQiOiIxZTlnZGs3IiwiYWxnIjoiUlMyNTYifQ.ewogImlz" 102 | "cyI6ICJodHRwOi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4" 103 | "Mjg5NzYxMDAxIiwKICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAi" 104 | "bi0wUzZfV3pBMk1qIiwKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEz" 105 | "MTEyODA5NzAsCiAibmFtZSI6ICJKYW5lIERvZSIsCiAiZ2l2ZW5fbmFtZSI6" 106 | "ICJKYW5lIiwKICJmYW1pbHlfbmFtZSI6ICJEb2UiLAogImdlbmRlciI6ICJm" 107 | "ZW1hbGUiLAogImJpcnRoZGF0ZSI6ICIwMDAwLTEwLTMxIiwKICJlbWFpbCI6" 108 | "ICJqYW5lZG9lQGV4YW1wbGUuY29tIiwKICJwaWN0dXJlIjogImh0dHA6Ly9l" 109 | "eGFtcGxlLmNvbS9qYW5lZG9lL21lLmpwZyIKfQ.rHQjEmBqn9Jre0OLykYNn" 110 | "spA10Qql2rvx4FsD00jwlB0Sym4NzpgvPKsDjn_wMkHxcp6CilPcoKrWHcip" 111 | "R2iAjzLvDNAReF97zoJqq880ZD1bwY82JDauCXELVR9O6_B0w3K-E7yM2mac" 112 | "AAgNCUwtik6SjoSUZRcf-O5lygIyLENx882p6MtmwaL1hd6qn5RZOQ0TLrOY" 113 | "u0532g9Exxcm-ChymrB4xLykpDj3lUivJt63eEGGN6DH5K6o33TcxkIjNrCD" 114 | "4XB1CKKumZvCedgHHF3IAK4dVEDSUoGlH9z4pP_eWYNXvqQOjGs-rDaQzUHl" 115 | "6cQQWNiDpWOl_lxXjQEvQ" 116 | ), 117 | ) 118 | assert exc.type is UnknownTokenType 119 | 120 | with pytest.raises(TypeError, match="does not contain a refresh_token") as exc: 121 | OAuth2Client.get_token_type(token_type="refresh_token", token=BearerToken("mytoken")) 122 | assert exc.type is UnknownTokenType 123 | 124 | with pytest.raises(TypeError, match="token is of type '") as exc2: 125 | OAuth2Client.get_token_type(token_type="id_token", token=BearerToken("mytoken")) 126 | assert exc2.type is UnknownTokenType 127 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guillp/requests_oauth2client/810e7b6099ac89742adbf7877d7a8b0f785c4016/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/test_api_client.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urljoin 2 | 3 | import pytest 4 | import requests 5 | from requests import HTTPError 6 | 7 | from requests_oauth2client import ApiClient, BearerToken, InvalidBoolFieldsParam, InvalidPathParam 8 | from tests.conftest import RequestsMocker, RequestValidatorType, join_url 9 | 10 | 11 | def test_session_at_init() -> None: 12 | session = requests.Session() 13 | api = ApiClient("https://test.local", session=session) 14 | assert api.session == session 15 | 16 | 17 | def test_get( 18 | requests_mock: RequestsMocker, 19 | api: ApiClient, 20 | access_token: str, 21 | target_api: str, 22 | target_path: str, 23 | bearer_auth_validator: RequestValidatorType, 24 | ) -> None: 25 | target_uri = join_url(target_api, target_path) 26 | requests_mock.get(target_uri) 27 | response = api.get(target_path) 28 | 29 | assert response.ok 30 | assert requests_mock.last_request is not None 31 | assert requests_mock.last_request.url == target_uri 32 | assert requests_mock.last_request.method == "GET" 33 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 34 | 35 | 36 | def test_post( 37 | requests_mock: RequestsMocker, 38 | api: ApiClient, 39 | access_token: str, 40 | target_api: str, 41 | target_path: str, 42 | bearer_auth_validator: RequestValidatorType, 43 | ) -> None: 44 | target_uri = join_url(target_api, target_path) 45 | requests_mock.post(target_uri) 46 | response = api.post(target_path) 47 | 48 | assert response.ok 49 | assert requests_mock.last_request is not None 50 | assert requests_mock.last_request.method == "POST" 51 | assert requests_mock.last_request.url == target_uri 52 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 53 | 54 | 55 | def test_patch( 56 | requests_mock: RequestsMocker, 57 | api: ApiClient, 58 | access_token: str, 59 | target_api: str, 60 | target_path: str, 61 | bearer_auth_validator: RequestValidatorType, 62 | ) -> None: 63 | target_uri = join_url(target_api, target_path) 64 | requests_mock.patch(target_uri) 65 | response = api.patch(target_path) 66 | 67 | assert response.ok 68 | assert requests_mock.last_request is not None 69 | assert requests_mock.last_request.method == "PATCH" 70 | assert requests_mock.last_request.url == target_uri 71 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 72 | 73 | 74 | def test_put( 75 | requests_mock: RequestsMocker, 76 | api: ApiClient, 77 | access_token: str, 78 | target_api: str, 79 | target_path: str, 80 | bearer_auth_validator: RequestValidatorType, 81 | ) -> None: 82 | target_uri = join_url(target_api, target_path) 83 | requests_mock.put(target_uri) 84 | response = api.put(target_path) 85 | 86 | assert response.ok 87 | assert requests_mock.last_request is not None 88 | assert requests_mock.last_request.method == "PUT" 89 | assert requests_mock.last_request.url == target_uri 90 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 91 | 92 | 93 | def test_delete( 94 | requests_mock: RequestsMocker, 95 | api: ApiClient, 96 | access_token: str, 97 | target_api: str, 98 | target_path: str, 99 | bearer_auth_validator: RequestValidatorType, 100 | ) -> None: 101 | target_uri = join_url(target_api, target_path) 102 | requests_mock.delete(target_uri) 103 | response = api.delete(target_path) 104 | 105 | assert response.ok 106 | assert requests_mock.last_request is not None 107 | assert requests_mock.last_request.method == "DELETE" 108 | assert requests_mock.last_request.url == target_uri 109 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 110 | 111 | 112 | def test_fail( 113 | requests_mock: RequestsMocker, 114 | api: ApiClient, 115 | access_token: str, 116 | target_api: str, 117 | bearer_auth_validator: RequestValidatorType, 118 | ) -> None: 119 | requests_mock.get(target_api, status_code=400) 120 | with pytest.raises(HTTPError): 121 | api.get() 122 | assert requests_mock.last_request is not None 123 | assert requests_mock.last_request.method == "GET" 124 | assert requests_mock.last_request.url == target_api 125 | bearer_auth_validator(requests_mock.last_request, access_token=access_token) 126 | 127 | 128 | def test_url_as_bytes(requests_mock: RequestsMocker, target_api: str) -> None: 129 | api = ApiClient(target_api) 130 | 131 | requests_mock.get(urljoin(target_api, "foo/bar")) 132 | resp = api.get((b"foo", b"bar")) 133 | assert resp.ok 134 | 135 | assert api.get(b"foo/bar").ok 136 | 137 | 138 | def test_url_as_iterable(requests_mock: RequestsMocker, target_api: str) -> None: 139 | api = ApiClient(target_api) 140 | 141 | target_uri = join_url(target_api, "/resource/1234/foo") 142 | requests_mock.get(target_uri) 143 | response = api.get(["resource", "1234", "foo"]) 144 | assert response.ok 145 | assert requests_mock.last_request is not None 146 | assert requests_mock.last_request.method == "GET" 147 | assert requests_mock.last_request.url == target_uri 148 | 149 | response = api.get(["resource", b"1234", "foo"]) 150 | assert response.ok 151 | assert requests_mock.last_request is not None 152 | assert requests_mock.last_request.method == "GET" 153 | assert requests_mock.last_request.url == target_uri 154 | 155 | response = api.get(["resource", 1234, "/foo"]) 156 | assert response.ok 157 | assert requests_mock.last_request is not None 158 | assert requests_mock.last_request.method == "GET" 159 | assert requests_mock.last_request.url == target_uri 160 | 161 | class NonStringableObject: 162 | def __str__(self) -> str: 163 | raise ValueError 164 | 165 | with pytest.raises(TypeError, match="Unexpected path") as exc: 166 | api.get(("resource", NonStringableObject())) # type: ignore[arg-type] 167 | assert exc.type is InvalidPathParam 168 | 169 | 170 | def test_raise_for_status(requests_mock: RequestsMocker, target_api: str) -> None: 171 | api = ApiClient(target_api, raise_for_status=False) 172 | 173 | requests_mock.get(target_api, status_code=400, json={"status": "error"}) 174 | resp = api.get() 175 | assert not resp.ok 176 | with pytest.raises(HTTPError): 177 | api.get(raise_for_status=True) 178 | 179 | api_raises = ApiClient(target_api, raise_for_status=True) 180 | with pytest.raises(HTTPError): 181 | api_raises.get() 182 | 183 | assert not api_raises.get(raise_for_status=False).ok 184 | 185 | 186 | def test_other_api( 187 | access_token: str, 188 | bearer_token: BearerToken, 189 | bearer_auth_validator: RequestValidatorType, 190 | ) -> None: 191 | api = ApiClient("https://some.api/foo", auth=bearer_token) 192 | with pytest.raises(ValueError): 193 | api.get("https://other.api/somethingelse") 194 | 195 | 196 | def test_url_type(target_api: str) -> None: 197 | api = ApiClient(target_api) 198 | with pytest.raises(TypeError) as exc: 199 | api.get(True) # type: ignore[arg-type] 200 | assert exc.type is InvalidPathParam 201 | 202 | 203 | def test_additional_kwargs(target_api: str) -> None: 204 | proxies = {"https": "http://localhost:8888"} 205 | api = ApiClient(target_api, proxies=proxies, timeout=10) 206 | assert api.session.proxies == proxies 207 | assert api.timeout == 10 208 | 209 | 210 | def test_none_fields(requests_mock: RequestsMocker, target_api: str) -> None: 211 | requests_mock.post(target_api) 212 | 213 | api_exclude = ApiClient(target_api) 214 | assert api_exclude.none_fields == "exclude" 215 | api_exclude.post(json={"foo": "bar", "none": None}) 216 | assert requests_mock.last_request is not None 217 | assert requests_mock.last_request.json() == {"foo": "bar"} 218 | 219 | assert requests_mock.last_request is not None 220 | api_exclude.post(data={"foo": "bar", "none": None}) 221 | assert requests_mock.last_request is not None 222 | assert requests_mock.last_request.text == "foo=bar" 223 | 224 | api_include = ApiClient(target_api, none_fields="include") 225 | api_include.post(json={"foo": "bar", "none": None}) 226 | assert requests_mock.last_request is not None 227 | assert requests_mock.last_request.json() == {"foo": "bar", "none": None} 228 | 229 | api_include.post(data={"foo": "bar", "none": None}) 230 | assert requests_mock.last_request is not None 231 | assert requests_mock.last_request.text == "foo=bar" 232 | 233 | api_include = ApiClient(target_api, none_fields="empty") 234 | api_include.post(json={"foo": "bar", "none": None}) 235 | assert requests_mock.last_request is not None 236 | assert requests_mock.last_request.json() == {"foo": "bar", "none": ""} 237 | 238 | api_include.post(data={"foo": "bar", "none": None}) 239 | assert requests_mock.last_request is not None 240 | assert requests_mock.last_request.text == "foo=bar&none=" 241 | 242 | 243 | def test_bool_fields(requests_mock: RequestsMocker, target_api: str) -> None: 244 | requests_mock.post(target_api) 245 | 246 | api_default = ApiClient(target_api) 247 | api_default.post( 248 | data={"foo": "bar", "true": True, "false": False}, 249 | params={"foo": "bar", "true": True, "false": False}, 250 | ) 251 | assert requests_mock.last_request is not None 252 | assert requests_mock.last_request.query == "foo=bar&true=true&false=false" 253 | assert requests_mock.last_request.text == "foo=bar&true=true&false=false" 254 | 255 | api_default.post( 256 | data={"foo": "bar", "true": True, "false": False}, 257 | params={"foo": "bar", "true": True, "false": False}, 258 | bool_fields=("OK", "KO"), 259 | ) 260 | assert requests_mock.last_request is not None 261 | assert requests_mock.last_request.query == "foo=bar&true=OK&false=KO" 262 | assert requests_mock.last_request.text == "foo=bar&true=OK&false=KO" 263 | 264 | api_none = ApiClient(target_api, bool_fields=None) # default behviour or requests 265 | api_none.post( 266 | data={"foo": "bar", "true": True, "false": False}, 267 | params={"foo": "bar", "true": True, "false": False}, 268 | ) 269 | assert requests_mock.last_request is not None 270 | assert requests_mock.last_request.query == "foo=bar&true=True&false=False" 271 | assert requests_mock.last_request.text == "foo=bar&true=True&false=False" 272 | 273 | api_yesno = ApiClient(target_api, bool_fields=("yes", "no")) 274 | api_yesno.post( 275 | data={"foo": "bar", "true": True, "false": False}, 276 | params={"foo": "bar", "true": True, "false": False}, 277 | ) 278 | assert requests_mock.last_request is not None 279 | assert requests_mock.last_request.query == "foo=bar&true=yes&false=no" 280 | assert requests_mock.last_request.text == "foo=bar&true=yes&false=no" 281 | 282 | api_1_0 = ApiClient(target_api, bool_fields=(1, 0)) 283 | api_1_0.post( 284 | data={"foo": "bar", "true": True, "false": False}, 285 | params={"foo": "bar", "true": True, "false": False}, 286 | ) 287 | assert requests_mock.last_request is not None 288 | assert requests_mock.last_request.query == "foo=bar&true=1&false=0" 289 | assert requests_mock.last_request.text == "foo=bar&true=1&false=0" 290 | 291 | with pytest.raises(ValueError, match="Invalid value for `bool_fields`") as exc: 292 | ApiClient(target_api).get(bool_fields=(1, 2, 3)) 293 | assert exc.type is InvalidBoolFieldsParam 294 | 295 | 296 | def test_getattr(requests_mock: RequestsMocker, target_api: str) -> None: 297 | api = ApiClient(target_api) 298 | 299 | requests_mock.post(target_api) 300 | assert api.post().ok 301 | assert requests_mock.last_request is not None 302 | 303 | requests_mock.reset_mock() 304 | requests_mock.post(urljoin(target_api, "foo")) 305 | assert api.foo.post().ok 306 | assert requests_mock.last_request is not None 307 | 308 | requests_mock.reset_mock() 309 | requests_mock.post(urljoin(target_api, "bar")) 310 | assert api.bar.post().ok 311 | assert requests_mock.last_request is not None 312 | 313 | 314 | def test_getitem(requests_mock: RequestsMocker, target_api: str) -> None: 315 | api = ApiClient(target_api) 316 | 317 | requests_mock.post(target_api) 318 | assert api.post().ok 319 | assert requests_mock.last_request is not None 320 | 321 | requests_mock.reset_mock() 322 | requests_mock.post(urljoin(target_api, "foo")) 323 | assert api["foo"].post().ok 324 | assert requests_mock.last_request is not None 325 | 326 | requests_mock.reset_mock() 327 | requests_mock.post(urljoin(target_api, "bar")) 328 | assert api["bar"].post().ok 329 | assert requests_mock.last_request is not None 330 | 331 | 332 | def test_contextmanager(requests_mock: RequestsMocker, target_api: str) -> None: 333 | requests_mock.post(target_api) 334 | 335 | with ApiClient(target_api) as api: 336 | api.post() 337 | 338 | assert requests_mock.last_request is not None 339 | 340 | 341 | def test_cookies_and_headers(target_api: str) -> None: 342 | cookies = {"cookie1": "value1", "cookie2": "value2"} 343 | headers = {"header1": "value1", "header2": "value2"} 344 | user_agent = "My User Agent" 345 | api = ApiClient(target_api, cookies=cookies, headers=headers, user_agent=user_agent) 346 | assert api.session.cookies == cookies 347 | for key, value in headers.items(): 348 | assert api.session.headers[key] == value 349 | assert api.session.headers["User-Agent"] == user_agent 350 | 351 | api_without_useragent = ApiClient(target_api, user_agent=None) 352 | assert "User-Agent" not in api_without_useragent.session.headers 353 | -------------------------------------------------------------------------------- /tests/unit_tests/test_auth.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta, timezone 2 | from urllib.parse import parse_qs 3 | 4 | import pytest 5 | import requests 6 | 7 | from requests_oauth2client import ( 8 | BearerToken, 9 | ExpiredAccessToken, 10 | NonRenewableTokenError, 11 | OAuth2AccessTokenAuth, 12 | OAuth2AuthorizationCodeAuth, 13 | OAuth2Client, 14 | OAuth2ClientCredentialsAuth, 15 | OAuth2DeviceCodeAuth, 16 | OAuth2ResourceOwnerPasswordAuth, 17 | ) 18 | from tests.conftest import RequestsMocker 19 | 20 | 21 | @pytest.fixture 22 | def minutes_ago() -> datetime: 23 | return datetime.now(tz=timezone.utc) - timedelta(minutes=3) 24 | 25 | 26 | def test_bearer_auth( 27 | requests_mock: RequestsMocker, 28 | target_api: str, 29 | bearer_token: BearerToken, 30 | access_token: str, 31 | ) -> None: 32 | requests_mock.post(target_api) 33 | response = requests.post(target_api, auth=bearer_token) 34 | assert response.ok 35 | assert requests_mock.last_request is not None 36 | assert requests_mock.last_request.headers.get("Authorization") == f"Bearer {access_token}" 37 | 38 | 39 | def test_expired_token(minutes_ago: datetime) -> None: 40 | token = BearerToken(access_token="foo", expires_at=minutes_ago) 41 | with pytest.raises(ExpiredAccessToken): 42 | requests.post("http://localhost/test", auth=token) 43 | 44 | 45 | def test_access_token_auth( 46 | requests_mock: RequestsMocker, 47 | target_uri: str, 48 | token_endpoint: str, 49 | client_id: str, 50 | client_secret: str, 51 | minutes_ago: datetime, 52 | ) -> None: 53 | access_token = "access_token" 54 | refresh_token = "refresh_token" 55 | new_access_token = "new_access_token" 56 | new_refresh_token = "new_refresh_token" 57 | 58 | token = BearerToken(access_token=access_token, refresh_token=refresh_token, expires_at=minutes_ago) 59 | client = OAuth2Client(token_endpoint, (client_id, client_secret)) 60 | auth = OAuth2AccessTokenAuth(client, token) 61 | 62 | assert auth.client is client 63 | assert auth.token is token 64 | 65 | requests_mock.post(target_uri) 66 | requests_mock.post( 67 | token_endpoint, 68 | json={ 69 | "access_token": new_access_token, 70 | "refresh_token": new_refresh_token, 71 | "expires_in": 3600, 72 | "token_type": "Bearer", 73 | }, 74 | ) 75 | requests.post(target_uri, auth=auth) 76 | 77 | assert len(requests_mock.request_history) == 2 78 | refresh_request = requests_mock.request_history[0] 79 | api_request = requests_mock.request_history[-1] 80 | 81 | assert refresh_request.url == token_endpoint 82 | refresh_params = parse_qs(refresh_request.body) 83 | assert refresh_params["grant_type"] == ["refresh_token"] 84 | assert refresh_params["refresh_token"] == [refresh_token] 85 | assert refresh_params["client_id"] == [client_id] 86 | assert refresh_params["client_secret"] == [client_secret] 87 | 88 | assert api_request.url == target_uri 89 | assert api_request.headers.get("Authorization") == f"Bearer {new_access_token}" 90 | 91 | assert auth.token is not None 92 | assert auth.token.access_token == new_access_token 93 | assert auth.token.refresh_token == new_refresh_token 94 | 95 | assert OAuth2AccessTokenAuth(client, token=access_token).token == BearerToken(access_token) 96 | 97 | 98 | def test_client_credentials_auth( 99 | requests_mock: RequestsMocker, 100 | target_api: str, 101 | token_endpoint: str, 102 | client_id: str, 103 | client_secret: str, 104 | access_token: str, 105 | refresh_token: str, 106 | ) -> None: 107 | client = OAuth2Client(token_endpoint, (client_id, client_secret)) 108 | auth = OAuth2ClientCredentialsAuth(client) 109 | 110 | assert auth.client is client 111 | assert auth.token is None 112 | 113 | requests_mock.post(target_api) 114 | requests_mock.post( 115 | token_endpoint, 116 | json={ 117 | "access_token": access_token, 118 | "refresh_token": refresh_token, 119 | "expires_in": 3600, 120 | "token_type": "Bearer", 121 | }, 122 | ) 123 | requests.post(target_api, auth=auth) 124 | 125 | assert len(requests_mock.request_history) == 2 126 | cc_request = requests_mock.request_history[0] 127 | api_request = requests_mock.request_history[-1] 128 | 129 | assert cc_request.url == token_endpoint 130 | refresh_params = parse_qs(cc_request.body) 131 | assert refresh_params["grant_type"] == ["client_credentials"] 132 | assert refresh_params["client_id"] == [client_id] 133 | assert refresh_params["client_secret"] == [client_secret] 134 | 135 | assert api_request.url == target_api 136 | assert api_request.headers.get("Authorization") == f"Bearer {access_token}" 137 | 138 | assert auth.token is not None 139 | assert auth.token.access_token == access_token 140 | assert auth.token.refresh_token == refresh_token 141 | 142 | requests_mock.reset_mock() 143 | requests.post(target_api, auth=auth) 144 | assert len(requests_mock.request_history) == 1 145 | 146 | assert OAuth2ClientCredentialsAuth(client, token=access_token).token == BearerToken(access_token) 147 | assert OAuth2ClientCredentialsAuth(client, token=BearerToken(access_token)).token == BearerToken(access_token) 148 | 149 | 150 | def test_authorization_code_auth( 151 | requests_mock: RequestsMocker, 152 | target_api: str, 153 | token_endpoint: str, 154 | client_id: str, 155 | client_secret: str, 156 | authorization_code: str, 157 | access_token: str, 158 | refresh_token: str, 159 | ) -> None: 160 | client = OAuth2Client(token_endpoint, (client_id, client_secret)) 161 | auth = OAuth2AuthorizationCodeAuth(client, authorization_code) 162 | 163 | assert auth.client is client 164 | assert auth.code is authorization_code 165 | assert auth.token is None 166 | 167 | requests_mock.post(target_api) 168 | requests_mock.post( 169 | token_endpoint, 170 | json={ 171 | "access_token": access_token, 172 | "refresh_token": refresh_token, 173 | "expires_in": 3600, 174 | "token_type": "Bearer", 175 | }, 176 | ) 177 | requests.post(target_api, auth=auth) 178 | 179 | assert len(requests_mock.request_history) == 2 180 | code_request = requests_mock.request_history[0] 181 | api_request = requests_mock.request_history[-1] 182 | 183 | assert code_request.url == token_endpoint 184 | refresh_params = parse_qs(code_request.body) 185 | assert refresh_params["grant_type"] == ["authorization_code"] 186 | assert refresh_params["code"] == [authorization_code] 187 | assert refresh_params["client_id"] == [client_id] 188 | assert refresh_params["client_secret"] == [client_secret] 189 | 190 | assert api_request.url == target_api 191 | assert api_request.headers.get("Authorization") == f"Bearer {access_token}" 192 | 193 | assert auth.token is not None 194 | assert auth.token.access_token == access_token 195 | assert auth.token.refresh_token == refresh_token 196 | 197 | requests_mock.reset_mock() 198 | requests.post(target_api, auth=auth) 199 | assert len(requests_mock.request_history) == 1 200 | 201 | assert OAuth2AuthorizationCodeAuth(client, code=authorization_code, token=access_token).token == BearerToken( 202 | access_token 203 | ) 204 | 205 | 206 | def test_ropc_auth( 207 | requests_mock: RequestsMocker, 208 | target_api: str, 209 | token_endpoint: str, 210 | client_id: str, 211 | client_secret: str, 212 | access_token: str, 213 | refresh_token: str, 214 | ) -> None: 215 | oauth2client = OAuth2Client( 216 | token_endpoint=token_endpoint, 217 | client_id=client_id, 218 | client_secret=client_secret, 219 | ) 220 | username = "my_user1" 221 | password = "T0t@lly_5eCur3!" 222 | 223 | auth = OAuth2ResourceOwnerPasswordAuth(client=oauth2client, username=username, password=password) 224 | 225 | assert auth.client is oauth2client 226 | assert auth.username is username 227 | assert auth.password is password 228 | assert auth.token is None 229 | 230 | requests_mock.post( 231 | token_endpoint, 232 | json={ 233 | "access_token": access_token, 234 | "refresh_token": refresh_token, 235 | "token_type": "Bearer", 236 | "expires_in": "3600", 237 | }, 238 | ) 239 | requests_mock.post(target_api) 240 | 241 | assert requests.post(target_api, auth=auth).ok 242 | 243 | assert len(requests_mock.request_history) == 2 244 | 245 | token_request = requests_mock.request_history[0] 246 | token_params = parse_qs(token_request.body) 247 | assert token_params["grant_type"] == ["password"] 248 | assert token_params["username"] == [username] 249 | assert token_params["password"] == [password] 250 | 251 | api_request = requests_mock.request_history[1] 252 | assert api_request.url == target_api 253 | assert api_request.headers.get("Authorization") == f"Bearer {access_token}" 254 | 255 | assert auth.token is not None 256 | assert auth.token.access_token == access_token 257 | assert auth.token.refresh_token == refresh_token 258 | 259 | requests_mock.reset_mock() 260 | requests.post(target_api, auth=auth) 261 | assert requests_mock.last_request is not None 262 | assert requests_mock.last_request.url == target_api 263 | 264 | assert OAuth2ResourceOwnerPasswordAuth( 265 | oauth2client, username=username, password=password, token=access_token 266 | ).token == BearerToken(access_token) 267 | 268 | 269 | def test_device_code_auth( 270 | requests_mock: RequestsMocker, 271 | target_api: str, 272 | device_authorization_endpoint: str, 273 | token_endpoint: str, 274 | client_id: str, 275 | client_secret: str, 276 | device_code: str, 277 | user_code: str, 278 | verification_uri: str, 279 | verification_uri_complete: str, 280 | access_token: str, 281 | refresh_token: str, 282 | ) -> None: 283 | oauth2client = OAuth2Client( 284 | token_endpoint=token_endpoint, 285 | device_authorization_endpoint=device_authorization_endpoint, 286 | auth=(client_id, client_secret), 287 | ) 288 | requests_mock.post( 289 | device_authorization_endpoint, 290 | json={ 291 | "device_code": device_code, 292 | "user_code": user_code, 293 | "verification_uri": verification_uri, 294 | "verification_uri_complete": verification_uri_complete, 295 | "expires_in": 300, 296 | "interval": 1, 297 | }, 298 | ) 299 | 300 | da_resp = oauth2client.authorize_device() 301 | 302 | requests_mock.reset_mock() 303 | requests_mock.post( 304 | token_endpoint, 305 | json={ 306 | "access_token": access_token, 307 | "expires_in": 60, 308 | "refresh_token": refresh_token, 309 | }, 310 | ) 311 | requests_mock.post(target_api) 312 | 313 | auth = OAuth2DeviceCodeAuth(client=oauth2client, device_code=da_resp.device_code, interval=1, expires_in=60) 314 | assert auth.client is oauth2client 315 | assert auth.device_code is da_resp.device_code 316 | assert auth.token is None 317 | 318 | assert requests.post(target_api, auth=auth) 319 | assert len(requests_mock.request_history) == 2 320 | device_code_request = requests_mock.request_history[0] 321 | api_request = requests_mock.request_history[1] 322 | 323 | assert device_code_request.url == token_endpoint 324 | da_params = parse_qs(device_code_request.body) 325 | assert da_params["grant_type"] == ["urn:ietf:params:oauth:grant-type:device_code"] 326 | assert da_params["device_code"] == [device_code] 327 | assert da_params["client_id"] == [client_id] 328 | assert da_params["client_secret"] == [client_secret] 329 | 330 | assert api_request.url == target_api 331 | assert api_request.headers.get("Authorization") == f"Bearer {access_token}" 332 | 333 | assert auth.token is not None 334 | assert auth.token.access_token == access_token 335 | assert auth.token.refresh_token == refresh_token 336 | 337 | requests_mock.reset_mock() 338 | requests.post(target_api, auth=auth) 339 | assert requests_mock.last_request is not None 340 | assert requests_mock.last_request.url == target_api 341 | 342 | auth.forget_token() 343 | with pytest.raises(NonRenewableTokenError): 344 | requests.post(target_api, auth=auth) 345 | 346 | assert OAuth2DeviceCodeAuth(oauth2client, device_code=device_code, token=access_token).token == BearerToken( 347 | access_token 348 | ) 349 | -------------------------------------------------------------------------------- /tests/unit_tests/test_authorization_request.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any 4 | 5 | import jwskate 6 | import pytest 7 | from freezegun import freeze_time 8 | from jwskate import JweCompact, Jwk, Jwt, SignedJwt 9 | 10 | from requests_oauth2client import ( 11 | AuthorizationRequest, 12 | AuthorizationRequestSerializer, 13 | AuthorizationResponse, 14 | AuthorizationResponseError, 15 | DPoPKey, 16 | InvalidMaxAgeParam, 17 | MismatchingIssuer, 18 | MismatchingState, 19 | MissingAuthCode, 20 | MissingIssuer, 21 | RequestParameterAuthorizationRequest, 22 | RequestUriParameterAuthorizationRequest, 23 | UnsupportedResponseTypeParam, 24 | ) 25 | 26 | if TYPE_CHECKING: 27 | from furl import furl # type: ignore[import-untyped] 28 | 29 | 30 | def test_authorization_url(authorization_request: AuthorizationRequest) -> None: 31 | url = authorization_request.furl 32 | assert dict(url.args) == {key: val for key, val in authorization_request.args.items() if val is not None} 33 | 34 | 35 | def test_authorization_signed_request( 36 | authorization_request: AuthorizationRequest, private_jwk: Jwk, public_jwk: Jwk, auth_request_kwargs: dict[str, Any] 37 | ) -> None: 38 | args = {key: value for key, value in authorization_request.args.items() if value is not None} 39 | signed_request = authorization_request.sign(private_jwk, custom_attr="custom_value") 40 | assert isinstance(signed_request, RequestParameterAuthorizationRequest) 41 | assert isinstance(signed_request.uri, str) 42 | assert signed_request.custom_attr == "custom_value" 43 | url = signed_request.furl 44 | request = url.args.get("request") 45 | jwt = Jwt(request) 46 | assert isinstance(jwt, SignedJwt) 47 | assert jwt.verify_signature(public_jwk) 48 | assert jwt.claims == args 49 | 50 | 51 | @freeze_time("2022-10-10 13:37:00") 52 | def test_authorization_signed_request_with_lifetime( 53 | authorization_request: AuthorizationRequest, private_jwk: Jwk, public_jwk: Jwk 54 | ) -> None: 55 | args = {key: value for key, value in authorization_request.args.items() if value is not None} 56 | args["iat"] = 1665409020 57 | args["exp"] = 1665409080 58 | signed_request = authorization_request.sign(private_jwk, lifetime=60) 59 | assert isinstance(signed_request.uri, str) 60 | 61 | url = signed_request.furl 62 | request = url.args.get("request") 63 | jwt = Jwt(request) 64 | assert isinstance(jwt, SignedJwt) 65 | assert jwt.verify_signature(public_jwk) 66 | assert jwt.claims == args 67 | 68 | 69 | @pytest.fixture(scope="session") 70 | def enc_jwk() -> Jwk: 71 | return Jwk.generate_for_alg(jwskate.KeyManagementAlgs.RSA_OAEP_256) 72 | 73 | 74 | @freeze_time("2022-10-10 13:37:00") 75 | def test_authorization_signed_and_encrypted_request( 76 | authorization_request: AuthorizationRequest, private_jwk: Jwk, public_jwk: Jwk, enc_jwk: Jwk 77 | ) -> None: 78 | args = {key: value for key, value in authorization_request.args.items() if value is not None} 79 | args["iat"] = 1665409020 80 | args["exp"] = 1665409080 81 | signed_and_encrypted_request = authorization_request.sign_and_encrypt( 82 | sign_jwk=private_jwk, enc_jwk=enc_jwk.public_jwk(), lifetime=60 83 | ) 84 | assert isinstance(signed_and_encrypted_request.uri, str) 85 | url = signed_and_encrypted_request.furl 86 | request = url.args.get("request") 87 | jwt = Jwt(request) 88 | assert isinstance(jwt, JweCompact) 89 | assert Jwt.decrypt_and_verify(jwt, enc_jwk, public_jwk).claims == args 90 | 91 | 92 | @pytest.mark.parametrize("request_uri", ["this_is_a_request_uri", "https://foo.bar/request_uri"]) 93 | def test_request_uri_authorization_request(authorization_endpoint: str, client_id: str, request_uri: str) -> None: 94 | request_uri_azr = RequestUriParameterAuthorizationRequest( 95 | authorization_endpoint=authorization_endpoint, 96 | client_id=client_id, 97 | request_uri=request_uri, 98 | custom_param="custom_value", 99 | ) 100 | assert isinstance(request_uri_azr.uri, str) 101 | url = request_uri_azr.furl 102 | assert url.origin + str(url.path) == authorization_endpoint 103 | assert url.args == {"client_id": client_id, "request_uri": request_uri, "custom_param": "custom_value"} 104 | assert request_uri_azr.custom_param == "custom_value" 105 | 106 | 107 | def test_request_uri_authorization_request_with_custom_param(authorization_endpoint: str) -> None: 108 | request_uri = "request_uri" 109 | custom_attr = "custom_attr" 110 | client_id = "client_id" 111 | request_uri_azr = RequestUriParameterAuthorizationRequest( 112 | authorization_endpoint=authorization_endpoint, 113 | client_id=client_id, 114 | request_uri=request_uri, 115 | custom_attr=custom_attr, 116 | ) 117 | assert isinstance(request_uri_azr.uri, str) 118 | url = request_uri_azr.furl 119 | assert url.origin + str(url.path) == authorization_endpoint 120 | assert url.args == {"client_id": client_id, "request_uri": request_uri, "custom_attr": custom_attr} 121 | 122 | 123 | @pytest.mark.parametrize("error", ["consent_required"]) 124 | def test_error_response( 125 | authorization_request: AuthorizationRequest, 126 | authorization_response_uri: furl, 127 | error: str, 128 | ) -> None: 129 | authorization_response_uri.args.pop("code") 130 | authorization_response_uri.args.add("error", error) 131 | with pytest.raises(AuthorizationResponseError): 132 | authorization_request.validate_callback(authorization_response_uri) 133 | 134 | 135 | def test_missing_code(authorization_request: AuthorizationRequest, authorization_response_uri: furl) -> None: 136 | authorization_response_uri.args.pop("code") 137 | with pytest.raises(MissingAuthCode): 138 | authorization_request.validate_callback(authorization_response_uri) 139 | 140 | 141 | def test_not_an_url(authorization_request: AuthorizationRequest) -> None: 142 | auth_response = "https://...$cz\\1.3ada////:@+++++z/" 143 | with pytest.raises(ValueError): 144 | authorization_request.validate_callback(auth_response) 145 | 146 | 147 | def test_mismatching_state( 148 | authorization_request: AuthorizationRequest, 149 | authorization_response_uri: furl, 150 | state: None | bool | str, 151 | ) -> None: 152 | authorization_response_uri.args["state"] = "foo" 153 | if state: 154 | with pytest.raises(MismatchingState): 155 | authorization_request.validate_callback(authorization_response_uri) 156 | 157 | 158 | def test_missing_state( 159 | authorization_request: AuthorizationRequest, 160 | authorization_response_uri: furl, 161 | state: None | bool | str, 162 | ) -> None: 163 | authorization_response_uri.args.pop("state", None) 164 | if state: 165 | with pytest.raises(MismatchingState): 166 | authorization_request.validate_callback(authorization_response_uri) 167 | 168 | 169 | def test_mismatching_iss( 170 | authorization_request: AuthorizationRequest, 171 | authorization_response_uri: furl, 172 | expected_issuer: str | bool | None, 173 | ) -> None: 174 | authorization_response_uri.args["iss"] = "foo" 175 | if expected_issuer: 176 | with pytest.raises(MismatchingIssuer): 177 | authorization_request.validate_callback(authorization_response_uri) 178 | 179 | 180 | def test_missing_issuer( 181 | authorization_request: AuthorizationRequest, 182 | authorization_response_uri: furl, 183 | expected_issuer: str | bool | None, 184 | ) -> None: 185 | authorization_response_uri.args.pop("iss", None) 186 | if expected_issuer: 187 | with pytest.raises(MissingIssuer): 188 | authorization_request.validate_callback(authorization_response_uri) 189 | 190 | 191 | def test_authorization_request_serializer(authorization_request: AuthorizationRequest) -> None: 192 | serializer = AuthorizationRequestSerializer() 193 | serialized = serializer.dumps(authorization_request) 194 | assert serializer.loads(serialized) == authorization_request 195 | 196 | 197 | def test_authorization_request_serializer_with_dpop_key() -> None: 198 | dpop_key = DPoPKey.generate() 199 | authorization_request = AuthorizationRequest( 200 | "https://as.local/authorize", 201 | client_id="foo", 202 | redirect_uri="http://localhost/local", 203 | scope="openid", 204 | dpop_key=dpop_key, 205 | ) 206 | 207 | serializer = AuthorizationRequestSerializer() 208 | 209 | serialized = serializer.dumps(authorization_request) 210 | deserialized_request = serializer.loads(serialized) 211 | 212 | assert isinstance(deserialized_request.dpop_key, DPoPKey) 213 | assert deserialized_request.dpop_key.private_key == dpop_key.private_key 214 | 215 | 216 | def test_request_acr_values() -> None: 217 | # you may provide acr_values as a space separated list or as a real list 218 | assert AuthorizationRequest( 219 | "https://as.local/authorize", 220 | client_id="foo", 221 | redirect_uri="http://localhost/local", 222 | scope="openid", 223 | acr_values="1 2 3", 224 | ).acr_values == ("1", "2", "3") 225 | assert AuthorizationRequest( 226 | "https://as.local/authorize", 227 | client_id="foo", 228 | redirect_uri="http://localhost/local", 229 | scope="openid", 230 | acr_values=("1", "2", "3"), 231 | ).acr_values == ("1", "2", "3") 232 | 233 | 234 | def test_code_challenge() -> None: 235 | # providing a code_challenge fails, you must provide the original code_verifier instead 236 | with pytest.raises(ValueError): 237 | AuthorizationRequest( 238 | "https://as.local/authorize", 239 | client_id="foo", 240 | redirect_uri="http://localhost/local", 241 | scope="openid", 242 | code_challenge="my_code_challenge", 243 | ) 244 | 245 | 246 | def test_issuer_parameter() -> None: 247 | with pytest.raises(ValueError, match="issuer"): 248 | AuthorizationRequest( 249 | "https://as.local/authorize", 250 | client_id="foo", 251 | redirect_uri="http://localhost/local", 252 | authorization_response_iss_parameter_supported=True, 253 | scope="openid", 254 | ) 255 | 256 | 257 | def test_invalid_max_age() -> None: 258 | with pytest.raises(ValueError, match="Invalid 'max_age' parameter") as exc: 259 | AuthorizationRequest( 260 | "https://as.local/authorize", 261 | client_id="foo", 262 | redirect_uri="http://localhost/local", 263 | scope="openid", 264 | max_age=-1, 265 | ) 266 | assert exc.type is InvalidMaxAgeParam 267 | 268 | 269 | def test_acr_values() -> None: 270 | acr_values = ("reinforced", "strong") 271 | assert ( 272 | AuthorizationResponse( 273 | code="code", 274 | client_id="foo", 275 | redirect_uri="http://localhost/local", 276 | scope="openid", 277 | acr_values=list(acr_values), 278 | ).acr_values 279 | == acr_values 280 | ) 281 | assert ( 282 | AuthorizationResponse( 283 | code="code", 284 | client_id="foo", 285 | redirect_uri="http://localhost/local", 286 | scope="openid", 287 | acr_values=" ".join(acr_values), 288 | ).acr_values 289 | == acr_values 290 | ) 291 | 292 | 293 | def test_custom_attrs() -> None: 294 | custom = "foobar" 295 | azresp = AuthorizationResponse( 296 | code="code", client_id="foo", redirect_uri="http://localhost/local", scope="openid", custom=custom 297 | ) 298 | assert azresp.custom == custom 299 | 300 | 301 | def test_request_as_dict() -> None: 302 | assert AuthorizationRequest( 303 | "https://authorization.endpoint", 304 | client_id="foo", 305 | redirect_uri="http://localhost/local", 306 | scope="openid", 307 | acr_values="1 2 3", 308 | customattr="customvalue", 309 | code_verifier="Jdvs0V61iQz3TGoPP_wjwPUIUHPZ7KYDXnQVKJ3f63MvDFhKFMLusp2JOZKoHEUizGvC5xUWlr4m8FemSvo7gERO8b3G87hB-oOGogPiqmTh_c_ISiDpFENXiFNDaAH3", 310 | nonce="mynonce", 311 | state="mystate", 312 | issuer="https://my.issuer", 313 | authorization_response_iss_parameter_supported=True, 314 | max_age=0, 315 | ).as_dict() == { 316 | "authorization_endpoint": "https://authorization.endpoint", 317 | "client_id": "foo", 318 | "redirect_uri": "http://localhost/local", 319 | "response_type": "code", 320 | "scope": ("openid",), 321 | "acr_values": ("1", "2", "3"), 322 | "code_verifier": "Jdvs0V61iQz3TGoPP_wjwPUIUHPZ7KYDXnQVKJ3f63MvDFhKFMLusp2JOZKoHEUizGvC5xUWlr4m8FemSvo7gERO8b3G87hB-oOGogPiqmTh_c_ISiDpFENXiFNDaAH3", 323 | "code_challenge_method": "S256", 324 | "nonce": "mynonce", 325 | "state": "mystate", 326 | "issuer": "https://my.issuer", 327 | "authorization_response_iss_parameter_supported": True, 328 | "max_age": 0, 329 | "customattr": "customvalue", 330 | "dpop_key": None, 331 | } 332 | 333 | 334 | def test_unsupported_response_type() -> None: 335 | with pytest.raises(UnsupportedResponseTypeParam): 336 | AuthorizationRequest("https://as.local/authorize", client_id="client_id", response_type="token") 337 | -------------------------------------------------------------------------------- /tests/unit_tests/test_backchannel_authentication.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import time 4 | from datetime import datetime 5 | from typing import TYPE_CHECKING 6 | 7 | import pytest 8 | from freezegun import freeze_time 9 | 10 | from requests_oauth2client import ( 11 | BackChannelAuthenticationPoolingJob, 12 | BackChannelAuthenticationResponse, 13 | BaseClientAuthenticationMethod, 14 | BearerToken, 15 | InvalidAcrValuesParam, 16 | InvalidBackChannelAuthenticationResponse, 17 | OAuth2Client, 18 | UnauthorizedClient, 19 | ) 20 | 21 | if TYPE_CHECKING: 22 | from freezegun.api import FrozenDateTimeFactory 23 | from jwskate import Jwk 24 | from pytest_mock import MockerFixture 25 | 26 | from tests.conftest import RequestsMocker, RequestValidatorType 27 | 28 | 29 | @freeze_time() 30 | def test_backchannel_authentication_response(auth_req_id: str) -> None: 31 | bca_resp = BackChannelAuthenticationResponse(auth_req_id=auth_req_id, expires_in=10, interval=10, foo="bar") 32 | 33 | assert bca_resp.auth_req_id == auth_req_id 34 | assert bca_resp.interval == 10 35 | assert not bca_resp.is_expired() 36 | assert isinstance(bca_resp.expires_at, datetime) 37 | assert isinstance(bca_resp.expires_in, int) 38 | assert bca_resp.expires_in == 10 39 | assert bca_resp.foo == "bar" 40 | with pytest.raises(AttributeError): 41 | bca_resp.notfound 42 | 43 | 44 | def test_backchannel_authentication_response_defaults(auth_req_id: str) -> None: 45 | bca_resp = BackChannelAuthenticationResponse( 46 | auth_req_id=auth_req_id, 47 | ) 48 | 49 | assert bca_resp.auth_req_id == auth_req_id 50 | assert bca_resp.interval == 20 51 | assert not bca_resp.is_expired() 52 | assert bca_resp.expires_at is None 53 | assert bca_resp.expires_in is None 54 | 55 | 56 | @pytest.fixture 57 | def bca_client( 58 | token_endpoint: str, 59 | backchannel_authentication_endpoint: str, 60 | client_auth_method: BaseClientAuthenticationMethod, 61 | ) -> OAuth2Client: 62 | bca_client = OAuth2Client( 63 | token_endpoint=token_endpoint, 64 | backchannel_authentication_endpoint=backchannel_authentication_endpoint, 65 | auth=client_auth_method, 66 | ) 67 | assert bca_client.backchannel_authentication_endpoint == backchannel_authentication_endpoint 68 | assert bca_client.auth == client_auth_method 69 | 70 | return bca_client 71 | 72 | 73 | @freeze_time() 74 | def test_backchannel_authentication( 75 | requests_mock: RequestsMocker, 76 | backchannel_authentication_endpoint: str, 77 | bca_client: OAuth2Client, 78 | auth_req_id: str, 79 | scope: None | str | list[str], 80 | backchannel_auth_request_validator: RequestValidatorType, 81 | ciba_request_validator: RequestValidatorType, 82 | token_endpoint: str, 83 | access_token: str, 84 | ) -> None: 85 | requests_mock.post( 86 | backchannel_authentication_endpoint, 87 | json={"auth_req_id": auth_req_id, "expires_in": 360, "interval": 3}, 88 | ) 89 | bca_resp = bca_client.backchannel_authentication_request(scope=scope, login_hint="user@example.com") 90 | 91 | assert requests_mock.called_once 92 | backchannel_auth_request_validator(requests_mock.last_request, scope=scope, login_hint="user@example.com") 93 | 94 | assert isinstance(bca_resp, BackChannelAuthenticationResponse) 95 | assert bca_resp.expires_in == 360 96 | 97 | requests_mock.post(token_endpoint, json={"access_token": access_token, "token_type": "Bearer"}) 98 | 99 | token_resp = bca_client.ciba(bca_resp) 100 | assert isinstance(token_resp, BearerToken) 101 | ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) 102 | 103 | requests_mock.reset() 104 | bca_client.ciba(BackChannelAuthenticationResponse(auth_req_id=auth_req_id)) 105 | assert requests_mock.called_once 106 | ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) 107 | 108 | 109 | @freeze_time() 110 | def test_backchannel_authentication_scope_acr_values_as_list( 111 | requests_mock: RequestsMocker, 112 | backchannel_authentication_endpoint: str, 113 | bca_client: OAuth2Client, 114 | auth_req_id: str, 115 | backchannel_auth_request_validator: RequestValidatorType, 116 | ) -> None: 117 | scope = ("openid", "email", "profile") 118 | acr_values = ("reinforced", "strong") 119 | 120 | requests_mock.post( 121 | backchannel_authentication_endpoint, 122 | json={"auth_req_id": auth_req_id, "expires_in": 360, "interval": 3}, 123 | ) 124 | bca_resp = bca_client.backchannel_authentication_request( 125 | scope=scope, acr_values=acr_values, login_hint="user@example.com" 126 | ) 127 | 128 | assert requests_mock.called_once 129 | backchannel_auth_request_validator( 130 | requests_mock.last_request, scope=scope, acr_values=acr_values, login_hint="user@example.com" 131 | ) 132 | 133 | assert isinstance(bca_resp, BackChannelAuthenticationResponse) 134 | assert bca_resp.expires_in == 360 135 | 136 | with pytest.raises(ValueError, match="Invalid 'acr_values'") as exc: 137 | bca_client.backchannel_authentication_request(login_hint="user@example.net", acr_values=1.44) # type: ignore[arg-type] 138 | assert exc.type is InvalidAcrValuesParam 139 | 140 | 141 | def test_backchannel_authentication_invalid_response( 142 | requests_mock: RequestsMocker, 143 | backchannel_authentication_endpoint: str, 144 | bca_client: OAuth2Client, 145 | scope: None | str | list[str], 146 | backchannel_auth_request_validator: RequestValidatorType, 147 | ) -> None: 148 | requests_mock.post( 149 | backchannel_authentication_endpoint, 150 | json={"foo": "bar"}, 151 | ) 152 | with pytest.raises(InvalidBackChannelAuthenticationResponse): 153 | bca_client.backchannel_authentication_request(scope=scope, login_hint="user@example.com") 154 | 155 | assert requests_mock.called_once 156 | backchannel_auth_request_validator(requests_mock.last_request, scope=scope, login_hint="user@example.com") 157 | 158 | 159 | def test_backchannel_authentication_jwt( 160 | requests_mock: RequestsMocker, 161 | backchannel_authentication_endpoint: str, 162 | bca_client: OAuth2Client, 163 | private_jwk: Jwk, 164 | public_jwk: Jwk, 165 | auth_req_id: str, 166 | scope: None | str | list[str], 167 | backchannel_auth_request_jwt_validator: RequestValidatorType, 168 | ) -> None: 169 | requests_mock.post( 170 | backchannel_authentication_endpoint, 171 | json={"auth_req_id": auth_req_id, "expires_in": 360, "interval": 3}, 172 | ) 173 | bca_resp = bca_client.backchannel_authentication_request( 174 | private_jwk=private_jwk, scope=scope, login_hint="user@example.com", alg="RS256" 175 | ) 176 | 177 | assert requests_mock.called_once 178 | backchannel_auth_request_jwt_validator( 179 | requests_mock.last_request, 180 | public_jwk=public_jwk, 181 | alg="RS256", 182 | scope=scope, 183 | login_hint="user@example.com", 184 | ) 185 | 186 | assert isinstance(bca_resp, BackChannelAuthenticationResponse) 187 | 188 | 189 | def test_backchannel_authentication_error( 190 | requests_mock: RequestsMocker, 191 | backchannel_authentication_endpoint: str, 192 | bca_client: OAuth2Client, 193 | scope: None | str | list[str], 194 | backchannel_auth_request_validator: RequestValidatorType, 195 | ) -> None: 196 | requests_mock.post( 197 | backchannel_authentication_endpoint, 198 | status_code=400, 199 | json={"error": "unauthorized_client"}, 200 | ) 201 | with pytest.raises(UnauthorizedClient): 202 | bca_client.backchannel_authentication_request(scope=scope, login_hint="user@example.com") 203 | 204 | assert requests_mock.called_once 205 | backchannel_auth_request_validator(requests_mock.last_request, scope=scope, login_hint="user@example.com") 206 | 207 | 208 | def test_backchannel_authentication_invalid_error( 209 | requests_mock: RequestsMocker, 210 | backchannel_authentication_endpoint: str, 211 | bca_client: OAuth2Client, 212 | scope: None | str | list[str], 213 | backchannel_auth_request_validator: RequestValidatorType, 214 | ) -> None: 215 | requests_mock.post( 216 | backchannel_authentication_endpoint, 217 | status_code=400, 218 | json={"foo": "bar"}, 219 | ) 220 | with pytest.raises(InvalidBackChannelAuthenticationResponse): 221 | bca_client.backchannel_authentication_request(scope=scope, login_hint="user@example.com") 222 | 223 | assert requests_mock.called_once 224 | backchannel_auth_request_validator(requests_mock.last_request, scope=scope, login_hint="user@example.com") 225 | 226 | 227 | def test_backchannel_authentication_not_json_error( 228 | requests_mock: RequestsMocker, 229 | backchannel_authentication_endpoint: str, 230 | bca_client: OAuth2Client, 231 | scope: None | str | list[str], 232 | backchannel_auth_request_validator: RequestValidatorType, 233 | ) -> None: 234 | requests_mock.post( 235 | backchannel_authentication_endpoint, 236 | status_code=400, 237 | text="Error!", 238 | ) 239 | with pytest.raises(InvalidBackChannelAuthenticationResponse): 240 | bca_client.backchannel_authentication_request(scope=scope, login_hint="user@example.com") 241 | 242 | assert requests_mock.called_once 243 | backchannel_auth_request_validator(requests_mock.last_request, scope=scope, login_hint="user@example.com") 244 | 245 | 246 | def test_backchannel_authentication_missing_hint( 247 | bca_client: OAuth2Client, 248 | scope: None | str | list[str], 249 | ) -> None: 250 | with pytest.raises(ValueError): 251 | bca_client.backchannel_authentication_request(scope=scope) 252 | 253 | with pytest.raises(ValueError): 254 | bca_client.backchannel_authentication_request( 255 | scope=scope, login_hint="user@example.net", login_hint_token="ABCDEF" 256 | ) 257 | 258 | 259 | def test_backchannel_authentication_invalid_scope(bca_client: OAuth2Client) -> None: 260 | with pytest.raises(ValueError): 261 | bca_client.backchannel_authentication_request( 262 | scope=1.44, # type: ignore[arg-type] 263 | login_hint="user@example.net", 264 | ) 265 | 266 | 267 | def test_pooling_job( 268 | requests_mock: RequestsMocker, 269 | bca_client: OAuth2Client, 270 | token_endpoint: str, 271 | auth_req_id: str, 272 | ciba_request_validator: RequestValidatorType, 273 | access_token: str, 274 | freezer: FrozenDateTimeFactory, 275 | mocker: MockerFixture, 276 | ) -> None: 277 | interval = 20 278 | job = BackChannelAuthenticationPoolingJob(client=bca_client, auth_req_id=auth_req_id, interval=interval) 279 | assert job.interval == interval 280 | assert job.slow_down_interval == 5 281 | 282 | assert job == BackChannelAuthenticationPoolingJob( 283 | bca_client, 284 | BackChannelAuthenticationResponse(auth_req_id, interval=interval), 285 | ) 286 | 287 | requests_mock.post(token_endpoint, status_code=401, json={"error": "authorization_pending"}) 288 | mocker.patch("time.sleep") 289 | 290 | assert job() is None 291 | time.sleep.assert_called_once_with(job.interval) # type: ignore[attr-defined] 292 | time.sleep.reset_mock() # type: ignore[attr-defined] 293 | assert requests_mock.called_once 294 | assert job.interval == interval 295 | 296 | ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) 297 | 298 | freezer.tick(job.interval) 299 | requests_mock.reset_mock() 300 | requests_mock.post(token_endpoint, status_code=401, json={"error": "slow_down"}) 301 | 302 | assert job() is None 303 | time.sleep.assert_called_once_with(interval) # type: ignore[attr-defined] 304 | time.sleep.reset_mock() # type: ignore[attr-defined] 305 | assert requests_mock.called_once 306 | assert job.interval == interval + job.slow_down_interval 307 | ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) 308 | 309 | freezer.tick(job.interval) 310 | requests_mock.reset_mock() 311 | requests_mock.post(token_endpoint, json={"access_token": access_token}) 312 | 313 | token = job() 314 | time.sleep.assert_called_once_with(interval + job.slow_down_interval) # type: ignore[attr-defined] 315 | time.sleep.reset_mock() # type: ignore[attr-defined] 316 | assert requests_mock.called_once 317 | assert job.interval == interval + job.slow_down_interval 318 | ciba_request_validator(requests_mock.last_request, auth_req_id=auth_req_id) 319 | assert isinstance(token, BearerToken) 320 | assert token.access_token == access_token 321 | 322 | 323 | def test_missing_backchannel_authentication_endpoint(token_endpoint: str, client_id: str, client_secret: str) -> None: 324 | client = OAuth2Client(token_endpoint, (client_id, client_secret)) 325 | with pytest.raises(AttributeError): 326 | client.backchannel_authentication_request(login_hint="username@foo.bar") 327 | -------------------------------------------------------------------------------- /tests/unit_tests/test_client_authentication.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | from jwskate import InvalidJwk, Jwk 4 | from requests_mock import ANY 5 | 6 | from requests_oauth2client import ( 7 | ClientSecretBasic, 8 | ClientSecretJwt, 9 | ClientSecretPost, 10 | InvalidClientAssertionSigningKeyOrAlg, 11 | InvalidRequestForClientAuthentication, 12 | OAuth2Client, 13 | PrivateKeyJwt, 14 | UnsupportedClientCredentials, 15 | ) 16 | from tests.conftest import RequestsMocker, RequestValidatorType 17 | 18 | 19 | def test_client_secret_post( 20 | requests_mock: RequestsMocker, 21 | access_token: str, 22 | token_endpoint: str, 23 | client_id: str, 24 | client_secret: str, 25 | client_secret_post_auth_validator: RequestValidatorType, 26 | ) -> None: 27 | client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 28 | 29 | requests_mock.post( 30 | token_endpoint, 31 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 32 | ) 33 | 34 | assert client.client_credentials() 35 | assert requests_mock.called_once 36 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 37 | 38 | 39 | def test_client_secret_basic( 40 | requests_mock: RequestsMocker, 41 | access_token: str, 42 | token_endpoint: str, 43 | client_id: str, 44 | client_secret: str, 45 | client_secret_basic_auth_validator: RequestValidatorType, 46 | ) -> None: 47 | client = OAuth2Client(token_endpoint, ClientSecretBasic(client_id, client_secret)) 48 | 49 | requests_mock.post( 50 | token_endpoint, 51 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 52 | ) 53 | 54 | assert client.client_credentials() 55 | assert requests_mock.called_once 56 | client_secret_basic_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 57 | 58 | 59 | def test_private_key_jwt( 60 | requests_mock: RequestsMocker, 61 | access_token: str, 62 | token_endpoint: str, 63 | client_id: str, 64 | private_jwk: Jwk, 65 | private_key_jwt_auth_validator: RequestValidatorType, 66 | public_jwk: Jwk, 67 | ) -> None: 68 | client = OAuth2Client(token_endpoint, PrivateKeyJwt(client_id, private_jwk=private_jwk)) 69 | 70 | requests_mock.post( 71 | token_endpoint, 72 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 73 | ) 74 | 75 | assert client.client_credentials() 76 | assert requests_mock.called_once 77 | private_key_jwt_auth_validator( 78 | requests_mock.last_request, 79 | client_id=client_id, 80 | public_jwk=public_jwk, 81 | endpoint=token_endpoint, 82 | ) 83 | 84 | with pytest.raises(ValueError, match="asymmetric private signing key") as exc: 85 | PrivateKeyJwt(client_id, private_jwk.public_jwk()) 86 | assert exc.type is InvalidClientAssertionSigningKeyOrAlg 87 | 88 | 89 | def test_private_key_jwt_with_kid( 90 | requests_mock: RequestsMocker, 91 | access_token: str, 92 | token_endpoint: str, 93 | client_id: str, 94 | private_jwk: Jwk, 95 | private_key_jwt_auth_validator: RequestValidatorType, 96 | public_jwk: Jwk, 97 | ) -> None: 98 | client = OAuth2Client(token_endpoint, PrivateKeyJwt(client_id, private_jwk=private_jwk)) 99 | 100 | requests_mock.post( 101 | token_endpoint, 102 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 103 | ) 104 | 105 | assert client.client_credentials() 106 | assert requests_mock.called_once 107 | private_key_jwt_auth_validator( 108 | requests_mock.last_request, 109 | client_id=client_id, 110 | public_jwk=public_jwk, 111 | endpoint=token_endpoint, 112 | ) 113 | 114 | 115 | def test_client_secret_jwt( 116 | requests_mock: RequestsMocker, 117 | access_token: str, 118 | token_endpoint: str, 119 | client_id: str, 120 | client_secret: str, 121 | client_secret_jwt_auth_validator: RequestValidatorType, 122 | ) -> None: 123 | client = OAuth2Client(token_endpoint, ClientSecretJwt(client_id, client_secret)) 124 | 125 | requests_mock.post( 126 | token_endpoint, 127 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 128 | ) 129 | 130 | assert client.client_credentials() 131 | assert requests_mock.called_once 132 | client_secret_jwt_auth_validator( 133 | requests_mock.last_request, 134 | client_id=client_id, 135 | client_secret=client_secret, 136 | endpoint=token_endpoint, 137 | ) 138 | 139 | 140 | def test_public_client( 141 | requests_mock: RequestsMocker, 142 | access_token: str, 143 | token_endpoint: str, 144 | client_id: str, 145 | target_api: str, 146 | public_app_auth_validator: RequestValidatorType, 147 | ) -> None: 148 | client = OAuth2Client(token_endpoint, client_id) 149 | 150 | requests_mock.post( 151 | token_endpoint, 152 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 153 | ) 154 | 155 | assert client.client_credentials() 156 | assert requests_mock.called_once 157 | public_app_auth_validator(requests_mock.last_request, client_id=client_id) 158 | 159 | 160 | def test_invalid_request(requests_mock: RequestsMocker, client_id: str, client_secret: str) -> None: 161 | requests_mock.get(ANY) 162 | with pytest.raises(RuntimeError) as exc: 163 | requests.get("http://localhost", auth=ClientSecretBasic(client_id, client_secret)) 164 | assert exc.type is InvalidRequestForClientAuthentication 165 | 166 | 167 | def test_private_key_jwt_missing_alg(client_id: str, private_jwk: Jwk) -> None: 168 | private_jwk_without_alg = dict(private_jwk) 169 | private_jwk_without_alg.pop("alg") 170 | with pytest.raises(ValueError) as exc: 171 | PrivateKeyJwt(client_id=client_id, private_jwk=private_jwk_without_alg, alg=None) 172 | assert exc.type is InvalidClientAssertionSigningKeyOrAlg 173 | 174 | 175 | def test_private_key_jwt_unsupported_alg(client_id: str, private_jwk: Jwk) -> None: 176 | private_jwk_without_alg = dict(private_jwk) 177 | private_jwk_without_alg.pop("alg") 178 | with pytest.raises(ValueError) as exc: 179 | PrivateKeyJwt(client_id=client_id, private_jwk=private_jwk_without_alg, alg="FOO") 180 | assert exc.type is InvalidClientAssertionSigningKeyOrAlg 181 | 182 | 183 | def test_private_key_jwt_missing_kid(client_id: str, private_jwk: Jwk) -> None: 184 | private_jwk_without_kid = dict(private_jwk) 185 | private_jwk_without_kid.pop("kid") 186 | with pytest.raises(ValueError) as exc: 187 | PrivateKeyJwt(client_id=client_id, private_jwk=private_jwk_without_kid) 188 | assert exc.type is InvalidClientAssertionSigningKeyOrAlg 189 | 190 | 191 | def test_init_auth(token_endpoint: str, client_id: str, client_secret: str, private_jwk: Jwk) -> None: 192 | csp_client = OAuth2Client(token_endpoint, (client_id, client_secret)) 193 | assert isinstance(csp_client.auth, ClientSecretPost) 194 | assert csp_client.auth.client_id == client_id 195 | assert csp_client.auth.client_secret == client_secret 196 | 197 | pkj_client = OAuth2Client(token_endpoint, (client_id, dict(private_jwk))) 198 | assert isinstance(pkj_client.auth, PrivateKeyJwt) 199 | assert pkj_client.auth.client_id == client_id 200 | assert pkj_client.auth.private_jwk == private_jwk 201 | 202 | with pytest.raises(ValueError, match="must have a Key Type") as exc: 203 | OAuth2Client(token_endpoint, (client_id, {"foo": "bar"})) 204 | assert exc.type is InvalidJwk 205 | 206 | with pytest.raises(TypeError, match="not supported") as exc2: 207 | OAuth2Client(token_endpoint, (client_id, object())) # type: ignore[arg-type] 208 | assert exc2.type is UnsupportedClientCredentials 209 | -------------------------------------------------------------------------------- /tests/unit_tests/test_device_authorization.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime, timezone 3 | 4 | import pytest 5 | from freezegun import freeze_time 6 | from freezegun.api import FrozenDateTimeFactory 7 | from pytest_mock import MockerFixture 8 | 9 | from requests_oauth2client import ( 10 | BearerToken, 11 | ClientSecretPost, 12 | DeviceAuthorizationError, 13 | DeviceAuthorizationPoolingJob, 14 | DeviceAuthorizationResponse, 15 | InvalidDeviceAuthorizationResponse, 16 | OAuth2Client, 17 | UnauthorizedClient, 18 | ) 19 | from tests.conftest import RequestsMocker, RequestValidatorType 20 | 21 | 22 | def test_device_authorization_response( 23 | device_code: str, 24 | user_code: str, 25 | verification_uri: str, 26 | verification_uri_complete: str, 27 | ) -> None: 28 | response = DeviceAuthorizationResponse( 29 | device_code=device_code, 30 | user_code=user_code, 31 | verification_uri=verification_uri, 32 | verification_uri_complete=verification_uri_complete, 33 | expires_in=180, 34 | interval=10, 35 | ) 36 | 37 | assert not response.is_expired() 38 | assert response.device_code == device_code 39 | assert response.user_code == user_code 40 | assert response.verification_uri == verification_uri 41 | assert response.verification_uri_complete == verification_uri_complete 42 | assert isinstance(response.expires_at, datetime) 43 | assert response.expires_at > datetime.now(tz=timezone.utc) 44 | assert response.interval == 10 45 | 46 | 47 | def test_device_authorization_response_expires_at( 48 | device_code: str, 49 | user_code: str, 50 | verification_uri: str, 51 | verification_uri_complete: str, 52 | ) -> None: 53 | expires_at = datetime(year=2021, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) 54 | response = DeviceAuthorizationResponse( 55 | device_code=device_code, 56 | user_code=user_code, 57 | verification_uri=verification_uri, 58 | verification_uri_complete=verification_uri_complete, 59 | expires_at=expires_at, 60 | interval=10, 61 | ) 62 | 63 | assert response.is_expired() 64 | assert response.device_code == device_code 65 | assert response.user_code == user_code 66 | assert response.verification_uri == verification_uri 67 | assert response.verification_uri_complete == verification_uri_complete 68 | assert response.expires_at == expires_at 69 | assert response.interval == 10 70 | 71 | 72 | def test_device_authorization_response_no_expiration( 73 | device_code: str, 74 | user_code: str, 75 | verification_uri: str, 76 | verification_uri_complete: str, 77 | ) -> None: 78 | response = DeviceAuthorizationResponse( 79 | device_code=device_code, 80 | user_code=user_code, 81 | verification_uri=verification_uri, 82 | verification_uri_complete=verification_uri_complete, 83 | interval=10, 84 | ) 85 | 86 | assert not response.is_expired() 87 | assert response.device_code == device_code 88 | assert response.user_code == user_code 89 | assert response.verification_uri == verification_uri 90 | assert response.verification_uri_complete == verification_uri_complete 91 | assert response.expires_at is None 92 | assert response.interval == 10 93 | 94 | 95 | @pytest.fixture 96 | def device_authorization_client( 97 | token_endpoint: str, 98 | device_authorization_endpoint: str, 99 | client_id: str, 100 | client_secret: str, 101 | ) -> OAuth2Client: 102 | client = OAuth2Client( 103 | token_endpoint=token_endpoint, 104 | device_authorization_endpoint=device_authorization_endpoint, 105 | auth=(client_id, client_secret), 106 | ) 107 | 108 | assert client.device_authorization_endpoint == device_authorization_endpoint 109 | assert isinstance(client.auth, ClientSecretPost) 110 | assert client.auth.client_id == client_id 111 | assert client.auth.client_secret == client_secret 112 | 113 | return client 114 | 115 | 116 | def test_device_authorization_client( 117 | requests_mock: RequestsMocker, 118 | device_authorization_client: OAuth2Client, 119 | device_authorization_endpoint: str, 120 | device_code: str, 121 | user_code: str, 122 | verification_uri: str, 123 | verification_uri_complete: str, 124 | client_secret_post_auth_validator: RequestValidatorType, 125 | client_id: str, 126 | client_secret: str, 127 | ) -> None: 128 | requests_mock.post( 129 | device_authorization_endpoint, 130 | json={ 131 | "device_code": device_code, 132 | "user_code": user_code, 133 | "verification_uri": verification_uri, 134 | "verification_uri_complete": verification_uri_complete, 135 | "expires_in": 300, 136 | "interval": 7, 137 | }, 138 | ) 139 | 140 | device_authorization_client.authorize_device() 141 | assert requests_mock.called_once 142 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 143 | 144 | 145 | def test_device_authorization_client_error( 146 | requests_mock: RequestsMocker, 147 | device_authorization_client: OAuth2Client, 148 | device_authorization_endpoint: str, 149 | client_secret_post_auth_validator: RequestValidatorType, 150 | client_id: str, 151 | client_secret: str, 152 | ) -> None: 153 | requests_mock.post( 154 | device_authorization_endpoint, 155 | status_code=400, 156 | json={ 157 | "error": "unauthorized_client", 158 | }, 159 | ) 160 | 161 | with pytest.raises(UnauthorizedClient): 162 | device_authorization_client.authorize_device() 163 | assert requests_mock.called_once 164 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 165 | 166 | 167 | def test_device_authorization_invalid_errors( 168 | requests_mock: RequestsMocker, 169 | device_authorization_client: OAuth2Client, 170 | device_authorization_endpoint: str, 171 | client_secret_post_auth_validator: RequestValidatorType, 172 | client_id: str, 173 | client_secret: str, 174 | ) -> None: 175 | requests_mock.post( 176 | device_authorization_endpoint, 177 | status_code=400, 178 | json={ 179 | "error": "foo", 180 | }, 181 | ) 182 | 183 | with pytest.raises(DeviceAuthorizationError): 184 | device_authorization_client.authorize_device() 185 | assert requests_mock.called_once 186 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 187 | 188 | requests_mock.reset_mock() 189 | requests_mock.post( 190 | device_authorization_endpoint, 191 | status_code=400, 192 | json={ 193 | "foo": "bar", 194 | }, 195 | ) 196 | 197 | with pytest.raises(InvalidDeviceAuthorizationResponse): 198 | device_authorization_client.authorize_device() 199 | assert requests_mock.called_once 200 | client_secret_post_auth_validator(requests_mock.last_request, client_id=client_id, client_secret=client_secret) 201 | 202 | 203 | @freeze_time() 204 | def test_device_authorization_pooling_job( 205 | requests_mock: RequestsMocker, 206 | token_endpoint: str, 207 | client_id: str, 208 | client_secret: str, 209 | device_code: str, 210 | device_code_grant_validator: RequestValidatorType, 211 | access_token: str, 212 | freezer: FrozenDateTimeFactory, 213 | mocker: MockerFixture, 214 | ) -> None: 215 | interval = 20 216 | client = OAuth2Client(token_endpoint, auth=(client_id, client_secret)) 217 | job = DeviceAuthorizationPoolingJob( 218 | client=client, 219 | device_code=device_code, 220 | interval=interval, 221 | ) 222 | assert job.interval == interval 223 | assert job.slow_down_interval == 5 224 | 225 | assert job == DeviceAuthorizationPoolingJob( 226 | client, 227 | DeviceAuthorizationResponse( 228 | device_code=device_code, user_code="foo", verification_uri="https://foo.bar", interval=interval 229 | ), 230 | ) 231 | 232 | requests_mock.post(token_endpoint, status_code=401, json={"error": "authorization_pending"}) 233 | mocker.patch("time.sleep") 234 | 235 | assert job() is None 236 | time.sleep.assert_called_once_with(interval) # type: ignore[attr-defined] 237 | assert requests_mock.called_once 238 | assert job.interval == interval 239 | device_code_grant_validator(requests_mock.last_request, device_code=device_code) 240 | 241 | requests_mock.reset_mock() 242 | requests_mock.post(token_endpoint, status_code=401, json={"error": "slow_down"}) 243 | time.sleep.reset_mock() # type: ignore[attr-defined] 244 | 245 | assert job() is None 246 | time.sleep.assert_called_once_with(interval) # type: ignore[attr-defined] 247 | assert requests_mock.called_once 248 | assert job.interval == interval + job.slow_down_interval 249 | device_code_grant_validator(requests_mock.last_request, device_code=device_code) 250 | 251 | requests_mock.reset_mock() 252 | requests_mock.post(token_endpoint, json={"access_token": access_token}) 253 | time.sleep.reset_mock() # type: ignore[attr-defined] 254 | 255 | token = job() 256 | time.sleep.assert_called_once_with(interval + job.slow_down_interval) # type: ignore[attr-defined] 257 | assert requests_mock.called_once 258 | assert isinstance(token, BearerToken) 259 | assert token.access_token == access_token 260 | 261 | 262 | def test_no_device_authorization_endpoint(token_endpoint: str, client_id: str, client_secret: str) -> None: 263 | client = OAuth2Client(token_endpoint, (client_id, client_secret)) 264 | with pytest.raises(AttributeError): 265 | client.authorize_device() 266 | -------------------------------------------------------------------------------- /tests/unit_tests/test_discovery.py: -------------------------------------------------------------------------------- 1 | from requests_oauth2client import ( 2 | oauth2_discovery_document_url, 3 | oidc_discovery_document_url, 4 | well_known_uri, 5 | ) 6 | 7 | 8 | def test_well_known_uri() -> None: 9 | assert well_known_uri("http://www.example.com", "example") == "http://www.example.com/.well-known/example" 10 | assert well_known_uri("http://www.example.com/", "example") == "http://www.example.com/.well-known/example" 11 | 12 | assert well_known_uri("http://www.example.com/foo", "example") == "http://www.example.com/.well-known/foo/example" 13 | assert well_known_uri("http://www.example.com/foo/", "example") == "http://www.example.com/.well-known/foo/example" 14 | 15 | assert ( 16 | well_known_uri("http://www.example.com/foo/bar", "example") 17 | == "http://www.example.com/.well-known/foo/bar/example" 18 | ) 19 | assert ( 20 | well_known_uri("http://www.example.com/foo/bar/", "example") 21 | == "http://www.example.com/.well-known/foo/bar/example" 22 | ) 23 | 24 | 25 | def test_oidc_discovery() -> None: 26 | assert oidc_discovery_document_url("https://issuer.com") == "https://issuer.com/.well-known/openid-configuration" 27 | assert ( 28 | oidc_discovery_document_url("https://issuer.com/oidc") 29 | == "https://issuer.com/oidc/.well-known/openid-configuration" 30 | ) 31 | assert ( 32 | oidc_discovery_document_url("https://issuer.com/oidc/") 33 | == "https://issuer.com/oidc/.well-known/openid-configuration" 34 | ) 35 | 36 | 37 | def test_oauth20_discovery() -> None: 38 | assert ( 39 | oauth2_discovery_document_url("https://issuer.com") 40 | == "https://issuer.com/.well-known/oauth-authorization-server" 41 | ) 42 | assert ( 43 | oauth2_discovery_document_url("https://issuer.com/oauth2") 44 | == "https://issuer.com/.well-known/oauth2/oauth-authorization-server" 45 | ) 46 | assert ( 47 | oauth2_discovery_document_url("https://issuer.com/oauth2/") 48 | == "https://issuer.com/.well-known/oauth2/oauth-authorization-server" 49 | ) 50 | -------------------------------------------------------------------------------- /tests/unit_tests/test_flask.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | from typing import Any 3 | from urllib.parse import parse_qs 4 | 5 | import pytest 6 | 7 | from requests_oauth2client import ApiClient, ClientSecretPost, OAuth2Client 8 | from tests.conftest import RequestsMocker 9 | 10 | session_key = "session_key" 11 | 12 | 13 | def test_flask( 14 | requests_mock: RequestsMocker, 15 | token_endpoint: str, 16 | client_id: str, 17 | client_secret: str, 18 | scope: str, 19 | target_api: str, 20 | ) -> None: 21 | try: 22 | from flask import Flask, request 23 | 24 | from requests_oauth2client.flask import FlaskOAuth2ClientCredentialsAuth 25 | except ImportError: 26 | pytest.skip("Flask is not available") 27 | 28 | oauth_client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) 29 | auth = FlaskOAuth2ClientCredentialsAuth( 30 | session_key=session_key, 31 | scope=scope, 32 | client=oauth_client, 33 | ) 34 | api_client = ApiClient(target_api, auth=auth) 35 | 36 | assert isinstance(api_client.auth, FlaskOAuth2ClientCredentialsAuth) 37 | 38 | app = Flask("testapp") 39 | app.config["TESTING"] = True 40 | app.config["SECRET_KEY"] = "thisissecret" 41 | 42 | @app.route("/api") 43 | def get() -> Any: 44 | return api_client.get(params=request.args).json() 45 | 46 | access_token = "access_token" 47 | json_resp = {"status": "success"} 48 | requests_mock.post( 49 | token_endpoint, 50 | json={"access_token": access_token, "token_type": "Bearer", "expires_in": 3600}, 51 | ) 52 | requests_mock.get(target_api, json=json_resp) 53 | 54 | with app.test_client() as client: 55 | resp = client.get("/api?call=1") 56 | assert resp.json == json_resp 57 | resp = client.get("/api?call=2") 58 | assert resp.json == json_resp 59 | api_client.auth.forget_token() 60 | # assert api_client.session.auth.token is None 61 | with client.session_transaction() as sess: 62 | sess.pop(auth.session_key) 63 | # this should trigger a new token request then the API request 64 | resp = client.get("/api?call=3") 65 | assert resp.json == json_resp 66 | 67 | token_request1 = requests_mock.request_history[0] 68 | assert token_request1.url == token_endpoint 69 | token_params = parse_qs(token_request1.text) 70 | assert token_params.get("client_id") == [client_id] 71 | if not scope: 72 | assert token_params.get("scope") is None 73 | elif isinstance(scope, str): 74 | assert token_params.get("scope") == [scope] 75 | elif isinstance(scope, Iterable): 76 | assert token_params.get("scope") == [" ".join(scope)] 77 | assert token_params.get("client_secret") == [client_secret] 78 | 79 | api_request1 = requests_mock.request_history[1] 80 | assert api_request1.url == "https://myapi.local/root/?call=1" 81 | assert api_request1.headers.get("Authorization") == f"Bearer {access_token}" 82 | 83 | api_request2 = requests_mock.request_history[2] 84 | assert api_request2.url == "https://myapi.local/root/?call=2" 85 | assert api_request2.headers.get("Authorization") == f"Bearer {access_token}" 86 | 87 | token_request2 = requests_mock.request_history[3] 88 | assert token_request2.url == token_endpoint 89 | token_params = parse_qs(token_request2.text) 90 | assert token_params.get("client_id") == [client_id] 91 | if not scope: 92 | assert token_params.get("scope") is None 93 | elif isinstance(scope, str): 94 | assert token_params.get("scope") == [scope] 95 | elif isinstance(scope, Iterable): 96 | assert token_params.get("scope") == [" ".join(scope)] 97 | assert token_params.get("client_secret") == [client_secret] 98 | 99 | api_request3 = requests_mock.request_history[4] 100 | assert api_request3.url == "https://myapi.local/root/?call=3" 101 | assert api_request3.headers.get("Authorization") == f"Bearer {access_token}" 102 | 103 | 104 | def test_flask_token_kwarg() -> None: 105 | try: 106 | from flask import Flask 107 | 108 | from requests_oauth2client.flask import FlaskOAuth2ClientCredentialsAuth 109 | except ImportError: 110 | pytest.skip("Flask is not available") 111 | 112 | app = Flask("testapp") 113 | app.config["TESTING"] = True 114 | app.config["SECRET_KEY"] = "thisissecret" 115 | 116 | with app.test_request_context("/"): 117 | auth = FlaskOAuth2ClientCredentialsAuth( 118 | client=None, 119 | session_key=session_key, 120 | token="xyz", 121 | ) 122 | assert auth.token 123 | assert auth.token.access_token == "xyz" 124 | -------------------------------------------------------------------------------- /tests/unit_tests/test_pkce.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import string 4 | 5 | import pytest 6 | 7 | from requests_oauth2client import PkceUtils 8 | 9 | 10 | def test_generate_code_verifier_and_challenge() -> None: 11 | verifier, challenge = PkceUtils.generate_code_verifier_and_challenge() 12 | assert isinstance(verifier, str) 13 | assert 43 <= len(verifier) <= 128 14 | assert set(verifier).issubset(set(string.ascii_letters + string.digits + "_-~.")) 15 | 16 | assert isinstance(challenge, str) 17 | assert len(challenge) == 43 18 | assert set(verifier).issubset(set(string.ascii_letters + string.digits + "_-")) 19 | 20 | assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge 21 | 22 | assert PkceUtils.validate_code_verifier(verifier, challenge) 23 | 24 | 25 | def test_unsupported_challenge_method() -> None: 26 | verifier = PkceUtils.generate_code_verifier() 27 | with pytest.raises(ValueError): 28 | PkceUtils.derive_challenge(verifier, method="foo") 29 | 30 | 31 | def test_challenge_method_plain() -> None: 32 | verifier = PkceUtils.generate_code_verifier() 33 | challenge = PkceUtils.derive_challenge(verifier, method="plain") 34 | assert challenge == verifier 35 | 36 | 37 | def test_invalid_verifier() -> None: 38 | with pytest.raises(ValueError): 39 | PkceUtils.derive_challenge("foo") 40 | 41 | 42 | def test_verifier_bytes() -> None: 43 | challenge = PkceUtils.derive_challenge(b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMOPQRSTUVWXYZ1234567890") 44 | assert challenge == "FYKCx6MubiaOxWp8-ciyDkkkOapyAjR9sxikqOSXLdw" 45 | -------------------------------------------------------------------------------- /tests/unit_tests/test_tokens.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta, timezone 2 | 3 | import jwskate 4 | import pytest 5 | from freezegun import freeze_time 6 | from freezegun.api import FrozenDateTimeFactory 7 | from jwskate import ( 8 | ExpiredJwt, 9 | InvalidClaim, 10 | InvalidJwt, 11 | InvalidSignature, 12 | Jwk, 13 | Jwt, 14 | SignatureAlgs, 15 | SignedJwt, 16 | ) 17 | 18 | from requests_oauth2client import BearerToken, BearerTokenSerializer, IdToken 19 | 20 | ID_TOKEN = ( 21 | "eyJhbGciOiJSUzI1NiIsImtpZCI6Im15X2tleSJ9.eyJhY3IiOiIyIiwiYW1yIjpbInB3ZCIsIm90cCJdLCJhdWQiOiJjbGllbnRfaWQiL" 22 | "CJhdXRoX3RpbWUiOjE2MjkyMDQ1NjAsImV4cCI6MTYyOTIwNDYyMCwiaWF0IjoxNjI5MjA0NTYwLCJpc3MiOiJodHRwczovL215YXMubG9" 23 | "jYWwiLCJub25jZSI6Im5vbmNlIiwic3ViIjoiMTIzNDU2In0.wUfjMyjlOSdvbFGFP8O8wGcNBK7akeyOUBMvYcNZclFUtokOyxhLUPxmo" 24 | "1THo1DV1BHUVd6AWfeKUnyTxl_8-G3E_a9u5wJfDyfghPDhCmfkYARvqQnnV_3aIbfTfUBC4f0bHr08d_q0fED88RLu77wESIPCVqQYy2b" 25 | "k4FLucc63yGBvaCskqzthZ85DbBJYWLlR8qBUk_NA8bWATYEtjwTrxoZe-uA-vB6NwUv1h8DKRsDF-9HSVHeWXXAeoG9UW7zgxoY3KbDIV" 26 | "zemvGzs2R9OgDBRRafBBVeAkDV6CdbdMNJDmHzcjase5jX6LE-3YCy7c7AMM1uWRCnK3f-azA" 27 | ) 28 | PUBLIC_JWK = Jwk( 29 | { 30 | "kty": "RSA", 31 | "alg": "RS256", 32 | "kid": "my_key", 33 | "n": "2m4QVSHdUo2DFSbGY24cJbxE10KbgdkSCtm0YZ1q0Zmna8pJg8YhaWCJHV7D5AxQ_L1b1PK0jsdpGYWc5-Pys0FB2hyABGPxXIdg1mjxn6geHLpWzsA3MHD29oqfl0Rt7g6AFc5St3lBgJCyWtci6QYBmBkX9oIMOx9pgv4BaT6y1DdrNh27-oSMXZ0a58KwnC6jbCpdA3V3Eume-Be1Tx9lJN3j6S8ydT7CGY1Xd-sc3oB8pXfkr1_EYf0Sgb9EwOJfqlNK_kVjT3GZ-1JJMKJ6zkU7H0yXe2SKXAzfayvJaIcYrk-sYwmf-u7yioOLLvjlGjysN7SOSM8socACcw", 34 | "e": "AQAB", 35 | } 36 | ) 37 | 38 | 39 | def test_bearer_token_simple() -> None: 40 | token = BearerToken(access_token="foo") 41 | assert token.access_token == "foo" 42 | assert token.refresh_token is None 43 | assert token.scope is None 44 | assert token.token_type == "Bearer" 45 | assert token.expires_at is None 46 | assert token.expires_in is None 47 | with pytest.raises(AttributeError): 48 | token.foo 49 | 50 | assert token.as_dict() == { 51 | "access_token": "foo", 52 | "token_type": "Bearer", 53 | } 54 | 55 | assert str(token) == "foo" 56 | assert repr(token) 57 | 58 | assert str(token) == "foo" 59 | assert token != 1.2 # type: ignore[comparison-overlap] 60 | 61 | 62 | @freeze_time("2021-08-17 12:50:18") 63 | def test_bearer_token_complete() -> None: 64 | id_token = IdToken.sign( 65 | { 66 | "iss": "https://issuer.local", 67 | "iat": IdToken.timestamp(), 68 | "exp": IdToken.timestamp(60), 69 | "sub": "myuserid", 70 | }, 71 | Jwk.generate_for_alg(SignatureAlgs.RS256), 72 | ) 73 | token = BearerToken( 74 | access_token="foo", 75 | expires_in=180, 76 | scope="myscope1 myscope2", 77 | refresh_token="refresh_token", 78 | custom_attr="custom_value", 79 | id_token=str(id_token), 80 | ) 81 | assert token.access_token == "foo" 82 | assert token.refresh_token == "refresh_token" 83 | assert token.scope == "myscope1 myscope2" 84 | assert token.token_type == "Bearer" 85 | assert token.expires_in == 180 86 | assert token.custom_attr == "custom_value" 87 | assert token.id_token == id_token 88 | assert token.expires_at == datetime(year=2021, month=8, day=17, hour=12, minute=53, second=18, tzinfo=timezone.utc) 89 | with pytest.raises(AttributeError): 90 | token.foo 91 | 92 | assert token.as_dict() == { 93 | "access_token": "foo", 94 | "token_type": "Bearer", 95 | "refresh_token": "refresh_token", 96 | "expires_in": 180, 97 | "scope": "myscope1 myscope2", 98 | "custom_attr": "custom_value", 99 | "id_token": str(id_token), 100 | } 101 | 102 | assert str(token) == "foo" 103 | assert repr(token) 104 | 105 | 106 | @freeze_time("2021-08-17 12:50:18") 107 | def test_nearly_expired_token() -> None: 108 | token = BearerToken( 109 | access_token="foo", 110 | expires_at=datetime(year=2021, month=8, day=17, hour=12, minute=50, second=20, tzinfo=timezone.utc), 111 | ) 112 | assert not token.is_expired() 113 | assert token.is_expired(3) 114 | 115 | 116 | @freeze_time("2021-08-17 12:50:21") 117 | def test_recently_expired_token() -> None: 118 | token = BearerToken( 119 | access_token="foo", 120 | expires_at=datetime(year=2021, month=8, day=17, hour=12, minute=50, second=20, tzinfo=timezone.utc), 121 | ) 122 | assert token.is_expired() 123 | assert token.is_expired(3) 124 | assert not token.is_expired(-3) 125 | 126 | 127 | def test_invalid_token_type() -> None: 128 | with pytest.raises(ValueError): 129 | BearerToken(access_token="foo", token_type="bar") 130 | 131 | 132 | def test_empty_jwt() -> None: 133 | jwt = SignedJwt( 134 | "eyJhbGciOiJSUzI1NiIsImtpZCI6Im15X2tleSJ9.e30.qoopspKRRo0LvRHcBVAjGNOVAnGkfgOmcSTwhRv46RUuEPvoDoodtLq5hINC3TvRm8GidshIU2e-lHZ033Ja4KE5DQSL8pPItjwUxFIQ9qUYhF625bOisufNoE9YK0qDup_jcawRaBWoxkJB9oPSFaV9sCXLBX_szrUI87PPs7GDxXfgpgnztazFizizIdNf29f_FKTKRwldiQz1zaB9D_svOOThQm3ECk0PFbjqlfn7uYxe5l_GDmdgvV479rkySHhgNEC-HrGYD18Kc7Zsl1avvuLV8X-qzj-I8N06Wst8kEVnrGcCm0S4K3HfG4xHzohPQFoIuwdVzDIjSVEfCQ" 135 | ) 136 | 137 | assert jwt.verify_signature(PUBLIC_JWK) 138 | assert jwt.expires_at is None 139 | assert jwt.issued_at is None 140 | assert jwt.not_before is None 141 | assert jwt.issuer is None 142 | assert jwt.alg == "RS256" 143 | 144 | with pytest.raises(InvalidClaim): 145 | jwt.validate(key=PUBLIC_JWK, issuer="foo") 146 | 147 | with pytest.raises(InvalidClaim): 148 | jwt.validate(key=PUBLIC_JWK, audience="foo") 149 | 150 | 151 | def test_jwt_iat_exp_nbf() -> None: 152 | jwt = SignedJwt( 153 | "eyJhbGciOiJSUzI1NiIsImtpZCI6Im15X2tleSJ9.eyJleHAiOjE2MjkzODQ5ODgsImlhdCI6MTYyOTM4NDkyOCwibmJmIjoxNjI5Mzg0ODY4fQ.k_0abUntpK5yVOvalZGnhEhUuq1lmtoRQfKmEJuQpYiHCb3x9buYWclQCMNGzHikiyGtrRqN0RcyUPeGI9QN7hasvj1ItzrhsdXJDO968y3VXjfPnOz2lDPUKJjsTdWXbCGDZD82d4OX8E9WFaOwwutMb_5ismEBvttNAmwHJG433TzEO2rFhno9X3RPo8IqOJg_HSw8Q0BLsub7Ak9I0eGDsb8x5J8_fp6zqGkZaqL35DkLPZSHdLzYalmH4ksH69SVWu-7rD-W1brGxVpJg8unV9fy_1AmiQu-8tIedo68br2Tg0oNekwT-lXMTjmiJkYv8hpnECbtFXMRQSGcvQ" 154 | ) 155 | 156 | assert jwt.verify_signature(PUBLIC_JWK, alg="RS256") 157 | assert jwt.issued_at == datetime(year=2021, month=8, day=19, hour=14, minute=55, second=28, tzinfo=timezone.utc) 158 | assert jwt.expires_at == datetime(year=2021, month=8, day=19, hour=14, minute=56, second=28, tzinfo=timezone.utc) 159 | assert jwt.not_before == datetime(year=2021, month=8, day=19, hour=14, minute=54, second=28, tzinfo=timezone.utc) 160 | 161 | assert jwt.iat == 1629384928 162 | assert jwt.exp == 1629384988 163 | assert jwt.nbf == 1629384868 164 | 165 | 166 | def test_id_token() -> None: 167 | issuer = "https://myas.local" 168 | audience = "client_id" 169 | nonce = "nonce" 170 | acr = "2" 171 | id_token = IdToken( 172 | "eyJhbGciOiJSUzI1NiIsImtpZCI6Im15X2tleSJ9.eyJhY3IiOiIyIiwiYW1yIjpbInB3ZCIsIm90cCJdLCJhdWQiOiJjbGllbnRfaWQiLCJhdXRoX3RpbWUiOjE2MjkyMDQ1NjAsImV4cCI6MTYyOTIwNDYyMCwiaWF0IjoxNjI5MjA0NTYwLCJpc3MiOiJodHRwczovL215YXMubG9jYWwiLCJub25jZSI6Im5vbmNlIiwic3ViIjoiMTIzNDU2In0.wUfjMyjlOSdvbFGFP8O8wGcNBK7akeyOUBMvYcNZclFUtokOyxhLUPxmo1THo1DV1BHUVd6AWfeKUnyTxl_8-G3E_a9u5wJfDyfghPDhCmfkYARvqQnnV_3aIbfTfUBC4f0bHr08d_q0fED88RLu77wESIPCVqQYy2bk4FLucc63yGBvaCskqzthZ85DbBJYWLlR8qBUk_NA8bWATYEtjwTrxoZe-uA-vB6NwUv1h8DKRsDF-9HSVHeWXXAeoG9UW7zgxoY3KbDIVzemvGzs2R9OgDBRRafBBVeAkDV6CdbdMNJDmHzcjase5jX6LE-3YCy7c7AMM1uWRCnK3f-azA" 173 | ) 174 | 175 | with pytest.raises(AttributeError): 176 | id_token.attr_not_found 177 | 178 | id_token.validate( 179 | PUBLIC_JWK, 180 | issuer=issuer, 181 | audience=audience, 182 | nonce=nonce, 183 | check_exp=False, 184 | acr=acr, 185 | ) 186 | 187 | with pytest.raises(ExpiredJwt): 188 | id_token.validate(PUBLIC_JWK, issuer=issuer, audience=audience, nonce=nonce, check_exp=True) 189 | 190 | assert id_token.alg == "RS256" 191 | assert id_token.kid == "my_key" 192 | assert id_token.aud == audience 193 | assert id_token.is_expired() 194 | assert id_token.is_expired(1000) 195 | assert id_token.expires_at == datetime(2021, 8, 17, 12, 50, 20, tzinfo=timezone.utc) 196 | assert id_token.issued_at == datetime(2021, 8, 17, 12, 49, 20, tzinfo=timezone.utc) 197 | 198 | 199 | def test_invalid_jwt() -> None: 200 | issuer = "https://myas.local" 201 | audience = "client_id" 202 | nonce = "nonce" 203 | 204 | id_token = IdToken(ID_TOKEN) 205 | modified_id_token = IdToken( 206 | ID_TOKEN[:-4] + "abcd" # strips a few chars from the signature # replace them with arbitrary data 207 | ) 208 | 209 | # invalid signature 210 | with pytest.raises(InvalidSignature): 211 | modified_id_token.validate(PUBLIC_JWK, issuer=issuer, audience=audience, nonce=nonce, check_exp=False) 212 | 213 | # invalid issuer 214 | with pytest.raises(InvalidClaim): 215 | id_token.validate(PUBLIC_JWK, issuer="foo", audience=audience, nonce=nonce, check_exp=False) 216 | 217 | # invalid audience 218 | with pytest.raises(InvalidClaim): 219 | id_token.validate(PUBLIC_JWK, issuer=issuer, audience="foo", nonce=nonce, check_exp=False) 220 | 221 | # invalid nonce 222 | with pytest.raises(InvalidClaim): 223 | id_token.validate(PUBLIC_JWK, issuer=issuer, audience=audience, nonce="foo", check_exp=False) 224 | 225 | # invalid claim 226 | with pytest.raises(InvalidClaim): 227 | id_token.validate( 228 | PUBLIC_JWK, 229 | issuer=issuer, 230 | audience=audience, 231 | nonce=nonce, 232 | check_exp=False, 233 | acr="4", 234 | ) 235 | 236 | # missing claim 237 | with pytest.raises(InvalidClaim): 238 | id_token.validate( 239 | PUBLIC_JWK, 240 | issuer=issuer, 241 | audience=audience, 242 | nonce=nonce, 243 | check_exp=False, 244 | foo="bar", 245 | ) 246 | 247 | 248 | def test_invalid_token() -> None: 249 | with pytest.raises(InvalidJwt): 250 | IdToken("foo.bar") 251 | 252 | 253 | def test_id_token_eq() -> None: 254 | id_token = IdToken(ID_TOKEN) 255 | 256 | assert id_token == ID_TOKEN 257 | assert id_token != "foo" 258 | assert id_token != 13.37 259 | 260 | 261 | def test_id_token_attributes() -> None: 262 | bad_id_token = IdToken(Jwt.sign({"azp": 1234, "auth_time": -3000}, Jwk.generate(alg="HS256")).value) 263 | with pytest.raises(AttributeError): 264 | bad_id_token.authorized_party 265 | 266 | with pytest.raises(AttributeError): 267 | bad_id_token.auth_datetime 268 | 269 | good_id_token = IdToken(Jwt.sign({"azp": "valid", "auth_time": 1725529281}, Jwk.generate(alg="HS256")).value) 270 | assert good_id_token.authorized_party == "valid" 271 | assert good_id_token.auth_datetime == datetime(2024, 9, 5, 9, 41, 21, tzinfo=timezone.utc) 272 | 273 | 274 | @pytest.mark.parametrize( 275 | "token", 276 | [ 277 | BearerToken("access_token"), 278 | # note that "expires_at" is calculated when the test is ran, so before `freezer` takes effect 279 | BearerToken("access_token", expires_in=60), 280 | BearerToken("access_token", expires_in=-60), 281 | ], 282 | ) 283 | def test_token_serializer(token: BearerToken, freezer: FrozenDateTimeFactory) -> None: 284 | freezer.move_to("2024-08-01") 285 | serializer = BearerTokenSerializer() 286 | candidate = serializer.dumps(token) 287 | freezer.move_to(datetime.now(tz=timezone.utc) + timedelta(days=365)) 288 | assert serializer.loads(candidate) == token 289 | 290 | 291 | @freeze_time() 292 | def test_expires_in_as_str() -> None: 293 | assert BearerToken("access_token", expires_in=60) == BearerToken("access_token", expires_in="60") 294 | assert BearerToken("access_token", expires_in=-60) == BearerToken("access_token", expires_in="-60") 295 | assert BearerToken("access_token", expires_in="foo") == BearerToken("access_token") 296 | 297 | 298 | def test_access_token_jwt() -> None: 299 | assert isinstance( 300 | BearerToken( 301 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" 302 | ).access_token_jwt, 303 | jwskate.SignedJwt, 304 | ) 305 | 306 | with pytest.raises(jwskate.InvalidJwt): 307 | BearerToken("not.a.jwt").access_token_jwt 308 | -------------------------------------------------------------------------------- /tests/unit_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime, timezone 4 | 5 | import pytest 6 | 7 | from requests_oauth2client import InvalidUri, validate_endpoint_uri 8 | from requests_oauth2client.utils import accepts_expires_in 9 | 10 | 11 | def test_validate_uri() -> None: 12 | validate_endpoint_uri("https://myas.local/token") 13 | validate_endpoint_uri("https://myas.local:443/token", no_port=True) 14 | with pytest.raises(ValueError, match="https") as exc: 15 | validate_endpoint_uri("http://myas.local/token") 16 | assert exc.type is InvalidUri 17 | with pytest.raises(ValueError, match="path") as exc: 18 | validate_endpoint_uri("https://myas.local") 19 | assert exc.type is InvalidUri 20 | with pytest.raises(ValueError, match="fragment") as exc: 21 | validate_endpoint_uri("https://myas.local/token#foo") 22 | assert exc.type is InvalidUri 23 | with pytest.raises(ValueError, match="credentials") as exc: 24 | validate_endpoint_uri("https://user:passwd@myas.local/token") 25 | assert exc.type is InvalidUri 26 | with pytest.raises(ValueError, match="port") as exc: 27 | validate_endpoint_uri("https://myas.local:1234/token", no_port=True) 28 | assert exc.type is InvalidUri 29 | 30 | 31 | @pytest.mark.parametrize("expires_in", [10, "10"]) 32 | def test_accepts_expires_in(expires_in: int | str) -> None: 33 | @accepts_expires_in 34 | def foo(expires_at: datetime | None = None) -> datetime | None: 35 | return expires_at 36 | 37 | now = datetime.now(tz=timezone.utc) 38 | assert foo(expires_at=now) == now 39 | assert foo(now) == now 40 | assert isinstance(foo(expires_in=expires_in), datetime) 41 | assert foo() is None 42 | -------------------------------------------------------------------------------- /tests/unit_tests/vendor_specific/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guillp/requests_oauth2client/810e7b6099ac89742adbf7877d7a8b0f785c4016/tests/unit_tests/vendor_specific/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/vendor_specific/test_auth0.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from requests_oauth2client import OAuth2ClientCredentialsAuth 4 | from requests_oauth2client.vendor_specific import Auth0 5 | 6 | 7 | def test_auth0_management() -> None: 8 | auth0api = Auth0.management_api_client("test.eu.auth0.com", ("client_id", "client_secret")) 9 | assert auth0api.auth is not None 10 | assert isinstance(auth0api.auth, OAuth2ClientCredentialsAuth) 11 | assert auth0api.auth.client is not None 12 | assert auth0api.auth.client.token_endpoint == "https://test.eu.auth0.com/oauth/token" 13 | assert auth0api.auth.token_kwargs == {"audience": "https://test.eu.auth0.com/api/v2/"} 14 | 15 | 16 | def test_auth0_client() -> None: 17 | auth0client = Auth0.client("test.eu.auth0.com", ("client_id", "client_secret")) 18 | assert auth0client.token_endpoint == "https://test.eu.auth0.com/oauth/token" 19 | assert auth0client.revocation_endpoint == "https://test.eu.auth0.com/oauth/revoke" 20 | assert auth0client.userinfo_endpoint == "https://test.eu.auth0.com/userinfo" 21 | assert auth0client.jwks_uri == "https://test.eu.auth0.com/.well-known/jwks.json" 22 | 23 | 24 | def test_auth0_client_short_tenant_name() -> None: 25 | auth0client = Auth0.client("test.eu", ("client_id", "client_secret")) 26 | assert auth0client.token_endpoint == "https://test.eu.auth0.com/oauth/token" 27 | assert auth0client.revocation_endpoint == "https://test.eu.auth0.com/oauth/revoke" 28 | assert auth0client.userinfo_endpoint == "https://test.eu.auth0.com/userinfo" 29 | assert auth0client.jwks_uri == "https://test.eu.auth0.com/.well-known/jwks.json" 30 | 31 | 32 | def test_tenant() -> None: 33 | assert Auth0.tenant("https://mytenant.eu.auth0.com") == "mytenant.eu.auth0.com" 34 | assert Auth0.tenant("mytenant.eu") == "mytenant.eu.auth0.com" 35 | with pytest.raises(ValueError): 36 | Auth0.tenant("ftp://mytenant.eu") 37 | with pytest.raises(ValueError): 38 | Auth0.tenant("") 39 | -------------------------------------------------------------------------------- /tests/unit_tests/vendor_specific/test_ping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from requests_oauth2client.vendor_specific import Ping 4 | 5 | 6 | def test_ping_client() -> None: 7 | ping_client = Ping.client("mydomain.tld", auth=("client_id", "client_secret")) 8 | assert ping_client.token_endpoint == "https://mydomain.tld/as/token.oauth2" 9 | assert ping_client.authorization_endpoint == "https://mydomain.tld/as/authorization.oauth2" 10 | assert ping_client.token_endpoint == "https://mydomain.tld/as/token.oauth2" 11 | assert ping_client.revocation_endpoint == "https://mydomain.tld/as/revoke_token.oauth2" 12 | assert ping_client.userinfo_endpoint == "https://mydomain.tld/idp/userinfo.openid" 13 | assert ping_client.introspection_endpoint == "https://mydomain.tld/as/introspect.oauth2" 14 | assert ping_client.jwks_uri == "https://mydomain.tld/pf/JWKS" 15 | assert ping_client.extra_metadata["registration_endpoint"] == "https://mydomain.tld/as/clients.oauth2" 16 | assert ( 17 | ping_client.extra_metadata["ping_revoked_sris_endpoint"] 18 | == "https://mydomain.tld/pf-ws/rest/sessionMgmt/revokedSris" 19 | ) 20 | assert ( 21 | ping_client.extra_metadata["ping_session_management_sris_endpoint"] 22 | == "https://mydomain.tld/pf-ws/rest/sessionMgmt/sessions" 23 | ) 24 | assert ( 25 | ping_client.extra_metadata["ping_session_management_users_endpoint"] 26 | == "https://mydomain.tld/pf-ws/rest/sessionMgmt/users" 27 | ) 28 | assert ping_client.extra_metadata["ping_end_session_endpoint"] == "https://mydomain.tld/idp/startSLO.ping" 29 | assert ping_client.device_authorization_endpoint == "https://mydomain.tld/as/device_authz.oauth2" 30 | 31 | 32 | def test_ping_invalid_domain() -> None: 33 | with pytest.raises(ValueError): 34 | Ping.client("foo") 35 | with pytest.raises(ValueError): 36 | Ping.client("ftp://foo.bar") 37 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | isolated_build = true 3 | envlist = py39, py310, py311, py312, py313, lint 4 | 5 | [gh-actions] 6 | python = 7 | 3.13: py313 8 | 3.12: py312 9 | 3.11: py311 10 | 3.10: py310 11 | 3.9: py39 12 | 13 | [testenv:lint] 14 | whitelist_externals = 15 | isort 16 | black 17 | flake8 18 | poetry 19 | mkdocs 20 | twine 21 | extras = 22 | test 23 | doc 24 | dev 25 | commands = 26 | mdformat --wrap 120 README.md 27 | isort requests_oauth2client 28 | black requests_oauth2client tests 29 | flake8 requests_oauth2client tests 30 | mypy requests_oauth2client 31 | poetry build 32 | mkdocs build 33 | twine check dist/* 34 | 35 | [testenv] 36 | allowlist_externals = 37 | poetry 38 | commands_pre = 39 | poetry install --no-root --sync -E test 40 | passenv = * 41 | setenv = 42 | PYTHONPATH = {toxinidir} 43 | PYTHONWARNINGS = ignore 44 | commands = 45 | poetry run pytest -s --cov=requests_oauth2client --cov-append --cov-report=xml --cov-report term-missing tests 46 | --------------------------------------------------------------------------------