├── tests ├── __init__.py ├── requirements.txt ├── util.py ├── test_redirect_uri_config.py ├── test_user_session.py ├── test_auth_response_handler.py ├── test_provider_configuration.py ├── test_pyoidc_facade.py └── test_flask_pyoidc.py ├── example ├── __init__.py ├── app.py └── test_example_app.py ├── src └── flask_pyoidc │ ├── __init__.py │ ├── message_factory.py │ ├── parse_fragment.html │ ├── redirect_uri_config.py │ ├── user_session.py │ ├── auth_response_handler.py │ ├── provider_configuration.py │ ├── pyoidc_facade.py │ └── flask_pyoidc.py ├── .bumpversion.cfg ├── .readthedocs.yml ├── .gitignore ├── .coveragerc ├── docs ├── index.rst ├── api.rst ├── Makefile ├── make.bat ├── conf.py ├── configuration.md └── quickstart.md ├── tox.ini ├── setup.py ├── .github └── workflows │ ├── release.yml │ └── ci.yml ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | responses 4 | -------------------------------------------------------------------------------- /src/flask_pyoidc/__init__.py: -------------------------------------------------------------------------------- 1 | from .flask_pyoidc import OIDCAuthentication 2 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 3.14.3 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | python: 7 | install: 8 | - method: pip 9 | path: . 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | build/ 4 | dist/ 5 | .idea/ 6 | .tox/ 7 | .cache/ 8 | .coverage* 9 | coverage.xml 10 | .pytest_cache 11 | .venv/ 12 | docs/_build 13 | docs/_static 14 | docs/_templates 15 | venv 16 | -------------------------------------------------------------------------------- /src/flask_pyoidc/message_factory.py: -------------------------------------------------------------------------------- 1 | from oic.oauth2.message import AccessTokenResponse, CCAccessTokenRequest, MessageTuple, OauthMessageFactory 2 | 3 | 4 | class CCMessageFactory(OauthMessageFactory): 5 | """Client Credential Request Factory.""" 6 | token_endpoint = MessageTuple(CCAccessTokenRequest, AccessTokenResponse) 7 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | from jwkest.jwk import SYMKey 2 | from oic import rndstr 3 | from oic.oic import IdToken 4 | 5 | 6 | def signed_id_token(claims): 7 | id_token = IdToken(**claims) 8 | signing_key = SYMKey(alg='HS256', key=rndstr()) 9 | jws = id_token.to_jwt(key=[signing_key], algorithm=signing_key.alg) 10 | return jws, signing_key 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = src/flask_pyoidc 4 | 5 | [paths] 6 | source = 7 | src/flask_pyoidc 8 | */site-packages/flask_pyoidc 9 | 10 | [report] 11 | exclude_lines = 12 | raise NotImplementedError() 13 | if __name__ == .__main__.: 14 | ignore_errors = True 15 | omit = 16 | tests/* 17 | setup.py 18 | show_missing = true 19 | precision = 2 20 | -------------------------------------------------------------------------------- /src/flask_pyoidc/parse_fragment.html: -------------------------------------------------------------------------------- 1 | 2 | 15 | 16 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to flask-pyoidc's documentation! 2 | ======================================== 3 | 4 | In addition to this documentation, you may have a look on some 5 | `example code`_ 6 | 7 | .. _example code: https://github.com/zamzterz/Flask-pyoidc/tree/master/example 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | quickstart 14 | configuration 15 | api 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = clean,py36,py37,py38,py39,py310,py311 3 | 4 | [testenv] 5 | commands = pytest --cov={envsitepackagesdir}/flask_pyoidc --cov-append --cov-report=term-missing tests/ example/ 6 | deps = -rtests/requirements.txt 7 | setenv = COVERAGE_FILE = .coverage.{envname} 8 | 9 | [testenv:clean] 10 | deps = coverage 11 | skip_install = true 12 | commands = coverage erase 13 | 14 | [flake8] 15 | max_line_length = 120 16 | 17 | [gh-actions] 18 | python = 19 | 3.7: py37 20 | 3.8: py38 21 | 3.9: py39 22 | 3.10: py310 23 | 3.11: py311 24 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API 4 | === 5 | 6 | Flask-pyoidc extension 7 | ---------------------- 8 | .. autoclass:: flask_pyoidc.OIDCAuthentication 9 | :members: 10 | 11 | Configuration 12 | ------------- 13 | .. automodule:: flask_pyoidc.provider_configuration 14 | :members: 15 | 16 | User session handling 17 | --------------------- 18 | .. automodule:: flask_pyoidc.user_session 19 | :members: 20 | 21 | Internals 22 | --------- 23 | .. automodule:: flask_pyoidc.auth_response_handler 24 | :members: 25 | 26 | .. automodule:: flask_pyoidc.pyoidc_facade 27 | :members: 28 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('README.md') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='Flask-pyoidc', 8 | version='3.14.3', 9 | packages=['flask_pyoidc'], 10 | package_dir={'': 'src'}, 11 | url='https://github.com/zamzterz/flask-pyoidc', 12 | license='Apache 2.0', 13 | author='Samuel Gulliksson', 14 | author_email='samuel.gulliksson@gmail.com', 15 | description='Flask extension for OpenID Connect authentication.', 16 | install_requires=[ 17 | 'oic==1.6.1', 18 | 'Flask', 19 | 'requests', 20 | 'importlib_resources' 21 | ], 22 | package_data={'flask_pyoidc': ['parse_fragment.html']}, 23 | long_description=long_description, 24 | long_description_content_type='text/markdown', 25 | ) 26 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | tags: ['*'] 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions/setup-python@v3 13 | with: 14 | python-version: '3.10' 15 | - name: Install pypa/build 16 | run: python -m pip install build --user 17 | - name: Build a binary wheel and a source tarball 18 | run: >- 19 | python -m 20 | build 21 | --sdist 22 | --wheel 23 | --outdir dist/ 24 | . 25 | - name: Publish distribution 📦 to PyPI 26 | if: startsWith(github.ref, 'refs/tags') 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | with: 29 | password: ${{ secrets.PYPI_API_TOKEN }} 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/test_redirect_uri_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from flask_pyoidc.redirect_uri_config import RedirectUriConfig 3 | 4 | 5 | class TestRedirectUriConfig: 6 | LEGACY_CONFIG = {'SERVER_NAME': 'example.com', 'PREFERRED_URL_SCHEME': 'http'} 7 | 8 | def test_legacy_config_defaults(self): 9 | config = RedirectUriConfig.from_config(self.LEGACY_CONFIG) 10 | assert config.endpoint == 'redirect_uri' 11 | assert config.full_uri == 'http://example.com/redirect_uri' 12 | 13 | def test_legacy_config_endpoint(self): 14 | config = RedirectUriConfig.from_config({'OIDC_REDIRECT_ENDPOINT': '/foo', **self.LEGACY_CONFIG}) 15 | assert config.endpoint == 'foo' 16 | 17 | def test_legacy_config_domain(self): 18 | config = { 19 | 'OIDC_REDIRECT_DOMAIN': 'other.example.com:6000', # should be preferred over SERVER_NAME 20 | **self.LEGACY_CONFIG 21 | } 22 | redirect_uri_config = RedirectUriConfig.from_config(config) 23 | assert redirect_uri_config.full_uri == 'http://other.example.com:6000/redirect_uri' 24 | 25 | def test_redirect_uri_config(self): 26 | config = { 27 | 'OIDC_REDIRECT_URI': 'https://myexample.com:6000/callback', # should be preferred over all other config 28 | 'OIDC_REDIRECT_DOMAIN': 'other.example.com:6000', 29 | **self.LEGACY_CONFIG 30 | } 31 | redirect_uri_config = RedirectUriConfig.from_config(config) 32 | assert redirect_uri_config.full_uri == 'https://myexample.com:6000/callback' 33 | assert redirect_uri_config.endpoint == 'callback' 34 | assert repr(redirect_uri_config) == f'({redirect_uri_config.full_uri}, {redirect_uri_config.endpoint})' 35 | 36 | def test_should_raise_if_missing_all_config(self): 37 | with pytest.raises(ValueError) as exc_info: 38 | RedirectUriConfig.from_config({}) 39 | assert 'OIDC_REDIRECT_URI' in str(exc_info.value) 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | tests: 10 | name: Python ${{ matrix.python-version }} 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: pip install -U setuptools tox tox-gh-actions pip virtualenv 24 | - name: Run tox targets for ${{ matrix.python-version }} 25 | run: "python -m tox" 26 | - name: Upload coverage data 27 | uses: actions/upload-artifact@v2 28 | with: 29 | name: coverage-data 30 | path: ".coverage.*" 31 | if-no-files-found: ignore 32 | 33 | coverage: 34 | name: Combine & check coverage. 35 | runs-on: ubuntu-latest 36 | needs: tests 37 | steps: 38 | - uses: actions/checkout@v2 39 | - uses: actions/setup-python@v2 40 | with: 41 | # Use latest, so it understands all syntax. 42 | python-version: "3.10" 43 | 44 | - run: python -m pip install --upgrade coverage[toml] 45 | 46 | - name: Download coverage data. 47 | uses: actions/download-artifact@v2 48 | with: 49 | name: coverage-data 50 | 51 | - name: Combine coverage & fail if it's <100%. 52 | run: | 53 | python -m coverage combine 54 | python -m coverage html --skip-covered --skip-empty 55 | python -m coverage report --fail-under=97 56 | 57 | - name: Upload HTML report if check failed. 58 | uses: actions/upload-artifact@v2 59 | with: 60 | name: html-report 61 | path: htmlcov 62 | if: ${{ failure() }} 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flask-pyoidc 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/flask-pyoidc.svg)](https://pypi.python.org/pypi/Flask-pyoidc) 4 | [![codecov.io](https://codecov.io/github/zamzterz/Flask-pyoidc/coverage.svg?branch=master)](https://codecov.io/github/its-dirg/Flask-pyoidc?branch=master) 5 | [![Build Status](https://travis-ci.org/zamzterz/Flask-pyoidc.svg?branch=master)](https://travis-ci.org/zamzterz/Flask-pyoidc) 6 | 7 | This Flask extension provides simple OpenID Connect authentication, backed by [pyoidc](https://github.com/rohe/pyoidc). 8 | 9 | ["Authorization Code Flow"](http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth), 10 | ["Implicit Flow"](https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth), 11 | ["Hybrid Flow"](https://openid.net/specs/openid-connect-core-1_0.html#HybridFlowAuth), 12 | ["Client Credentials Flow"](https://oauth.net/2/grant-types/client-credentials/) are supported. 13 | 14 | ## Getting started 15 | Read [the documentation](https://flask-pyoidc.readthedocs.io/) or have a look at the 16 | [example Flask app](example/app.py) for a full example of how to use this extension. 17 | 18 | Below is a basic example of how to get started: 19 | ```python 20 | app = Flask(__name__) 21 | app.config.update( 22 | OIDC_REDIRECT_URI = 'https://example.com/redirect_uri', 23 | SECRET_KEY = ... 24 | ) 25 | 26 | # Static Client Registration 27 | client_metadata = ClientMetadata( 28 | client_id='client1', 29 | client_secret='secret1', 30 | post_logout_redirect_uris=['https://example.com/logout']) 31 | 32 | 33 | provider_config = ProviderConfiguration(issuer='', 34 | client_metadata=client_metadata) 35 | 36 | auth = OIDCAuthentication({'default': provider_config}, app) 37 | 38 | @app.route('/') 39 | @auth.oidc_auth('default') # endpoint will require login 40 | def index(): 41 | user_session = UserSession(flask.session) 42 | return jsonify(access_token=user_session.access_token, 43 | id_token=user_session.id_token, 44 | userinfo=user_session.userinfo) 45 | ``` 46 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'flask-pyoidc' 21 | copyright = '2022, Samuel Gulliksson' 22 | author = 'Samuel Gulliksson' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.napoleon', 33 | 'recommonmark' 34 | ] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 43 | 44 | autoclass_content = 'both' 45 | master_doc = 'index' 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = 'default' 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ['_static'] 58 | -------------------------------------------------------------------------------- /src/flask_pyoidc/redirect_uri_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Samuel Gulliksson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | from urllib.parse import urlparse 16 | 17 | 18 | class RedirectUriConfig: 19 | def __init__(self, full_uri, endpoint): 20 | self.full_uri = full_uri 21 | self.endpoint = endpoint 22 | 23 | def __eq__(self, other): 24 | return self.full_uri == other.full_uri and self.endpoint == other.endpoint 25 | 26 | def __str__(self): 27 | return '(' + self.full_uri + ', ' + self.endpoint + ')' 28 | 29 | def __repr__(self): 30 | return str(self) 31 | 32 | @classmethod 33 | def from_config(cls, config): 34 | if 'OIDC_REDIRECT_URI' in config: 35 | return cls(*RedirectUriConfig._parse_redirect_uri(config['OIDC_REDIRECT_URI'])) 36 | 37 | return cls(*RedirectUriConfig._parse_legacy_config(config)) 38 | 39 | @staticmethod 40 | def _parse_redirect_uri(redirect_uri): 41 | parsed = urlparse(redirect_uri) 42 | endpoint = parsed.path.lstrip('/') 43 | return redirect_uri, endpoint 44 | 45 | @staticmethod 46 | def _parse_legacy_config(config): 47 | redirect_domain = config.get('OIDC_REDIRECT_DOMAIN', config.get('SERVER_NAME')) 48 | if not redirect_domain: 49 | raise ValueError("'OIDC_REDIRECT_URI' must be configured.") 50 | 51 | scheme = config.get('PREFERRED_URL_SCHEME', 'http') 52 | 53 | warnings.warn( 54 | "Please use 'OIDC_REDIRECT_URI' to configure the redirect_uri for flask-pyoidc. 'OIDC_REDIRECT_DOMAIN' and 'OIDC_REDIRECT_ENDPOINT' have been deprecated.", 55 | DeprecationWarning, 56 | stacklevel=2 57 | ) 58 | 59 | endpoint = config.get('OIDC_REDIRECT_ENDPOINT', 'redirect_uri').lstrip('/') 60 | full_uri = scheme + '://' + redirect_domain + '/' + endpoint 61 | 62 | return full_uri, endpoint 63 | -------------------------------------------------------------------------------- /example/app.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import flask 3 | import logging 4 | from flask import Flask, jsonify 5 | 6 | from flask_pyoidc import OIDCAuthentication 7 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ClientMetadata 8 | from flask_pyoidc.user_session import UserSession 9 | 10 | app = Flask(__name__) 11 | # See https://flask.palletsprojects.com/en/2.0.x/config/ 12 | app.config.update({'OIDC_REDIRECT_URI': 'http://localhost:5000/redirect_uri', 13 | 'SECRET_KEY': 'dev_key', # make sure to change this!! 14 | 'PERMANENT_SESSION_LIFETIME': datetime.timedelta(days=7).total_seconds(), 15 | 'DEBUG': True}) 16 | 17 | ISSUER1 = 'https://provider1.example.com' 18 | CLIENT1 = 'client@provider1' 19 | PROVIDER_NAME1 = 'provider1' 20 | PROVIDER_CONFIG1 = ProviderConfiguration(issuer=ISSUER1, 21 | client_metadata=ClientMetadata(CLIENT1, 'secret1')) 22 | ISSUER2 = 'https://provider2.example.com' 23 | CLIENT2 = 'client@provider2' 24 | PROVIDER_NAME2 = 'provider2' 25 | PROVIDER_CONFIG2 = ProviderConfiguration(issuer=ISSUER2, 26 | client_metadata=ClientMetadata(CLIENT2, 'secret2')) 27 | auth = OIDCAuthentication({PROVIDER_NAME1: PROVIDER_CONFIG1, PROVIDER_NAME2: PROVIDER_CONFIG2}) 28 | 29 | 30 | @app.route('/') 31 | @auth.oidc_auth(PROVIDER_NAME1) 32 | def login1(): 33 | user_session = UserSession(flask.session) 34 | return jsonify(access_token=user_session.access_token, 35 | id_token=user_session.id_token, 36 | userinfo=user_session.userinfo) 37 | 38 | 39 | @app.route('/login2') 40 | @auth.oidc_auth(PROVIDER_NAME2) 41 | def login2(): 42 | user_session = UserSession(flask.session) 43 | return jsonify(access_token=user_session.access_token, 44 | id_token=user_session.id_token, 45 | userinfo=user_session.userinfo) 46 | 47 | 48 | @app.route('/api') 49 | @auth.token_auth(PROVIDER_NAME1, 50 | scopes_required=['read', 'write']) 51 | def api(): 52 | current_token_identity = auth.current_token_identity 53 | return current_token_identity 54 | 55 | 56 | @app.route('/profile') 57 | @auth.access_control(PROVIDER_NAME1, 58 | scopes_required=['read', 'write']) 59 | def profile(): 60 | if auth.current_token_identity: 61 | return auth.current_token_identity 62 | else: 63 | user_session = UserSession(flask.session) 64 | return jsonify(access_token=user_session.access_token, 65 | id_token=user_session.id_token, 66 | userinfo=user_session.userinfo) 67 | 68 | 69 | @app.route('/logout') 70 | @auth.oidc_logout 71 | def logout(): 72 | return "You've been successfully logged out!" 73 | 74 | 75 | @auth.error_view 76 | def error(error=None, error_description=None): 77 | return jsonify({'error': error, 'message': error_description}) 78 | 79 | 80 | if __name__ == '__main__': 81 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 82 | auth.init_app(app) 83 | app.run() 84 | -------------------------------------------------------------------------------- /example/test_example_app.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import json 4 | import pytest 5 | import responses 6 | from oic.oic import IdToken 7 | from six.moves.urllib.parse import parse_qsl, urlencode, urlparse 8 | 9 | from example.app import ISSUER1, ISSUER2, CLIENT1, CLIENT2 10 | from .app import app, auth 11 | 12 | 13 | class TestExampleApp(object): 14 | PROVIDER1_METADATA = { 15 | 'issuer': ISSUER1, 16 | 'authorization_endpoint': ISSUER1 + '/auth', 17 | 'jwks_uri': ISSUER1 + '/jwks', 18 | 'token_endpoint': ISSUER1 + '/token', 19 | 'userinfo_endpoint': ISSUER1 + '/userinfo' 20 | } 21 | PROVIDER2_METADATA = { 22 | 'issuer': ISSUER2, 23 | 'authorization_endpoint': ISSUER2 + '/auth', 24 | 'jwks_uri': ISSUER2 + '/jwks', 25 | 'token_endpoint': ISSUER2 + '/token', 26 | 'userinfo_endpoint': ISSUER2 + '/userinfo' 27 | } 28 | USER_ID = 'user1' 29 | 30 | @pytest.fixture(scope='session', autouse=True) 31 | def setup(self): 32 | app.testing = True 33 | 34 | with responses.RequestsMock() as r: 35 | # mock provider discovery 36 | r.add(responses.GET, ISSUER1 + '/.well-known/openid-configuration', json=self.PROVIDER1_METADATA) 37 | r.add(responses.GET, ISSUER2 + '/.well-known/openid-configuration', json=self.PROVIDER2_METADATA) 38 | auth.init_app(app) 39 | 40 | @responses.activate 41 | def perform_authentication(self, test_client, login_endpoint, client_id, provider_metadata): 42 | # index page should make auth request 43 | auth_redirect = test_client.get(login_endpoint) 44 | parsed_auth_request = dict(parse_qsl(urlparse(auth_redirect.location).query)) 45 | 46 | now = int(time.time()) 47 | # mock token response 48 | id_token = IdToken(iss=provider_metadata['issuer'], 49 | aud=client_id, 50 | sub=self.USER_ID, 51 | exp=now + 10, 52 | iat=now, 53 | nonce=parsed_auth_request['nonce']) 54 | token_response = {'access_token': 'test_access_token', 'token_type': 'Bearer', 'id_token': id_token.to_jwt()} 55 | responses.add(responses.POST, provider_metadata['token_endpoint'], json=token_response) 56 | 57 | # mock userinfo response 58 | userinfo = {'sub': self.USER_ID, 'name': 'Test User'} 59 | responses.add(responses.GET, provider_metadata['userinfo_endpoint'], json=userinfo) 60 | 61 | # fake auth response sent to redirect URI 62 | fake_auth_response = 'code=fake_auth_code&state={}'.format(parsed_auth_request['state']) 63 | logged_in_page = test_client.get('/redirect_uri?{}'.format(fake_auth_response), follow_redirects=True) 64 | result = json.loads(logged_in_page.data.decode('utf-8')) 65 | 66 | assert result['access_token'] == 'test_access_token' 67 | assert result['id_token'] == id_token.to_dict() 68 | assert result['userinfo'] == {'sub': self.USER_ID, 'name': 'Test User'} 69 | 70 | @pytest.mark.parametrize('login_endpoint, client_id, provider_metadata', [ 71 | ('/', CLIENT1, PROVIDER1_METADATA), 72 | ('/login2', CLIENT2, PROVIDER2_METADATA), 73 | ]) 74 | def test_login_logout(self, login_endpoint, client_id, provider_metadata): 75 | client = app.test_client() 76 | 77 | self.perform_authentication(client, login_endpoint, client_id, provider_metadata) 78 | 79 | response = client.get('/logout') 80 | assert response.data.decode('utf-8') == "You've been successfully logged out!" 81 | 82 | def test_error_view(self): 83 | client = app.test_client() 84 | 85 | auth_redirect = client.get('/') 86 | parsed_auth_request = dict(parse_qsl(urlparse(auth_redirect.location).query)) 87 | 88 | # fake auth error response sent to redirect_uri 89 | error_auth_response = { 90 | 'error': 'invalid_request', 91 | 'error_description': 'test error', 92 | 'state': parsed_auth_request['state'] 93 | } 94 | error_page = client.get('/redirect_uri?{}'.format(urlencode(error_auth_response)), follow_redirects=True) 95 | 96 | assert json.loads(error_page.data.decode('utf-8')) == { 97 | 'error': error_auth_response['error'], 98 | 'message': error_auth_response['error_description'] 99 | } 100 | -------------------------------------------------------------------------------- /src/flask_pyoidc/user_session.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class UninitialisedSession(Exception): 5 | pass 6 | 7 | 8 | class UserSession: 9 | """Session object for user login state. 10 | 11 | Wraps comparison of times necessary for session handling. 12 | """ 13 | 14 | KEYS = [ 15 | 'access_token', 16 | 'access_token_expires_at', 17 | 'current_provider', 18 | 'id_token', 19 | 'id_token_jwt', 20 | 'last_authenticated', 21 | 'last_session_refresh', 22 | 'userinfo', 23 | 'refresh_token' 24 | ] 25 | 26 | def __init__(self, session_storage, provider_name=None): 27 | self._session_storage = session_storage 28 | if 'current_provider' not in self._session_storage and not provider_name: 29 | raise UninitialisedSession("Trying to pick-up uninitialised session without specifying 'provider_name'") 30 | 31 | if provider_name: 32 | if 'current_provider' in self._session_storage and \ 33 | provider_name != self._session_storage['current_provider']: 34 | # provider has changed, initialise new session 35 | self.clear() 36 | 37 | self._session_storage['current_provider'] = provider_name 38 | 39 | def is_authenticated(self): 40 | """ 41 | flask_session is empty when the session hasn't been initialised or has expired. 42 | Thus checking for existence of any item is enough to determine if we're authenticated. 43 | """ 44 | 45 | return self._session_storage.get('last_authenticated') is not None 46 | 47 | def should_refresh(self, refresh_interval_seconds=None): 48 | return refresh_interval_seconds is not None and \ 49 | self._session_storage.get('last_session_refresh') is not None and \ 50 | self._refresh_time(refresh_interval_seconds) < time.time() 51 | 52 | def _refresh_time(self, refresh_interval_seconds): 53 | last = self._session_storage.get('last_session_refresh', 0) 54 | return last + refresh_interval_seconds 55 | 56 | def update(self, *, 57 | access_token=None, 58 | expires_in=None, 59 | id_token=None, 60 | id_token_jwt=None, 61 | userinfo=None, 62 | refresh_token=None): 63 | """ 64 | Args: 65 | access_token (str) 66 | expires_in (int) 67 | id_token (Mapping[str, str]) 68 | id_token_jwt (str) 69 | userinfo (Mapping[str, str]) 70 | """ 71 | 72 | def set_if_defined(session_key, value): 73 | if value: 74 | self._session_storage[session_key] = value 75 | 76 | now = int(time.time()) 77 | auth_time = now 78 | if id_token: 79 | auth_time = id_token.get('auth_time', auth_time) 80 | 81 | self._session_storage['last_authenticated'] = auth_time 82 | self._session_storage['last_session_refresh'] = now 83 | set_if_defined('access_token', access_token) 84 | set_if_defined('access_token_expires_at', now + expires_in if expires_in else None) 85 | set_if_defined('id_token', id_token) 86 | set_if_defined('id_token_jwt', id_token_jwt) 87 | set_if_defined('userinfo', userinfo) 88 | set_if_defined('refresh_token', refresh_token) 89 | 90 | def clear(self): 91 | for key in self.KEYS: 92 | self._session_storage.pop(key, None) 93 | 94 | @property 95 | def access_token(self): 96 | return self._session_storage.get('access_token') 97 | 98 | @property 99 | def access_token_expires_at(self): 100 | return self._session_storage.get('access_token_expires_at') 101 | 102 | @property 103 | def refresh_token(self): 104 | return self._session_storage.get('refresh_token') 105 | 106 | @property 107 | def id_token(self): 108 | return self._session_storage.get('id_token') 109 | 110 | @property 111 | def id_token_jwt(self): 112 | return self._session_storage.get('id_token_jwt') 113 | 114 | @property 115 | def userinfo(self): 116 | return self._session_storage.get('userinfo') 117 | 118 | @property 119 | def current_provider(self): 120 | return self._session_storage.get('current_provider') 121 | 122 | @property 123 | def last_authenticated(self): 124 | return self._session_storage.get('last_authenticated') 125 | -------------------------------------------------------------------------------- /tests/test_user_session.py: -------------------------------------------------------------------------------- 1 | import time 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | from flask_pyoidc.user_session import UserSession, UninitialisedSession 6 | 7 | 8 | class TestUserSession: 9 | PROVIDER_NAME = 'test_provider' 10 | 11 | def initialised_session(self, session_storage): 12 | return UserSession(session_storage, self.PROVIDER_NAME) 13 | 14 | def test_initialising_session_with_existing_user_session_should_preserve_state(self): 15 | storage = {} 16 | session1 = UserSession(storage, self.PROVIDER_NAME) 17 | session1.update() 18 | assert session1.is_authenticated() is True 19 | assert session1.current_provider == self.PROVIDER_NAME 20 | 21 | session2 = UserSession(storage, self.PROVIDER_NAME) 22 | assert session2.is_authenticated() is True 23 | assert session2.current_provider == self.PROVIDER_NAME 24 | 25 | session3 = UserSession(storage) 26 | assert session3.is_authenticated() is True 27 | assert session3.current_provider == self.PROVIDER_NAME 28 | 29 | def test_initialising_session_with_new_provider_name_should_reset_session(self): 30 | storage = {} 31 | session1 = UserSession(storage, 'provider1') 32 | session1.update() 33 | assert session1.is_authenticated() is True 34 | session2 = UserSession(storage, 'provider2') 35 | assert session2.is_authenticated() is False 36 | 37 | def test_unauthenticated_session(self): 38 | assert self.initialised_session({}).is_authenticated() is False 39 | 40 | def test_authenticated_session(self): 41 | assert self.initialised_session({'last_authenticated': 1234}).is_authenticated() is True 42 | 43 | def test_should_not_refresh_if_not_supported(self): 44 | assert self.initialised_session({}).should_refresh() is False 45 | 46 | def test_should_not_refresh_if_authenticated_within_refresh_interval(self): 47 | refresh_interval = 10 48 | session = self.initialised_session({'last_session_refresh': time.time() + (refresh_interval - 1)}) 49 | assert session.should_refresh(refresh_interval) is False 50 | 51 | def test_should_refresh_if_supported_and_necessary(self): 52 | refresh_interval = 10 53 | # authenticated too far in the past 54 | session_storage = {'last_session_refresh': time.time() - (refresh_interval + 1)} 55 | assert self.initialised_session(session_storage).should_refresh(refresh_interval) is True 56 | 57 | def test_should_not_refresh_if_not_previously_authenticated(self): 58 | assert self.initialised_session({}).should_refresh(10) is False 59 | 60 | @pytest.mark.parametrize('data', [ 61 | {'access_token': 'test_access_token'}, 62 | {'id_token': {'iss': 'issuer1', 'sub': 'user1', 'aud': 'client1', 'exp': 1235, 'iat': 1234}}, 63 | {'id_token_jwt': 'eyJh.eyJz.SflK'}, 64 | {'userinfo': {'sub': 'user1', 'name': 'Test User'}}, 65 | ]) 66 | @patch('time.time') 67 | def test_update(self, time_mock, data): 68 | storage = {} 69 | auth_time = 1234 70 | time_mock.return_value = auth_time 71 | 72 | self.initialised_session(storage).update(**data) 73 | 74 | expected_session_data = { 75 | 'last_authenticated': auth_time, 76 | 'last_session_refresh': auth_time, 77 | 'current_provider': self.PROVIDER_NAME 78 | } 79 | expected_session_data.update(**data) 80 | assert storage == expected_session_data 81 | 82 | def test_update_should_use_auth_time_from_id_token_if_it_exists(self): 83 | auth_time = 1234 84 | session = self.initialised_session({}) 85 | session.update(id_token={'auth_time': auth_time}) 86 | assert session.last_authenticated == auth_time 87 | 88 | @patch('time.time') 89 | def test_update_should_update_last_session_refresh_timestamp(self, time_mock): 90 | now_timestamp = 1234 91 | time_mock.return_value = now_timestamp 92 | data = {} 93 | session = self.initialised_session(data) 94 | session.update() 95 | assert data['last_session_refresh'] == now_timestamp 96 | 97 | def test_trying_to_update_uninitialised_session_should_throw_exception(self): 98 | with pytest.raises(UninitialisedSession): 99 | UserSession(session_storage={}).update() 100 | 101 | def test_clear(self): 102 | expected_data = {'initial data': 'should remain'} 103 | session_storage = expected_data.copy() 104 | 105 | session = self.initialised_session(session_storage) 106 | session.update(access_token='access_token', expires_in=3600, id_token={'sub': 'user1'}, id_token_jwt='eyJh.eyJz.SflK', userinfo={'sub': 'user1}'}, refresh_token='refresh_token') 107 | session.clear() 108 | 109 | assert session_storage == expected_data 110 | 111 | def test_access_token_expiry(self): 112 | session = self.initialised_session({}) 113 | expires_in = 3600 114 | session.update(expires_in=expires_in) 115 | assert session.access_token_expires_at == int(time.time()) + expires_in 116 | -------------------------------------------------------------------------------- /src/flask_pyoidc/auth_response_handler.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | 4 | from oic.exception import PyoidcError 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | AuthenticationResult = collections.namedtuple('AuthenticationResult', 9 | [ 10 | 'access_token', 11 | 'expires_in', 12 | 'id_token_claims', 13 | 'id_token_jwt', 14 | 'userinfo_claims', 15 | 'refresh_token' 16 | ]) 17 | 18 | 19 | class AuthResponseProcessError(ValueError): 20 | pass 21 | 22 | 23 | class AuthResponseUnexpectedStateError(AuthResponseProcessError): 24 | pass 25 | 26 | 27 | class InvalidIdTokenError(AuthResponseProcessError): 28 | pass 29 | 30 | 31 | class AuthResponseMismatchingSubjectError(AuthResponseProcessError): 32 | pass 33 | 34 | 35 | class AuthResponseErrorResponseError(AuthResponseProcessError): 36 | def __init__(self, error_response): 37 | """ 38 | Args: 39 | error_response (Mapping[str, str]): OAuth error response containing 'error' and 'error_description' 40 | """ 41 | self.error_response = error_response 42 | 43 | 44 | class AuthResponseHandler: 45 | def __init__(self, client): 46 | """ 47 | Args: 48 | client (flask_pyoidc.pyoidc_facade.PyoidcFacade): Client proxy to make requests to the provider 49 | """ 50 | self._client = client 51 | 52 | def process_auth_response(self, auth_response, auth_request, extra_token_args): 53 | """ 54 | Args: 55 | auth_response (Union[AuthorizationResponse, AuthorizationErrorResponse]): parsed OIDC auth response 56 | auth_request (Mapping[str, str]): original OIDC auth request 57 | extra_token_args (Mapping[str, Any]): extra arguments to pass to pyoidc 58 | Returns: 59 | AuthenticationResult: All relevant data associated with the authenticated user 60 | """ 61 | if 'error' in auth_response: 62 | raise AuthResponseErrorResponseError(auth_response.to_dict()) 63 | 64 | if auth_response['state'] != auth_request['state']: 65 | raise AuthResponseUnexpectedStateError() 66 | 67 | # implicit/hybrid flow may return tokens in the auth response 68 | access_token = auth_response.get('access_token', None) 69 | expires_in = auth_response.get('expires_in', None) 70 | id_token_claims = auth_response['id_token'].to_dict() if 'id_token' in auth_response else None 71 | id_token_jwt = auth_response.get('id_token_jwt', None) 72 | refresh_token = None # but never refresh token 73 | 74 | if 'code' in auth_response: 75 | token_resp = self._client.exchange_authorization_code(auth_response['code'], 76 | auth_response['state'], 77 | extra_token_args) 78 | if token_resp: 79 | if 'error' in token_resp: 80 | raise AuthResponseErrorResponseError(token_resp.to_dict()) 81 | 82 | access_token = token_resp['access_token'] 83 | expires_in = token_resp.get('expires_in', None) 84 | refresh_token = token_resp.get('refresh_token', None) 85 | 86 | if 'id_token' in token_resp: 87 | id_token = token_resp['id_token'] 88 | logger.debug('received id token: %s', id_token.to_json()) 89 | 90 | try: 91 | self._client.verify_id_token(id_token, auth_request) 92 | except PyoidcError as ex: 93 | raise InvalidIdTokenError(str(ex)) from ex 94 | 95 | id_token_claims = id_token.to_dict() 96 | id_token_jwt = token_resp.get('id_token_jwt') 97 | 98 | # do userinfo request 99 | userinfo = self._client.userinfo_request(access_token) 100 | userinfo_claims = None 101 | if userinfo: 102 | userinfo_claims = userinfo.to_dict() 103 | 104 | if id_token_claims and userinfo_claims and userinfo_claims['sub'] != id_token_claims['sub']: 105 | raise AuthResponseMismatchingSubjectError('The \'sub\' of userinfo does not match \'sub\' of ID Token.') 106 | 107 | return AuthenticationResult(access_token, 108 | expires_in, 109 | id_token_claims, 110 | id_token_jwt, 111 | userinfo_claims, 112 | refresh_token) 113 | 114 | @classmethod 115 | def expect_fragment_encoded_response(cls, auth_request): 116 | if 'response_mode' in auth_request: 117 | return auth_request['response_mode'] == 'fragment' 118 | 119 | response_type = set(auth_request['response_type'].split(' ')) 120 | is_implicit_flow = response_type == {'id_token'} or \ 121 | response_type == {'id_token', 'token'} 122 | is_hybrid_flow = response_type == {'code', 'id_token'} or \ 123 | response_type == {'code', 'token'} or \ 124 | response_type == {'code', 'id_token', 'token'} 125 | 126 | return is_implicit_flow or is_hybrid_flow 127 | -------------------------------------------------------------------------------- /docs/configuration.md: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | Both static and dynamic provider configuration discovery, as well as static and dynamic client registration, is 4 | supported. The different modes of provider configuration can be combined with any of the client registration modes. 5 | 6 | ## Client Configuration 7 | 8 | ### Static Client Registration 9 | 10 | If you have already registered a client with the provider, specify the client credentials directly: 11 | ```python 12 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ClientMetadata 13 | 14 | client_metadata = ClientMetadata(client_id='client1', client_secret='secret1') 15 | provider_config = ProviderConfiguration(client_metadata=client_metadata, [provider_configuration]) 16 | ``` 17 | 18 | **Note: The redirect URIs registered with the provider MUST include the URI specified in 19 | [`OIDC_REDIRECT_URI`](#flask-configuration).** 20 | 21 | 22 | ### Dynamic Client Registration 23 | 24 | To dynamically register a new client for your application, the required client registration info can be specified: 25 | 26 | ```python 27 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ClientRegistrationInfo 28 | 29 | client_registration_info = ClientRegistrationInfo(client_name='Test App', contacts=['dev@example.com'], 30 | redirect_uris=['https://client.example.com/redirect', 31 | 'https://client.example.com/redirect2'], 32 | post_logout_redirect_uris=['https://client.example.com/logout', 33 | 'https://client.example.com/logout2] 34 | registration_token='initial_access_token') 35 | provider_config = ProviderConfiguration(client_registration_info=client_registration_info, [provider_configuration]) 36 | ``` 37 | 38 | **Note: To register all `redirect_uris` and `post_logout_redirect_uris` with the provider, 39 | you must provide them as a list in their respective keyword arguments.** 40 | 41 | Identity Providers support two ways how new clients can be registered through Dynamic Client Registration: 42 | 43 | 1. Authenticated requests - the registration request must contain an "initial access token" obtained from your 44 | identity provider. 45 | If you want to use this method then you must provide `registration_token` keyword argument to `ClientRegistrationInfo`. 46 | 47 | 2. Anonymous requests - the registration request doesn't need to contain any token. 48 | 49 | You can set any Client Metadata parameters for `ClientRegistrationInfo` during the registration. For a complete list of 50 | keyword arguments, see [Client Metadata](https://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata). 51 | Also refer to the 52 | [Client Registration Request example](https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest). 53 | 54 | ## Provider configuration 55 | 56 | ### Dynamic provider configuration 57 | 58 | To use a provider which supports dynamic discovery it suffices to specify the issuer URL: 59 | ```python 60 | from flask_pyoidc.provider_configuration import ProviderConfiguration 61 | 62 | provider_config = ProviderConfiguration(issuer='https://idp.example.com', [client configuration]) 63 | ``` 64 | 65 | ### Static provider configuration 66 | 67 | To use a provider not supporting dynamic discovery, the static provider metadata can be specified: 68 | ```python 69 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ProviderMetadata 70 | 71 | provider_metadata = ProviderMetadata(issuer='https://idp.example.com', 72 | authorization_endpoint='https://idp.example.com/auth', 73 | token_endpoint='https://idp.example.com/token', 74 | introspection_endpoint='https://idp.example.com/introspect', 75 | userinfo_endpoint='https://idp.example.com/userinfo', 76 | end_session_endpoint='https://idp.example.com/logout', 77 | jwks_uri='https://idp.example.com/certs', 78 | registration_endpoint='https://idp.example.com/registration') 79 | provider_config = ProviderConfiguration(provider_metadata=provider_metadata, [client configuration]) 80 | ``` 81 | 82 | See the OpenID Connect specification for more information about the 83 | [provider metadata](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata). 84 | 85 | As mentioned in OpenID Connect specification, `userinfo_endpoint` is optional. If it's not provided, no userinfo 86 | request will be done and `flask_pyoidc.UserSession.userinfo` will be set to `None`. 87 | 88 | ### Customizing authentication request parameters 89 | To customize the [authentication request parameters](https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest), 90 | use `auth_request_params` in `ProviderConfiguration`: 91 | ```python 92 | auth_params = {'scope': ['openid', 'profile']} # specify the scope to request 93 | provider_config = ProviderConfiguration([provider/client config], auth_request_params=auth_params) 94 | ``` 95 | 96 | ### Session refresh 97 | 98 | If your provider supports the `prompt=none` authentication request parameter, this extension can automatically refresh 99 | user sessions. This ensures that the user attributes (OIDC claims, user being active, etc.) are kept up-to-date without 100 | having to log the user out and back in. To enable and configure the feature, specify the interval (in seconds) between 101 | refreshes: 102 | ```python 103 | from flask_pyoidc.provider_configuration import ProviderConfiguration 104 | 105 | provier_config = ProviderConfiguration(session_refresh_interval_seconds=1800, [provider/client config]) 106 | ``` 107 | 108 | **Note: The user will still be logged out when the session expires (as set in the Flask session configuration).** 109 | 110 | ## Flask configuration 111 | 112 | The application using this extension **MUST** set the following configuration parameters: 113 | 114 | * `SECRET_KEY`: This extension relies on [Flask sessions](https://flask.palletsprojects.com/en/2.0.x/quickstart/#sessions), which 115 | requires [`SECRET_KEY`](https://flask.palletsprojects.com/en/2.0.x/config/#builtin-configuration-values). 116 | * `OIDC_REDIRECT_URI`: The URI used as redirect URI to receive authentication responses. This extension will add a url 117 | rule to handle all requests to the specified endpoint, so make sure the domain correctly points to your app and that 118 | the URL path is not already used in the app. 119 | 120 | This extension also uses the following configuration parameters: 121 | * `OIDC_SESSION_PERMANENT`: If set to `True` (which is the default) the user session will be kept until the configured 122 | session lifetime (see below). If set to `False` the session will be deleted when the user closes the browser. 123 | * `PERMANENT_SESSION_LIFETIME`: Control how long a user session is valid, see 124 | [Flask documentation](https://flask.palletsprojects.com/en/2.0.x/config/#PERMANENT_SESSION_LIFETIME) for more information. 125 | * `OIDC_CLOCK_SKEW`: Number of seconds of clock skew allowed when checking the “don’t use before” and “don’t use after” values for tokens. 126 | 127 | ### Legacy configuration parameters 128 | The following parameters have been deprecated: 129 | * `OIDC_REDIRECT_DOMAIN`: Set the domain (which may contain port number) used in the redirect_uri to receive 130 | authentication responses. Defaults to the `SERVER_NAME` configured for Flask. 131 | * `OIDC_REDIRECT_ENDPOINT`: Set the endpoint used in the redirect_uri to receive authentication responses. Defaults to 132 | `redirect_uri`, meaning the URL `/redirect_uri` needs to be registered with the provider(s). 133 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | To add authentication to one of your endpoints create the `OIDCAuthentication` object: 4 | 5 | ```python 6 | from flask import Flask 7 | 8 | from flask_pyoidc import OIDCAuthentication 9 | from flask_pyoidc.provider_configuration import ProviderConfiguration 10 | 11 | app = Flask(__name__) 12 | app.config.update( 13 | OIDC_REDIRECT_URI = 'https://example.com/redirect_uri', 14 | SECRET_KEY = ... 15 | ) 16 | 17 | # Static Client Registration 18 | # If you have already registered a client with the provider, specify the client 19 | # credentials directly: 20 | client_metadata = ClientMetadata( 21 | client_id='client1', 22 | client_secret='secret1', 23 | post_logout_redirect_uris=['https://example.com/logout']) 24 | 25 | 26 | provider_config = ProviderConfiguration(issuer='', 27 | client_metadata=client_metadata) 28 | 29 | auth = OIDCAuthentication({'default': provider_config}, app) 30 | ``` 31 | 32 | You can also use a Flask application factory: 33 | ```python 34 | config = ProviderConfiguration(...) 35 | auth = OIDCAuthentication({'default': config}) 36 | 37 | def create_app(): 38 | app = Flask(__name__) 39 | app.config.update( 40 | OIDC_REDIRECT_URI = 'https://example.com/redirect_uri', 41 | SECRET_KEY = ... 42 | ) 43 | auth.init_app(app) 44 | return app 45 | ``` 46 | 47 | ## OpenID Connect 48 | 49 | To add user authentication via an OpenID Connect provider to your endpoints use the `oidc_auth` decorator: 50 | ```python 51 | import flask 52 | from flask import jsonify 53 | from flask_pyoidc.user_session import UserSession 54 | 55 | @app.route('/') 56 | @auth.oidc_auth('default') 57 | def index(): 58 | user_session = UserSession(flask.session) 59 | return jsonify(access_token=user_session.access_token, 60 | id_token=user_session.id_token, 61 | userinfo=user_session.userinfo) 62 | ``` 63 | 64 | After a successful login, this extension will place three things in the user session (if they are received from the 65 | provider): 66 | * [ID Token](http://openid.net/specs/openid-connect-core-1_0.html#IDToken) 67 | * [Access Token](http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse) 68 | * [Userinfo Response](http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse) 69 | 70 | In addition to this documentation, you may have a look on a 71 | [code example](https://github.com/zamzterz/Flask-pyoidc/tree/master/example). 72 | 73 | ## Token-Based Authorization 74 | 75 | To add token-based authorization to your endpoints use the `token_auth` 76 | decorator. It authorizes requests to your endpoint with Bearer tokens in 77 | the `Authorization` header. 78 | 79 | The token-based authorization is backed by 80 | [token introspection](https://datatracker.ietf.org/doc/html/rfc7662) 81 | so make sure the Identity Provider's introspection endpoint is configured. 82 | ```python 83 | provider_metadata = ProviderMetadata( 84 | ..., 85 | introspection_endpoint='https://idp.example.com/introspect', 86 | ...) 87 | 88 | @app.route('/api') 89 | @auth.token_auth('default') 90 | def api(): 91 | current_token_identity = auth.current_token_identity 92 | ... 93 | 94 | # Optionally, you can specify scopes required by your endpoint. 95 | @app.route('/api') 96 | @auth.token_auth('default', 97 | scopes_required=['read', 'write']) 98 | def api(): 99 | current_token_identity = auth.current_token_identity 100 | ... 101 | ``` 102 | To obtain information about the token, use `auth.current_token_identity` inside 103 | your view function. `current_token_identity` persists for current request only. 104 | 105 | ## Combined Authorization 106 | 107 | If you want to apply both OIDC-based authentication and token-based 108 | authorization on your endpoint, you can use `access_control` decorator. 109 | Then the request will first be checked for a valid access token (in the `Authorization` header). 110 | If there is no token in the request it will fall back to the OIDC-based authentication mechanism. 111 | 112 | If there is a token in the request but it's invalid (e.g. expired, or missing required scopes) the request will 113 | be rejected with a `403 Forbidden` response. 114 | 115 | 116 | ```python 117 | # If you are using Static Provider Configuration, add introspection_endpoint 118 | # in ProviderMetadata. 119 | provider_metadata = ProviderMetadata( 120 | ..., 121 | introspection_endpoint='https://idp.example.com/introspect', 122 | ...) 123 | 124 | @app.route('/api') 125 | @auth.access_control('default') 126 | def api(): 127 | current_identity = None 128 | if auth.current_token_identity: 129 | current_identity = auth.current_token_identity 130 | else: 131 | current_identity = UserSession(flask.session) 132 | ... 133 | ``` 134 | 135 | ## Using multiple providers 136 | 137 | To allow users to login with multiple different providers, configure all of them in the `OIDCAuthentication` 138 | constructor and specify which one to use by name for each endpoint: 139 | ```python 140 | auth = OIDCAuthentication({'provider1': ProviderConfiguration(...), 'provider2': ProviderConfiguration(...)}, app) 141 | 142 | @app.route('/login1') 143 | @auth.oidc_auth('provider1') 144 | def login1(): 145 | ... 146 | 147 | @app.route('/login2') 148 | @auth.oidc_auth('provider2') 149 | def login2(): 150 | ... 151 | ``` 152 | 153 | ## User logout 154 | 155 | To support user logout, use the `oidc_logout` decorator: 156 | ```python 157 | @app.route('/logout') 158 | @auth.oidc_logout 159 | def logout(): 160 | return "You've been successfully logged out!" 161 | ``` 162 | 163 | If the logout view is mounted under a custom endpoint (other than the default, which is 164 | [the name of the view function](https://flask.palletsprojects.com/en/2.0.x/api/#flask.Flask.route)), or if using Blueprints, you 165 | must specify the full URL in the Flask-pyoidc configuration using `post_logout_redirect_uris`: 166 | ```python 167 | ClientMetadata(..., post_logout_redirect_uris=['https://example.com/post_logout']) # if using static client registration 168 | ClientRegistrationInfo(..., post_logout_redirect_uris=['https://example.com/post_logout']) # if using dynamic client registration 169 | ``` 170 | 171 | This extension also supports [RP-Initiated Logout](http://openid.net/specs/openid-connect-session-1_0.html#RPLogout), 172 | if the provider allows it. Make sure the `end_session_endpoint` is defined in the provider metadata to enable notifying 173 | the provider when the user logs out. 174 | 175 | ## Refreshing the user access token 176 | 177 | If the provider returns a refresh token, this extension can use it to automatically refresh the access token when it 178 | has expired. Please see the helper method `OIDCAuthentication.valid_access_token()`. 179 | 180 | ## Specify the error view 181 | 182 | If an OAuth error response is received, either in the authentication or token response, it will be passed to the 183 | "error view", specified using the `error_view` decorator: 184 | 185 | ```python 186 | from flask import jsonify 187 | 188 | @auth.error_view 189 | def error(error=None, error_description=None): 190 | return jsonify({'error': error, 'message': error_description}) 191 | ``` 192 | 193 | The function specified as the error view MUST accept two parameters, `error` and `error_description`, which corresponds 194 | to the [OIDC error parameters](http://openid.net/specs/openid-connect-core-1_0.html#AuthError), and return the content 195 | that should be displayed to the user. 196 | 197 | If no error view is specified, a generic error message will be displayed to the user. 198 | 199 | ## Client Credentials Flow 200 | The [Client Credentials](https://tools.ietf.org/html/rfc6749#section-4.4) grant type can be used to obtain an 201 | access token for your service (outside the context of a user). 202 | 203 | You can obtain such an access token by using the `client_credentials_grant` method: 204 | 205 | ```python 206 | token_response = auth.clients['default'].client_credentials_grant() 207 | access_token = token_response.get('access_token') 208 | 209 | # Optionally, you can specify scopes for the access token. 210 | auth.clients['default'].client_credentials_grant( 211 | scope=['read', 'write']) 212 | # You can also specify extra keyword arguments to client credentials flow. 213 | auth.clients['default'].client_credentials_grant( 214 | scope=['read', 'write'], audience=['client_id1', 'client_id2']) 215 | ``` 216 | -------------------------------------------------------------------------------- /src/flask_pyoidc/provider_configuration.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import logging 3 | 4 | import requests 5 | from oic.oic import Client 6 | from oic.utils.settings import ClientSettings 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class OIDCData(collections.abc.MutableMapping): 12 | """ 13 | Basic OIDC data representation providing validation of required fields. 14 | """ 15 | 16 | def __init__(self, *args, **kwargs): 17 | """ 18 | Args: 19 | args (List[Tuple[String, String]]): key-value pairs to store 20 | kwargs (Dict[string, string]): key-value pairs to store 21 | """ 22 | self.store = {} 23 | self.update(dict(*args, **kwargs)) 24 | 25 | def __getitem__(self, key): 26 | return self.store[key] 27 | 28 | def __setitem__(self, key, value): 29 | self.store[key] = value 30 | 31 | def __delitem__(self, key): 32 | del self.store[key] 33 | 34 | def __iter__(self): 35 | return iter(self.store) 36 | 37 | def __len__(self): 38 | return len(self.store) 39 | 40 | def __str__(self): 41 | data = self.store.copy() 42 | if 'client_secret' in data: 43 | data['client_secret'] = '' 44 | return str(data) 45 | 46 | def __repr__(self): 47 | return str(self.store) 48 | 49 | def __bool__(self): 50 | return True 51 | 52 | def copy(self, **kwargs): 53 | values = self.to_dict() 54 | values.update(kwargs) 55 | return self.__class__(**values) 56 | 57 | def to_dict(self): 58 | return self.store.copy() 59 | 60 | 61 | class ProviderMetadata(OIDCData): 62 | 63 | def __init__(self, 64 | issuer=None, 65 | authorization_endpoint=None, 66 | jwks_uri=None, 67 | token_endpoint=None, 68 | userinfo_endpoint=None, 69 | introspection_endpoint=None, 70 | registration_endpoint=None, 71 | **kwargs): 72 | """OpenID Providers have metadata describing their configuration. 73 | 74 | Parameters 75 | ---------- 76 | issuer: str, Optional 77 | OP Issuer Identifier. 78 | authorization_endpoint: str, Optional 79 | URL of the OP's OAuth 2.0 Authorization Endpoint. 80 | jwks_uri: str, Optional 81 | URL of the OP's JSON Web Key Set [JWK] document. 82 | token_endpoint: str, Optional 83 | URL of the OP's OAuth 2.0 Token Endpoint. 84 | userinfo_endpoint: str, Optional 85 | URL of the OP's UserInfo Endpoint. 86 | introspection_endpoint: str, Optional 87 | URL of the OP's token introspection endpoint. 88 | registration_endpoint: str, Optional 89 | URL of the OP's Dynamic Client Registration Endpoint. 90 | **kwargs : dict, Optional 91 | Extra arguments to [OpenID Provider Metadata](https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata) 92 | """ 93 | super().__init__(issuer=issuer, authorization_endpoint=authorization_endpoint, 94 | token_endpoint=token_endpoint, userinfo_endpoint=userinfo_endpoint, 95 | jwks_uri=jwks_uri, introspection_endpoint=introspection_endpoint, 96 | registration_endpoint=registration_endpoint, **kwargs) 97 | 98 | 99 | class ClientRegistrationInfo(OIDCData): 100 | pass 101 | 102 | 103 | class ClientMetadata(OIDCData): 104 | def __init__(self, client_id=None, client_secret=None, **kwargs): 105 | """ 106 | Args: 107 | client_id (str): client identifier representing the client 108 | client_secret (str): client secret to authenticate the client with 109 | the OP 110 | kwargs (dict): key-value pairs 111 | """ 112 | super().__init__(client_id=client_id, client_secret=client_secret, **kwargs) 113 | 114 | 115 | class ProviderConfiguration: 116 | """ 117 | Metadata for communicating with an OpenID Connect Provider (OP). 118 | 119 | Attributes: 120 | auth_request_params (dict): Extra parameters, as key-value pairs, to include in the query parameters 121 | of the authentication request 122 | registered_client_metadata (ClientMetadata): The client metadata registered with the provider. 123 | requests_session (requests.Session): Requests object to use when communicating with the provider. 124 | session_refresh_interval_seconds (int): Number of seconds between updates of user data (tokens, user data, etc.) 125 | fetched from the provider. If `None` is specified, no silent updates should be made user data will be made. 126 | userinfo_endpoint_method (str): HTTP method ("GET" or "POST") to use when making the UserInfo Request. If 127 | `None` is specified, no UserInfo Request will be made. 128 | """ 129 | 130 | DEFAULT_REQUEST_TIMEOUT = 5 131 | 132 | def __init__(self, 133 | issuer=None, 134 | provider_metadata=None, 135 | userinfo_http_method='GET', 136 | client_registration_info=None, 137 | client_metadata=None, 138 | auth_request_params=None, 139 | session_refresh_interval_seconds=None, 140 | requests_session=None): 141 | """ 142 | Args: 143 | issuer (str): OP Issuer Identifier. If this is specified discovery will be used to fetch the provider 144 | metadata, otherwise `provider_metadata` must be specified. 145 | provider_metadata (ProviderMetadata): OP metadata, 146 | userinfo_http_method (Optional[str]): HTTP method (GET or POST) to use when sending the UserInfo Request. 147 | If `none` is specified, no userinfo request will be sent. 148 | client_registration_info (ClientRegistrationInfo): Client metadata to register your app 149 | dynamically with the provider. Either this or `registered_client_metadata` must be specified. 150 | client_metadata (ClientMetadata): Client metadata if your app is statically 151 | registered with the provider. Either this or `client_registration_info` must be specified. 152 | auth_request_params (dict): Extra parameters that should be included in the authentication request. 153 | session_refresh_interval_seconds (int): Length of interval (in seconds) between attempted user data 154 | refreshes. 155 | requests_session (requests.Session): custom requests object to allow for example retry handling, etc. 156 | """ 157 | 158 | if not issuer and not provider_metadata: 159 | raise ValueError("Specify either 'issuer' or 'provider_metadata'.") 160 | 161 | if not client_registration_info and not client_metadata: 162 | raise ValueError("Specify either 'client_registration_info' or 'client_metadata'.") 163 | 164 | self._issuer = issuer 165 | self._provider_metadata = provider_metadata 166 | 167 | self._client_registration_info = client_registration_info 168 | self._client_metadata = client_metadata 169 | 170 | self.userinfo_endpoint_method = userinfo_http_method 171 | self.auth_request_params = auth_request_params or {} 172 | self.session_refresh_interval_seconds = session_refresh_interval_seconds 173 | # For session persistence 174 | self.client_settings = ClientSettings(timeout=self.DEFAULT_REQUEST_TIMEOUT, 175 | requests_session=requests_session or requests.Session()) 176 | 177 | def ensure_provider_metadata(self, client: Client): 178 | if not self._provider_metadata: 179 | resp = client.provider_config(self._issuer) 180 | logger.debug(f'Received discovery response: {resp.to_dict()}') 181 | 182 | self._provider_metadata = ProviderMetadata(**resp.to_dict()) 183 | 184 | return self._provider_metadata 185 | 186 | @property 187 | def registered_client_metadata(self): 188 | return self._client_metadata 189 | 190 | def register_client(self, client: Client): 191 | 192 | if not self._client_metadata: 193 | if not self._provider_metadata['registration_endpoint']: 194 | raise ValueError("Can't use dynamic client registration, provider metadata is missing " 195 | "'registration_endpoint'.") 196 | 197 | registration_request = self._client_registration_info.to_dict() 198 | 199 | # Send request to register the client dynamically. 200 | registration_response = client.register( 201 | url=self._provider_metadata['registration_endpoint'], 202 | **registration_request) 203 | logger.info('Received registration response.') 204 | self._client_metadata = ClientMetadata( 205 | **registration_response.to_dict()) 206 | 207 | return self._client_metadata 208 | -------------------------------------------------------------------------------- /tests/test_auth_response_handler.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from unittest.mock import create_autospec, MagicMock 3 | 4 | import pytest 5 | from flask_pyoidc.auth_response_handler import AuthResponseHandler, AuthResponseUnexpectedStateError, \ 6 | InvalidIdTokenError, AuthResponseErrorResponseError, AuthResponseMismatchingSubjectError 7 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ProviderMetadata, ClientMetadata 8 | from flask_pyoidc.pyoidc_facade import PyoidcFacade 9 | from oic.oic import AuthorizationResponse, AccessTokenResponse, IdToken, TokenErrorResponse, \ 10 | OpenIDSchema, AuthorizationErrorResponse, AuthorizationRequest 11 | 12 | 13 | def _create_id_token(issuer, client_id, nonce): 14 | id_token = IdToken(**{'iss': issuer, 'sub': 'test_sub', 'aud': client_id, 'nonce': nonce, 'exp': time() + 60}) 15 | id_token.jws_header = {'alg': 'RS256'} 16 | return id_token 17 | 18 | 19 | class TestAuthResponseHandler: 20 | ISSUER = 'https://issuer.example.com' 21 | CLIENT_ID = 'client1' 22 | AUTH_REQUEST = AuthorizationRequest(**{'state': 'test_state', 'nonce': 'test_nonce'}) 23 | AUTH_RESPONSE = AuthorizationResponse(**{'code': 'test_auth_code', 'state': AUTH_REQUEST['state']}) 24 | TOKEN_RESPONSE = AccessTokenResponse(**{ 25 | 'access_token': 'test_token', 26 | 'expires_in': 3600, 27 | 'id_token': _create_id_token(ISSUER, CLIENT_ID, AUTH_REQUEST['nonce']), 28 | 'id_token_jwt': 'test_id_token_jwt', 29 | 'refresh_token': 'test_refresh_token' 30 | }) 31 | USERINFO_RESPONSE = OpenIDSchema(**{'sub': 'test_sub'}) 32 | ERROR_RESPONSE = {'error': 'test_error', 'error_description': 'something went wrong'} 33 | 34 | @pytest.fixture 35 | def client_mock(self): 36 | return create_autospec(PyoidcFacade, True, True) 37 | 38 | def test_should_detect_state_mismatch(self, client_mock): 39 | auth_request = {'state': 'other_state', 'nonce': self.AUTH_REQUEST['nonce']} 40 | with pytest.raises(AuthResponseUnexpectedStateError): 41 | AuthResponseHandler(client_mock).process_auth_response(self.AUTH_RESPONSE, auth_request, {}) 42 | 43 | def test_should_detect_nonce_mismatch(self): 44 | client = PyoidcFacade( 45 | ProviderConfiguration(provider_metadata=ProviderMetadata(issuer=self.ISSUER), 46 | client_metadata=ClientMetadata(client_id=self.CLIENT_ID)), 47 | redirect_uri='https://client.example.com/redirect') 48 | client.exchange_authorization_code = MagicMock(return_value=self.TOKEN_RESPONSE) 49 | auth_request = {'state': self.AUTH_RESPONSE['state'], 'nonce': 'other_nonce'} 50 | with pytest.raises(InvalidIdTokenError): 51 | AuthResponseHandler(client).process_auth_response(self.AUTH_RESPONSE, auth_request, {}) 52 | 53 | def test_should_handle_auth_error_response(self, client_mock): 54 | with pytest.raises(AuthResponseErrorResponseError) as exc: 55 | AuthResponseHandler(client_mock).process_auth_response(AuthorizationErrorResponse(**self.ERROR_RESPONSE), 56 | self.AUTH_REQUEST, 57 | {}) 58 | assert exc.value.error_response == self.ERROR_RESPONSE 59 | 60 | def test_should_handle_token_error_response(self, client_mock): 61 | client_mock.exchange_authorization_code.return_value = TokenErrorResponse(**self.ERROR_RESPONSE) 62 | with pytest.raises(AuthResponseErrorResponseError) as exc: 63 | AuthResponseHandler(client_mock).process_auth_response(AuthorizationResponse(**self.AUTH_RESPONSE), 64 | self.AUTH_REQUEST, 65 | {}) 66 | assert exc.value.error_response == self.ERROR_RESPONSE 67 | 68 | def test_should_detect_mismatching_subject(self, client_mock): 69 | client_mock.exchange_authorization_code.return_value = AccessTokenResponse(**self.TOKEN_RESPONSE) 70 | client_mock.userinfo_request.return_value = OpenIDSchema(**{'sub': 'other_sub'}) 71 | with pytest.raises(AuthResponseMismatchingSubjectError): 72 | AuthResponseHandler(client_mock).process_auth_response(AuthorizationResponse(**self.AUTH_RESPONSE), 73 | self.AUTH_REQUEST, 74 | {}) 75 | 76 | def test_should_handle_auth_response_with_authorization_code(self, client_mock): 77 | client_mock.exchange_authorization_code.return_value = self.TOKEN_RESPONSE 78 | client_mock.userinfo_request.return_value = self.USERINFO_RESPONSE 79 | result = AuthResponseHandler(client_mock).process_auth_response(self.AUTH_RESPONSE, 80 | self.AUTH_REQUEST, 81 | {}) 82 | assert result.access_token == 'test_token' 83 | assert result.expires_in == self.TOKEN_RESPONSE['expires_in'] 84 | assert result.id_token_claims == self.TOKEN_RESPONSE['id_token'].to_dict() 85 | assert result.id_token_jwt == self.TOKEN_RESPONSE['id_token_jwt'] 86 | assert result.userinfo_claims == self.USERINFO_RESPONSE.to_dict() 87 | assert result.refresh_token == self.TOKEN_RESPONSE['refresh_token'] 88 | 89 | def test_should_handle_auth_response_without_authorization_code(self, client_mock): 90 | auth_response = AuthorizationResponse(**self.TOKEN_RESPONSE) 91 | auth_response['state'] = 'test_state' 92 | client_mock.userinfo_request.return_value = self.USERINFO_RESPONSE 93 | result = AuthResponseHandler(client_mock).process_auth_response(auth_response, self.AUTH_REQUEST, {}) 94 | assert not client_mock.exchange_authorization_code.called 95 | assert result.access_token == 'test_token' 96 | assert result.expires_in == self.TOKEN_RESPONSE['expires_in'] 97 | assert result.id_token_jwt == self.TOKEN_RESPONSE['id_token_jwt'] 98 | assert result.id_token_claims == self.TOKEN_RESPONSE['id_token'].to_dict() 99 | assert result.userinfo_claims == self.USERINFO_RESPONSE.to_dict() 100 | assert result.refresh_token is None 101 | 102 | def test_should_handle_token_response_without_id_token(self, client_mock): 103 | token_response = {'access_token': 'test_token'} 104 | client_mock.exchange_authorization_code.return_value = AccessTokenResponse(**token_response) 105 | result = AuthResponseHandler(client_mock).process_auth_response(AuthorizationResponse(**self.AUTH_RESPONSE), 106 | self.AUTH_REQUEST, 107 | {}) 108 | assert result.access_token == 'test_token' 109 | assert result.id_token_claims is None 110 | 111 | def test_should_handle_no_token_response(self, client_mock): 112 | client_mock.exchange_authorization_code.return_value = None 113 | client_mock.userinfo_request.return_value = None 114 | hybrid_auth_response = self.AUTH_RESPONSE.copy() 115 | hybrid_auth_response.update(self.TOKEN_RESPONSE) 116 | result = AuthResponseHandler(client_mock).process_auth_response(AuthorizationResponse(**hybrid_auth_response), 117 | self.AUTH_REQUEST, 118 | {}) 119 | assert result.access_token == 'test_token' 120 | assert result.id_token_claims == self.TOKEN_RESPONSE['id_token'].to_dict() 121 | assert result.id_token_jwt == self.TOKEN_RESPONSE['id_token_jwt'] 122 | 123 | @pytest.mark.parametrize('response_type, expected', [ 124 | ('code', False), # Authorization Code Flow 125 | ('id_token', True), # Implicit Flow 126 | ('id_token token', True), # Implicit Flow 127 | ('code id_token', True), # Hybrid Flow 128 | ('code token', True), # Hybrid Flow 129 | ('code id_token token', True) # Hybrid Flow 130 | ]) 131 | def test_expect_fragment_encoded_response_by_response_type(self, response_type, expected): 132 | assert AuthResponseHandler.expect_fragment_encoded_response({'response_type': response_type}) is expected 133 | 134 | @pytest.mark.parametrize('response_type, response_mode, expected', [ 135 | ('code', 'fragment', True), 136 | ('id_token', 'query', False), 137 | ('code token', 'form_post', False), 138 | ]) 139 | def test_expect_fragment_encoded_response_with_non_default_response_mode(self, 140 | response_type, 141 | response_mode, 142 | expected): 143 | auth_req = {'response_type': response_type, 'response_mode': response_mode} 144 | assert AuthResponseHandler.expect_fragment_encoded_response(auth_req) is expected 145 | -------------------------------------------------------------------------------- /tests/test_provider_configuration.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | import pytest 4 | import responses 5 | from flask_pyoidc.provider_configuration import ProviderConfiguration, ClientRegistrationInfo, ProviderMetadata, \ 6 | ClientMetadata, OIDCData 7 | from oic.oic import Client 8 | from oic.utils.authn.client import CLIENT_AUTHN_METHOD 9 | 10 | 11 | class TestProviderConfiguration: 12 | PROVIDER_BASEURL = 'https://op.example.com' 13 | 14 | @staticmethod 15 | def provider_metadata(**kwargs): 16 | return ProviderMetadata(issuer='', authorization_endpoint='', jwks_uri='', **kwargs) 17 | 18 | def test_missing_provider_metadata_raises_exception(self): 19 | with pytest.raises(ValueError) as exc_info: 20 | ProviderConfiguration(client_registration_info=ClientRegistrationInfo()) 21 | 22 | exc_message = str(exc_info.value) 23 | assert 'issuer' in exc_message 24 | assert 'provider_metadata' in exc_message 25 | 26 | def test_missing_client_metadata_raises_exception(self): 27 | with pytest.raises(ValueError) as exc_info: 28 | ProviderConfiguration(issuer=self.PROVIDER_BASEURL) 29 | 30 | exc_message = str(exc_info.value) 31 | assert 'client_registration_info' in exc_message 32 | assert 'client_metadata' in exc_message 33 | 34 | @responses.activate 35 | def test_should_fetch_provider_metadata_if_not_given(self): 36 | provider_metadata = { 37 | 'issuer': self.PROVIDER_BASEURL, 38 | 'authorization_endpoint': self.PROVIDER_BASEURL + '/auth', 39 | 'jwks_uri': self.PROVIDER_BASEURL + '/jwks' 40 | } 41 | responses.add(responses.GET, 42 | self.PROVIDER_BASEURL + '/.well-known/openid-configuration', 43 | json=provider_metadata) 44 | 45 | provider_config = ProviderConfiguration(issuer=self.PROVIDER_BASEURL, 46 | client_registration_info=ClientRegistrationInfo()) 47 | provider_config.ensure_provider_metadata(Client(CLIENT_AUTHN_METHOD)) 48 | assert provider_config._provider_metadata['issuer'] == self.PROVIDER_BASEURL 49 | assert provider_config._provider_metadata['authorization_endpoint'] == self.PROVIDER_BASEURL + '/auth' 50 | assert provider_config._provider_metadata['jwks_uri'] == self.PROVIDER_BASEURL + '/jwks' 51 | 52 | def test_should_not_fetch_provider_metadata_if_given(self): 53 | provider_metadata = self.provider_metadata() 54 | provider_config = ProviderConfiguration(provider_metadata=provider_metadata, 55 | client_registration_info=ClientRegistrationInfo()) 56 | 57 | provider_config.ensure_provider_metadata(Client(CLIENT_AUTHN_METHOD)) 58 | assert provider_config._provider_metadata == provider_metadata 59 | 60 | @responses.activate 61 | def test_should_register_dynamic_client_if_client_registration_info_is_given(self): 62 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 63 | redirect_uris = ['https://client.example.com/redirect', 64 | 'https://client.example.com/redirect2'] 65 | post_logout_redirect_uris = ['https://client.example.com/logout'] 66 | responses.add(responses.POST, registration_endpoint, json={ 67 | 'client_id': 'client1', 'client_secret': 'secret1', 68 | 'redirect_uris': redirect_uris, 69 | 'post_logout_redirect_uris': post_logout_redirect_uris}) 70 | 71 | provider_config = ProviderConfiguration( 72 | provider_metadata=self.provider_metadata(registration_endpoint=registration_endpoint), 73 | client_registration_info=ClientRegistrationInfo( 74 | redirect_uris=redirect_uris, 75 | post_logout_redirect_uris=post_logout_redirect_uris)) 76 | 77 | provider_config.register_client(Client(CLIENT_AUTHN_METHOD)) 78 | assert provider_config._client_metadata['client_id'] == 'client1' 79 | assert provider_config._client_metadata['client_secret'] == 'secret1' 80 | assert provider_config._client_metadata['redirect_uris'] == redirect_uris 81 | assert provider_config._client_metadata[ 82 | 'post_logout_redirect_uris'] == post_logout_redirect_uris 83 | 84 | def test_should_not_register_dynamic_client_if_client_metadata_is_given(self): 85 | client_metadata = ClientMetadata(client_id='client1', 86 | client_secret='secret1') 87 | provider_config = ProviderConfiguration(provider_metadata=self.provider_metadata(), 88 | client_metadata=client_metadata) 89 | provider_config.register_client(None) 90 | assert provider_config._client_metadata == client_metadata 91 | 92 | def test_should_raise_exception_for_non_registered_client_when_missing_registration_endpoint(self): 93 | provider_config = ProviderConfiguration(provider_metadata=self.provider_metadata(), 94 | client_registration_info=ClientRegistrationInfo()) 95 | assert provider_config._provider_metadata['registration_endpoint'] is None 96 | with pytest.raises(ValueError): 97 | provider_config.register_client(None) 98 | 99 | @responses.activate 100 | def test_register_client_should_register_dynamic_client_if_initial_access_token_present(self): 101 | 102 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 103 | redirect_uris = ['https://client.example.com/redirect', 104 | 'https://client.example.com/redirect2'] 105 | post_logout_redirect_uris = ['https://client.example.com/logout'] 106 | client_registration_response = { 107 | 'client_id': 'client1', 108 | 'client_secret': 'secret1', 109 | 'client_name': 'Test Client', 110 | 'redirect_uris': redirect_uris, 111 | 'post_logout_redirect_uris': post_logout_redirect_uris, 112 | 'registration_client_uri': 'https://op.example.com/register/client1', 113 | 'registration_access_token': 'registration_access_token1' 114 | } 115 | responses.add(responses.POST, registration_endpoint, json=client_registration_response) 116 | provider_config = ProviderConfiguration( 117 | provider_metadata=self.provider_metadata(registration_endpoint=registration_endpoint), 118 | client_registration_info=ClientRegistrationInfo( 119 | client_name='Test Client', 120 | redirect_uris=redirect_uris, 121 | post_logout_redirect_uris=post_logout_redirect_uris, 122 | registration_token='initial_access_token')) 123 | 124 | provider_config.register_client(Client(CLIENT_AUTHN_METHOD)) 125 | 126 | assert provider_config._client_metadata['client_id'] == 'client1' 127 | assert provider_config._client_metadata['client_secret'] == 'secret1' 128 | assert provider_config._client_metadata['client_name'] == 'Test Client' 129 | assert provider_config._client_metadata['registration_client_uri'] == 'https://op.example.com/register/client1' 130 | assert provider_config._client_metadata['registration_access_token'] == 'registration_access_token1' 131 | assert provider_config._client_metadata['redirect_uris'] == redirect_uris 132 | assert provider_config._client_metadata[ 133 | 'post_logout_redirect_uris'] == post_logout_redirect_uris 134 | assert responses.calls[0].request.headers['Authorization'] == \ 135 | f"Bearer {base64.b64encode('initial_access_token'.encode()).decode()}" 136 | 137 | @responses.activate 138 | def test_register_client_should_register_client_even_if_post_logout_redirect_uris_missing(self): 139 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 140 | redirect_uris = ['https://client.example.com/redirect', 141 | 'https://client.example.com/redirect2'] 142 | client_registration_response = { 143 | 'client_id': 'client1', 144 | 'client_secret': 'secret1', 145 | 'client_name': 'Test Client', 146 | 'redirect_uris': redirect_uris, 147 | 'registration_client_uri': 'https://op.example.com/register/client1', 148 | 'registration_access_token': 'registration_access_token1' 149 | } 150 | responses.add(responses.POST, registration_endpoint, json=client_registration_response) 151 | provider_config = ProviderConfiguration( 152 | provider_metadata=self.provider_metadata(registration_endpoint=registration_endpoint), 153 | client_registration_info=ClientRegistrationInfo( 154 | client_name='Test Client', 155 | redirect_uris=redirect_uris, 156 | post_logout_redirect_uris=[], 157 | registration_token='initial_access_token')) 158 | 159 | provider_config.register_client(client=Client(CLIENT_AUTHN_METHOD)) 160 | 161 | assert provider_config._client_metadata['client_id'] == 'client1' 162 | assert provider_config._client_metadata['client_secret'] == 'secret1' 163 | assert provider_config._client_metadata['client_name'] == 'Test Client' 164 | assert provider_config._client_metadata['registration_client_uri'] == 'https://op.example.com/register/client1' 165 | assert provider_config._client_metadata['registration_access_token'] == 'registration_access_token1' 166 | assert provider_config._client_metadata['redirect_uris'] == redirect_uris 167 | assert provider_config._client_metadata.get('post_logout_redirect_uris') is None 168 | 169 | 170 | class TestOIDCData: 171 | def test_client_secret_should_not_be_in_string_representation(self): 172 | client_secret = 'secret123456' 173 | client_metadata = OIDCData(client_id='client1', client_secret=client_secret) 174 | assert client_secret not in str(client_metadata) 175 | assert client_secret in repr(client_metadata) 176 | 177 | def test_copy_should_overwrite_existing_value(self): 178 | data = OIDCData(abc='xyz') 179 | copy_data = data.copy(qwe='rty', abc='123') 180 | assert copy_data == {'abc': '123', 'qwe': 'rty'} 181 | 182 | def test_del_and_len(self): 183 | data = OIDCData(abc='xyz', qwe='rty') 184 | assert len(data) == 2 185 | del data['qwe'] 186 | assert data.to_dict() == {'abc': 'xyz'} 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/flask_pyoidc/pyoidc_facade.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Mapping 3 | 4 | from oic.extension.client import Client as ClientExtension 5 | from oic.extension.message import TokenIntrospectionResponse 6 | from oic.oauth2 import Client as Oauth2Client 7 | from oic.oauth2.message import AccessTokenResponse 8 | from oic.oic import Client 9 | from oic.oic import Token 10 | from oic.oic.message import AuthorizationResponse, ProviderConfigurationResponse, RegistrationResponse 11 | from oic.utils.authn.client import CLIENT_AUTHN_METHOD 12 | 13 | from .message_factory import CCMessageFactory 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class PyoidcFacade: 19 | """ 20 | Wrapper around pyoidc library, coupled with config for a simplified API for flask-pyoidc. 21 | """ 22 | 23 | def __init__(self, provider_configuration, redirect_uri): 24 | """ 25 | Args: 26 | provider_configuration (flask_pyoidc.provider_configuration.ProviderConfiguration) 27 | """ 28 | self._provider_configuration = provider_configuration 29 | self._client = Client(client_authn_method=CLIENT_AUTHN_METHOD, 30 | settings=provider_configuration.client_settings) 31 | # Token Introspection is implemented under extension sub-package of 32 | # the client in pyoidc. 33 | self._client_extension = ClientExtension(client_authn_method=CLIENT_AUTHN_METHOD, 34 | settings=provider_configuration.client_settings) 35 | # Client Credentials Flow is implemented under oauth2 sub-package of 36 | # the client in pyoidc. 37 | self._oauth2_client = Oauth2Client(client_authn_method=CLIENT_AUTHN_METHOD, 38 | message_factory=CCMessageFactory, 39 | settings=self._provider_configuration.client_settings) 40 | 41 | provider_metadata = provider_configuration.ensure_provider_metadata(self._client) 42 | # Should be called explicitly for "Static Provider Registration" to register the issuer. 43 | if not self._client.issuer: 44 | self._client.handle_provider_config(ProviderConfigurationResponse(**provider_metadata.to_dict()), 45 | provider_metadata['issuer']) 46 | 47 | if self._provider_configuration.registered_client_metadata: 48 | client_metadata = self._provider_configuration.registered_client_metadata.to_dict() 49 | client_metadata.update(redirect_uris=[redirect_uri]) 50 | self._store_registration_info(client_metadata) 51 | 52 | self._redirect_uri = redirect_uri 53 | 54 | def _store_registration_info(self, client_metadata): 55 | registration_response = RegistrationResponse(**client_metadata) 56 | self._client.store_registration_info(registration_response) 57 | self._client_extension.store_registration_info(registration_response) 58 | # Set client_id and client_secret for _oauth2_client. This is used 59 | # by Client Credentials Flow. 60 | self._oauth2_client.client_id = registration_response['client_id'] 61 | self._oauth2_client.client_secret = registration_response['client_secret'] 62 | 63 | def is_registered(self): 64 | return bool(self._provider_configuration.registered_client_metadata) 65 | 66 | def register(self): 67 | client_metadata = self._provider_configuration.register_client(self._client) 68 | logger.debug(f'client registration response: {client_metadata}') 69 | self._store_registration_info(client_metadata) 70 | 71 | def authentication_request(self, state, nonce, extra_auth_params): 72 | """ 73 | 74 | Args: 75 | state (str): authentication request parameter 'state' 76 | nonce (str): authentication request parameter 'nonce' 77 | extra_auth_params (Mapping[str, str]): extra authentication request parameters 78 | Returns: 79 | AuthorizationRequest: the authentication request 80 | """ 81 | args = { 82 | 'client_id': self._client.client_id, 83 | 'response_type': 'code', 84 | 'scope': ['openid'], 85 | 'redirect_uri': self._redirect_uri, 86 | 'state': state, 87 | 'nonce': nonce, 88 | } 89 | 90 | args.update(self._provider_configuration.auth_request_params) 91 | args.update(extra_auth_params) 92 | auth_request = self._client.construct_AuthorizationRequest(request_args=args) 93 | logger.debug('sending authentication request: %s', auth_request.to_json()) 94 | 95 | return auth_request 96 | 97 | def login_url(self, auth_request): 98 | """ 99 | Args: 100 | auth_request (AuthorizationRequest): authentication request 101 | Returns: 102 | str: Authentication request as a URL to redirect the user to the provider. 103 | """ 104 | return auth_request.request(self._client.authorization_endpoint) 105 | 106 | def parse_authentication_response(self, response_params): 107 | """ 108 | Parameters 109 | ---------- 110 | response_params: Mapping[str, str] 111 | authentication response parameters. 112 | 113 | Returns 114 | ------- 115 | Union[AuthorizationResponse, AuthorizationErrorResponse] 116 | The parsed authorization response. 117 | """ 118 | auth_resp = self._client.parse_response(AuthorizationResponse, info=response_params, sformat='dict') 119 | if 'id_token' in response_params: 120 | auth_resp['id_token_jwt'] = response_params['id_token'] 121 | return auth_resp 122 | 123 | def exchange_authorization_code(self, authorization_code: str, state: str, extra_token_args: Mapping[str, Any]): 124 | """Requests tokens from an authorization code. 125 | 126 | Parameters 127 | ---------- 128 | authorization_code: str 129 | authorization code issued to client after user authorization 130 | state: str 131 | state is used to keep track of responses to outstanding requests. 132 | extra_token_args: Mapping[str, Any] 133 | extra arguments to pass to pyoidc 134 | 135 | Returns 136 | ------- 137 | Union[AccessTokenResponse, TokenErrorResponse, None] 138 | The parsed token response, or None if no token request was performed. 139 | """ 140 | if not self._client.token_endpoint: 141 | return None 142 | 143 | request_args = { 144 | 'grant_type': 'authorization_code', 145 | 'code': authorization_code, 146 | 'redirect_uri': self._redirect_uri 147 | } 148 | logger.debug('making token request: %s', request_args) 149 | client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', 150 | 'client_secret_basic') 151 | token_response = self._client.do_access_token_request(state=state, 152 | request_args=request_args, 153 | authn_method=client_auth_method, 154 | endpoint=self._client.token_endpoint, 155 | **extra_token_args) 156 | logger.info('Received token response.') 157 | 158 | return token_response 159 | 160 | def verify_id_token(self, id_token, auth_request): 161 | """ 162 | Verifies the ID Token. 163 | 164 | Args: 165 | id_token (Mapping[str, str]): ID token claims 166 | auth_request (Mapping[str, str]): original authentication request parameters to validate against 167 | (nonce, acr_values, max_age, etc.) 168 | 169 | Raises: 170 | PyoidcError: If the ID token is invalid. 171 | 172 | """ 173 | self._client.verify_id_token(id_token, auth_request) 174 | 175 | def refresh_token(self, refresh_token: str): 176 | """Requests new tokens using a refresh token. 177 | 178 | Parameters 179 | ---------- 180 | refresh_token: str 181 | refresh token issued to client after user authorization. 182 | 183 | Returns 184 | ------- 185 | Union[AccessTokenResponse, TokenErrorResponse, None] 186 | The parsed token response, or None if no token request was performed. 187 | """ 188 | request_args = { 189 | 'grant_type': 'refresh_token', 190 | 'refresh_token': refresh_token, 191 | 'redirect_uri': self._redirect_uri 192 | } 193 | client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', 194 | 'client_secret_basic') 195 | return self._client.do_access_token_refresh(request_args=request_args, 196 | authn_method=client_auth_method, 197 | token=Token(resp={'refresh_token': refresh_token}), 198 | endpoint=self._client.token_endpoint 199 | ) 200 | 201 | def userinfo_request(self, access_token: str): 202 | """Retrieves ID token. 203 | 204 | Parameters 205 | ---------- 206 | access_token: str 207 | Bearer access token to use when fetching userinfo. 208 | 209 | Returns 210 | ------- 211 | Union[OpenIDSchema, UserInfoErrorResponse, ErrorResponse, None] 212 | """ 213 | http_method = self._provider_configuration.userinfo_endpoint_method 214 | if not access_token or http_method is None or not self._client.userinfo_endpoint: 215 | return None 216 | 217 | logger.debug('making userinfo request') 218 | userinfo_response = self._client.do_user_info_request(method=http_method, token=access_token) 219 | logger.debug('received userinfo response: %s', userinfo_response) 220 | 221 | return userinfo_response 222 | 223 | def _token_introspection_request(self, access_token: str) -> TokenIntrospectionResponse: 224 | """Make token introspection request. 225 | 226 | Parameters 227 | ---------- 228 | access_token: str 229 | Access token to be validated. 230 | 231 | Returns 232 | ------- 233 | TokenIntrospectionResponse 234 | Response object contains result of the token introspection. 235 | """ 236 | request_args = { 237 | 'token': access_token, 238 | 'token_type_hint': 'access_token' 239 | } 240 | client_auth_method = self._client.registration_response.get('introspection_endpoint_auth_method', 241 | 'client_secret_basic') 242 | logger.info('making token introspection request') 243 | token_introspection_response = self._client_extension.do_token_introspection( 244 | request_args=request_args, authn_method=client_auth_method, endpoint=self._client.introspection_endpoint) 245 | 246 | return token_introspection_response 247 | 248 | def client_credentials_grant(self, scope: list = None, **kwargs) -> AccessTokenResponse: 249 | """Public method to request access_token using client_credentials flow. 250 | This is useful for service to service communication where user-agent is 251 | not available which is required in authorization code flow. Your 252 | service can request access_token in order to access APIs of other 253 | services. 254 | 255 | On API call, token introspection will ensure that only valid token can 256 | be used to access your APIs. 257 | 258 | Parameters 259 | ---------- 260 | scope: list, optional 261 | List of scopes to be requested. 262 | **kwargs : dict, optional 263 | Extra arguments to client credentials flow. 264 | 265 | Returns 266 | ------- 267 | AccessTokenResponse 268 | 269 | Examples 270 | -------- 271 | :: 272 | 273 | auth = OIDCAuthentication({'default': provider_config}, 274 | access_token_required=True) 275 | auth.init_app(app) 276 | auth.clients['default'].client_credentials_grant() 277 | 278 | Optionally, you can specify scopes for the access token. 279 | 280 | :: 281 | 282 | auth.clients['default'].client_credentials_grant( 283 | scope=['read', 'write']) 284 | 285 | You can also specify extra keyword arguments to client credentials flow. 286 | 287 | :: 288 | 289 | auth.clients['default'].client_credentials_grant( 290 | scope=['read', 'write'], audience=['client_id1', 'client_id2']) 291 | """ 292 | request_args = { 293 | 'grant_type': 'client_credentials', 294 | **kwargs 295 | } 296 | if scope: 297 | request_args['scope'] = ' '.join(scope) 298 | client_auth_method = self._client.registration_response.get('token_endpoint_auth_method', 299 | 'client_secret_basic') 300 | access_token = self._oauth2_client.do_access_token_request(request_args=request_args, 301 | authn_method=client_auth_method, 302 | endpoint=self._client.token_endpoint) 303 | return access_token 304 | 305 | @property 306 | def session_refresh_interval_seconds(self): 307 | return self._provider_configuration.session_refresh_interval_seconds 308 | 309 | @property 310 | def provider_end_session_endpoint(self): 311 | provider_metadata = self._provider_configuration.ensure_provider_metadata(self._client) 312 | return provider_metadata.get('end_session_endpoint') 313 | 314 | @property 315 | def post_logout_redirect_uris(self): 316 | return self._client.registration_response.get('post_logout_redirect_uris') 317 | -------------------------------------------------------------------------------- /tests/test_pyoidc_facade.py: -------------------------------------------------------------------------------- 1 | import time 2 | from urllib.parse import parse_qsl 3 | 4 | import pytest 5 | import responses 6 | from oic.oic import (AccessTokenResponse, AuthorizationErrorResponse, 7 | AuthorizationResponse, Grant, OpenIDSchema, 8 | TokenErrorResponse) 9 | 10 | from flask_pyoidc.provider_configuration import (ClientMetadata, 11 | ClientRegistrationInfo, 12 | ProviderConfiguration, 13 | ProviderMetadata) 14 | from flask_pyoidc.pyoidc_facade import PyoidcFacade 15 | 16 | from .util import signed_id_token 17 | 18 | REDIRECT_URI = 'https://rp.example.com/redirect_uri' 19 | 20 | 21 | class TestPyoidcFacade: 22 | PROVIDER_BASEURL = 'https://op.example.com' 23 | PROVIDER_METADATA = ProviderMetadata(issuer=PROVIDER_BASEURL, 24 | authorization_endpoint=PROVIDER_BASEURL + '/auth', 25 | jwks_uri=PROVIDER_BASEURL + '/jwks') 26 | CLIENT_METADATA = ClientMetadata('client1', 'secret1') 27 | 28 | @pytest.mark.parametrize('provider_config', [ 29 | {'issuer': PROVIDER_BASEURL, 'client_registration_info': ClientRegistrationInfo()}, 30 | {'provider_metadata': PROVIDER_METADATA, 'client_metadata': CLIENT_METADATA} 31 | ]) 32 | @responses.activate 33 | def test_should_handle_provider_config_with_static_and_dynamic_provider(self, provider_config): 34 | provider_metadata = { 35 | 'issuer': self.PROVIDER_BASEURL, 36 | 'authorization_endpoint': self.PROVIDER_BASEURL + '/auth', 37 | 'jwks_uri': self.PROVIDER_BASEURL + '/jwks' 38 | } 39 | responses.add(responses.GET, 40 | self.PROVIDER_BASEURL + '/.well-known/openid-configuration', 41 | json=provider_metadata) 42 | 43 | config = ProviderConfiguration(**provider_config) 44 | facade = PyoidcFacade(config, REDIRECT_URI) 45 | assert facade._client.issuer == self.PROVIDER_BASEURL 46 | 47 | def test_registered_client_metadata_is_forwarded_to_pyoidc(self): 48 | config = ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, client_metadata=self.CLIENT_METADATA) 49 | facade = PyoidcFacade(config, REDIRECT_URI) 50 | 51 | expected = { 52 | 'client_id': self.CLIENT_METADATA['client_id'], 53 | 'client_secret': self.CLIENT_METADATA['client_secret'], 54 | 'redirect_uris': [REDIRECT_URI], 55 | 'token_endpoint_auth_method': 'client_secret_basic' 56 | } 57 | assert facade._client.registration_response.to_dict() == expected 58 | 59 | def test_no_registered_client_metadata_is_handled(self): 60 | config = ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 61 | client_registration_info=ClientRegistrationInfo()) 62 | facade = PyoidcFacade(config, REDIRECT_URI) 63 | assert not facade._client.registration_response 64 | 65 | def test_is_registered(self): 66 | unregistered = ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 67 | client_registration_info=ClientRegistrationInfo()) 68 | registered = ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 69 | client_metadata=self.CLIENT_METADATA) 70 | assert PyoidcFacade(unregistered, REDIRECT_URI).is_registered() is False 71 | assert PyoidcFacade(registered, REDIRECT_URI).is_registered() is True 72 | 73 | @responses.activate 74 | def test_register(self): 75 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 76 | redirect_uris = ['https://client.example.com/redirect'] 77 | post_logout_redirect_uris = ['https://client.example.com/logout'] 78 | client_registration_response = { 79 | 'client_id': 'client1', 80 | 'client_secret': 'secret1', 81 | 'client_name': 'Test Client', 82 | 'redirect_uris': redirect_uris, 83 | 'post_logout_redirect_uris': post_logout_redirect_uris 84 | } 85 | responses.add(responses.POST, registration_endpoint, json=client_registration_response) 86 | 87 | provider_metadata = self.PROVIDER_METADATA.copy(registration_endpoint=registration_endpoint) 88 | unregistered = ProviderConfiguration(provider_metadata=provider_metadata, 89 | client_registration_info=ClientRegistrationInfo( 90 | redirect_uris=redirect_uris, 91 | post_logout_redirect_uris=post_logout_redirect_uris 92 | )) 93 | facade = PyoidcFacade(unregistered, REDIRECT_URI) 94 | facade.register() 95 | assert facade.is_registered() is True 96 | 97 | def test_authentication_request(self): 98 | extra_user_auth_params = {'foo': 'bar', 'abc': 'xyz'} 99 | config = ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 100 | client_metadata=self.CLIENT_METADATA, 101 | auth_request_params=extra_user_auth_params) 102 | 103 | state = 'test_state' 104 | nonce = 'test_nonce' 105 | 106 | facade = PyoidcFacade(config, REDIRECT_URI) 107 | extra_lib_auth_params = {'foo': 'baz', 'qwe': 'rty'} 108 | auth_request = facade.authentication_request(state, nonce, extra_lib_auth_params) 109 | 110 | expected_auth_params = { 111 | 'scope': 'openid', 112 | 'response_type': 'code', 113 | 'client_id': self.CLIENT_METADATA['client_id'], 114 | 'redirect_uri': REDIRECT_URI, 115 | 'state': state, 116 | 'nonce': nonce 117 | } 118 | expected_auth_params.update(extra_user_auth_params) 119 | expected_auth_params.update(extra_lib_auth_params) 120 | assert auth_request.to_dict() == expected_auth_params 121 | 122 | def test_parse_authentication_response(self): 123 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 124 | client_metadata=self.CLIENT_METADATA), 125 | REDIRECT_URI) 126 | auth_code = 'auth_code-1234' 127 | state = 'state-1234' 128 | auth_response = AuthorizationResponse(**{'state': state, 'code': auth_code}) 129 | parsed_auth_response = facade.parse_authentication_response(auth_response.to_dict()) 130 | assert isinstance(parsed_auth_response, AuthorizationResponse) 131 | assert parsed_auth_response.to_dict() == auth_response.to_dict() 132 | 133 | def test_parse_authentication_response_handles_error_response(self): 134 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 135 | client_metadata=self.CLIENT_METADATA), 136 | REDIRECT_URI) 137 | error_response = AuthorizationErrorResponse(**{'error': 'invalid_request', 'state': 'state-1234'}) 138 | parsed_auth_response = facade.parse_authentication_response(error_response) 139 | assert isinstance(parsed_auth_response, AuthorizationErrorResponse) 140 | assert parsed_auth_response.to_dict() == error_response.to_dict() 141 | 142 | @responses.activate 143 | def test_parse_authentication_response_preserves_id_token_jwt(self): 144 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 145 | client_metadata=self.CLIENT_METADATA), 146 | REDIRECT_URI) 147 | state = 'state-1234' 148 | now = int(time.time()) 149 | id_token, id_token_signing_key = signed_id_token({ 150 | 'iss': self.PROVIDER_METADATA['issuer'], 151 | 'sub': 'test_sub', 152 | 'aud': 'client1', 153 | 'exp': now + 1, 154 | 'iat': now 155 | }) 156 | responses.add(responses.GET, 157 | self.PROVIDER_METADATA['jwks_uri'], 158 | json={'keys': [id_token_signing_key.serialize()]}) 159 | auth_response = AuthorizationResponse(**{'state': state, 'id_token': id_token}) 160 | parsed_auth_response = facade.parse_authentication_response(auth_response) 161 | assert isinstance(parsed_auth_response, AuthorizationResponse) 162 | assert parsed_auth_response['state'] == state 163 | assert parsed_auth_response['id_token_jwt'] == id_token 164 | 165 | @pytest.mark.parametrize('request_func, expected_token_request', [ 166 | ( 167 | lambda facade: facade.exchange_authorization_code('auth-code', 'test-state', {}), 168 | { 169 | 'grant_type': 'authorization_code', 170 | 'state': 'test-state', 171 | 'redirect_uri': REDIRECT_URI 172 | } 173 | ), 174 | ( 175 | lambda facade: facade.refresh_token('refresh-token'), 176 | { 177 | 'grant_type': 'refresh_token', 178 | 'refresh_token': 'refresh-token', 179 | 'redirect_uri': REDIRECT_URI 180 | } 181 | ) 182 | ]) 183 | @responses.activate 184 | def test_token_request(self, request_func, expected_token_request): 185 | token_endpoint = self.PROVIDER_BASEURL + '/token' 186 | now = int(time.time()) 187 | id_token_claims = { 188 | 'iss': self.PROVIDER_METADATA['issuer'], 189 | 'sub': 'test_user', 190 | 'aud': [self.CLIENT_METADATA['client_id']], 191 | 'exp': now + 1, 192 | 'iat': now, 193 | 'nonce': 'test_nonce' 194 | } 195 | id_token_jwt, id_token_signing_key = signed_id_token(id_token_claims) 196 | token_response = AccessTokenResponse(access_token='test_access_token', 197 | refresh_token='refresh-token', 198 | token_type='Bearer', 199 | id_token=id_token_jwt, 200 | expires_in=now + 1) 201 | 202 | responses.add(responses.POST, token_endpoint, json=token_response.to_dict()) 203 | 204 | provider_metadata = self.PROVIDER_METADATA.copy(token_endpoint=token_endpoint) 205 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, 206 | client_metadata=self.CLIENT_METADATA), 207 | REDIRECT_URI) 208 | grant = Grant(resp=token_response) 209 | grant.grant_expiration_time = now + grant.exp_in 210 | facade._client.grant = {'test-state': grant} 211 | 212 | responses.add(responses.GET, 213 | self.PROVIDER_METADATA['jwks_uri'], 214 | json={'keys': [id_token_signing_key.serialize()]}) 215 | token_response = request_func(facade) 216 | 217 | assert isinstance(token_response, AccessTokenResponse) 218 | expected_token_response = token_response.to_dict() 219 | expected_token_response['id_token'] = id_token_claims 220 | expected_token_response['id_token_jwt'] = id_token_jwt 221 | assert token_response.to_dict() == expected_token_response 222 | 223 | token_request = dict(parse_qsl(responses.calls[0].request.body)) 224 | assert token_request == expected_token_request 225 | 226 | @responses.activate 227 | def test_token_request_handles_error_response(self): 228 | token_endpoint = self.PROVIDER_BASEURL + '/token' 229 | token_response = TokenErrorResponse(error='invalid_request', error_description='test error description') 230 | responses.add(responses.POST, token_endpoint, json=token_response.to_dict(), status=400) 231 | 232 | provider_metadata = self.PROVIDER_METADATA.copy(token_endpoint=token_endpoint) 233 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, 234 | client_metadata=self.CLIENT_METADATA), 235 | REDIRECT_URI) 236 | state = 'test-state' 237 | grant = Grant() 238 | grant.grant_expiration_time = int(time.time()) + grant.exp_in 239 | facade._client.grant = {state: grant} 240 | assert facade.exchange_authorization_code('1234', state, {}) == token_response 241 | 242 | def test_token_request_handles_missing_provider_token_endpoint(self): 243 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 244 | client_metadata=self.CLIENT_METADATA), 245 | REDIRECT_URI) 246 | assert facade.exchange_authorization_code(None, None, {}) is None 247 | 248 | @pytest.mark.parametrize('userinfo_http_method', [ 249 | 'GET', 250 | 'POST' 251 | ]) 252 | @responses.activate 253 | def test_configurable_userinfo_endpoint_method_is_used(self, userinfo_http_method): 254 | userinfo_endpoint = self.PROVIDER_BASEURL + '/userinfo' 255 | userinfo_response = OpenIDSchema(sub='user1') 256 | responses.add(userinfo_http_method, userinfo_endpoint, json=userinfo_response.to_dict()) 257 | 258 | provider_metadata = self.PROVIDER_METADATA.copy(userinfo_endpoint=userinfo_endpoint) 259 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, 260 | client_metadata=self.CLIENT_METADATA, 261 | userinfo_http_method=userinfo_http_method), 262 | REDIRECT_URI) 263 | assert facade.userinfo_request('test_token') == userinfo_response 264 | 265 | def test_no_userinfo_request_is_made_if_no_userinfo_http_method_is_configured(self): 266 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 267 | client_metadata=self.CLIENT_METADATA, 268 | userinfo_http_method=None), 269 | REDIRECT_URI) 270 | assert facade.userinfo_request('test_token') is None 271 | 272 | def test_no_userinfo_request_is_made_if_no_userinfo_endpoint_is_configured(self): 273 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 274 | client_metadata=self.CLIENT_METADATA), 275 | REDIRECT_URI) 276 | assert facade.userinfo_request('test_token') is None 277 | 278 | def test_no_userinfo_request_is_made_if_no_access_token(self): 279 | provider_metadata = self.PROVIDER_METADATA.copy(userinfo_endpoint=self.PROVIDER_BASEURL + '/userinfo') 280 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=provider_metadata, 281 | client_metadata=self.CLIENT_METADATA), 282 | REDIRECT_URI) 283 | assert facade.userinfo_request(None) is None 284 | 285 | @responses.activate 286 | @pytest.mark.parametrize('scope, extra_args', 287 | [(None, {}), 288 | (['read', 'write'], 289 | {'audience': ['client_id1, client_id2']}) 290 | ]) 291 | def test_client_credentials_grant(self, scope, extra_args): 292 | token_endpoint = f'{self.PROVIDER_BASEURL}/token' 293 | provider_metadata = self.PROVIDER_METADATA.copy( 294 | token_endpoint=token_endpoint) 295 | facade = PyoidcFacade( 296 | ProviderConfiguration( 297 | provider_metadata=provider_metadata, 298 | client_metadata=self.CLIENT_METADATA), 299 | REDIRECT_URI) 300 | client_credentials_grant_response = { 301 | 'access_token': 'access_token', 302 | 'expires_in': 60, 303 | 'not-before-policy': 0, 304 | 'refresh_expires_in': 0, 305 | 'scope': 'read write', 306 | 'token_type': 'Bearer'} 307 | responses.add(responses.POST, token_endpoint, 308 | json=client_credentials_grant_response) 309 | assert client_credentials_grant_response == facade.client_credentials_grant( 310 | scope=scope, **extra_args).to_dict() 311 | 312 | def test_post_logout_redirect_uris(self): 313 | post_logout_redirect_uris = ['https://client.example.com/logout'] 314 | client_metadata = self.CLIENT_METADATA.copy( 315 | post_logout_redirect_uris=post_logout_redirect_uris) 316 | facade = PyoidcFacade(ProviderConfiguration(provider_metadata=self.PROVIDER_METADATA, 317 | client_metadata=client_metadata), 318 | REDIRECT_URI) 319 | assert facade.post_logout_redirect_uris == post_logout_redirect_uris 320 | -------------------------------------------------------------------------------- /src/flask_pyoidc/flask_pyoidc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Samuel Gulliksson 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import functools 17 | import json 18 | import logging 19 | import time 20 | from typing import Optional 21 | from urllib.parse import parse_qsl 22 | 23 | import flask 24 | import importlib_resources 25 | from flask import current_app, g 26 | from flask.helpers import url_for 27 | from oic import rndstr 28 | from oic.extension.message import TokenIntrospectionResponse 29 | from oic.oic import AuthorizationRequest 30 | from oic.oic.message import EndSessionRequest 31 | from werkzeug.exceptions import Forbidden, Unauthorized 32 | from werkzeug.local import LocalProxy 33 | from werkzeug.routing import BuildError 34 | from werkzeug.utils import redirect 35 | 36 | from .auth_response_handler import AuthResponseProcessError, AuthResponseHandler, AuthResponseErrorResponseError 37 | from .pyoidc_facade import PyoidcFacade 38 | from .redirect_uri_config import RedirectUriConfig 39 | from .user_session import UninitialisedSession, UserSession 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class OIDCAuthentication: 45 | """ 46 | OIDCAuthentication object for Flask extension. 47 | """ 48 | 49 | def __init__(self, provider_configurations, app=None, 50 | redirect_uri_config=None): 51 | """ 52 | Args: 53 | provider_configurations (Mapping[str, ProviderConfiguration]): 54 | provider configurations by name 55 | app (flask.app.Flask): optional Flask app 56 | redirect_uri_config (RedirectUriConfig): optional redirect URI config to use instead of 57 | 'OIDC_REDIRECT_URI' config parameter. 58 | """ 59 | self._provider_configurations = provider_configurations 60 | 61 | self.clients = None 62 | self._logout_views = [] 63 | self._error_view = None 64 | # current_token_identity proxy to obtain user info whose token was 65 | # passed in the request. It is available until current request only and 66 | # is destroyed between the requests. The value is set by token_auth 67 | # decorator. 68 | self.current_token_identity = LocalProxy(lambda: getattr( 69 | g, 'current_token_identity', None)) 70 | self._redirect_uri_config = redirect_uri_config 71 | 72 | if app: 73 | self.init_app(app) 74 | 75 | def init_app(self, app): 76 | if not self._redirect_uri_config: 77 | self._redirect_uri_config = RedirectUriConfig.from_config(app.config) 78 | 79 | # setup redirect_uri as a flask route 80 | app.add_url_rule('/' + self._redirect_uri_config.endpoint, 81 | self._redirect_uri_config.endpoint, 82 | self._handle_authentication_response, 83 | methods=['GET', 'POST']) 84 | 85 | # dynamically add the Flask redirect uri to the client info 86 | self.clients = { 87 | name: PyoidcFacade(configuration, self._redirect_uri_config.full_uri) 88 | for (name, configuration) in self._provider_configurations.items() 89 | } 90 | 91 | def _get_urls_for_logout_views(self): 92 | try: 93 | return [url_for(view.__name__, _external=True) for view in self._logout_views] 94 | except BuildError: 95 | logger.error('could not build url for logout view, it might be mounted under a custom endpoint') 96 | raise 97 | 98 | def _register_client(self, client): 99 | if not client._provider_configuration._client_registration_info.get('redirect_uris'): 100 | client._provider_configuration._client_registration_info[ 101 | 'redirect_uris'] = [self._redirect_uri_config.full_uri] 102 | post_logout_redirect_uris = client._provider_configuration._client_registration_info.get( 103 | 'post_logout_redirect_uris') 104 | if not post_logout_redirect_uris: 105 | client._provider_configuration._client_registration_info[ 106 | 'post_logout_redirect_uris'] = self._get_urls_for_logout_views() 107 | logger.debug( 108 | f'''registering with post_logout_redirect_uris = { 109 | client._provider_configuration._client_registration_info[ 110 | 'post_logout_redirect_uris']}''') 111 | client.register() 112 | 113 | def _authenticate(self, client, interactive=True): 114 | if not client.is_registered(): 115 | self._register_client(client) 116 | 117 | flask.session['destination'] = flask.request.full_path 118 | 119 | # Use silent authentication for session refresh 120 | # This will not show login prompt to the user 121 | extra_auth_params = {} 122 | if not interactive: 123 | extra_auth_params['prompt'] = 'none' 124 | 125 | auth_req = client.authentication_request(state=rndstr(), 126 | nonce=rndstr(), 127 | extra_auth_params=extra_auth_params) 128 | flask.session['auth_request'] = auth_req.to_json() 129 | login_url = client.login_url(auth_req) 130 | 131 | auth_params = dict(parse_qsl(login_url.split('?')[1])) 132 | flask.session['fragment_encoded_response'] = AuthResponseHandler.expect_fragment_encoded_response(auth_params) 133 | return redirect(login_url) 134 | 135 | def _handle_authentication_response(self): 136 | has_error = flask.request.args.get('error', False, lambda x: bool(int(x))) 137 | if has_error: 138 | if 'error' in flask.session: 139 | return self._show_error_response(flask.session.pop('error')) 140 | return 'Something went wrong.' 141 | 142 | try: 143 | session = UserSession(flask.session) 144 | except UninitialisedSession: 145 | return self._handle_error_response({'error': 'unsolicited_response', 'error_description': 'No initialised user session.'}) 146 | 147 | if flask.session.pop('fragment_encoded_response', False): 148 | return (importlib_resources.files('flask_pyoidc') / 'parse_fragment.html').read_text(encoding='utf-8') 149 | 150 | if 'auth_request' not in flask.session: 151 | return self._handle_error_response({'error': 'unsolicited_response', 'error_description': 'No authentication request stored.'}) 152 | auth_request = AuthorizationRequest().from_json(flask.session.pop('auth_request')) 153 | 154 | is_processing_fragment_encoded_response = flask.request.method == 'POST' 155 | if is_processing_fragment_encoded_response: 156 | auth_resp = flask.request.form 157 | else: 158 | auth_resp = flask.request.args 159 | 160 | client = self.clients[session.current_provider] 161 | 162 | authn_resp = client.parse_authentication_response(auth_resp) 163 | logger.debug('received authentication response: %s', authn_resp.to_json()) 164 | 165 | try: 166 | extra_token_args = {} 167 | if 'OIDC_CLOCK_SKEW' in current_app.config: 168 | extra_token_args['skew'] = current_app.config['OIDC_CLOCK_SKEW'] 169 | result = AuthResponseHandler(client).process_auth_response(authn_resp, auth_request, extra_token_args) 170 | except AuthResponseErrorResponseError as ex: 171 | return self._handle_error_response(ex.error_response, is_processing_fragment_encoded_response) 172 | except AuthResponseProcessError as ex: 173 | return self._handle_error_response({'error': 'unexpected_error', 'error_description': str(ex)}, 174 | is_processing_fragment_encoded_response) 175 | 176 | if current_app.config.get('OIDC_SESSION_PERMANENT', True): 177 | flask.session.permanent = True 178 | 179 | UserSession(flask.session).update(access_token=result.access_token, 180 | expires_in=result.expires_in, 181 | id_token=result.id_token_claims, 182 | id_token_jwt=result.id_token_jwt, 183 | userinfo=result.userinfo_claims, 184 | refresh_token=result.refresh_token) 185 | 186 | destination = flask.session.pop('destination') 187 | if is_processing_fragment_encoded_response: 188 | # if the POST request was from the JS page handling fragment encoded responses we need to return 189 | # the destination URL as the response body 190 | return destination 191 | 192 | return redirect(destination) 193 | 194 | def _handle_error_response(self, error_response, should_redirect=False): 195 | if should_redirect: 196 | # if the current request was from the JS page handling fragment encoded responses we need to return 197 | # a URL for the error page to redirect to 198 | flask.session['error'] = error_response 199 | return '/' + self._redirect_uri_config.endpoint + '?error=1' 200 | return self._show_error_response(error_response) 201 | 202 | def _show_error_response(self, error_response): 203 | logger.error(json.dumps(error_response)) 204 | if self._error_view: 205 | error = {k: error_response[k] for k in ['error', 'error_description'] if k in error_response} 206 | return self._error_view(**error) 207 | 208 | return 'Something went wrong with the authentication, please try to login again.' 209 | 210 | def oidc_auth(self, provider_name: str): 211 | 212 | if provider_name not in self._provider_configurations: 213 | raise ValueError( 214 | f"Provider name '{provider_name}' not in configured providers: {self._provider_configurations.keys()}." 215 | ) 216 | 217 | def oidc_decorator(view_func): 218 | 219 | @functools.wraps(view_func) 220 | def wrapper(*args, **kwargs): 221 | 222 | session = UserSession(flask.session, provider_name) 223 | client = self.clients[session.current_provider] 224 | 225 | if session.should_refresh(client.session_refresh_interval_seconds): 226 | logger.debug('user auth will be refreshed "silently"') 227 | return self._authenticate(client, interactive=False) 228 | elif session.is_authenticated(): 229 | logger.debug('user is already authenticated') 230 | return view_func(*args, **kwargs) 231 | else: 232 | logger.debug('user not authenticated, start flow') 233 | return self._authenticate(client) 234 | 235 | return wrapper 236 | 237 | return oidc_decorator 238 | 239 | def _logout(self, post_logout_redirect_uri): 240 | logger.debug('user logout') 241 | try: 242 | session = UserSession(flask.session) 243 | except UninitialisedSession: 244 | logger.info('user was already logged out, doing nothing') 245 | return None 246 | 247 | id_token_jwt = session.id_token_jwt 248 | client = self.clients[session.current_provider] 249 | session.clear() 250 | 251 | if client.provider_end_session_endpoint: 252 | flask.session['end_session_state'] = rndstr() 253 | 254 | end_session_request = EndSessionRequest(id_token_hint=id_token_jwt, 255 | post_logout_redirect_uri=post_logout_redirect_uri, 256 | state=flask.session['end_session_state']) 257 | 258 | logger.debug('send endsession request: %s', end_session_request.to_json()) 259 | 260 | return redirect(end_session_request.request(client.provider_end_session_endpoint), 303) 261 | return None 262 | 263 | def oidc_logout(self, view_func): 264 | self._logout_views.append(view_func) 265 | 266 | @functools.wraps(view_func) 267 | def wrapper(*args, **kwargs): 268 | if 'state' in flask.request.args: 269 | # returning redirect from provider 270 | if flask.request.args['state'] != flask.session.pop('end_session_state', None): 271 | logger.error("Got unexpected state '%s' after logout redirect.", flask.request.args['state']) 272 | return view_func(*args, **kwargs) 273 | 274 | post_logout_redirect_uri = flask.request.url 275 | redirect_to_provider = self._logout(post_logout_redirect_uri) 276 | if redirect_to_provider: 277 | return redirect_to_provider 278 | 279 | return view_func(*args, **kwargs) 280 | 281 | return wrapper 282 | 283 | def error_view(self, view_func): 284 | self._error_view = view_func 285 | return view_func 286 | 287 | def valid_access_token(self, force_refresh=False): 288 | """ 289 | Returns a valid access token. 290 | 291 | 1. If the current access token in the user session is valid, return that. 292 | 2. If the current access token has expired and there is a refresh token in the user session, 293 | make a refresh token request and return the new access token. 294 | 3. If the token refresh fails, either due to missing refresh token or token error response, return None. 295 | 296 | Args: 297 | force_refresh (bool): whether to perform the refresh token request even if the current access token is valid 298 | Returns: 299 | Option[str]: valid access token 300 | """ 301 | try: 302 | session = UserSession(flask.session) 303 | except UninitialisedSession: 304 | logger.debug('user does not have an active session') 305 | return None 306 | 307 | has_expired = session.access_token_expires_at < time.time() if session.access_token_expires_at else False 308 | if not has_expired and not force_refresh: 309 | logger.debug("access token doesn't need to be refreshed") 310 | return session.access_token 311 | 312 | if not session.refresh_token: 313 | logger.info('no refresh token exists in the session') 314 | return None 315 | 316 | client = self.clients[session.current_provider] 317 | response = client.refresh_token(session.refresh_token) 318 | if 'error' in response: 319 | logger.info('failed to refresh access token: ' + json.dumps(response.to_dict())) 320 | return None 321 | 322 | access_token = response.get('access_token') 323 | session.update(access_token=access_token, 324 | expires_in=response.get('expires_in'), 325 | id_token=response['id_token'].to_dict() if 'id_token' in response else None, 326 | id_token_jwt=response.get('id_token_jwt'), 327 | refresh_token=response.get('refresh_token')) 328 | return access_token 329 | 330 | @staticmethod 331 | def _check_authorization_header(request) -> bool: 332 | """Look for authorization in request header. 333 | 334 | Parameters 335 | ---------- 336 | request : flask.Request 337 | flask request object. 338 | 339 | Returns 340 | ------- 341 | bool 342 | True if the request header contains authorization else False. 343 | """ 344 | if 'Authorization' in request.headers and request.headers['Authorization'].startswith('Bearer '): 345 | return True 346 | return False 347 | 348 | @staticmethod 349 | def _parse_access_token(request) -> str: 350 | """Parse access token from the authorization request header. 351 | 352 | Parameters 353 | ---------- 354 | request : flask.Request 355 | flask request object. 356 | 357 | Returns 358 | ------- 359 | accept_token : str 360 | access token from the request header. 361 | """ 362 | _, access_token = request.headers['Authorization'].split(maxsplit=1) 363 | return access_token 364 | 365 | def introspect_token(self, request, client, scopes: list = None) ->\ 366 | Optional[TokenIntrospectionResponse]: 367 | """RFC 7662: Token Introspection 368 | The Token Introspection extension defines a mechanism for resource 369 | servers to obtain information about access tokens. With this spec, 370 | resource servers can check the validity of access tokens, and find out 371 | other information such as which user and which scopes are associated 372 | with the token. 373 | 374 | Parameters 375 | ---------- 376 | request : flask.Request 377 | flask request object. 378 | client : flask_pyoidc.pyoidc_facade.PyoidcFacade 379 | PyoidcFacade object contains metadata of the provider and client. 380 | scopes : list, optional 381 | Specify scopes required by your endpoint. 382 | 383 | Returns 384 | ------- 385 | result: TokenIntrospectionResponse or None 386 | If access_token is valid or None if invalid. 387 | """ 388 | received_access_token = self._parse_access_token(request) 389 | # send token introspection request 390 | result = client._token_introspection_request( 391 | access_token=received_access_token) 392 | logger.debug(result) 393 | # Check if access_token is valid, active can be True or False 394 | if not result.get('active'): 395 | return None 396 | # Check if client_id is in audience claim 397 | if client._client.client_id not in result['aud']: 398 | # log the exception if client_id is not in audience and returns 399 | # False, you can configure audience with Identity Provider 400 | logger.info('Token is valid but required audience is missing.') 401 | return None 402 | # Check if the scopes associated with the access_token are the ones 403 | # required by the endpoint and not something else which is not 404 | # permitted. 405 | if scopes and not set(scopes).issubset(set(result['scope'])): 406 | logger.info('Token is valid but does not have required scopes.') 407 | return None 408 | return result 409 | 410 | def token_auth(self, provider_name, scopes_required: list = None): 411 | """Token based authorization. 412 | 413 | Parameters 414 | ---------- 415 | provider_name : str 416 | Name of the provider registered with OIDCAuthorization. 417 | scopes_required : list, optional 418 | List of valid scopes associated with the endpoint. 419 | 420 | Raises 421 | ------ 422 | Unauthorized 423 | flask.abort(401) if authorization field is missing. 424 | Forbidden 425 | flask.abort(403) if access token is invalid. 426 | 427 | Examples 428 | -------- 429 | :: 430 | 431 | auth = OIDCAuthentication({'default': provider_config}) 432 | @app.route('/') 433 | @auth.token_auth(provider_name='default') 434 | def index(): 435 | ... 436 | 437 | You can also specify scopes required by the endpoint. 438 | 439 | :: 440 | 441 | @auth.token_auth(provider_name='default', 442 | scopes_required=['read', 'write']) 443 | """ 444 | def token_decorator(view_func): 445 | 446 | @functools.wraps(view_func) 447 | def wrapper(*args, **kwargs): 448 | 449 | client = self.clients[provider_name] 450 | # Check for authorization field in the request header. 451 | if not self._check_authorization_header(flask.request): 452 | logger.info('Request header has no authorization field') 453 | # Abort the request if authorization field is missing. 454 | flask.abort(401) 455 | token_introspection_result = self.introspect_token( 456 | request=flask.request, client=client, 457 | scopes=scopes_required) 458 | if token_introspection_result: 459 | logger.info('Request has valid access token.') 460 | # Store token introspection info within the application 461 | # context. 462 | g.current_token_identity = token_introspection_result.to_dict() 463 | return view_func(*args, **kwargs) 464 | # Forbid access if the access token is invalid. 465 | flask.abort(403) 466 | 467 | return wrapper 468 | 469 | return token_decorator 470 | 471 | def access_control(self, provider_name: str, 472 | scopes_required: list = None): 473 | """This decorator serves dual purpose that is it can do both token 474 | based authorization and oidc based authentication. If your API needs 475 | to be accessible by either modes, use this decorator otherwise use 476 | either oidc_auth or token_auth. 477 | 478 | Parameters 479 | ---------- 480 | provider_name : str 481 | Name of the provider registered with OIDCAuthorization. 482 | scopes_required : list, optional 483 | List of valid scopes associated with the endpoint. 484 | 485 | Raises 486 | ------ 487 | Forbidden 488 | If accesss_token is invalid. 489 | 490 | Examples 491 | -------- 492 | :: 493 | 494 | auth = OIDCAuthentication({'default': provider_config}) 495 | @app.route('/') 496 | @auth.access_control(provider_name='default') 497 | def index(): 498 | ... 499 | 500 | You can also specify scopes required by the endpoint: 501 | 502 | :: 503 | 504 | @auth.access_control(provider_name='default', 505 | scopes_required=['read', 'write']) 506 | """ 507 | def hybrid_decorator(view_func): 508 | 509 | fallback_to_oidc = self.oidc_auth(provider_name)(view_func) 510 | 511 | @functools.wraps(view_func) 512 | def wrapper(*args, **kwargs): 513 | 514 | try: 515 | # If the request header contains authorization, token_auth 516 | # verifies the access_token otherwise an exception occurs 517 | # and the request falls back to oidc_auth. 518 | return self.token_auth(provider_name, scopes_required)( 519 | view_func)(*args, **kwargs) 520 | # token_auth will raise the HTTPException if either 521 | # authorization field is missing from the request header or 522 | # token is invalid. If the authorization field is missing, 523 | # fallback to oidc. 524 | except Unauthorized: 525 | return fallback_to_oidc(*args, **kwargs) 526 | # If token is present, but it's invalid, do not fall back to 527 | # oidc_auth. Instead, abort the request. 528 | except Forbidden: 529 | flask.abort(403) 530 | 531 | return wrapper 532 | 533 | return hybrid_decorator 534 | -------------------------------------------------------------------------------- /tests/test_flask_pyoidc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from datetime import datetime 5 | from http.cookies import SimpleCookie 6 | from unittest.mock import MagicMock, patch 7 | from urllib.parse import parse_qsl, urlparse, urlencode 8 | 9 | import flask 10 | import pytest 11 | import responses 12 | from flask import Flask 13 | from flask_pyoidc import OIDCAuthentication 14 | from flask_pyoidc.provider_configuration import (ProviderConfiguration, ProviderMetadata, ClientMetadata, 15 | ClientRegistrationInfo) 16 | from flask_pyoidc.redirect_uri_config import RedirectUriConfig 17 | from flask_pyoidc.user_session import UserSession 18 | from jwkest import jws 19 | from oic.oic import AuthorizationResponse 20 | from oic.oic.message import IdToken 21 | from werkzeug.exceptions import Forbidden, Unauthorized 22 | from werkzeug.routing import BuildError 23 | 24 | from .util import signed_id_token 25 | 26 | 27 | class TestOIDCAuthentication: 28 | PROVIDER_BASEURL = 'https://op.example.com' 29 | PROVIDER_NAME = 'test_provider' 30 | CLIENT_ID = 'client1' 31 | CLIENT_DOMAIN = 'client.example.com' 32 | CALLBACK_RETURN_VALUE = 'callback called successfully' 33 | 34 | @pytest.fixture(autouse=True) 35 | def create_flask_app(self): 36 | self.app = Flask(__name__) 37 | self.app.config.update({'SERVER_NAME': self.CLIENT_DOMAIN, 'SECRET_KEY': 'test_key', 'OIDC_CLOCK_SKEW': 10}) 38 | 39 | def init_app(self, provider_metadata_extras=None, client_metadata_extras=None, **kwargs): 40 | required_provider_metadata = { 41 | 'issuer': self.PROVIDER_BASEURL, 42 | 'authorization_endpoint': self.PROVIDER_BASEURL + '/auth', 43 | 'jwks_uri': self.PROVIDER_BASEURL + '/jwks' 44 | } 45 | if provider_metadata_extras: 46 | required_provider_metadata.update(provider_metadata_extras) 47 | provider_metadata = ProviderMetadata(**required_provider_metadata) 48 | 49 | required_client_metadata = { 50 | 'client_id': self.CLIENT_ID, 51 | 'client_secret': 'secret1' 52 | } 53 | if client_metadata_extras: 54 | required_client_metadata.update(client_metadata_extras) 55 | client_metadata = ClientMetadata(**required_client_metadata) 56 | 57 | provider_configurations = {self.PROVIDER_NAME: ProviderConfiguration(provider_metadata=provider_metadata, 58 | client_metadata=client_metadata, 59 | **kwargs)} 60 | authn = OIDCAuthentication(provider_configurations) 61 | authn.init_app(self.app) 62 | return authn 63 | 64 | def get_view_mock(self, name='test_callback'): 65 | mock = MagicMock() 66 | mock.__name__ = name 67 | mock.return_value = self.CALLBACK_RETURN_VALUE 68 | return mock 69 | 70 | def assert_auth_redirect(self, auth_redirect): 71 | assert auth_redirect.status_code == 302 72 | assert auth_redirect.location.startswith(self.PROVIDER_BASEURL) 73 | 74 | def assert_view_mock(self, callback_mock, result): 75 | assert callback_mock.called 76 | assert result == self.CALLBACK_RETURN_VALUE 77 | 78 | def test_explicit_redirect_uri_config_should_be_preferred(self): 79 | redirect_uri_config = RedirectUriConfig('https://example.com/redirect_uri', 'redirect_uri') 80 | assert OIDCAuthentication({}, self.app, redirect_uri_config)._redirect_uri_config == redirect_uri_config 81 | 82 | def test_explicit_redirect_uri_config_should_be_preserved_after_init_app(self): 83 | redirect_uri_config = RedirectUriConfig('https://example.com/redirect_uri', 'redirect_uri') 84 | authn = OIDCAuthentication({}, None, redirect_uri_config) 85 | assert authn._redirect_uri_config == redirect_uri_config 86 | authn.init_app(self.app) 87 | assert authn._redirect_uri_config == redirect_uri_config 88 | 89 | def test_should_authenticate_if_no_session(self): 90 | authn = self.init_app() 91 | view_mock = self.get_view_mock() 92 | with self.app.test_request_context('/'): 93 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 94 | 95 | self.assert_auth_redirect(auth_redirect) 96 | assert not view_mock.called 97 | 98 | def test_should_not_authenticate_if_session_exists(self): 99 | authn = self.init_app() 100 | view_mock = self.get_view_mock() 101 | with self.app.test_request_context('/'): 102 | UserSession(flask.session, self.PROVIDER_NAME).update() 103 | result = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 104 | self.assert_view_mock(view_mock, result) 105 | 106 | def test_reauthenticate_silent_if_session_expired(self): 107 | authn = self.init_app(session_refresh_interval_seconds=1) 108 | view_mock = self.get_view_mock() 109 | with self.app.test_request_context('/'): 110 | now = time.time() 111 | with patch('time.time') as time_mock: 112 | time_mock.return_value = now - 1 # authenticated in the past 113 | UserSession(flask.session, self.PROVIDER_NAME).update() 114 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 115 | 116 | self.assert_auth_redirect(auth_redirect) 117 | assert 'prompt=none' in auth_redirect.location # ensure silent auth is used 118 | assert not view_mock.called 119 | 120 | def test_dont_reauthenticate_silent_if_session_not_expired(self): 121 | authn = self.init_app(session_refresh_interval_seconds=999) 122 | view_mock = self.get_view_mock() 123 | with self.app.test_request_context('/'): 124 | UserSession(flask.session, self.PROVIDER_NAME).update() # freshly authenticated 125 | result = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 126 | self.assert_view_mock(view_mock, result) 127 | 128 | @pytest.mark.parametrize('response_type,expected', [ 129 | ('code', False), 130 | ('id_token token', True) 131 | ]) 132 | def test_expected_auth_response_mode_is_set(self, response_type, expected): 133 | authn = self.init_app(auth_request_params={'response_type': response_type}) 134 | view_mock = self.get_view_mock() 135 | with self.app.test_request_context('/'): 136 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 137 | assert flask.session['fragment_encoded_response'] is expected 138 | self.assert_auth_redirect(auth_redirect) 139 | 140 | @responses.activate 141 | @pytest.mark.parametrize('post_logout_redirect_uris', [ 142 | None, 143 | ['https://example.com/post_logout'] 144 | ]) 145 | def test_should_register_client_if_not_registered_before(self, post_logout_redirect_uris): 146 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 147 | provider_metadata = ProviderMetadata(self.PROVIDER_BASEURL, 148 | self.PROVIDER_BASEURL + '/auth', 149 | self.PROVIDER_BASEURL + '/jwks', 150 | registration_endpoint=registration_endpoint) 151 | provider_configurations = { 152 | self.PROVIDER_NAME: ProviderConfiguration(provider_metadata=provider_metadata, 153 | client_registration_info=ClientRegistrationInfo( 154 | redirect_uris=[f'https://{self.CLIENT_DOMAIN}/redirect', 155 | f'https://{self.CLIENT_DOMAIN}/redirect2'], 156 | post_logout_redirect_uris=post_logout_redirect_uris 157 | )) 158 | } 159 | authn = OIDCAuthentication(provider_configurations) 160 | authn.init_app(self.app) 161 | 162 | # register logout view to force 'post_logout_redirect_uris' to be included in registration request 163 | logout_view_mock = self.get_view_mock() 164 | self.app.add_url_rule('/logout', view_func=authn.oidc_logout(logout_view_mock)) 165 | 166 | expected_post_logout_redirect_uris = post_logout_redirect_uris if post_logout_redirect_uris else \ 167 | [f'http://{self.CLIENT_DOMAIN}/logout'] 168 | responses.add(responses.POST, registration_endpoint, json={ 169 | 'client_id': 'client1', 'client_secret': 'secret1', 170 | 'redirect_uris': [f'https://{self.CLIENT_DOMAIN}/redirect', 171 | f'https://{self.CLIENT_DOMAIN}/redirect2'], 172 | 'post_logout_redirect_uris': expected_post_logout_redirect_uris}) 173 | view_mock = self.get_view_mock() 174 | with self.app.test_request_context('/'): 175 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 176 | 177 | self.assert_auth_redirect(auth_redirect) 178 | 179 | registration_request = json.loads(responses.calls[0].request.body) 180 | expected_registration_request = {'application_type': 'web', 'response_types': ['code'], 181 | 'redirect_uris': [f'https://{self.CLIENT_DOMAIN}/redirect', 182 | f'https://{self.CLIENT_DOMAIN}/redirect2'], 183 | 'post_logout_redirect_uris': expected_post_logout_redirect_uris, 184 | 'grant_types': ['authorization_code']} 185 | assert registration_request == expected_registration_request 186 | 187 | @responses.activate 188 | def test_register_client_should_add_redirect_uri_if_not_passed(self): 189 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 190 | provider_metadata = ProviderMetadata(self.PROVIDER_BASEURL, 191 | self.PROVIDER_BASEURL + '/auth', 192 | self.PROVIDER_BASEURL + '/jwks', 193 | registration_endpoint=registration_endpoint) 194 | post_logout_redirect_uris = [f'https://{self.CLIENT_DOMAIN}/logout'] 195 | provider_configurations = { 196 | self.PROVIDER_NAME: ProviderConfiguration( 197 | provider_metadata=provider_metadata, 198 | client_registration_info=ClientRegistrationInfo( 199 | post_logout_redirect_uris=post_logout_redirect_uris 200 | )) 201 | } 202 | authn = OIDCAuthentication(provider_configurations) 203 | authn.init_app(self.app) 204 | 205 | # register logout view to force 'post_logout_redirect_uris' to be included in registration request 206 | logout_view_mock = self.get_view_mock() 207 | self.app.add_url_rule('/logout', view_func=logout_view_mock) 208 | authn.oidc_logout(logout_view_mock) 209 | 210 | redirect_uris = [f'http://{self.CLIENT_DOMAIN}/redirect_uri'] 211 | responses.add(responses.POST, registration_endpoint, json={ 212 | 'client_id': 'client1', 'client_secret': 'secret1', 213 | 'redirect_uris': redirect_uris}) 214 | view_mock = self.get_view_mock() 215 | with self.app.test_request_context('/'): 216 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 217 | 218 | self.assert_auth_redirect(auth_redirect) 219 | 220 | registration_request = json.loads(responses.calls[0].request.body) 221 | expected_registration_request = {'application_type': 'web', 'response_types': ['code'], 222 | 'redirect_uris': redirect_uris, 223 | 'post_logout_redirect_uris': post_logout_redirect_uris, 224 | 'grant_types': ['authorization_code']} 225 | assert registration_request == expected_registration_request 226 | 227 | @responses.activate 228 | def test_register_client_should_return_empty_post_logout_redirect_uris_if_logout_view_not_exist(self): 229 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 230 | provider_metadata = ProviderMetadata(self.PROVIDER_BASEURL, 231 | self.PROVIDER_BASEURL + '/auth', 232 | self.PROVIDER_BASEURL + '/jwks', 233 | registration_endpoint=registration_endpoint) 234 | provider_configurations = { 235 | self.PROVIDER_NAME: ProviderConfiguration( 236 | provider_metadata=provider_metadata, 237 | client_registration_info=ClientRegistrationInfo()) 238 | } 239 | authn = OIDCAuthentication(provider_configurations) 240 | authn.init_app(self.app) 241 | 242 | redirect_uris = [f'http://{self.CLIENT_DOMAIN}/redirect_uri'] 243 | responses.add(responses.POST, registration_endpoint, json={ 244 | 'client_id': 'client1', 'client_secret': 'secret1', 245 | 'redirect_uris': redirect_uris}) 246 | view_mock = self.get_view_mock() 247 | with self.app.test_request_context('/'): 248 | auth_redirect = authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 249 | 250 | self.assert_auth_redirect(auth_redirect) 251 | assert authn.clients[self.PROVIDER_NAME]._provider_configuration._client_metadata.get( 252 | 'post_logout_redirect_uris') is None 253 | 254 | registration_request = json.loads(responses.calls[0].request.body) 255 | expected_registration_request = {'application_type': 'web', 'response_types': ['code'], 256 | 'redirect_uris': redirect_uris, 257 | 'grant_types': ['authorization_code']} 258 | assert registration_request == expected_registration_request 259 | 260 | @patch('time.time') 261 | @patch('oic.utils.time_util.utc_time_sans_frac') # used internally by pyoidc when verifying ID Token 262 | @responses.activate 263 | def test_handle_authentication_response(self, time_mock, utc_time_sans_frac_mock): 264 | # freeze time since ID Token validation includes expiration timestamps 265 | timestamp = time.mktime(datetime(2017, 1, 1).timetuple()) 266 | time_mock.return_value = timestamp 267 | utc_time_sans_frac_mock.return_value = int(timestamp) 268 | 269 | # mock token response 270 | user_id = 'user1' 271 | exp_time = 10 272 | nonce = 'test_nonce' 273 | id_token_claims = { 274 | 'iss': self.PROVIDER_BASEURL, 275 | 'aud': [self.CLIENT_ID], 276 | 'sub': user_id, 277 | 'exp': int(timestamp) + exp_time, 278 | 'iat': int(timestamp), 279 | 'nonce': nonce 280 | } 281 | id_token_jwt, id_token_signing_key = signed_id_token(id_token_claims) 282 | access_token = 'test_access_token' 283 | expires_in = 3600 284 | token_response = { 285 | 'access_token': access_token, 286 | 'expires_in': expires_in, 287 | 'token_type': 'Bearer', 288 | 'id_token': id_token_jwt 289 | } 290 | token_endpoint = self.PROVIDER_BASEURL + '/token' 291 | responses.add(responses.POST, token_endpoint, json=token_response) 292 | responses.add(responses.GET, 293 | self.PROVIDER_BASEURL + '/jwks', 294 | json={'keys': [id_token_signing_key.serialize()]}) 295 | 296 | # mock userinfo response 297 | userinfo = {'sub': user_id, 'name': 'Test User'} 298 | userinfo_endpoint = self.PROVIDER_BASEURL + '/userinfo' 299 | responses.add(responses.GET, userinfo_endpoint, json=userinfo) 300 | 301 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint, 302 | 'userinfo_endpoint': userinfo_endpoint}) 303 | state = 'test_state' 304 | with self.app.test_request_context(f'/redirect_uri?state={state}&code=test'): 305 | UserSession(flask.session, self.PROVIDER_NAME) 306 | flask.session['destination'] = '/' 307 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': nonce}) 308 | authn._handle_authentication_response() 309 | session = UserSession(flask.session) 310 | assert session.access_token == access_token 311 | assert session.access_token_expires_at == int(timestamp) + expires_in 312 | assert session.id_token == id_token_claims 313 | assert session.id_token_jwt == id_token_jwt 314 | assert session.userinfo == userinfo 315 | 316 | @patch('time.time') 317 | @patch('oic.utils.time_util.utc_time_sans_frac') # used internally by pyoidc when verifying ID Token 318 | @responses.activate 319 | def test_handle_implicit_authentication_response(self, time_mock, utc_time_sans_frac_mock): 320 | # freeze time since ID Token validation includes expiration timestamps 321 | timestamp = time.mktime(datetime(2017, 1, 1).timetuple()) 322 | time_mock.return_value = timestamp 323 | utc_time_sans_frac_mock.return_value = int(timestamp) 324 | 325 | # mock auth response 326 | access_token = 'test_access_token' 327 | user_id = 'user1' 328 | exp_time = 10 329 | nonce = 'test_nonce' 330 | id_token_claims = { 331 | 'iss': self.PROVIDER_BASEURL, 332 | 'aud': [self.CLIENT_ID], 333 | 'sub': user_id, 334 | 'exp': int(timestamp) + exp_time, 335 | 'iat': int(timestamp), 336 | 'nonce': nonce, 337 | 'at_hash': jws.left_hash(access_token) 338 | } 339 | id_token_jwt, id_token_signing_key = signed_id_token(id_token_claims) 340 | 341 | responses.add(responses.GET, 342 | self.PROVIDER_BASEURL + '/jwks', 343 | json={'keys': [id_token_signing_key.serialize()]}) 344 | 345 | # mock userinfo response 346 | userinfo = {'sub': user_id, 'name': 'Test User'} 347 | userinfo_endpoint = self.PROVIDER_BASEURL + '/userinfo' 348 | responses.add(responses.GET, userinfo_endpoint, json=userinfo) 349 | 350 | self.init_app(provider_metadata_extras={'userinfo_endpoint': userinfo_endpoint}) 351 | state = 'test_state' 352 | auth_response = AuthorizationResponse( 353 | **{'state': state, 'access_token': access_token, 'token_type': 'Bearer', 'id_token': id_token_jwt}) 354 | 355 | with self.app.test_client() as client: 356 | with client.session_transaction() as session: 357 | UserSession(session, self.PROVIDER_NAME) 358 | session['destination'] = '/' 359 | session['auth_request'] = json.dumps({'state': state, 'nonce': nonce}) 360 | session['fragment_encoded_response'] = True 361 | client.get(f'/redirect_uri#{auth_response.to_urlencoded()}') 362 | assert 'auth_request' in session # stored auth_request should not have been removed yet 363 | 364 | # fake the POST request from the 'parse_fragment.html' template 365 | resp = client.post('/redirect_uri', data=auth_response.to_dict()) 366 | user_session = UserSession(flask.session) 367 | assert user_session.access_token == access_token 368 | assert user_session.id_token == id_token_claims 369 | assert user_session.id_token_jwt == id_token_jwt 370 | assert user_session.userinfo == userinfo 371 | assert 'auth_request' not in flask.session # stored auth_request should have been removed now 372 | assert resp.data.decode('utf-8') == '/' # final redirect back to the protected endpoint 373 | 374 | def test_handle_authentication_response_POST(self): 375 | access_token = 'test_access_token' 376 | state = 'test_state' 377 | 378 | authn = self.init_app() 379 | auth_response = AuthorizationResponse(**{'state': state, 'token_type': 'Bearer', 'access_token': access_token}) 380 | 381 | with self.app.test_request_context('/redirect_uri', 382 | method='POST', 383 | data=auth_response.to_dict(), 384 | mimetype='application/x-www-form-urlencoded'): 385 | UserSession(flask.session, self.PROVIDER_NAME) 386 | flask.session['destination'] = '/test' 387 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': 'test_nonce'}) 388 | response = authn._handle_authentication_response() 389 | session = UserSession(flask.session) 390 | assert session.access_token == access_token 391 | assert response == '/test' 392 | 393 | def test_handle_error_response_POST(self): 394 | state = 'test_state' 395 | 396 | authn = self.init_app() 397 | error_resp = {'state': state, 'error': 'invalid_request', 'error_description': 'test error'} 398 | 399 | with self.app.test_request_context('/redirect_uri', 400 | method='POST', 401 | data=error_resp, 402 | mimetype='application/x-www-form-urlencoded'): 403 | UserSession(flask.session, self.PROVIDER_NAME) 404 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': 'test_nonce'}) 405 | response = authn._handle_authentication_response() 406 | assert flask.session['error'] == error_resp 407 | assert response == '/redirect_uri?error=1' 408 | 409 | def test_handle_authentication_response_without_initialised_session(self): 410 | authn = self.init_app() 411 | 412 | with self.app.test_request_context('/redirect_uri?state=test-state&code=test'): 413 | response = authn._handle_authentication_response() 414 | assert response == 'Something went wrong with the authentication, please try to login again.' 415 | 416 | # with error view configured, error object should be sent to it instead 417 | error_view_mock = self.get_view_mock() 418 | authn.error_view(error_view_mock) 419 | result = authn._handle_authentication_response() 420 | self.assert_view_mock(error_view_mock, result) 421 | error_view_mock.assert_called_with( 422 | **{'error': 'unsolicited_response', 'error_description': 'No initialised user session.'}) 423 | 424 | def test_handle_authentication_response_without_stored_auth_request(self): 425 | authn = self.init_app() 426 | 427 | with self.app.test_request_context('/redirect_uri?state=test-state&code=test'): 428 | UserSession(flask.session, self.PROVIDER_NAME) 429 | flask.session['destination'] = '/test' 430 | response = authn._handle_authentication_response() 431 | assert response == 'Something went wrong with the authentication, please try to login again.' 432 | 433 | # with error view configured, error object should be sent to it instead 434 | error_view_mock = self.get_view_mock() 435 | authn.error_view(error_view_mock) 436 | result = authn._handle_authentication_response() 437 | self.assert_view_mock(error_view_mock, result) 438 | error_view_mock.assert_called_with( 439 | **{'error': 'unsolicited_response', 'error_description': 'No authentication request stored.'}) 440 | 441 | def test_handle_authentication_response_fragment_encoded(self): 442 | authn = self.init_app() 443 | with self.app.test_request_context('/redirect_uri'): 444 | UserSession(flask.session, self.PROVIDER_NAME) 445 | flask.session['auth_request'] = json.dumps({'state': 'test_state', 'nonce': 'test_nonce'}) 446 | flask.session['fragment_encoded_response'] = True 447 | response = authn._handle_authentication_response() 448 | assert response.startswith('') 449 | 450 | def test_handle_authentication_response_error_message(self): 451 | authn = self.init_app() 452 | with self.app.test_request_context('/redirect_uri?error=1'): 453 | flask.session['error'] = {'error': 'test'} 454 | response = authn._handle_authentication_response() 455 | assert response == 'Something went wrong with the authentication, please try to login again.' 456 | 457 | def test_handle_authentication_response_error_message_without_stored_error(self): 458 | authn = self.init_app() 459 | with self.app.test_request_context('/redirect_uri?error=1'): 460 | response = authn._handle_authentication_response() 461 | assert response == 'Something went wrong.' 462 | 463 | @patch('time.time') 464 | @patch('oic.utils.time_util.utc_time_sans_frac') # used internally by pyoidc when verifying ID Token 465 | @responses.activate 466 | def test_session_expiration_set_to_configured_lifetime(self, time_mock, utc_time_sans_frac_mock): 467 | timestamp = time.mktime(datetime(2017, 1, 1).timetuple()) 468 | time_mock.return_value = timestamp 469 | utc_time_sans_frac_mock.return_value = int(timestamp) 470 | 471 | exp_time = 10 472 | state = 'test_state' 473 | nonce = 'test_nonce' 474 | id_token = IdToken(iss=self.PROVIDER_BASEURL, 475 | aud=self.CLIENT_ID, 476 | sub='sub1', 477 | exp=int(timestamp) + exp_time, 478 | iat=int(timestamp), 479 | nonce=nonce) 480 | token_response = {'access_token': 'test', 'token_type': 'Bearer', 'id_token': id_token.to_jwt()} 481 | token_endpoint = self.PROVIDER_BASEURL + '/token' 482 | responses.add(responses.POST, token_endpoint, json=token_response) 483 | 484 | session_lifetime = 1234 485 | self.app.config['PERMANENT_SESSION_LIFETIME'] = session_lifetime 486 | self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 487 | 488 | with self.app.test_client() as client: 489 | with client.session_transaction() as session: 490 | UserSession(session, self.PROVIDER_NAME) 491 | session['destination'] = '/' 492 | session['auth_request'] = json.dumps({'state': state, 'nonce': nonce, 'response_type': 'code'}) 493 | resp = client.get(f'/redirect_uri?state={state}&code=test') 494 | 495 | cookies = SimpleCookie() 496 | cookies.load(resp.headers['Set-Cookie']) 497 | session_cookie_expiration = cookies[self.app.config['SESSION_COOKIE_NAME']]['expires'] 498 | parsed_expiration = datetime.strptime(session_cookie_expiration, '%a, %d %b %Y %H:%M:%S GMT') 499 | cookie_lifetime = (parsed_expiration - datetime.utcnow()).total_seconds() 500 | assert cookie_lifetime == pytest.approx(session_lifetime, abs=1) 501 | 502 | def test_logout_redirects_to_provider_if_end_session_endpoint_is_configured(self): 503 | end_session_endpoint = 'https://provider.example.com/end_session' 504 | client_metadata = {} 505 | 506 | authn = self.init_app(provider_metadata_extras={'end_session_endpoint': end_session_endpoint}, 507 | client_metadata_extras=client_metadata) 508 | logout_view_mock = self.get_view_mock() 509 | id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'}) 510 | 511 | # register logout view 512 | view_func = authn.oidc_logout(logout_view_mock) 513 | self.app.add_url_rule('/logout', view_func=view_func) 514 | 515 | with self.app.test_request_context('/logout'): 516 | UserSession(flask.session, self.PROVIDER_NAME).update(access_token='test_access_token', 517 | id_token=id_token.to_dict(), 518 | id_token_jwt=id_token.to_jwt(), 519 | userinfo={'sub': 'user1'}) 520 | end_session_redirect = view_func() 521 | # ensure user session has been cleared 522 | assert all(k not in flask.session for k in UserSession.KEYS) 523 | parsed_request = dict(parse_qsl(urlparse(end_session_redirect.headers['Location']).query)) 524 | assert parsed_request['state'] == flask.session['end_session_state'] 525 | 526 | assert end_session_redirect.status_code == 303 527 | assert end_session_redirect.location.startswith(end_session_endpoint) 528 | assert IdToken().from_jwt(parsed_request['id_token_hint']) == id_token 529 | 530 | assert parsed_request['post_logout_redirect_uri'] == f'http://{self.CLIENT_DOMAIN}/logout' 531 | assert not logout_view_mock.called 532 | 533 | @responses.activate 534 | def test_multiple_logout_endpoints_are_supported(self): 535 | end_session_endpoint = 'https://provider.example.com/end_session' 536 | registration_endpoint = self.PROVIDER_BASEURL + '/register' 537 | id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'}) 538 | 539 | provider_metadata = ProviderMetadata(self.PROVIDER_BASEURL, 540 | self.PROVIDER_BASEURL + '/auth', 541 | self.PROVIDER_BASEURL + '/jwks', 542 | registration_endpoint=registration_endpoint, 543 | end_session_endpoint=end_session_endpoint) 544 | post_logout_redirect_uris = [f'http://{self.CLIENT_DOMAIN}/logout1', 545 | f'http://{self.CLIENT_DOMAIN}/logout2'] 546 | client_registration_info = ClientRegistrationInfo(post_logout_redirect_uris=post_logout_redirect_uris) 547 | provider_configurations = { 548 | self.PROVIDER_NAME: ProviderConfiguration(provider_metadata=provider_metadata, 549 | client_registration_info=client_registration_info) 550 | } 551 | authn = OIDCAuthentication(provider_configurations) 552 | authn.init_app(self.app) 553 | 554 | # register multiple logout views 555 | view_func1 = authn.oidc_logout(self.get_view_mock('logout')) 556 | self.app.add_url_rule('/logout1', 'logout', view_func=view_func1) 557 | view_func2 = authn.oidc_logout(self.get_view_mock('otherlogout')) 558 | # register logout view with custom endpoint 559 | self.app.add_url_rule('/logout2', 'test.otherlogout', view_func=view_func2) 560 | 561 | # verify client registration includes all logout endpoints as 'post_logout_redirect_uris' 562 | responses.add(responses.POST, registration_endpoint, json={ 563 | 'client_id': 'client1', 'client_secret': 'secret1', 564 | 'redirect_uris': [f'http://{self.CLIENT_DOMAIN}/redirect_uri'], 565 | 'post_logout_redirect_uris': post_logout_redirect_uris}) 566 | view_mock = self.get_view_mock() 567 | with self.app.test_request_context('/'): 568 | authn.oidc_auth(self.PROVIDER_NAME)(view_mock)() 569 | 570 | registration_request = json.loads(responses.calls[0].request.body) 571 | expected_registration_request = {'application_type': 'web', 'response_types': ['code'], 572 | 'redirect_uris': [f'http://{self.CLIENT_DOMAIN}/redirect_uri'], 573 | 'post_logout_redirect_uris': post_logout_redirect_uris, 574 | 'grant_types': ['authorization_code']} 575 | assert registration_request == expected_registration_request 576 | 577 | # verify each logout endpoint can be called 578 | for endpoint, view_func in [('/logout1', view_func1), ('/logout2', view_func2)]: 579 | with self.app.test_request_context(endpoint): 580 | UserSession(flask.session, self.PROVIDER_NAME).update(access_token='test_access_token', 581 | id_token=id_token.to_dict(), 582 | id_token_jwt=id_token.to_jwt(), 583 | userinfo={'sub': 'user1'}) 584 | end_session_redirect = view_func() 585 | # ensure user session has been cleared 586 | assert all(k not in flask.session for k in UserSession.KEYS) 587 | parsed_request = dict(parse_qsl(urlparse(end_session_redirect.headers['Location']).query)) 588 | assert parsed_request['state'] == flask.session['end_session_state'] 589 | 590 | assert end_session_redirect.status_code == 303 591 | assert end_session_redirect.location.startswith(end_session_endpoint) 592 | assert IdToken().from_jwt(parsed_request['id_token_hint']) == id_token 593 | 594 | expected_post_logout_redirect_uri = f'http://{self.CLIENT_DOMAIN}{endpoint}' 595 | assert parsed_request['post_logout_redirect_uri'] == expected_post_logout_redirect_uri 596 | 597 | def test_logout_with_missing_end_session_state_fails_gracefully(self): 598 | end_session_endpoint = 'https://provider.example.com/end_session' 599 | authn = self.init_app(provider_metadata_extras={'end_session_endpoint': end_session_endpoint}) 600 | id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'}) 601 | logout_view_mock = self.get_view_mock() 602 | 603 | # register logout view 604 | view_func = authn.oidc_logout(logout_view_mock) 605 | self.app.add_url_rule('/logout', view_func=view_func) 606 | 607 | with self.app.test_request_context('/logout?state=incorrect'): 608 | UserSession(flask.session, self.PROVIDER_NAME).update(access_token='test_access_token', 609 | id_token=id_token.to_dict(), 610 | id_token_jwt=id_token.to_jwt(), 611 | userinfo={'sub': 'user1'}) 612 | flask.session.pop('end_session_state', None) # make sure there's no 'end_session_state' 613 | logout_result = authn.oidc_logout(logout_view_mock)() 614 | 615 | self.assert_view_mock(logout_view_mock, logout_result) 616 | 617 | def test_logout_handles_provider_without_end_session_endpoint(self): 618 | authn = self.init_app() 619 | id_token = IdToken(**{'sub': 'sub1', 'nonce': 'nonce'}) 620 | logout_view_mock = self.get_view_mock() 621 | with self.app.test_request_context('/logout'): 622 | UserSession(flask.session, self.PROVIDER_NAME).update(access_token='test_access_token', 623 | id_token=id_token.to_dict(), 624 | id_token_jwt=id_token.to_jwt(), 625 | userinfo={'sub': 'user1'}) 626 | 627 | logout_result = authn.oidc_logout(logout_view_mock)() 628 | assert all(k not in flask.session for k in UserSession.KEYS) 629 | 630 | self.assert_view_mock(logout_view_mock, logout_result) 631 | 632 | def test_logout_handles_redirect_back_from_provider(self): 633 | authn = self.init_app() 634 | logout_view_mock = self.get_view_mock() 635 | state = 'end_session_123' 636 | with self.app.test_request_context(f'/logout?state={state}'): 637 | flask.session['end_session_state'] = state 638 | result = authn.oidc_logout(logout_view_mock)() 639 | assert 'end_session_state' not in flask.session 640 | 641 | self.assert_view_mock(logout_view_mock, result) 642 | 643 | def test_logout_handles_redirect_back_from_provider_with_incorrect_state(self, caplog): 644 | authn = self.init_app() 645 | logout_view_mock = self.get_view_mock() 646 | state = 'some_state' 647 | with self.app.test_request_context(f'/logout?state={state}'): 648 | flask.session['end_session_state'] = 'other_state' 649 | result = authn.oidc_logout(logout_view_mock)() 650 | assert 'end_session_state' not in flask.session 651 | 652 | self.assert_view_mock(logout_view_mock, result) 653 | assert caplog.record_tuples[-1] == ('flask_pyoidc.flask_pyoidc', 654 | logging.ERROR, 655 | f"Got unexpected state '{state}' after logout redirect.") 656 | 657 | def test_logout_handles_no_user_session(self): 658 | authn = self.init_app() 659 | logout_view_mock = self.get_view_mock() 660 | with self.app.test_request_context('/logout'): 661 | result = authn.oidc_logout(logout_view_mock)() 662 | 663 | self.assert_view_mock(logout_view_mock, result) 664 | 665 | def test_authentication_error_response_calls_to_error_view_if_set(self): 666 | state = 'test_tate' 667 | error_response = {'error': 'invalid_request', 'error_description': 'test error'} 668 | authn = self.init_app() 669 | error_view_mock = self.get_view_mock() 670 | authn.error_view(error_view_mock) 671 | with self.app.test_request_context(f'/redirect_uri?{urlencode(error_response)}&state={state}'): 672 | UserSession(flask.session, self.PROVIDER_NAME) 673 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': 'test_nonce'}) 674 | result = authn._handle_authentication_response() 675 | 676 | self.assert_view_mock(error_view_mock, result) 677 | error_view_mock.assert_called_with(**error_response) 678 | 679 | def test_authentication_error_response_returns_default_error_if_no_error_view_set(self): 680 | state = 'test_tate' 681 | error_response = {'error': 'invalid_request', 'error_description': 'test error', 'state': state} 682 | config = { 683 | 'provider_configuration_info': {'issuer': self.PROVIDER_BASEURL}, 684 | 'client_registration_info': {'client_id': 'abc', 'client_secret': 'foo'} 685 | } 686 | authn = self.init_app(config) 687 | with self.app.test_request_context(f'/redirect_uri?{urlencode(error_response)}'): 688 | UserSession(flask.session, self.PROVIDER_NAME) 689 | flask.session['state'] = state 690 | flask.session['nonce'] = 'test_nonce' 691 | response = authn._handle_authentication_response() 692 | assert response == "Something went wrong with the authentication, please try to login again." 693 | 694 | @responses.activate 695 | def test_token_error_response_calls_to_error_view_if_set(self): 696 | token_endpoint = self.PROVIDER_BASEURL + '/token' 697 | error_response = {'error': 'invalid_request', 'error_description': 'test error'} 698 | responses.add(responses.POST, token_endpoint, json=error_response) 699 | 700 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 701 | error_view_mock = self.get_view_mock() 702 | authn.error_view(error_view_mock) 703 | state = 'test_tate' 704 | with self.app.test_request_context(f'/redirect_uri?code=foo&state={state}'): 705 | UserSession(flask.session, self.PROVIDER_NAME) 706 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': 'test_nonce'}) 707 | result = authn._handle_authentication_response() 708 | 709 | self.assert_view_mock(error_view_mock, result) 710 | error_view_mock.assert_called_with(**error_response) 711 | 712 | @responses.activate 713 | def test_token_error_response_returns_default_error_if_no_error_view_set(self): 714 | token_endpoint = self.PROVIDER_BASEURL + '/token' 715 | state = 'test_tate' 716 | error_response = {'error': 'invalid_request', 'error_description': 'test error', 'state': state} 717 | responses.add(responses.POST, token_endpoint, json=error_response) 718 | 719 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 720 | with self.app.test_request_context('/redirect_uri?code=foo&state=' + state): 721 | UserSession(flask.session, self.PROVIDER_NAME) 722 | flask.session['state'] = state 723 | flask.session['nonce'] = 'test_nonce' 724 | response = authn._handle_authentication_response() 725 | assert response == "Something went wrong with the authentication, please try to login again." 726 | 727 | def test_using_unknown_provider_name_should_raise_exception(self): 728 | with pytest.raises(ValueError) as exc_info: 729 | self.init_app().oidc_auth('unknown') 730 | assert 'unknown' in str(exc_info.value) 731 | 732 | def test_should_not_refresh_if_no_user_session(self): 733 | with self.app.test_request_context('/foo'): 734 | assert self.init_app().valid_access_token() is None 735 | 736 | @responses.activate 737 | def test_should_refresh_expired_access_token(self): 738 | token_endpoint = self.PROVIDER_BASEURL + '/token' 739 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 740 | 741 | token_response = { 742 | 'access_token': 'new-access-token', 743 | 'expires_in': 3600, 744 | 'token_type': 'Bearer', 745 | 'refresh_token': 'new-refresh-token' 746 | } 747 | responses.add(responses.POST, token_endpoint, json=token_response) 748 | 749 | with self.app.test_request_context('/foo'): 750 | session = UserSession(flask.session, self.PROVIDER_NAME) 751 | session.update(expires_in=-10, refresh_token='refresh-token') 752 | assert authn.valid_access_token() == token_response['access_token'] 753 | assert session.access_token == token_response['access_token'] 754 | assert session.refresh_token == token_response['refresh_token'] 755 | 756 | def test_should_not_refresh_still_valid_access_token(self): 757 | authn = self.init_app() 758 | 759 | access_token = 'access_token' 760 | with self.app.test_request_context('/foo'): 761 | session = UserSession(flask.session, self.PROVIDER_NAME) 762 | session.update(access_token=access_token, expires_in=10, refresh_token='refresh-token') 763 | assert authn.valid_access_token() == access_token 764 | 765 | @responses.activate 766 | def test_should_refresh_still_valid_access_token_if_forced(self): 767 | token_endpoint = self.PROVIDER_BASEURL + '/token' 768 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 769 | 770 | token_response = { 771 | 'access_token': 'new-access-token', 772 | 'expires_in': 3600, 773 | 'token_type': 'Bearer', 774 | 'refresh_token': 'new-refresh-token' 775 | } 776 | responses.add(responses.POST, token_endpoint, json=token_response) 777 | 778 | with self.app.test_request_context('/foo'): 779 | session = UserSession(flask.session, self.PROVIDER_NAME) 780 | session.update(expires_in=10, refresh_token='refresh-token') 781 | assert authn.valid_access_token(force_refresh=True) == token_response['access_token'] 782 | assert session.access_token == token_response['access_token'] 783 | assert session.refresh_token == token_response['refresh_token'] 784 | 785 | def test_should_not_refresh_without_refresh_token(self): 786 | authn = self.init_app() 787 | 788 | with self.app.test_request_context('/foo'): 789 | session = UserSession(flask.session, self.PROVIDER_NAME) 790 | session.update(expires_in=-10) 791 | assert authn.valid_access_token() is None 792 | 793 | def test_should_not_refresh_access_token_without_expiry(self): 794 | authn = self.init_app() 795 | 796 | access_token = 'access_token' 797 | with self.app.test_request_context('/foo'): 798 | session = UserSession(flask.session, self.PROVIDER_NAME) 799 | session.update(access_token=access_token, refresh_token='refresh-token') 800 | assert authn.valid_access_token() == access_token 801 | 802 | @responses.activate 803 | def test_should_return_None_if_token_refresh_request_fails(self): 804 | token_endpoint = self.PROVIDER_BASEURL + '/token' 805 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint}) 806 | 807 | token_response = { 808 | 'error': 'invalid_grant', 809 | 'error_description': 'The refresh token is invalid' 810 | } 811 | responses.add(responses.POST, token_endpoint, json=token_response) 812 | 813 | access_token = 'access_token' 814 | with self.app.test_request_context('/foo'): 815 | session = UserSession(flask.session, self.PROVIDER_NAME) 816 | session.update(access_token=access_token, expires_in=-10, refresh_token='refresh-token') 817 | assert authn.valid_access_token(force_refresh=True) is None 818 | assert session.access_token == access_token 819 | 820 | def test_should_check_for_authorization_header(self): 821 | 822 | authn = self.init_app() 823 | with self.app.test_request_context('/'): 824 | assert not authn._check_authorization_header(flask.request) 825 | flask.request.headers = { 826 | 'Authorization': 'Bearer access_token' 827 | } 828 | assert authn._check_authorization_header(flask.request) 829 | 830 | def test_should_parse_access_token_from_request_header(self): 831 | 832 | authn = self.init_app() 833 | with self.app.test_request_context('/'): 834 | flask.request.headers = { 835 | 'Authorization': 'Bearer access_token' 836 | } 837 | assert authn._parse_access_token(flask.request) == 'access_token' 838 | 839 | @responses.activate 840 | def test_introspect_token_should_return_none_if_invalid_access_token(self): 841 | 842 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 843 | authn = self.init_app(provider_metadata_extras={ 844 | 'introspection_endpoint': introspection_endpoint}) 845 | with self.app.test_request_context('/'): 846 | flask.request.headers = { 847 | 'Authorization': 'Bearer access_token' 848 | } 849 | responses.add(responses.POST, introspection_endpoint, 850 | json={'active': False}) 851 | assert authn.introspect_token( 852 | flask.request, authn.clients[self.PROVIDER_NAME]) is None 853 | 854 | @responses.activate 855 | def test_introspect_token_should_return_none_if_client_id_not_in_audience(self): 856 | 857 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 858 | authn = self.init_app(provider_metadata_extras={ 859 | 'introspection_endpoint': introspection_endpoint}) 860 | with self.app.test_request_context('/'): 861 | flask.request.headers = { 862 | 'Authorization': 'Bearer access_token' 863 | } 864 | token_introspection_response = { 865 | 'active': True, 866 | 'aud': ['admin', 'user'] 867 | } 868 | responses.add(responses.POST, introspection_endpoint, 869 | json=token_introspection_response) 870 | assert authn.introspect_token( 871 | flask.request, authn.clients[self.PROVIDER_NAME]) is None 872 | 873 | @responses.activate 874 | def test_introspect_token_should_return_none_if_required_scopes_not_permitted(self): 875 | 876 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 877 | authn = self.init_app(provider_metadata_extras={ 878 | 'introspection_endpoint': introspection_endpoint}) 879 | with self.app.test_request_context('/'): 880 | flask.request.headers = { 881 | 'Authorization': 'Bearer access_token' 882 | } 883 | token_introspection_response = { 884 | 'active': True, 885 | 'aud': ['admin', 'user', self.CLIENT_ID], 886 | 'scope': ['read', 'write'] 887 | } 888 | responses.add(responses.POST, introspection_endpoint, 889 | json=token_introspection_response) 890 | assert authn.introspect_token( 891 | flask.request, authn.clients[self.PROVIDER_NAME], 892 | scopes=['read', 'write', 'delete']) is None 893 | 894 | @responses.activate 895 | def test_introspect_token_should_return_introspection_result_if_valid_access_token(self): 896 | 897 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 898 | authn = self.init_app(provider_metadata_extras={ 899 | 'introspection_endpoint': introspection_endpoint}) 900 | with self.app.test_request_context('/'): 901 | flask.request.headers = { 902 | 'Authorization': 'Bearer access_token' 903 | } 904 | token_introspection_response = { 905 | 'active': True, 906 | 'aud': ['admin', 'user', self.CLIENT_ID], 907 | 'scope': 'read write delete', 908 | 'client_id': self.CLIENT_ID 909 | } 910 | responses.add(responses.POST, introspection_endpoint, 911 | json=token_introspection_response) 912 | introspection_result = authn.introspect_token( 913 | flask.request, authn.clients[self.PROVIDER_NAME]) 914 | assert token_introspection_response == introspection_result.to_dict() 915 | 916 | def test_token_auth_should_raise_unauthorized_if_authorization_missing(self): 917 | 918 | authn = self.init_app() 919 | view_mock = self.get_view_mock() 920 | with self.app.test_request_context('/'): 921 | with pytest.raises(Unauthorized): 922 | authn.token_auth(self.PROVIDER_NAME)(view_mock)() 923 | 924 | @responses.activate 925 | def test_token_auth_should_run_view_function_if_valid_token(self): 926 | 927 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 928 | authn = self.init_app(provider_metadata_extras={ 929 | 'introspection_endpoint': introspection_endpoint}) 930 | view_mock = self.get_view_mock() 931 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 932 | token_introspection_response = { 933 | 'active': True, 934 | 'aud': ['admin', 'user', self.CLIENT_ID], 935 | 'scope': 'read write delete', 936 | 'client_id': self.CLIENT_ID 937 | } 938 | responses.add(responses.POST, introspection_endpoint, 939 | json=token_introspection_response) 940 | with self.app.test_request_context('/'): 941 | flask.request.headers = { 942 | 'Authorization': 'Bearer access_token' 943 | } 944 | authn.token_auth(self.PROVIDER_NAME, 945 | scopes_required=['read', 'write'])(view_mock)() 946 | assert view_mock.called 947 | assert flask.g.current_token_identity == token_introspection_response 948 | 949 | @responses.activate 950 | def test_token_auth_should_raise_forbidden_if_invalid_token(self): 951 | 952 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 953 | authn = self.init_app(provider_metadata_extras={ 954 | 'introspection_endpoint': introspection_endpoint}) 955 | view_mock = self.get_view_mock() 956 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 957 | token_introspection_response = { 958 | 'active': False, 959 | 'aud': ['admin', 'user', self.CLIENT_ID], 960 | 'scope': 'read write delete', 961 | 'client_id': self.CLIENT_ID 962 | } 963 | responses.add(responses.POST, introspection_endpoint, 964 | json=token_introspection_response) 965 | with self.app.test_request_context('/'): 966 | flask.request.headers = { 967 | 'Authorization': 'Bearer access_token' 968 | } 969 | with pytest.raises(Forbidden): 970 | authn.token_auth( 971 | self.PROVIDER_NAME, 972 | scopes_required=['read', 'write'])(view_mock)() 973 | 974 | @responses.activate 975 | def test_access_control_should_fallback_to_oidc_auth_on_401(self): 976 | 977 | authn = self.init_app() 978 | view_mock = self.get_view_mock() 979 | with self.app.test_request_context('/'): 980 | auth_fallback = authn.access_control( 981 | self.PROVIDER_NAME)(view_mock)() 982 | self.assert_auth_redirect(auth_fallback) 983 | assert not view_mock.called 984 | 985 | @responses.activate 986 | def test_access_control_should_abort_request_if_invalid_token(self): 987 | 988 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 989 | authn = self.init_app(provider_metadata_extras={ 990 | 'introspection_endpoint': introspection_endpoint}) 991 | view_mock = self.get_view_mock() 992 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 993 | token_introspection_response = { 994 | 'active': False, 995 | 'aud': ['admin', 'user', self.CLIENT_ID], 996 | 'scope': 'read write delete', 997 | 'client_id': self.CLIENT_ID 998 | } 999 | responses.add(responses.POST, introspection_endpoint, 1000 | json=token_introspection_response) 1001 | with self.app.test_request_context('/'): 1002 | flask.request.headers = { 1003 | 'Authorization': 'Bearer access_token' 1004 | } 1005 | with pytest.raises(Forbidden): 1006 | authn.access_control( 1007 | self.PROVIDER_NAME, 1008 | scopes_required=['read', 'write'])(view_mock)() 1009 | 1010 | @responses.activate 1011 | def test_access_control_should_run_view_function_if_valid_token(self): 1012 | 1013 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 1014 | authn = self.init_app(provider_metadata_extras={ 1015 | 'introspection_endpoint': introspection_endpoint}) 1016 | view_mock = self.get_view_mock() 1017 | introspection_endpoint = f'{self.PROVIDER_BASEURL}/token/introspect' 1018 | token_introspection_response = { 1019 | 'active': True, 1020 | 'aud': ['admin', 'user', self.CLIENT_ID], 1021 | 'scope': 'read write delete', 1022 | 'client_id': self.CLIENT_ID 1023 | } 1024 | responses.add(responses.POST, introspection_endpoint, 1025 | json=token_introspection_response) 1026 | with self.app.test_request_context('/'): 1027 | flask.request.headers = { 1028 | 'Authorization': 'Bearer access_token' 1029 | } 1030 | authn.access_control( 1031 | self.PROVIDER_NAME, 1032 | scopes_required=['read', 'write'])(view_mock)() 1033 | assert view_mock.called 1034 | assert flask.g.current_token_identity == token_introspection_response 1035 | 1036 | def test_get_url_for_logout_view_should_raise_build_error_if_mounted_under_custom_endpoint(self): 1037 | authn = self.init_app() 1038 | logout_view_mock = self.get_view_mock() 1039 | self.app.add_url_rule('/logout', endpoint='test.logout', view_func=authn.oidc_logout(logout_view_mock)) 1040 | 1041 | with self.app.test_request_context('/'): 1042 | with pytest.raises(BuildError): 1043 | authn._get_urls_for_logout_views() 1044 | 1045 | @patch('time.time') 1046 | @patch('oic.utils.time_util.utc_time_sans_frac') # used internally by pyoidc when verifying ID Token 1047 | @responses.activate 1048 | def test_oidc_clock_skew_passed(self, time_mock, utc_time_sans_frac_mock): 1049 | # freeze time since ID Token validation includes expiration timestamps 1050 | timestamp = time.mktime(datetime(2017, 1, 1).timetuple()) 1051 | time_mock.return_value = timestamp 1052 | utc_time_sans_frac_mock.return_value = int(timestamp) 1053 | 1054 | # mock token response 1055 | user_id = 'user1' 1056 | skew = 10 1057 | exp_time = 10 1058 | nonce = 'test_nonce' 1059 | id_token_claims = { 1060 | 'iss': self.PROVIDER_BASEURL, 1061 | 'aud': [self.CLIENT_ID], 1062 | 'sub': user_id, 1063 | 'exp': int(timestamp) + exp_time + skew, 1064 | 'iat': int(timestamp) + skew, 1065 | 'nonce': nonce 1066 | } 1067 | 1068 | id_token_jwt, id_token_signing_key = signed_id_token(id_token_claims) 1069 | access_token = 'test_access_token' 1070 | expires_in = 3600 1071 | token_response = { 1072 | 'access_token': access_token, 1073 | 'expires_in': expires_in, 1074 | 'token_type': 'Bearer', 1075 | 'id_token': id_token_jwt 1076 | } 1077 | token_endpoint = self.PROVIDER_BASEURL + '/token' 1078 | responses.add(responses.POST, token_endpoint, json=token_response) 1079 | responses.add(responses.GET, 1080 | self.PROVIDER_BASEURL + '/jwks', 1081 | json={'keys': [id_token_signing_key.serialize()]}) 1082 | 1083 | # mock userinfo response 1084 | userinfo = {'sub': user_id, 'name': 'Test User'} 1085 | userinfo_endpoint = self.PROVIDER_BASEURL + '/userinfo' 1086 | responses.add(responses.GET, userinfo_endpoint, json=userinfo) 1087 | 1088 | authn = self.init_app(provider_metadata_extras={'token_endpoint': token_endpoint, 1089 | 'userinfo_endpoint': userinfo_endpoint}) 1090 | state = 'test_state' 1091 | with self.app.test_request_context('/redirect_uri?state={}&code=test'.format(state)): 1092 | UserSession(flask.session, self.PROVIDER_NAME) 1093 | flask.session['destination'] = '/' 1094 | flask.session['auth_request'] = json.dumps({'state': state, 'nonce': nonce}) 1095 | authn._handle_authentication_response() 1096 | session = UserSession(flask.session) 1097 | assert session.access_token == access_token 1098 | assert session.access_token_expires_at == int(timestamp) + expires_in 1099 | assert session.id_token == id_token_claims 1100 | assert session.id_token_jwt == id_token_jwt 1101 | assert session.userinfo == userinfo 1102 | --------------------------------------------------------------------------------