├── .bumpversion.cfg ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE ├── NOTICE ├── README.md ├── application.cfg ├── docker ├── setup.sh └── start.sh ├── example ├── README.md ├── __init__.py ├── app.py ├── application.cfg ├── https.crt ├── https.key ├── requirements.txt ├── signing_key.pem ├── templates │ └── logout.jinja2 ├── views.py └── wsgi.py ├── setup.py ├── signing_key.pem ├── src └── pyop │ ├── __init__.py │ ├── access_token.py │ ├── authz_state.py │ ├── client_authentication.py │ ├── crypto.py │ ├── exceptions.py │ ├── message.py │ ├── provider.py │ ├── request_validator.py │ ├── storage.py │ ├── subject_identifier.py │ ├── userinfo.py │ └── util.py ├── tests ├── pyop │ ├── test_access_token.py │ ├── test_authz_state.py │ ├── test_client_authentication.py │ ├── test_exceptions.py │ ├── test_provider.py │ ├── test_stateless_provider.py │ ├── test_storage.py │ └── test_util.py └── test_requirements.txt └── tox.ini /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 2.0.8 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .cache/ 3 | .tox/ 4 | __pycache__ 5 | *.egg-info 6 | build/ 7 | dist/ 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | os: linux 2 | dist: bionic 3 | language: python 4 | 5 | services: 6 | - docker 7 | - mongodb 8 | 9 | install: 10 | - pip install tox 11 | - pip install tox-travis 12 | 13 | script: 14 | - tox 15 | 16 | jobs: 17 | allow_failures: 18 | - python: 3.9-dev 19 | include: 20 | - python: 3.6 21 | - python: 3.7 22 | - python: 3.8 23 | - python: pypy3 24 | 25 | - stage: Deploy new release 26 | script: skip 27 | deploy: 28 | - provider: pypi 29 | distributions: sdist bdist_wheel 30 | skip_existing: true 31 | user: Lundberg 32 | password: 33 | secure: H5d+Its9YTMSvVddRWX2qgChMb8Eur5zI+qRy3NAPdwRNs1RNyIk1a2z9/EPFmRIu6OsBBDcHsCiq4VXcwvpigdAqMu4iAoZ/Xe0xf88k21GCggfaAPbINRVL6031RFUQkfGZ4abT2cXnerDylMv2DporPZkfCEUJonq+we0GmtJHoCSemXewMxt28TSu0aPKRL4aBfbuRoAPx50jUns9ekxgc0sqpSLvE5qyxWxXIePK0/+8tX3OrdCcKMg/IshgoK7Yondu+DhN+qhf+AkQuPDXUQTx/TKdg/YDVqj8SHT6hIFFi6dCakuhkYIKlkggnSguLhZ2zhVUjYFt1f0NOv2j7dHuKxyUFR9Qm/49rdY/E3ir3CU5YgUEprcgo/jj5K3B1/jY2uXNez1JD97RC6IAPg4o+PwenVQ9a3pLwqnImSaJKPTQf9IyFfrV/xru3ZyQiftmUmCYtCPybDATOq5iqNAQa9Ec0Mg54OGcabPQkNp9CrNFkcO0sM3VNRnGTmuqdYIkjNxPwNCzjbAQKlwcXVNg48kHjH6vb9D+mxjt9CYwCJdfkGm2F2pekr5S8tDdLAkxE9VW+r7SrsRaJRFCHU+6AaejnWvOLCy2S+KJ0JhQesJm0k1iT2fsC8v92MKIzghrY/sqKck33pxB57cFqxLIQIVYgCaBFWbwfc= 34 | on: 35 | tags: true 36 | repo: IdentityPython/pyop 37 | if: tag IS present 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN /bin/echo -e "deb http://se.archive.ubuntu.com/ubuntu xenial main restricted universe\ndeb http://archive.ubuntu.com/ubuntu xenial-updates main restricted universe\ndeb http://security.ubuntu.com/ubuntu xenial-security main restricted universe" > /etc/apt/sources.list 6 | 7 | RUN apt-get update && \ 8 | apt-get -y dist-upgrade && \ 9 | apt-get -y install \ 10 | python3-pip \ 11 | python-virtualenv \ 12 | libpython3-dev \ 13 | python-setuptools \ 14 | build-essential \ 15 | libffi-dev \ 16 | libssl-dev \ 17 | iputils-ping \ 18 | && apt-get clean 19 | 20 | RUN rm -rf /var/lib/apt/lists/* 21 | 22 | RUN adduser --system --no-create-home --shell /bin/false --group pyop 23 | 24 | COPY . /opt/pyop/src/ 25 | COPY docker/setup.sh /opt/pyop/setup.sh 26 | COPY docker/start.sh /start.sh 27 | RUN /opt/pyop/setup.sh 28 | 29 | # Add Dockerfile to the container as documentation 30 | COPY Dockerfile /Dockerfile 31 | 32 | WORKDIR / 33 | 34 | EXPOSE 9090 35 | 36 | CMD ["bash", "/start.sh"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownershUpdaip of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2016 Umeå universitet 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pyOP 2 | [![Build Status](https://travis-ci.org/IdentityPython/pyop.svg)](https://travis-ci.org/IdentityPython/pyop) 3 | [![PyPI](https://img.shields.io/pypi/v/pyop.svg)](https://pypi.python.org/pypi/pyop) 4 | 5 | 6 | OpenID Connect Provider (OP) library in Python. 7 | Uses [pyoidc](https://github.com/rohe/pyoidc/) and 8 | [pyjwkest](https://github.com/rohe/pyjwkest). 9 | 10 | # Provider implementations using pyOP 11 | * [se-leg-op](https://github.com/SUNET/se-leg-op) 12 | * [SATOSA OIDC frontend](https://github.com/its-dirg/SATOSA/blob/master/src/satosa/frontends/openid_connect.py) 13 | * [local example](example/views.py) 14 | 15 | # Introduction 16 | 17 | pyOP is a high-level library intended to be usable in any web server application. 18 | By only providing the core functionality for OpenID Connect the application can freely choose to implement any kind of 19 | authentication mechanisms, while pyOP provides a simple interface for the OpenID Connect messages to send back to 20 | clients. 21 | 22 | ## OpenID Connect support 23 | * [Dynamic Provider Discovery](https://openid.net/specs/openid-connect-discovery-1_0.html) 24 | * [Dynamic Client Registration](https://openid.net/specs/openid-connect-registration-1_0.html) 25 | * [Core](http://openid.net/specs/openid-connect-core-1_0.html) 26 | * [Authorization Code Flow](http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps) 27 | * [Implicit Flow](http://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth) 28 | * [Hybrid Flow](http://openid.net/specs/openid-connect-core-1_0.html#HybridFlowAuth) 29 | * Claims 30 | * [Requesting Claims using Scope Values](http://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims) 31 | * [Claims Request Parameter](http://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter) 32 | * Crypto support 33 | * Currently only supports issuing signed ID Tokens 34 | 35 | # Configuration 36 | The provider instance can be configured through the provider configuration information. In the following example, a 37 | provider instance is initiated to use a MongoDB instance as its backend storage: 38 | 39 | ```python 40 | from jwkest.jwk import rsa_load, RSAKey 41 | 42 | from pyop.authz_state import AuthorizationState 43 | from pyop.provider import Provider 44 | from pyop.storage import MongoWrapper 45 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory 46 | from pyop.userinfo import Userinfo 47 | 48 | signing_key = RSAKey(key=rsa_load('signing_key.pem'), use='sig', alg='RS256') 49 | configuration_information = { 50 | 'issuer': 'https://example.com', 51 | 'authorization_endpoint': 'https://example.com/authorization', 52 | 'token_endpoint': 'https://example.com/token', 53 | 'userinfo_endpoint': 'https://example.com/userinfo', 54 | 'registration_endpoint': 'https://example.com/registration', 55 | 'response_types_supported': ['code', 'id_token token'], 56 | 'id_token_signing_alg_values_supported': [signing_key.alg], 57 | 'response_modes_supported': ['fragment', 'query'], 58 | 'subject_types_supported': ['public', 'pairwise'], 59 | 'grant_types_supported': ['authorization_code', 'implicit'], 60 | 'claim_types_supported': ['normal'], 61 | 'claims_parameter_supported': True, 62 | 'claims_supported': ['sub', 'name', 'given_name', 'family_name'], 63 | 'request_parameter_supported': False, 64 | 'request_uri_parameter_supported': False, 65 | 'scopes_supported': ['openid', 'profile'] 66 | } 67 | 68 | subject_id_factory = HashBasedSubjectIdentifierFactory(sub_hash_salt) 69 | authz_state = AuthorizationState(subject_id_factory, 70 | MongoWrapper(db_uri, 'provider', 'authz_codes'), 71 | MongoWrapper(db_uri, 'provider', 'access_tokens'), 72 | MongoWrapper(db_uri, 'provider', 'refresh_tokens'), 73 | MongoWrapper(db_uri, 'provider', 'subject_identifiers')) 74 | client_db = MongoWrapper(db_uri, 'provider', 'clients') 75 | user_db = MongoWrapper(db_uri, 'provider', 'users') 76 | provider = Provider(signing_key, configuration_information, authz_state, client_db, Userinfo(user_db)) 77 | ``` 78 | 79 | where `db_uri` is the [MongoDB connection URI](https://docs.mongodb.com/manual/reference/connection-string/) and 80 | `sub_hash_salt` is a secret string to use as a salt when creating hash based subject identifiers. 81 | 82 | ## Token lifetimes 83 | The ID token lifetime (in seconds) can be supplied to the `Provider` constructor with `id_token_lifetime`, e.g.: 84 | 85 | ```python 86 | Provider(..., id_token_lifetime=600) 87 | ``` 88 | If not specified it will default to 1 hour. 89 | 90 | The lifetime of authorization codes, access tokens, and refresh tokens is configured in the `AuthorizationState`, e.g.: 91 | 92 | ```python 93 | AuthorizationState(..., authorization_code_lifetime=300, access_token_lifetime=60*60*24, 94 | refresh_token_lifetime=60*60*24*365, refresh_token_threshold=None) 95 | ``` 96 | 97 | If not specified the lifetimes defaults to the following values: 98 | * Authorization codes are valid for 10 minutes. 99 | * Access tokens are valid for 1 hour. 100 | * Refresh tokens are not issued. 101 | 102 | To make sure refresh tokens are issued in response to code exchange token requests, specify a 103 | `refresh_token_lifetime` > 0. 104 | To make sure refresh tokens are renewed if they are close to expiry in response to refresh token requests, 105 | specify a `refresh_token_threshold` > 0. 106 | 107 | # Dynamic discovery: Provider Configuration Information 108 | To publish the provider configuration information at an endpoint, use `Provider.provider_configuration`. 109 | 110 | The following example illustrates the high-level idea: 111 | 112 | ```python 113 | @app.route('/.well-known/openid-configuration') 114 | def provider_config(): 115 | return HTTPResponse(provider.provider_configuration.to_json(), content_type="application/json") 116 | ``` 117 | 118 | # Authorization endpoint 119 | An incoming authentication request can be validated by the provider using `Provider.parse_authentication_request`. 120 | If the request is valid, it should be stored and associated with the current user session to be able to retrieve it 121 | when the end-user authentication is completed. 122 | 123 | ```python 124 | from pyop.exceptions import InvalidAuthenticationRequest 125 | 126 | @app.route('/authorization') 127 | def authorization_endpoints(request): 128 | try: 129 | authn_req = provider.parse_authentication_request(request) 130 | except InvalidAuthenticationRequest as e: 131 | error_url = e.to_error_url() 132 | 133 | if error_url: 134 | return HTTPResponse(error_url, status=303) 135 | else: 136 | return HTTPResponse("Something went wrong: {}".format(str(e)), status=400) 137 | 138 | session['authn_req'] = authn_req.to_dict() 139 | // TODO initiate end-user authentication 140 | ``` 141 | 142 | When the authentication is completed by the user, the provider must be notified to make an authentication response 143 | to the client's 'redirect_uri'. This is done with `Provider.authorize`, where the local user id supplied must exist 144 | in the user database supplied on initialization. When using the included `MongoWrapper`, no mapping is done between 145 | user data and OpenID Connect claim names. Hence the underlying data source must contain the user information under the 146 | same names as the [standard claims of OpenID Connect](http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims). 147 | 148 | ```python 149 | from pyop.message import AuthorizationRequest 150 | 151 | from pyop.util import should_fragment_encode 152 | 153 | authn_req = session['authn_req'] 154 | authn_response = provider.authorize(AuthorizationRequest().from_dict(authn_req), user_id) 155 | return_url = authn_response.request(authn_req['redirect_uri'], should_fragment_encode(authn_req)) 156 | 157 | return HTTPResponse(return_url, status=303) 158 | ``` 159 | 160 | ## Authentication request validation 161 | The provider instance is by default configured to validate authentication requests according to the OpenID Connect 162 | Core specification. If you need to add additional custom validation of authentication requests, that's possible by 163 | adding such validation functions to the list of authentication request validators. 164 | 165 | In this example an additional validator that checks that the 'nonce' parameter is included in all requests is added: 166 | 167 | ```python 168 | from pyop.exceptions import InvalidAuthenticationRequest 169 | 170 | def request_contains_nonce(authentication_request): 171 | if 'nonce' not in authentication_request: 172 | raise InvalidAuthenticationRequest('The request does not contain a nonce', authentication_request, 173 | oauth_error='invalid_request') 174 | 175 | provider.authentication_request_validators.append(request_contains_nonce) 176 | ``` 177 | 178 | # Token endpoint 179 | An incoming token request is processed by `Provider.handle_token_request`. It will validate the request and issue all 180 | necessary tokens (access token and possibly refresh token) 181 | 182 | ```python 183 | from oic.oic.message import TokenErrorResponse 184 | 185 | from pyop.exceptions import InvalidClientAuthentication 186 | from pyop.exceptions import OAuthError 187 | 188 | @app.route('/token', methods=['POST', 'GET']) 189 | def token_endpoint(request): 190 | try: 191 | token_response = provider.handle_token_request(request.get_data().decode('utf-8'), 192 | request.headers) 193 | return HTTPResponse(token_response.to_json(), content_type='application/json') 194 | except InvalidClientAuthentication as e: 195 | error_resp = TokenErrorResponse(error='invalid_client', error_description=str(e)) 196 | http_response = HTTPResponse(error_resp.to_json(), status=401, content_type='application/json') 197 | http_response.headers['WWW-Authenticate'] = 'Basic' 198 | return http_response 199 | except OAuthError as e: 200 | error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e)) 201 | return HTTPResponse(error_resp.to_json(), status=400, content_type='application/json') 202 | ``` 203 | 204 | 205 | # Userinfo endpoint 206 | An incoming userinfo request is processed by `Provider.handle_userinfo_request`. It will validate the request and return 207 | all requested userinfo. 208 | 209 | ```python 210 | from oic.oic.message import UserInfoErrorResponse 211 | 212 | from pyop.access_token import AccessToken 213 | from pyop.exceptions import BearerTokenError 214 | from pyop.exceptions import InvalidAccessToken 215 | 216 | @app.route('/userinfo', methods=['GET', 'POST']) 217 | def userinfo_endpoint(request): 218 | try: 219 | response = provider.handle_userinfo_request(request.get_data().decode('utf-8'), 220 | request.headers) 221 | return HTTPResponse(response.to_json(), content_type='application/json') 222 | except (BearerTokenError, InvalidAccessToken) as e: 223 | error_resp = UserInfoErrorResponse(error='invalid_token', error_description=str(e)) 224 | http_response = HTTPResponse(error_resp.to_json(), status=401, content_type='application/json') 225 | http_response.headers['WWW-Authenticate'] = AccessToken.BEARER_TOKEN_TYPE 226 | return http_response 227 | ``` 228 | 229 | 230 | # Dynamic client registration 231 | 232 | An incoming client registration request is process by `Provider.handle_client_registration_request`. It will validate the request, 233 | store the registered metadata and issue new client credentials. 234 | 235 | ```python 236 | from pyop.exceptions import InvalidClientRegistrationRequest 237 | 238 | @app.route('/registration', methods=['POST']) 239 | def registration_endpoint(request): 240 | try: 241 | response = provider.handle_client_registration_request(request.get_data().decode('utf-8')) 242 | return HTTPResponse(response.to_json(), status=201, content_type='application/json') 243 | except InvalidClientRegistrationRequest as e: 244 | return HTTPResponse(e.to_json(), status=400, content_type='application/json') 245 | ``` 246 | 247 | ## Registration request validation 248 | The provider instance is by default configured to validate registration requests according to the OpenID Connect 249 | Dynamic Registration specification. If you need to add additional custom validation of registration requests, that's 250 | possible by adding such validation functions to the list of registration request validators. 251 | 252 | In this example an additional validator that checks that the 'software_statement' parameter is included in all requests 253 | is added: 254 | 255 | ```python 256 | def request_contains_software_statement(registration_request): 257 | if 'software_statement' not in registration_request: 258 | raise InvalidClientRegistrationRequest('The request does not contain a software_statement', registration_request, 259 | oauth_error='invalid_request') 260 | 261 | provider.registration_request_validators.append(request_contains_software_statement) 262 | ``` 263 | 264 | # User logout 265 | 266 | RP-initiated logout, as described in [Section 5 of OpenID Connect Session Management](http://openid.net/specs/openid-connect-session-1_0.html#RPLogout) 267 | is supported. The parsed request should be passed to `Provider.logout_user` together with any known subject identifier 268 | for the user, and then `Provider.do_post_logout_redirect` should be called do obey any valid `post_logout_redirect_uri` 269 | included in the request. 270 | 271 | ```python 272 | from oic.oic.message import EndSessionRequest 273 | 274 | from pyop.exceptions import InvalidSubjectIdentifier 275 | 276 | @app.route('/logout') 277 | def end_session_endpoint(request): 278 | end_session_request = EndSessionRequest().deserialize(request.get_data().decode('utf-8')) 279 | 280 | try: 281 | provider.logout_user(session.get('sub'), end_session_request) 282 | except InvalidSubjectIdentifier as e: 283 | return HTTPResponse('Logout unsuccessful!', content_type='text/html', status=400) 284 | 285 | # TODO automagic logout, should ask user first! 286 | redirect_url = provider.do_post_logout_redirect(end_session_request) 287 | if redirect_url: 288 | return HTTPResponse(redirect_url, status=303) 289 | 290 | return HTTPResponse('Logout successful!', content_type='text/html') 291 | ``` 292 | 293 | # Exceptions 294 | All exceptions, except `AuthorizationError`, inherits from `ValueError`. However it might be necessary to distinguish 295 | between them to send the correct error message back to the client according to the OpenID Connect specifications. 296 | 297 | All OAuth errors contain the OAuth error code in `OAuthError.oauth_error`, together with the error description as the 298 | message of the exception (accessed by `str(exception)`). 299 | -------------------------------------------------------------------------------- /application.cfg: -------------------------------------------------------------------------------- 1 | example/application.cfg -------------------------------------------------------------------------------- /docker/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Install all requirements 4 | # 5 | 6 | set -e 7 | set -x 8 | 9 | PYPI="https://pypi.nordu.net/simple/" 10 | ping -c 1 -q pypiserver.docker && PYPI="http://pypiserver.docker:8080/simple/" 11 | 12 | echo "#############################################################" 13 | echo "$0: Using PyPi URL ${PYPI}" 14 | echo "#############################################################" 15 | 16 | virtualenv -p python3 /opt/pyop 17 | /opt/pyop/bin/pip install -U pip 18 | 19 | # setup.py points to current directory 20 | # so we need to change to the right one. 21 | cd /opt/pyop/src 22 | 23 | /opt/pyop/bin/python3 setup.py install 24 | /opt/pyop/bin/pip install Flask 25 | /opt/pyop/bin/pip install gunicorn 26 | -------------------------------------------------------------------------------- /docker/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | . /opt/pyop/bin/activate 7 | 8 | # nice to have in docker run output, to check what 9 | # version of something is actually running. 10 | /opt/pyop/bin/pip freeze 11 | 12 | export PYTHONPATH=/opt/pyop/src 13 | 14 | start-stop-daemon --start \ 15 | -c pyop:pyop \ 16 | --exec /opt/pyop/bin/gunicorn \ 17 | --pidfile /var/run/pyop.pid \ 18 | --chdir /opt/pyop/src \ 19 | -- \ 20 | example.wsgi:app \ 21 | -b :9090 \ 22 | --certfile example/https.crt \ 23 | --keyfile example/https.key 24 | 25 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # pyOP example application 2 | To run the example application, execute the following commands: 3 | 4 | ```bash 5 | cd example/ 6 | pip install -r requirements.txt # install the dependencies 7 | gunicorn wsgi:app -b :9090 --certfile https.crt --keyfile https.key # run the application 8 | ``` -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IdentityPython/pyop/fab87f9f6193079171fdad0c223c810fe9532dd2/example/__init__.py -------------------------------------------------------------------------------- /example/app.py: -------------------------------------------------------------------------------- 1 | from flask.app import Flask 2 | from flask.helpers import url_for 3 | from jwkest.jwk import RSAKey, rsa_load 4 | 5 | from pyop.authz_state import AuthorizationState 6 | from pyop.provider import Provider 7 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory 8 | from pyop.userinfo import Userinfo 9 | 10 | 11 | def init_oidc_provider(app): 12 | with app.app_context(): 13 | issuer = url_for('oidc_provider.index')[:-1] 14 | authentication_endpoint = url_for('oidc_provider.authentication_endpoint') 15 | jwks_uri = url_for('oidc_provider.jwks_uri') 16 | token_endpoint = url_for('oidc_provider.token_endpoint') 17 | userinfo_endpoint = url_for('oidc_provider.userinfo_endpoint') 18 | registration_endpoint = url_for('oidc_provider.registration_endpoint') 19 | end_session_endpoint = url_for('oidc_provider.end_session_endpoint') 20 | 21 | configuration_information = { 22 | 'issuer': issuer, 23 | 'authorization_endpoint': authentication_endpoint, 24 | 'jwks_uri': jwks_uri, 25 | 'token_endpoint': token_endpoint, 26 | 'userinfo_endpoint': userinfo_endpoint, 27 | 'registration_endpoint': registration_endpoint, 28 | 'end_session_endpoint': end_session_endpoint, 29 | 'scopes_supported': ['openid', 'profile'], 30 | 'response_types_supported': ['code', 'code id_token', 'code token', 'code id_token token'], # code and hybrid 31 | 'response_modes_supported': ['query', 'fragment'], 32 | 'grant_types_supported': ['authorization_code', 'implicit'], 33 | 'subject_types_supported': ['pairwise'], 34 | 'token_endpoint_auth_methods_supported': ['client_secret_basic'], 35 | 'claims_parameter_supported': True 36 | } 37 | 38 | userinfo_db = Userinfo(app.users) 39 | signing_key = RSAKey(key=rsa_load('signing_key.pem'), alg='RS256') 40 | provider = Provider(signing_key, configuration_information, 41 | AuthorizationState(HashBasedSubjectIdentifierFactory(app.config['SUBJECT_ID_HASH_SALT'])), 42 | {}, userinfo_db) 43 | 44 | return provider 45 | 46 | 47 | def oidc_provider_init_app(name=None): 48 | name = name or __name__ 49 | app = Flask(name) 50 | app.config.from_pyfile('application.cfg') 51 | 52 | app.users = {'test_user': {'name': 'Testing Name'}} 53 | 54 | from .views import oidc_provider_views 55 | app.register_blueprint(oidc_provider_views) 56 | 57 | # Initialize the oidc_provider after views to be able to set correct urls 58 | app.provider = init_oidc_provider(app) 59 | 60 | return app 61 | -------------------------------------------------------------------------------- /example/application.cfg: -------------------------------------------------------------------------------- 1 | SERVER_NAME = 'localhost:9090' 2 | SECRET_KEY = 'secret_key' 3 | SESSION_COOKIE_NAME='pyop_session' 4 | SUBJECT_ID_HASH_SALT = 'salt' 5 | PREFERRED_URL_SCHEME = 'https' -------------------------------------------------------------------------------- /example/https.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEBjCCAu6gAwIBAgIJAIybVu7kfIK0MA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV 3 | BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX 4 | aWRnaXRzIFB0eSBMdGQxGDAWBgNVBAMTD2xva2kuaXRzLnVtdS5zZTAeFw0xNTEy 5 | MTAxNDQyMDFaFw0yNTEyMDcxNDQyMDFaMF8xCzAJBgNVBAYTAkFVMRMwEQYDVQQI 6 | EwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQx 7 | GDAWBgNVBAMTD2xva2kuaXRzLnVtdS5zZTCCASIwDQYJKoZIhvcNAQEBBQADggEP 8 | ADCCAQoCggEBALiLDBwIteIobC+7JHoNeQRrTIbws9BghN4UUzyLo7+xeP9YwHaS 9 | tq6HqYK4cVLyx8k06Siw/4PwqMPNj9/B4f/ZXhEkXgbBP5TP36UgKrUIk4zInRFb 10 | Rjy+DcqjSZdgW1CKBKWJstXjSYen5rPm+voM/0msi164NPcfDMQIZmcQWh0MmEfG 11 | qlvdwTvjdaAQt8p7CGsxIdu4gPfhubknbTQKu+BVq5/RCVP7VU830PSr1RYhthX8 12 | Gt8ir32jEdDdjIrfA/zFx6PChyLkQFXtg/9WymnIM1j2ngNreL2nppwnqMYRnI9i 13 | C/y7MY4al3WeL9IETrtgh1jzXUNpgpJ03B0CAwEAAaOBxDCBwTAdBgNVHQ4EFgQU 14 | cTNphzIIRpQBQ2VT0Vx9xQYzNN0wgZEGA1UdIwSBiTCBhoAUcTNphzIIRpQBQ2VT 15 | 0Vx9xQYzNN2hY6RhMF8xCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRl 16 | MSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxGDAWBgNVBAMTD2xv 17 | a2kuaXRzLnVtdS5zZYIJAIybVu7kfIK0MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcN 18 | AQEFBQADggEBAJYJfUqOPTyZ+tflKoN4l+scIXpBxqtQbjX+MYli6VHpl+M8y163 19 | KsCglXPddL7Z58KBrUDx1m6f7dFQ3PMYn/S2dUcRrNOdfaDKZ5QgyYj/iVr8HSOh 20 | 6i1OtMFaBqW5WyqA5YgvUz63hZ2kDOBHZcEfSn2+roylBUiueV9gFNKDWneNMLo2 21 | PMZxcGWZ3wIQbu9ahakbJUvTigFStKeLoY1A2ZSTH7W4elB5DDxYOKZSzd/KZpfn 22 | /o/Pc7YbbEUYgIyf3QNusdH+t2pw9ZkrlKMhiv9ZAmjAWMDY7O/i3r7u7AkODu7z 23 | OHbH3rJqkbaiS8/q0cMZG6AMUzPRglzsTc4= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /example/https.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAuIsMHAi14ihsL7skeg15BGtMhvCz0GCE3hRTPIujv7F4/1jA 3 | dpK2roepgrhxUvLHyTTpKLD/g/Cow82P38Hh/9leESReBsE/lM/fpSAqtQiTjMid 4 | EVtGPL4NyqNJl2BbUIoEpYmy1eNJh6fms+b6+gz/SayLXrg09x8MxAhmZxBaHQyY 5 | R8aqW93BO+N1oBC3ynsIazEh27iA9+G5uSdtNAq74FWrn9EJU/tVTzfQ9KvVFiG2 6 | Ffwa3yKvfaMR0N2Mit8D/MXHo8KHIuRAVe2D/1bKacgzWPaeA2t4vaemnCeoxhGc 7 | j2IL/LsxjhqXdZ4v0gROu2CHWPNdQ2mCknTcHQIDAQABAoIBAQCooAWUqDDqUl1o 8 | z+voyt7FtvXaZ58mzMsb0h6suDwMMTKKwKI8tprOp4+wrrB+RvFfXUWftPwFp6XO 9 | JMtOfm7vxcM6jqyMJ5DdfYSx8c6UVR3eCoHbFjf70P3xJ3tbIuTNlw/f4w7Sejj6 10 | B+W6hVjXm4C55TwEdPWQyYJ0rehESxITn7DDjDNXxxwDqwAv8yPTYf8m7mb7qh3V 11 | EBnvZPIWpgEIVqV2crQSfHJwg29KhS7cnExJBYEPppBQ4aoUGyMqJN0EyBL4DrRo 12 | Ds6hPttLaXEmB+ACm/OQzhEeFKduob5OIKRSyp9Z8t6/B9uHiIKCENRU4O4zux57 13 | jZ8MIypRAoGBAO0anHYHlniyP5f/8u4kvTnW3wbDtRJ3L3zVSKN8shQmQnjNWmsI 14 | LhLWLU2OTRsXlJB7oYqNBojUBdeGimYot9kmjx4XkxELB6XAaYDG6pDvM7axC5qH 15 | iuC4jVHdEsIfy9dP6wZ/b+A+JOWWpS1vdAizfvgWI32JLGDPkjjUlzNvAoGBAMdA 16 | F5KZZLZFYsZM470/bb20qFMURDRI+5yz0VUHTNUQEr/xhMcYFske6ox14A7gXzwd 17 | SHAvTDkV8DsGu/FzSWzZmVhNc1EtdM3Cbe3Y7WJoIQoupuxBDEvrBE9riyOjW61q 18 | dYIO2ymfJfc2Vx6d7LAK9itXdqa3RIo7ZPK7EbMzAoGAAR1C5vsaJe8QhXJafewG 19 | R6NO4QVCcJfGzVtjQAFyBM45OcAdUKt1K/l9tQOaMSpnNFagZ7pJ8ZKthFnJhLlk 20 | Q8z+lzGdK1NV8d15oXVN3OiC4bTrTQqeCHhVkbDsSaVEm/pwLFOk/vTLz5hpplED 21 | xpaxXhEckZZ3cu0GzuWQ4FkCgYAb34tspqjAFtTKiNcTElx3vV4OwTcJWWxZb45J 22 | JsxIwgbdcxvv/h6x4/FL1PGTIzAvaKlJiFRRaBBDMZ35GPeckpQxFiSbppBAeIKI 23 | U2Bh888rbXtMcY0W0bm4ooLEaYXZrJrjptBh8jGNc7ycO9twhRgK2CFxERI1hDmK 24 | +0BuoQKBgHLRFdJYFkNKB+j56vTlB5AFTXcjs6x0dZ4j0SDAqVYgIHErpd2bhgOz 25 | s698rfdadwbYN4UEqvV/2NhDLx/jag4BvCermK71HaXvYnvqKw3UHGAG8K/dCCFR 26 | eSuZnxTSoWAPjMLi64hRzbWVT+oeZuy83lG6qDygbI1z1nlYWdo3 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /example/requirements.txt: -------------------------------------------------------------------------------- 1 | pyop 2 | Flask 3 | gunicorn 4 | oic>=1.2.1 5 | setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability 6 | werkzeug>=3.0.1 # not directly required, pinned by Snyk to avoid a vulnerability 7 | requests>=2.32.2 # not directly required, pinned by Snyk to avoid a vulnerability 8 | urllib3>=2.2.2 # not directly required, pinned by Snyk to avoid a vulnerability 9 | -------------------------------------------------------------------------------- /example/signing_key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIICXAIBAAKBgQCi7nye2Ye1MrUD/sZAplfkpMkXHYduydvfvv/+Ihx1ClxKS/KG 3 | /1EhqyBVVvhHLRs9pimKMyLm2pBE51rGOt//XKhsoFAa37VID2iz7DQuV6DGgyBS 4 | FaKgaYBpinEQy2WcjU4eABnABV2r+K2UmGkqVJheqHqOqHUKasT4gy/6kQIDAQAB 5 | AoGAf7u+YX6ioNCvDwHHBVojn/H8YK3axmVkhiYkZWTysGM99VVTPridL2sMfzse 6 | jBZ1u8Av4tOyMg/5eLtz8+KmRjljpeAEFfsA1htWE8vESXnvDFwKldXD9Vi/kppb 7 | CYqASGCBUX3i1LPYffvjUxIgD+Tjx4k56c5EN5G331flDV0CQQDP8fWraegLJ+K1 8 | iXGNQzpaqG3EI3vf35Yb2bJpmD39QIXFIcJJ5MZHVW+1TyvgiavM4hS2+LGA8kGh 9 | OvMWfbYTAkEAyJWGBUmAW9mooo1Vw4tJjEWAHjHvzcQ4dqIju+WN8Xy5JTWkDD6Z 10 | VgKGtgLt2HfpSsgej14+Rh5mrjo4SbYxSwJAKG0syq9jOk/9xjc7STBJtvhJprkT 11 | SxnHsBBpnBfJ7WNO3l1KzVzZo2Kbvg7vQ87gBIvrZQsCT0RJuBOi0LuN2wJAAInm 12 | Qj1gSt7axRT8FfpZyDankW0w56yPOkJVNjv3lZ5wINl0B1RjtQdstTBs0xf/WGQR 13 | MPFf2XBbdjxRymDi4QJBAM3MUYPOlUk1UVCSQKyCkBwL3zMaPjBjD5LXkhGJzxsb 14 | T74NznwmCib/r0Rl/KmD7/bAq7R4aheOS/OMaZyhbkk= 15 | -----END RSA PRIVATE KEY----- 16 | -------------------------------------------------------------------------------- /example/templates/logout.jinja2: -------------------------------------------------------------------------------- 1 | 2 | Logout 3 | 4 | Do you really want to logout? 5 |
6 | 7 | 8 |
-------------------------------------------------------------------------------- /example/views.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlencode, parse_qs 2 | 3 | import flask 4 | from flask import Blueprint, redirect 5 | from flask import current_app 6 | from flask import jsonify 7 | from flask.helpers import make_response 8 | from flask.templating import render_template 9 | from oic.oic.message import TokenErrorResponse, UserInfoErrorResponse, EndSessionRequest 10 | 11 | from pyop.access_token import AccessToken, BearerTokenError 12 | from pyop.exceptions import InvalidAuthenticationRequest, InvalidAccessToken, InvalidClientAuthentication, OAuthError, \ 13 | InvalidSubjectIdentifier, InvalidClientRegistrationRequest 14 | from pyop.util import should_fragment_encode 15 | 16 | oidc_provider_views = Blueprint('oidc_provider', __name__, url_prefix='') 17 | 18 | 19 | @oidc_provider_views.route('/') 20 | def index(): 21 | return 'Hello world!' 22 | 23 | 24 | @oidc_provider_views.route('/registration', methods=['POST']) 25 | def registration_endpoint(): 26 | try: 27 | response = current_app.provider.handle_client_registration_request(flask.request.get_data().decode('utf-8')) 28 | return make_response(jsonify(response.to_dict()), 201) 29 | except InvalidClientRegistrationRequest as e: 30 | return make_response(jsonify(e.to_dict()), status=400) 31 | 32 | 33 | @oidc_provider_views.route('/authentication', methods=['GET']) 34 | def authentication_endpoint(): 35 | # parse authentication request 36 | try: 37 | auth_req = current_app.provider.parse_authentication_request(urlencode(flask.request.args), 38 | flask.request.headers) 39 | except InvalidAuthenticationRequest as e: 40 | current_app.logger.debug('received invalid authn request', exc_info=True) 41 | error_url = e.to_error_url() 42 | if error_url: 43 | return redirect(error_url, 303) 44 | else: 45 | # show error to user 46 | return make_response('Something went wrong: {}'.format(str(e)), 400) 47 | 48 | # automagic authentication 49 | authn_response = current_app.provider.authorize(auth_req, 'test_user') 50 | response_url = authn_response.request(auth_req['redirect_uri'], should_fragment_encode(auth_req)) 51 | return redirect(response_url, 303) 52 | 53 | 54 | @oidc_provider_views.route('/.well-known/openid-configuration') 55 | def provider_configuration(): 56 | return jsonify(current_app.provider.provider_configuration.to_dict()) 57 | 58 | 59 | @oidc_provider_views.route('/jwks') 60 | def jwks_uri(): 61 | return jsonify(current_app.provider.jwks) 62 | 63 | 64 | @oidc_provider_views.route('/token', methods=['POST']) 65 | def token_endpoint(): 66 | try: 67 | token_response = current_app.provider.handle_token_request(flask.request.get_data().decode('utf-8'), 68 | flask.request.headers) 69 | return jsonify(token_response.to_dict()) 70 | except InvalidClientAuthentication as e: 71 | current_app.logger.debug('invalid client authentication at token endpoint', exc_info=True) 72 | error_resp = TokenErrorResponse(error='invalid_client', error_description=str(e)) 73 | response = make_response(error_resp.to_json(), 401) 74 | response.headers['Content-Type'] = 'application/json' 75 | response.headers['WWW-Authenticate'] = 'Basic' 76 | return response 77 | except OAuthError as e: 78 | current_app.logger.debug('invalid request: %s', str(e), exc_info=True) 79 | error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e)) 80 | response = make_response(error_resp.to_json(), 400) 81 | response.headers['Content-Type'] = 'application/json' 82 | return response 83 | 84 | 85 | @oidc_provider_views.route('/userinfo', methods=['GET', 'POST']) 86 | def userinfo_endpoint(): 87 | try: 88 | response = current_app.provider.handle_userinfo_request(flask.request.get_data().decode('utf-8'), 89 | flask.request.headers) 90 | return jsonify(response.to_dict()) 91 | except (BearerTokenError, InvalidAccessToken) as e: 92 | error_resp = UserInfoErrorResponse(error='invalid_token', error_description=str(e)) 93 | response = make_response(error_resp.to_json(), 401) 94 | response.headers['WWW-Authenticate'] = AccessToken.BEARER_TOKEN_TYPE 95 | response.headers['Content-Type'] = 'application/json' 96 | return response 97 | 98 | 99 | def do_logout(end_session_request): 100 | try: 101 | current_app.provider.logout_user(end_session_request=end_session_request) 102 | except InvalidSubjectIdentifier as e: 103 | return make_response('Logout unsuccessful!', 400) 104 | 105 | redirect_url = current_app.provider.do_post_logout_redirect(end_session_request) 106 | if redirect_url: 107 | return redirect(redirect_url, 303) 108 | 109 | return make_response('Logout successful!') 110 | 111 | 112 | @oidc_provider_views.route('/logout', methods=['GET', 'POST']) 113 | def end_session_endpoint(): 114 | if flask.request.method == 'GET': 115 | # redirect from RP 116 | end_session_request = EndSessionRequest().deserialize(urlencode(flask.request.args)) 117 | flask.session['end_session_request'] = end_session_request.to_dict() 118 | return render_template('logout.jinja2') 119 | else: 120 | form = parse_qs(flask.request.get_data().decode('utf-8')) 121 | if 'logout' in form: 122 | return do_logout(EndSessionRequest().from_dict(flask.session['end_session_request'])) 123 | else: 124 | return make_response('You chose not to logout') 125 | -------------------------------------------------------------------------------- /example/wsgi.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from example.app import oidc_provider_init_app 4 | 5 | name = 'oidc_provider' 6 | app = oidc_provider_init_app(name) 7 | logging.basicConfig(level=logging.DEBUG) 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='pyop', 5 | version='3.4.2', 6 | packages=find_packages('src'), 7 | package_dir={'': 'src'}, 8 | url='https://github.com/IdentityPython/pyop', 9 | license='Apache 2.0', 10 | author='Rebecka Gulliksson', 11 | author_email='satosa-dev@lists.sunet.se', 12 | description='OpenID Connect Provider (OP) library in Python.', 13 | install_requires=[ 14 | 'oic >= 1.2.1', 15 | 'pycryptodomex', 16 | ], 17 | extras_require={ 18 | 'mongo': 'pymongo >= 3.12, < 4.0', 19 | 'redis': 'redis', 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /signing_key.pem: -------------------------------------------------------------------------------- 1 | example/signing_key.pem -------------------------------------------------------------------------------- /src/pyop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IdentityPython/pyop/fab87f9f6193079171fdad0c223c810fe9532dd2/src/pyop/__init__.py -------------------------------------------------------------------------------- /src/pyop/access_token.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from urllib.parse import parse_qsl 3 | 4 | from .exceptions import BearerTokenError 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class AccessToken(object): 10 | """ 11 | Representation of an access token. 12 | """ 13 | BEARER_TOKEN_TYPE = 'Bearer' 14 | 15 | def __init__(self, value, expires_in, typ=BEARER_TOKEN_TYPE): 16 | self.value = value 17 | self.expires_in = expires_in 18 | self.type = typ 19 | 20 | 21 | def extract_bearer_token_from_http_request(parsed_request=None, authz_header=None): 22 | # type (Optional[Mapping[str, str]], Optional[str] -> str 23 | """ 24 | Extracts a Bearer token from an http request 25 | :param parsed_request: parsed request (URL query part of request body) 26 | :param authz_header: HTTP Authorization header 27 | :return: Bearer access token, if found 28 | :raise BearerTokenError: if no Bearer token could be extracted from the request 29 | """ 30 | if authz_header: 31 | # Authorization Request Header Field: https://tools.ietf.org/html/rfc6750#section-2.1 32 | if authz_header.startswith(AccessToken.BEARER_TOKEN_TYPE): 33 | access_token = authz_header[len(AccessToken.BEARER_TOKEN_TYPE) + 1:] 34 | logger.debug('found access token %s in authz header', access_token) 35 | return access_token 36 | elif parsed_request: 37 | if 'access_token' in parsed_request: 38 | """ 39 | Form-Encoded Body Parameter: https://tools.ietf.org/html/rfc6750#section-2.2, and 40 | URI Query Parameter: https://tools.ietf.org/html/rfc6750#section-2.3 41 | """ 42 | access_token = parsed_request['access_token'] 43 | logger.debug('found access token %s in request', access_token) 44 | return access_token 45 | 46 | raise BearerTokenError('Bearer Token could not be found in the request') 47 | -------------------------------------------------------------------------------- /src/pyop/authz_state.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import uuid 4 | 5 | from oic.extension.message import TokenIntrospectionResponse 6 | 7 | from .message import AuthorizationRequest 8 | from .access_token import AccessToken 9 | from .exceptions import InvalidAccessToken 10 | from .exceptions import InvalidAuthorizationCode 11 | from .exceptions import InvalidRefreshToken 12 | from .exceptions import InvalidScope 13 | from .exceptions import InvalidSubjectIdentifier 14 | from .storage import StatelessWrapper 15 | from .util import requested_scope_is_allowed 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def rand_str(): 21 | return uuid.uuid4().hex 22 | 23 | 24 | # TODO remove expired/invalid authorization codes/tokens from db's 25 | 26 | class AuthorizationState(object): 27 | KEY_AUTHORIZATION_REQUEST = 'auth_req' 28 | KEY_USER_INFO = 'user_info' 29 | KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims' 30 | 31 | def __init__(self, subject_identifier_factory, authorization_code_db=None, access_token_db=None, 32 | refresh_token_db=None, subject_identifier_db=None, *, 33 | authorization_code_lifetime=600, access_token_lifetime=3600, refresh_token_lifetime=None, 34 | refresh_token_threshold=None): 35 | # type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any], 36 | # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None 37 | """ 38 | :param subject_identifier_factory: callable to use when construction subject identifiers 39 | :param authorization_code_db: database for storing authorization codes, defaults to in-memory 40 | dict if not specified 41 | :param access_token_db: database for storing access tokens, defaults to in-memory 42 | dict if not specified 43 | :param refresh_token_db: database for storing refresh tokens, defaults to in-memory 44 | dict if not specified 45 | :param subject_identifier_db: database for storing subject identifiers, defaults to in-memory 46 | dict if not specified 47 | :param authorization_code_lifetime: how long before authorization codes should expire (in seconds), 48 | defaults to 10 minutes 49 | :param access_token_lifetime: how long before access tokens should expire (in seconds), 50 | defaults to 1 hour 51 | :param refresh_token_lifetime: how long before refresh tokens should expire (in seconds), 52 | defaults to never issuing a refresh token if not defined 53 | :param refresh_token_threshold: how long before refresh token expiry time a new one should be issued (in 54 | seconds) in a token refresh request, defaults to never issuing a new refresh token 55 | """ 56 | 57 | if not subject_identifier_factory: 58 | raise ValueError('subject_identifier_factory can\'t be None') 59 | self._subject_identifier_factory = subject_identifier_factory 60 | 61 | self.authorization_code_lifetime = authorization_code_lifetime 62 | """ 63 | Mapping of authorization codes to the subject identifier and auth request. 64 | """ 65 | self.authorization_codes = authorization_code_db if authorization_code_db is not None else {} 66 | 67 | self.access_token_lifetime = access_token_lifetime 68 | """ 69 | Mapping of access tokens to the scope, token type, client id and subject identifier. 70 | """ 71 | self.access_tokens = access_token_db if access_token_db is not None else {} 72 | 73 | self.refresh_token_lifetime = refresh_token_lifetime 74 | self.refresh_token_threshold = refresh_token_threshold 75 | """ 76 | Mapping of refresh tokens to access tokens. 77 | """ 78 | self.refresh_tokens = refresh_token_db if refresh_token_db is not None else {} 79 | 80 | """ 81 | Mapping of user id's to subject identifiers. 82 | """ 83 | self.stateless = ( 84 | isinstance(self.authorization_codes, StatelessWrapper) 85 | or isinstance(self.access_tokens, StatelessWrapper) 86 | or isinstance(self.refresh_tokens, StatelessWrapper) 87 | ) 88 | self.subject_identifiers = ( 89 | {} 90 | if self.stateless 91 | else subject_identifier_db 92 | if subject_identifier_db is not None 93 | else {} 94 | ) 95 | 96 | def create_authorization_code( 97 | self, 98 | authorization_request, 99 | subject_identifier, 100 | scope=None, 101 | user_info=None, 102 | extra_id_token_claims=None, 103 | ): 104 | # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str 105 | """ 106 | Creates an authorization code bound to the authorization request and the authenticated user identified 107 | by the subject identifier. 108 | """ 109 | 110 | if not self._is_valid_subject_identifier(subject_identifier): 111 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier)) 112 | 113 | scope = ' '.join(scope or authorization_request['scope']) 114 | logger.debug('creating authz code for scope=%s', scope) 115 | 116 | authz_info = { 117 | 'used': False, 118 | 'exp': int(time.time()) + self.authorization_code_lifetime, 119 | 'sub': subject_identifier, 120 | 'granted_scope': scope, 121 | self.KEY_AUTHORIZATION_REQUEST: authorization_request.to_dict() 122 | } 123 | 124 | if self.stateless: 125 | if user_info: 126 | authz_info[self.KEY_USER_INFO] = user_info 127 | authz_info[self.KEY_EXTRA_ID_TOKEN_CLAIMS] = extra_id_token_claims or {} 128 | authorization_code = self.authorization_codes.pack(authz_info) 129 | else: 130 | authorization_code = rand_str() 131 | self.authorization_codes[authorization_code] = authz_info 132 | 133 | logger.debug('new authz_code=%s to client_id=%s for sub=%s valid_until=%s', authorization_code, 134 | authorization_request['client_id'], subject_identifier, authz_info['exp']) 135 | return authorization_code 136 | 137 | def create_access_token(self, authorization_request, subject_identifier, scope=None, user_info=None): 138 | # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict]) -> se_leg_op.access_token.AccessToken 139 | """ 140 | Creates an access token bound to the authentication request and the authenticated user identified by the 141 | subject identifier. 142 | """ 143 | if not self._is_valid_subject_identifier(subject_identifier): 144 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier)) 145 | 146 | scope = scope or authorization_request['scope'] 147 | 148 | return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope), 149 | user_info=user_info) 150 | 151 | def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None, 152 | user_info=None): 153 | # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken 154 | """ 155 | Creates an access token bound to the subject identifier, client id and requested scope. 156 | """ 157 | scope = current_scope or granted_scope 158 | logger.debug('creating access token for scope=%s', scope) 159 | 160 | authz_info = { 161 | 'iat': int(time.time()), 162 | 'exp': int(time.time()) + self.access_token_lifetime, 163 | 'sub': subject_identifier, 164 | 'client_id': auth_req['client_id'], 165 | 'aud': [auth_req['client_id']], 166 | 'scope': scope, 167 | 'granted_scope': granted_scope, 168 | 'token_type': AccessToken.BEARER_TOKEN_TYPE, 169 | self.KEY_AUTHORIZATION_REQUEST: auth_req 170 | } 171 | 172 | if self.stateless: 173 | if user_info: 174 | authz_info[self.KEY_USER_INFO] = user_info 175 | access_token_val = self.access_tokens.pack(authz_info) 176 | else: 177 | access_token_val = rand_str() 178 | self.access_tokens[access_token_val] = authz_info 179 | 180 | logger.debug('new access_token=%s to client_id=%s for sub=%s valid_until=%s', 181 | access_token_val, auth_req['client_id'], subject_identifier, authz_info['exp']) 182 | access_token = AccessToken(access_token_val, self.access_token_lifetime) 183 | return access_token 184 | 185 | def exchange_code_for_token(self, authorization_code): 186 | # type: (str) -> se_leg_op.access_token.AccessToken 187 | """ 188 | Exchanges an authorization code for an access token. 189 | """ 190 | if authorization_code not in self.authorization_codes: 191 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) 192 | 193 | authz_info = self.authorization_codes[authorization_code] 194 | if authz_info['used']: 195 | logger.debug('detected already used authz_code=%s', authorization_code) 196 | raise InvalidAuthorizationCode('{} has already been used'.format(authorization_code)) 197 | elif authz_info['exp'] < int(time.time()): 198 | logger.debug('detected expired authz_code=%s, now=%s > exp=%s ', 199 | authorization_code, int(time.time()), authz_info['exp']) 200 | raise InvalidAuthorizationCode('{} has expired'.format(authorization_code)) 201 | 202 | authz_info['used'] = True 203 | self.authorization_codes[authorization_code] = authz_info 204 | 205 | access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST], 206 | authz_info['granted_scope'], 207 | user_info=authz_info.get(self.KEY_USER_INFO)) 208 | 209 | logger.debug('authz_code=%s exchanged to access_token=%s', authorization_code, access_token.value) 210 | return access_token 211 | 212 | def introspect_access_token(self, access_token_value): 213 | # type: (str) -> Dict[str, Union[str, List[str]]] 214 | """ 215 | Returns authorization data associated with the access token. 216 | See "Token Introspection", Section 2.2. 217 | """ 218 | if access_token_value not in self.access_tokens: 219 | raise InvalidAccessToken('{} unknown'.format(access_token_value)) 220 | 221 | authz_info = self.access_tokens[access_token_value] 222 | 223 | introspection = {'active': authz_info['exp'] >= int(time.time())} 224 | 225 | introspection_params = {k: v for k, v in authz_info.items() if k in TokenIntrospectionResponse.c_param} 226 | introspection.update(introspection_params) 227 | return introspection 228 | 229 | def create_refresh_token(self, access_token_value): 230 | # type: (str) -> str 231 | """ 232 | Creates an refresh token bound to the specified access token. 233 | """ 234 | if access_token_value not in self.access_tokens: 235 | raise InvalidAccessToken('{} unknown'.format(access_token_value)) 236 | 237 | if not self.refresh_token_lifetime: 238 | logger.debug('no refresh token issued for for access_token=%s', access_token_value) 239 | return None 240 | 241 | authz_info = {'access_token': access_token_value, 'exp': int(time.time()) + self.refresh_token_lifetime} 242 | 243 | if self.stateless: 244 | refresh_token = self.refresh_tokens.pack(authz_info) 245 | else: 246 | refresh_token = rand_str() 247 | self.refresh_tokens[refresh_token] = authz_info 248 | 249 | logger.debug('issued refresh_token=%s expiring=%d for access_token=%s', refresh_token, authz_info['exp'], 250 | access_token_value) 251 | return refresh_token 252 | 253 | def use_refresh_token(self, refresh_token, scope=None): 254 | # type (str, Optional[List[str]]) -> Tuple[se_leg_op.access_token.AccessToken, Optional[str]] 255 | """ 256 | Creates a new access token, and refresh token, based on the supplied refresh token. 257 | :return: new access token and new refresh token if the old one had an expiration time 258 | """ 259 | 260 | if refresh_token not in self.refresh_tokens: 261 | raise InvalidRefreshToken('{} unknown'.format(refresh_token)) 262 | 263 | refresh_token_info = self.refresh_tokens[refresh_token] 264 | if 'exp' in refresh_token_info and refresh_token_info['exp'] < int(time.time()): 265 | raise InvalidRefreshToken('{} has expired'.format(refresh_token)) 266 | 267 | authz_info = self.access_tokens[refresh_token_info['access_token']] 268 | 269 | if scope: 270 | if not requested_scope_is_allowed(scope, authz_info['granted_scope']): 271 | logger.debug('trying to refresh token with superset scope, requested_scope=%s, granted_scope=%s', 272 | scope, authz_info['granted_scope']) 273 | raise InvalidScope('Requested scope includes non-granted value') 274 | scope = ' '.join(scope) 275 | logger.debug('refreshing token with new scope, old_scope=%s -> new_scope=%s', authz_info['scope'], scope) 276 | else: 277 | # OAuth 2.0: scope: "[...] if omitted is treated as equal to the scope originally granted by the resource owner" 278 | scope = authz_info['granted_scope'] 279 | 280 | new_access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST], 281 | authz_info['granted_scope'], scope, 282 | user_info=authz_info.get(self.KEY_USER_INFO)) 283 | 284 | new_refresh_token = None 285 | if self.refresh_token_threshold \ 286 | and 'exp' in refresh_token_info \ 287 | and refresh_token_info['exp'] - int(time.time()) < self.refresh_token_threshold: 288 | # refresh token is close to expiry, issue a new one 289 | new_refresh_token = self.create_refresh_token(new_access_token.value) 290 | else: 291 | self.refresh_tokens[refresh_token]['access_token'] = new_access_token.value 292 | 293 | logger.debug('refreshed tokens, new_access_token=%s new_refresh_token=%s old_refresh_token=%s', 294 | new_access_token, new_refresh_token, refresh_token) 295 | return new_access_token, new_refresh_token 296 | 297 | def get_subject_identifier(self, subject_type, user_id, sector_identifier=None): 298 | # type: (str, str, str) -> str 299 | """ 300 | Returns a subject identifier for the local user identifier. 301 | :param subject_type: 'pairwise' or 'public', see 302 | 303 | "OpenID Connect Core 1.0", Section 8. 304 | :param user_id: local user identifier 305 | :param sector_identifier: the client's sector identifier, 306 | see 307 | "OpenID Connect Core 1.0", Section 1.2 308 | """ 309 | 310 | if user_id not in self.subject_identifiers: 311 | self.subject_identifiers[user_id] = {} 312 | 313 | if subject_type == 'public': 314 | if 'public' not in self.subject_identifiers[user_id]: 315 | new_sub = self._subject_identifier_factory.create_public_identifier(user_id) 316 | self.subject_identifiers[user_id] = {'public': new_sub} 317 | 318 | logger.debug('created new public sub=%s for user_id=%s', 319 | self.subject_identifiers[user_id]['public'], user_id) 320 | sub = self.subject_identifiers[user_id]['public'] 321 | logger.debug('returning public sub=%s', sub) 322 | return sub 323 | elif subject_type == 'pairwise': 324 | if not sector_identifier: 325 | raise ValueError('sector_identifier cannot be None or empty') 326 | 327 | subject_id = self._subject_identifier_factory.create_pairwise_identifier(user_id, sector_identifier) 328 | logger.debug('returning pairwise sub=%s for user_id=%s and sector_identifier=%s', 329 | subject_id, user_id, sector_identifier) 330 | sub = self.subject_identifiers[user_id] 331 | pairwise_set = set(sub.get('pairwise', [])) 332 | pairwise_set.add(subject_id) 333 | sub['pairwise'] = list(pairwise_set) 334 | self.subject_identifiers[user_id] = sub 335 | return subject_id 336 | 337 | raise ValueError('Unknown subject_type={}'.format(subject_type)) 338 | 339 | def _is_valid_subject_identifier(self, sub): 340 | # type: (str) -> bool 341 | """ 342 | Determines whether the subject identifier is known. 343 | """ 344 | 345 | try: 346 | self.get_user_id_for_subject_identifier(sub) 347 | return True 348 | except InvalidSubjectIdentifier: 349 | return False 350 | 351 | def get_user_id_for_subject_identifier(self, subject_identifier): 352 | for user_id, subject_identifiers in self.subject_identifiers.items(): 353 | is_public_sub = 'public' in subject_identifiers and subject_identifier == subject_identifiers['public'] 354 | is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers['pairwise'] 355 | if is_public_sub or is_pairwise_sub: 356 | return user_id 357 | 358 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier)) 359 | 360 | def get_user_info_for_code(self, authorization_code): 361 | # type: (str) -> dict 362 | if authorization_code not in self.authorization_codes: 363 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) 364 | 365 | return self.authorization_codes[authorization_code].get(self.KEY_USER_INFO) 366 | 367 | def get_extra_id_token_claims_for_code(self, authorization_code): 368 | # type: (str) -> dict 369 | if authorization_code not in self.authorization_codes: 370 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) 371 | 372 | return self.authorization_codes[authorization_code].get(self.KEY_EXTRA_ID_TOKEN_CLAIMS) 373 | 374 | def get_user_info_for_access_token(self, access_token): 375 | # type: (str) -> dict 376 | if access_token not in self.access_tokens: 377 | raise InvalidAccessToken('{} unknown'.format(access_token)) 378 | 379 | return self.access_tokens[access_token].get(self.KEY_USER_INFO) 380 | 381 | def get_authorization_request_for_code(self, authorization_code): 382 | # type: (str) -> AuthorizationRequest 383 | if authorization_code not in self.authorization_codes: 384 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) 385 | 386 | return AuthorizationRequest().from_dict( 387 | self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST]) 388 | 389 | def get_authorization_request_for_access_token(self, access_token_value): 390 | # type: (str) -> 391 | if access_token_value not in self.access_tokens: 392 | raise InvalidAccessToken('{} unknown'.format(access_token_value)) 393 | 394 | return AuthorizationRequest().from_dict(self.access_tokens[access_token_value][self.KEY_AUTHORIZATION_REQUEST]) 395 | 396 | def get_subject_identifier_for_code(self, authorization_code): 397 | # type: (str) -> AuthorizationRequest 398 | if authorization_code not in self.authorization_codes: 399 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code)) 400 | 401 | return self.authorization_codes[authorization_code]['sub'] 402 | 403 | def delete_state_for_subject_identifier(self, subject_identifier): 404 | # type (str) -> None 405 | if not self._is_valid_subject_identifier(subject_identifier): 406 | raise InvalidSubjectIdentifier('Trying to delete state for unknown subject identifier') 407 | 408 | for tokens in [self.authorization_codes, self.access_tokens]: 409 | tokens_to_remove = [k for k, v in tokens.items() if v['sub'] == subject_identifier] 410 | for ac in tokens_to_remove: 411 | tokens.pop(ac, None) 412 | -------------------------------------------------------------------------------- /src/pyop/client_authentication.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | from urllib.parse import unquote 4 | 5 | from .exceptions import InvalidClientAuthentication 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def verify_client_authentication(clients, parsed_request, authz_header=None): 11 | # type: (Mapping[str, str], Mapping[str, Mapping[str, Any]], Optional[str]) -> bool 12 | """ 13 | Verifies client authentication at the token endpoint, see 14 | "The OAuth 2.0 Authorization Framework", 15 | Section 2.3.1 16 | :param parsed_request: key-value pairs from parsed urlencoded request 17 | :param clients: clients db 18 | :param authz_header: the HTTP Authorization header value 19 | :return: the unmodified parsed request 20 | :raise InvalidClientAuthentication: if the client authentication was incorrect 21 | """ 22 | client_id = None 23 | client_secret = None 24 | authn_method = None 25 | if authz_header: 26 | logger.debug('client authentication in Authorization header %s', authz_header) 27 | 28 | authz_scheme = authz_header.split(maxsplit=1)[0] 29 | if authz_scheme == 'Basic': 30 | authn_method = 'client_secret_basic' 31 | credentials = authz_header[len('Basic '):] 32 | missing_padding = 4 - len(credentials) % 4 33 | if missing_padding: 34 | credentials += '=' * missing_padding 35 | try: 36 | auth = base64.urlsafe_b64decode(credentials.encode('utf-8')).decode('utf-8') 37 | except UnicodeDecodeError as e: 38 | raise InvalidClientAuthentication('Could not userid/password from authorization header'.format(authz_scheme)) 39 | client_id, client_secret = [unquote(part) for part in auth.split(':')] 40 | else: 41 | raise InvalidClientAuthentication('Unknown scheme in authorization header, {} != Basic'.format(authz_scheme)) 42 | elif 'client_id' in parsed_request: 43 | logger.debug('client authentication in request body %s', parsed_request) 44 | 45 | client_id = parsed_request['client_id'] 46 | if 'client_secret' in parsed_request: 47 | authn_method = 'client_secret_post' 48 | client_secret = parsed_request['client_secret'] 49 | else: 50 | authn_method = 'none' 51 | client_secret = None 52 | 53 | if client_id not in clients: 54 | raise InvalidClientAuthentication('client_id \'{}\' unknown'.format(client_id)) 55 | 56 | client_info = clients[client_id] 57 | if client_secret != client_info.get('client_secret', None): 58 | raise InvalidClientAuthentication('Incorrect client_secret') 59 | 60 | expected_authn_method = client_info.get('token_endpoint_auth_method', 'client_secret_basic') 61 | if authn_method != expected_authn_method: 62 | raise InvalidClientAuthentication( 63 | 'Wrong authentication method used, MUST use \'{}\''.format(expected_authn_method)) 64 | 65 | return client_id 66 | -------------------------------------------------------------------------------- /src/pyop/crypto.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | 4 | from Cryptodome import Random 5 | from Cryptodome.Cipher import AES 6 | 7 | 8 | class _AESCipher(object): 9 | """ 10 | This class will perform AES encryption/decryption with a keylength of 256. 11 | 12 | @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256 13 | """ 14 | 15 | def __init__(self, key): 16 | """ 17 | Constructor 18 | 19 | :type key: str 20 | 21 | :param key: The key used for encryption and decryption. The longer key the better. 22 | """ 23 | self.bs = 32 24 | self.key = hashlib.sha256(key.encode()).digest() 25 | 26 | def encrypt(self, raw): 27 | """ 28 | Encryptes the parameter raw. 29 | 30 | :type raw: bytes 31 | :rtype: str 32 | 33 | :param: bytes to be encrypted. 34 | 35 | :return: A base 64 encoded string. 36 | """ 37 | raw = self._pad(raw) 38 | iv = Random.new().read(AES.block_size) 39 | cipher = AES.new(self.key, AES.MODE_CBC, iv) 40 | return base64.urlsafe_b64encode(iv + cipher.encrypt(raw)) 41 | 42 | def decrypt(self, enc): 43 | """ 44 | Decryptes the parameter enc. 45 | 46 | :type enc: bytes 47 | :rtype: bytes 48 | 49 | :param: The value to be decrypted. 50 | :return: The decrypted value. 51 | """ 52 | enc = base64.urlsafe_b64decode(enc) 53 | iv = enc[:AES.block_size] 54 | cipher = AES.new(self.key, AES.MODE_CBC, iv) 55 | return self._unpad(cipher.decrypt(enc[AES.block_size:])) 56 | 57 | def _pad(self, b): 58 | """ 59 | Will padd the param to be of the correct length for the encryption alg. 60 | 61 | :type b: bytes 62 | :rtype: bytes 63 | """ 64 | return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8") 65 | 66 | @staticmethod 67 | def _unpad(b): 68 | """ 69 | Removes the padding performed by the method _pad. 70 | 71 | :type b: bytes 72 | :rtype: bytes 73 | """ 74 | return b[:-ord(b[len(b) - 1:])] 75 | -------------------------------------------------------------------------------- /src/pyop/exceptions.py: -------------------------------------------------------------------------------- 1 | from oic.oic.message import AuthorizationErrorResponse, ClientRegistrationErrorResponse 2 | 3 | from .util import should_fragment_encode 4 | 5 | 6 | class OAuthError(ValueError): 7 | def __init__(self, message, oauth_error): 8 | super().__init__(message) 9 | self.oauth_error = oauth_error 10 | 11 | 12 | class InvalidAuthorizationCode(OAuthError): 13 | def __init__(self, message): 14 | super().__init__(message, 'invalid_grant') 15 | 16 | 17 | class InvalidRefreshToken(OAuthError): 18 | def __init__(self, message): 19 | super().__init__(message, 'invalid_grant') 20 | 21 | 22 | class InvalidAccessToken(OAuthError): 23 | def __init__(self, message): 24 | super().__init__(message, 'invalid_token') 25 | 26 | 27 | class InvalidScope(OAuthError): 28 | def __init__(self, message): 29 | super().__init__(message, 'invalid_scope') 30 | 31 | 32 | class InvalidClientAuthentication(OAuthError): 33 | def __init__(self, message): 34 | super().__init__(message, 'invalid_client') 35 | 36 | 37 | class InvalidSubjectIdentifier(ValueError): 38 | pass 39 | 40 | 41 | class InvalidRequestError(OAuthError): 42 | def __init__(self, message, parsed_request, oauth_error): 43 | super().__init__(message, oauth_error) 44 | self.request = parsed_request 45 | 46 | 47 | class InvalidAuthenticationRequest(InvalidRequestError): 48 | def __init__(self, message, parsed_request, oauth_error=None): 49 | super().__init__(message, parsed_request, oauth_error) 50 | 51 | def to_error_url(self): 52 | redirect_uri = self.request.get('redirect_uri') 53 | response_type = self.request.get('response_type') 54 | if redirect_uri and response_type and self.oauth_error: 55 | error_resp = AuthorizationErrorResponse(error=self.oauth_error, error_message=str(self), 56 | state=self.request.get('state')) 57 | return error_resp.request(redirect_uri, should_fragment_encode(self.request)) 58 | 59 | return None 60 | 61 | 62 | class InvalidRedirectURI(InvalidAuthenticationRequest): 63 | def to_error_url(self): 64 | return None 65 | 66 | 67 | class InvalidTokenRequest(InvalidRequestError): 68 | def __init__(self, message, parsed_request, oauth_error='invalid_request'): 69 | super().__init__(message, parsed_request, oauth_error) 70 | 71 | 72 | class InvalidClientRegistrationRequest(InvalidRequestError): 73 | def __init__(self, message, parsed_request, oauth_error='invalid_request'): 74 | super().__init__(message, parsed_request, oauth_error) 75 | 76 | def to_json(self): 77 | error = ClientRegistrationErrorResponse(error=self.oauth_error, error_description=str(self)) 78 | return error.to_json() 79 | 80 | 81 | class BearerTokenError(ValueError): 82 | pass 83 | 84 | 85 | class AuthorizationError(Exception): 86 | pass 87 | -------------------------------------------------------------------------------- /src/pyop/message.py: -------------------------------------------------------------------------------- 1 | from oic.oauth2.message import SINGLE_OPTIONAL_STRING 2 | from oic.oic import message 3 | 4 | class AccessTokenRequest(message.AccessTokenRequest): 5 | c_param = message.AccessTokenRequest.c_param.copy() 6 | c_param.update( 7 | { 8 | 'code_verifier': SINGLE_OPTIONAL_STRING 9 | } 10 | ) 11 | 12 | class AuthorizationRequest(message.AuthorizationRequest): 13 | c_param = message.AuthorizationRequest.c_param.copy() 14 | c_param.update( 15 | { 16 | 'code_challenge': SINGLE_OPTIONAL_STRING, 17 | 'code_challenge_method': SINGLE_OPTIONAL_STRING 18 | } 19 | ) 20 | 21 | c_allowed_values = message.AuthorizationRequest.c_allowed_values.copy() 22 | c_allowed_values.update( 23 | { 24 | "code_challenge_method": [ 25 | "plain", 26 | "S256" 27 | ] 28 | } 29 | ) 30 | -------------------------------------------------------------------------------- /src/pyop/provider.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import logging 4 | import time 5 | import uuid 6 | from urllib.parse import parse_qsl 7 | from urllib.parse import urlparse 8 | 9 | from jwkest import jws 10 | from oic import rndstr 11 | from oic.exception import MessageException 12 | from oic.oic import PREFERENCE2PROVIDER 13 | from oic.oic import scope2claims 14 | from oic.oic.message import AccessTokenResponse 15 | from oic.oic.message import AuthorizationResponse 16 | from oic.oic.message import EndSessionRequest 17 | from oic.oic.message import EndSessionResponse 18 | from oic.oic.message import IdToken 19 | from oic.oic.message import OpenIDSchema 20 | from oic.oic.message import ProviderConfigurationResponse 21 | from oic.oic.message import RefreshAccessTokenRequest 22 | from oic.oic.message import RegistrationRequest 23 | from oic.oic.message import RegistrationResponse 24 | from oic.extension.provider import Provider as OICProviderExtensions 25 | 26 | from .message import AuthorizationRequest 27 | from .message import AccessTokenRequest 28 | from .access_token import extract_bearer_token_from_http_request 29 | from .client_authentication import verify_client_authentication 30 | from .exceptions import AuthorizationError 31 | from .exceptions import InvalidAccessToken 32 | from .exceptions import InvalidTokenRequest 33 | from .exceptions import InvalidAuthorizationCode 34 | from .exceptions import OAuthError 35 | from .request_validator import authorization_request_verify 36 | from .request_validator import client_id_is_known 37 | from .request_validator import client_preferences_match_provider_capabilities 38 | from .request_validator import redirect_uri_is_in_registered_redirect_uris 39 | from .request_validator import registration_request_verify 40 | from .request_validator import requested_scope_is_supported 41 | from .request_validator import response_type_is_in_registered_response_types 42 | from .request_validator import userinfo_claims_only_specified_when_access_token_is_issued 43 | from .util import find_common_values 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | class Provider(object): 49 | def __init__(self, signing_key, configuration_information, authz_state, clients, userinfo, *, 50 | id_token_lifetime=3600, extra_scopes=None): 51 | # type: (jwkest.jwk.Key, Dict[str, Union[str, Sequence[str]]], se_leg_op.authz_state.AuthorizationState, 52 | # Mapping[str, Mapping[str, Any]], se_leg_op.userinfo.Userinfo, int) -> None 53 | """ 54 | Creates a new provider instance. 55 | :param configuration_information: see 56 | 57 | "OpenID Connect Discovery 1.0", Section 3 58 | :param clients: see 59 | "OpenID Connect Dynamic Client Registration 1.0", Section 2 60 | :param userinfo: read-only interface for user info 61 | :param id_token_lifetime: how long the signed ID Tokens should be valid (in seconds), defaults to 1 hour 62 | """ 63 | self.signing_key = signing_key 64 | self.configuration_information = ProviderConfigurationResponse(**configuration_information) 65 | if 'subject_types_supported' not in configuration_information: 66 | self.configuration_information['subject_types_supported'] = ['pairwise'] 67 | if 'id_token_signing_alg_values_supported' not in configuration_information: 68 | self.configuration_information['id_token_signing_alg_values_supported'] = ['RS256'] 69 | if 'scopes_supported' not in configuration_information: 70 | self.configuration_information['scopes_supported'] = ['openid'] 71 | if 'response_types_supported' not in configuration_information: 72 | self.configuration_information['response_types_supported'] = ['code', 'id_token', 'token id_token'] 73 | 74 | self.extra_scopes = {} if extra_scopes is None else extra_scopes 75 | _scopes = self.configuration_information['scopes_supported'] 76 | _scopes.extend(self.extra_scopes.keys()) 77 | self.configuration_information['scopes_supported'] = list(set(_scopes)) 78 | 79 | self.configuration_information.verify() 80 | 81 | self.authz_state = authz_state 82 | self.stateless = self.authz_state and self.authz_state.stateless 83 | 84 | self.clients = clients 85 | self.userinfo = userinfo 86 | self.id_token_lifetime = id_token_lifetime 87 | 88 | self.authentication_request_validators = [] # type: List[Callable[[AuthorizationRequest], Boolean]] 89 | self.authentication_request_validators.append(authorization_request_verify) 90 | self.authentication_request_validators.append( 91 | functools.partial(client_id_is_known, self)) 92 | self.authentication_request_validators.append( 93 | functools.partial(redirect_uri_is_in_registered_redirect_uris, self)) 94 | self.authentication_request_validators.append( 95 | functools.partial(response_type_is_in_registered_response_types, self)) 96 | self.authentication_request_validators.append(userinfo_claims_only_specified_when_access_token_is_issued) 97 | self.authentication_request_validators.append(functools.partial(requested_scope_is_supported, self)) 98 | 99 | self.registration_request_validators = [] # type: List[Callable[[oic.oic.message.RegistrationRequest], Boolean]] 100 | self.registration_request_validators.append(registration_request_verify) 101 | self.registration_request_validators.append( 102 | functools.partial(client_preferences_match_provider_capabilities, self)) 103 | 104 | @property 105 | def provider_configuration(self): 106 | """ 107 | The provider configuration information. 108 | """ 109 | return copy.deepcopy(self.configuration_information) 110 | 111 | @property 112 | def jwks(self): 113 | """ 114 | All keys published by the provider as JSON Web Key Set. 115 | """ 116 | 117 | keys = [self.signing_key.serialize()] 118 | return {'keys': keys} 119 | 120 | def parse_authentication_request(self, request_body, http_headers=None): 121 | # type: (str, Optional[Mapping[str, str]]) -> AuthorizationRequest 122 | """ 123 | Parses and verifies an authentication request. 124 | 125 | :param request_body: urlencoded authentication request 126 | :param http_headers: http headers 127 | """ 128 | 129 | auth_req = AuthorizationRequest().deserialize(request_body) 130 | 131 | for validator in self.authentication_request_validators: 132 | validator(auth_req) 133 | 134 | logger.debug('parsed authentication_request: %s', auth_req) 135 | return auth_req 136 | 137 | def authorize( 138 | self, 139 | authentication_request, # type: AuthorizationRequest 140 | user_id, # type: str 141 | extra_id_token_claims=None, # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]] 142 | ): 143 | # type: (...) -> oic.oic.message.AuthorizationResponse 144 | """ 145 | Creates an Authentication Response for the specified authentication request and local identifier of the 146 | authenticated user. 147 | """ 148 | custom_sub = self.userinfo[user_id].get('sub') 149 | if custom_sub: 150 | self.authz_state.subject_identifiers[user_id] = {'public': custom_sub} 151 | sub = custom_sub 152 | else: 153 | sub = self._create_subject_identifier( 154 | user_id, 155 | authentication_request['client_id'], 156 | authentication_request['redirect_uri'], 157 | ) 158 | 159 | self._check_subject_identifier_matches_requested(authentication_request, sub) 160 | 161 | if extra_id_token_claims is None: 162 | extra_id_token_claims = {} 163 | elif callable(extra_id_token_claims): 164 | extra_id_token_claims = extra_id_token_claims( 165 | user_id, authentication_request['client_id'] 166 | ) 167 | 168 | response = AuthorizationResponse() 169 | 170 | authz_code = None 171 | if 'code' in authentication_request['response_type']: 172 | authz_code = self.authz_state.create_authorization_code( 173 | authentication_request, 174 | sub, 175 | user_info=self.userinfo[user_id], 176 | extra_id_token_claims=extra_id_token_claims, 177 | ) 178 | response['code'] = authz_code 179 | 180 | access_token_value = None 181 | if 'token' in authentication_request['response_type']: 182 | access_token = self.authz_state.create_access_token( 183 | authentication_request, sub, user_info=self.userinfo[user_id] 184 | ) 185 | access_token_value = access_token.value 186 | self._add_access_token_to_response(response, access_token) 187 | 188 | if 'id_token' in authentication_request['response_type']: 189 | requested_claims = self._get_requested_claims_in(authentication_request, 'id_token') 190 | if len(authentication_request['response_type']) == 1: 191 | # only id token is issued -> no way of doing userinfo request, so include all claims in ID Token, 192 | # even those requested by the scope parameter 193 | requested_claims.update( 194 | scope2claims( 195 | authentication_request['scope'], extra_scope_dict=self.extra_scopes 196 | ) 197 | ) 198 | 199 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims) 200 | response['id_token'] = self._create_signed_id_token( 201 | authentication_request['client_id'], 202 | sub, 203 | user_claims, 204 | authentication_request.get('nonce'), 205 | authz_code, 206 | access_token_value, 207 | extra_id_token_claims, 208 | ) 209 | logger.debug( 210 | 'issued id_token=%s from requested_claims=%s userinfo=%s extra_claims=%s', 211 | response['id_token'], 212 | requested_claims, 213 | user_claims, 214 | extra_id_token_claims, 215 | ) 216 | 217 | if 'state' in authentication_request: 218 | response['state'] = authentication_request['state'] 219 | return response 220 | 221 | def _add_access_token_to_response(self, response, access_token): 222 | # type: (oic.message.AccessTokenResponse, se_leg_op.access_token.AccessToken) -> None 223 | """ 224 | Adds the Access Token and the associated parameters to the Token Response. 225 | """ 226 | response['access_token'] = access_token.value 227 | response['token_type'] = access_token.type 228 | response['expires_in'] = access_token.expires_in 229 | 230 | def _create_subject_identifier(self, user_id, client_id, redirect_uri): 231 | # type (str, str, str) -> str 232 | """ 233 | Creates a subject identifier for the specified client and user 234 | see 235 | "OpenID Connect Core 1.0", Section 1.2. 236 | :param user_id: local user identifier 237 | :param client_id: which client to generate a subject identifier for 238 | :param redirect_uri: the clients' redirect_uri 239 | :return: a subject identifier for the user intended for client who made the authentication request 240 | """ 241 | supported_subject_types = self.configuration_information['subject_types_supported'][0] 242 | subject_type = self.clients[client_id].get('subject_type', supported_subject_types) 243 | sector_identifier_uri = self.clients[client_id].get('sector_identifier_uri', redirect_uri) 244 | sector_identifier = urlparse(sector_identifier_uri).netloc 245 | return self.authz_state.get_subject_identifier(subject_type, user_id, sector_identifier) 246 | 247 | def _get_requested_claims_in(self, authentication_request, response_method): 248 | # type (AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]] 249 | """ 250 | Parses any claims requested using the 'claims' request parameter, see 251 | 252 | "OpenID Connect Core 1.0", Section 5.5. 253 | :param authentication_request: the authentication request 254 | :param response_method: 'id_token' or 'userinfo' 255 | """ 256 | if response_method != 'id_token' and response_method != 'userinfo': 257 | raise ValueError('response_method must be \'id_token\' or \'userinfo\'') 258 | 259 | requested_claims = {} 260 | 261 | if 'claims' in authentication_request and response_method in authentication_request['claims']: 262 | requested_claims.update(authentication_request['claims'][response_method]) 263 | return requested_claims 264 | 265 | def _create_signed_id_token(self, 266 | client_id, # type: str 267 | sub, # type: str 268 | user_claims=None, # type: Optional[Mapping[str, Union[str, List[str]]]] 269 | nonce=None, # type: Optional[str] 270 | authorization_code=None, # type: Optional[str] 271 | access_token_value=None, # type: Optional[str] 272 | extra_id_token_claims=None): # type: Optional[Mappings[str, Union[str, List[str]]]] 273 | # type: (...) -> str 274 | """ 275 | Creates a signed ID Token. 276 | :param client_id: who the ID Token is intended for 277 | :param sub: who the ID Token is regarding 278 | :param user_claims: any claims about the user to be included 279 | :param nonce: nonce from the authentication request 280 | :param authorization_code: the authorization code issued together with this ID Token 281 | :param access_token_value: the access token issued together with this ID Token 282 | :param extra_id_token_claims: any extra claims that should be included in the ID Token 283 | :return: a JWS, containing the ID Token as payload 284 | """ 285 | 286 | alg = self.clients[client_id].get('id_token_signed_response_alg', 287 | self.configuration_information['id_token_signing_alg_values_supported'][0]) 288 | args = {} 289 | 290 | hash_alg = 'HS{}'.format(alg[-3:]) 291 | if authorization_code: 292 | args['c_hash'] = jws.left_hash(authorization_code.encode('utf-8'), hash_alg) 293 | if access_token_value: 294 | args['at_hash'] = jws.left_hash(access_token_value.encode('utf-8'), hash_alg) 295 | 296 | if user_claims: 297 | args.update(user_claims) 298 | 299 | if extra_id_token_claims: 300 | args.update(extra_id_token_claims) 301 | 302 | id_token = IdToken(iss=self.configuration_information['issuer'], 303 | sub=sub, 304 | aud=client_id, 305 | iat=int(time.time()), 306 | exp=int(time.time()) + self.id_token_lifetime, 307 | **args) 308 | 309 | if nonce: 310 | id_token['nonce'] = nonce 311 | 312 | logger.debug('signed id_token with kid=%s using alg=%s', self.signing_key, alg) 313 | return id_token.to_jwt([self.signing_key], alg) 314 | 315 | def _check_subject_identifier_matches_requested(self, authentication_request, sub): 316 | # type (AuthorizationRequest, str) -> None 317 | """ 318 | Verifies the subject identifier against any requested subject identifier using the claims request parameter. 319 | :param authentication_request: authentication request 320 | :param sub: subject identifier 321 | :raise AuthorizationError: if the subject identifier does not match the requested one 322 | """ 323 | if 'claims' in authentication_request: 324 | requested_id_token_sub = authentication_request['claims'].get('id_token', {}).get('sub') 325 | requested_userinfo_sub = authentication_request['claims'].get('userinfo', {}).get('sub') 326 | if requested_id_token_sub and requested_userinfo_sub and requested_id_token_sub != requested_userinfo_sub: 327 | raise AuthorizationError('Requested different subject identifier for IDToken and userinfo: {} != {}' 328 | .format(requested_id_token_sub, requested_userinfo_sub)) 329 | 330 | requested_sub = requested_id_token_sub or requested_userinfo_sub 331 | if requested_sub and sub != requested_sub: 332 | raise AuthorizationError('Requested subject identifier \'{}\' could not be matched' 333 | .format(requested_sub)) 334 | 335 | def handle_token_request(self, request_body, # type: str 336 | http_headers=None, # type: Optional[Mapping[str, str]] 337 | extra_id_token_claims=None 338 | # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]] 339 | ): 340 | # type: (...) -> oic.oic.message.AccessTokenResponse 341 | """ 342 | Handles a token request, either for exchanging an authorization code or using a refresh token. 343 | :param request_body: urlencoded token request 344 | :param http_headers: http headers 345 | :param extra_id_token_claims: extra claims to include in the signed ID Token 346 | """ 347 | 348 | token_request = self._verify_client_authentication(request_body, http_headers) 349 | 350 | if 'grant_type' not in token_request: 351 | raise InvalidTokenRequest('grant_type missing', token_request) 352 | elif token_request['grant_type'] == 'authorization_code': 353 | return self._do_code_exchange(token_request, extra_id_token_claims) 354 | elif token_request['grant_type'] == 'refresh_token': 355 | return self._do_token_refresh(token_request) 356 | 357 | raise InvalidTokenRequest('grant_type \'{}\' unknown'.format(token_request['grant_type']), token_request, 358 | oauth_error='unsupported_grant_type') 359 | 360 | def _PKCE_verify(self, 361 | token_request, # type: AccessTokenRequest 362 | authentication_request # type: AuthorizationRequest 363 | ): 364 | # type: (...) -> bool 365 | """ 366 | Verify that the given code_verifier complies with the initially supplied code_challenge. 367 | 368 | Only supports the SHA256 code challenge method, plaintext is regarded as unsafe. 369 | 370 | :param token_request: the token request containing the initially supplied code challenge and code_challenge method. 371 | :param authentication_request: the code_verfier to check against the code challenge. 372 | :returns: whether the code_verifier is what was expected given the cc_cm 373 | """ 374 | if not 'code_verifier' in token_request: 375 | return False 376 | 377 | if not 'code_challenge_method' in authentication_request: 378 | raise InvalidTokenRequest("A code_challenge and code_verifier have been supplied" 379 | "but missing code_challenge_method in authentication_request", token_request) 380 | 381 | # OIC Provider extension returns either a boolean or Response object containing an error. To support 382 | # stricter typing guidelines, return if True. Error handling support should be in encapsulating function. 383 | return OICProviderExtensions.verify_code_challenge(token_request['code_verifier'], 384 | authentication_request['code_challenge'], authentication_request['code_challenge_method']) == True 385 | 386 | def _verify_code_exchange_req(self, 387 | token_request, # type: AccessTokenRequest 388 | authentication_request # type: AuthorizationRequest 389 | ): 390 | # type: (...) -> None 391 | """ 392 | Verify that the code exchange request is valid. In order to be valid we validate 393 | the expected client and redirect_uri. Finally, if requested by the client, perform a 394 | PKCE check. 395 | 396 | :param token_request: The request asking for a token given a code, and optionally a code_verifier 397 | :param authentication_request: The authentication request belonging to the provided code. 398 | :raises InvalidTokenRequest, InvalidAuthorizationCode: If request is invalid, throw a representing exception. 399 | """ 400 | if token_request['client_id'] != authentication_request['client_id']: 401 | logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'', 402 | token_request['code'], authentication_request['client_id'], token_request['client_id']) 403 | raise InvalidAuthorizationCode('{} unknown'.format(token_request['code'])) 404 | if token_request['redirect_uri'] != authentication_request['redirect_uri']: 405 | raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'], 406 | authentication_request['redirect_uri']), 407 | token_request) 408 | if 'code_challenge' in authentication_request and not self._PKCE_verify(token_request, authentication_request): 409 | raise InvalidTokenRequest('Unexpected Code Verifier: {}'.format(authentication_request['code_challenge']), 410 | token_request) 411 | 412 | def _do_code_exchange(self, request, # type: Dict[str, str] 413 | extra_id_token_claims=None 414 | # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]] 415 | ): 416 | # type: (...) -> oic.message.AccessTokenResponse 417 | """ 418 | Handles a token request for exchanging an authorization code for an access token 419 | (grant_type=authorization_code). 420 | :param request: parsed http request parameters 421 | :param extra_id_token_claims: any extra parameters to include in the signed ID Token, either as a dict-like 422 | object or as a callable object accepting the local user identifier and client identifier which returns 423 | any extra claims which might depend on the user id and/or client id. 424 | :return: a token response containing a signed ID Token, an Access Token, and a Refresh Token 425 | :raise InvalidTokenRequest: if the token request is invalid 426 | """ 427 | token_request = AccessTokenRequest().from_dict(request) 428 | try: 429 | token_request.verify() 430 | except MessageException as e: 431 | raise InvalidTokenRequest(str(e), token_request) from e 432 | 433 | authentication_request = self.authz_state.get_authorization_request_for_code(token_request['code']) 434 | 435 | self._verify_code_exchange_req(token_request, authentication_request) 436 | 437 | sub = self.authz_state.get_subject_identifier_for_code(token_request['code']) 438 | if not self.stateless: 439 | user_id = self.authz_state.get_user_id_for_subject_identifier(sub) 440 | 441 | response = AccessTokenResponse() 442 | 443 | access_token = self.authz_state.exchange_code_for_token(token_request['code']) 444 | self._add_access_token_to_response(response, access_token) 445 | refresh_token = self.authz_state.create_refresh_token(access_token.value) 446 | if refresh_token is not None: 447 | response['refresh_token'] = refresh_token 448 | 449 | extra_id_token_claims = {} 450 | if self.stateless: 451 | extra_id_token_claims_in_code = self.authz_state.get_extra_id_token_claims_for_code(token_request['code']) 452 | extra_id_token_claims.update(extra_id_token_claims_in_code) 453 | elif callable(extra_id_token_claims): 454 | extra_id_token_claims = extra_id_token_claims(user_id, authentication_request['client_id']) 455 | 456 | requested_claims = self._get_requested_claims_in(authentication_request, 'id_token') 457 | if self.stateless: 458 | user_info = self.authz_state.get_user_info_for_code(token_request['code']) 459 | user_claims = self.userinfo.get_claims_for(None, requested_claims, user_info) 460 | else: 461 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims) 462 | response['id_token'] = self._create_signed_id_token(authentication_request['client_id'], sub, 463 | user_claims, 464 | authentication_request.get('nonce'), 465 | None, access_token.value, 466 | extra_id_token_claims) 467 | logger.debug('issued id_token=%s from requested_claims=%s userinfo=%s extra_claims=%s', 468 | response['id_token'], requested_claims, user_claims, extra_id_token_claims) 469 | 470 | return response 471 | 472 | def _do_token_refresh(self, request): 473 | # type: (Mapping[str, str]) -> oic.oic.message.AccessTokenResponse 474 | """ 475 | Handles a token request for refreshing an access token (grant_type=refresh_token). 476 | :param request: parsed http request parameters 477 | :return: a token response containing a new Access Token and possibly a new Refresh Token 478 | :raise InvalidTokenRequest: if the token request is invalid 479 | """ 480 | token_request = RefreshAccessTokenRequest().from_dict(request) 481 | try: 482 | token_request.verify() 483 | except MessageException as e: 484 | raise InvalidTokenRequest(str(e), token_request) from e 485 | 486 | response = AccessTokenResponse() 487 | 488 | access_token, refresh_token = self.authz_state.use_refresh_token(token_request['refresh_token'], 489 | scope=token_request.get('scope')) 490 | self._add_access_token_to_response(response, access_token) 491 | if refresh_token: 492 | response['refresh_token'] = refresh_token 493 | 494 | return response 495 | 496 | def _verify_client_authentication(self, request_body, http_headers=None): 497 | # type (str, Optional[Mapping[str, str]] -> Mapping[str, str] 498 | """ 499 | Verifies the client authentication. 500 | :param request_body: urlencoded token request 501 | :param http_headers: 502 | :return: The parsed request body. 503 | """ 504 | if http_headers is None: 505 | http_headers = {} 506 | token_request = dict(parse_qsl(request_body)) 507 | token_request['client_id'] = verify_client_authentication(self.clients, token_request, http_headers.get('Authorization')) 508 | return token_request 509 | 510 | def handle_userinfo_request(self, request=None, http_headers=None): 511 | # type: (Optional[str], Optional[Mapping[str, str]]) -> oic.oic.message.OpenIDSchema 512 | """ 513 | Handles a userinfo request. 514 | :param request: urlencoded request (either query string or POST body) 515 | :param http_headers: http headers 516 | """ 517 | if http_headers is None: 518 | http_headers = {} 519 | userinfo_request = dict(parse_qsl(request)) 520 | bearer_token = extract_bearer_token_from_http_request(userinfo_request, http_headers.get('Authorization')) 521 | 522 | introspection = self.authz_state.introspect_access_token(bearer_token) 523 | if not introspection['active']: 524 | raise InvalidAccessToken('The access token has expired') 525 | scopes = introspection['scope'].split() 526 | 527 | if not self.stateless: 528 | user_id = self.authz_state.get_user_id_for_subject_identifier(introspection['sub']) 529 | 530 | requested_claims = scope2claims(scopes, extra_scope_dict=self.extra_scopes) 531 | authentication_request = self.authz_state.get_authorization_request_for_access_token(bearer_token) 532 | requested_claims.update(self._get_requested_claims_in(authentication_request, 'userinfo')) 533 | 534 | if self.stateless: 535 | user_info = self.authz_state.get_user_info_for_access_token(bearer_token) 536 | user_claims = self.userinfo.get_claims_for(None, requested_claims, user_info) 537 | else: 538 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims) 539 | 540 | user_claims.setdefault('sub', introspection['sub']) 541 | response = OpenIDSchema(**user_claims) 542 | logger.debug('userinfo=%s from requested_claims=%s userinfo=%s', 543 | response, requested_claims, user_claims) 544 | return response 545 | 546 | def _issue_new_client(self): 547 | # create unique client id 548 | client_id = rndstr(12) 549 | while client_id in self.clients: 550 | client_id = rndstr(12) 551 | # create random secret 552 | client_secret = uuid.uuid4().hex 553 | 554 | return client_id, client_secret 555 | 556 | def match_client_preferences_with_provider_capabilities(self, client_preferences): 557 | # type: (oic.message.RegistrationRequest) -> Mapping[str, Union[str, List[str]]] 558 | """ 559 | Match as many as of the client preferences as possible. 560 | :param client_preferences: requested preferences from client registration request 561 | :return: the matched preferences selected by the provider 562 | """ 563 | matched_prefs = client_preferences.to_dict() 564 | for pref in ['response_types', 'default_acr_values']: 565 | if pref not in client_preferences: 566 | continue 567 | 568 | capability = PREFERENCE2PROVIDER[pref] 569 | # only preserve the common values 570 | matched_values = find_common_values(client_preferences[pref], self.configuration_information[capability]) 571 | # deal with space separated values 572 | matched_prefs[pref] = [' '.join(v) for v in matched_values] 573 | 574 | return matched_prefs 575 | 576 | def handle_client_registration_request(self, request, http_headers=None): 577 | # type: (Optional[str], Optional[Mapping[str, str]]) -> oic.oic.message.RegistrationResponse 578 | """ 579 | Handles a client registration request. 580 | :param request: JSON request from POST body 581 | :param http_headers: http headers 582 | """ 583 | registration_req = RegistrationRequest().deserialize(request, 'json') 584 | 585 | for validator in self.registration_request_validators: 586 | validator(registration_req) 587 | logger.debug('parsed authentication_request: %s', registration_req) 588 | 589 | client_id, client_secret = self._issue_new_client() 590 | credentials = { 591 | 'client_id': client_id, 592 | 'client_id_issued_at': int(time.time()), 593 | 'client_secret': client_secret, 594 | 'client_secret_expires_at': 0 # never expires 595 | } 596 | 597 | response_params = self.match_client_preferences_with_provider_capabilities(registration_req) 598 | response_params.update(credentials) 599 | self.clients[client_id] = copy.deepcopy(response_params) 600 | 601 | registration_resp = RegistrationResponse(**response_params) 602 | logger.debug('registration_resp=%s from registration_req=%s', registration_resp, registration_req) 603 | return registration_resp 604 | 605 | def logout_user(self, subject_identifier=None, end_session_request=None): 606 | # type: (Optional[str], Optional[oic.oic.message.EndSessionRequest]) -> None 607 | if self.stateless: 608 | raise OAuthError("Logout is not supported with stateless storage provider", "invalid_request") 609 | if not end_session_request: 610 | end_session_request = EndSessionRequest() 611 | if 'id_token_hint' in end_session_request: 612 | id_token = IdToken().from_jwt(end_session_request['id_token_hint'], key=[self.signing_key]) 613 | subject_identifier = id_token['sub'] 614 | 615 | self.authz_state.delete_state_for_subject_identifier(subject_identifier) 616 | 617 | def do_post_logout_redirect(self, end_session_request): 618 | # type: (oic.oic.message.EndSessionRequest) -> oic.oic.message.EndSessionResponse 619 | if 'post_logout_redirect_uri' not in end_session_request: 620 | return None 621 | 622 | client_id = None 623 | if 'id_token_hint' in end_session_request: 624 | id_token = IdToken().from_jwt(end_session_request['id_token_hint'], key=[self.signing_key]) 625 | client_id = id_token['aud'][0] 626 | 627 | if 'post_logout_redirect_uri' in end_session_request: 628 | if not client_id: 629 | return None 630 | if not end_session_request['post_logout_redirect_uri'] in self.clients[client_id].get( 631 | 'post_logout_redirect_uris', []): 632 | return None 633 | 634 | end_session_response = EndSessionResponse() 635 | if 'state' in end_session_request: 636 | end_session_response['state'] = end_session_request['state'] 637 | 638 | return end_session_response.request(end_session_request['post_logout_redirect_uri']) 639 | -------------------------------------------------------------------------------- /src/pyop/request_validator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from oic.exception import MessageException 4 | from oic.oic import PREFERENCE2PROVIDER 5 | 6 | from .exceptions import InvalidClientRegistrationRequest 7 | from .exceptions import InvalidAuthenticationRequest 8 | from .exceptions import InvalidRedirectURI 9 | from .util import is_allowed_response_type, find_common_values 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def authorization_request_verify(authentication_request): 15 | """ 16 | Verifies that all required parameters and correct values are included in the authentication request. 17 | :param authentication_request: the authentication request to verify 18 | :raise InvalidAuthenticationRequest: if the authentication is incorrect 19 | """ 20 | try: 21 | authentication_request.verify() 22 | except MessageException as e: 23 | raise InvalidAuthenticationRequest(str(e), authentication_request, oauth_error='invalid_request') from e 24 | 25 | 26 | def client_id_is_known(provider, authentication_request): 27 | """ 28 | Verifies the client identifier is known. 29 | :param provider: provider instance 30 | :param authentication_request: the authentication request to verify 31 | :raise InvalidAuthenticationRequest: if the client_id is unknown 32 | """ 33 | if authentication_request['client_id'] not in provider.clients: 34 | logger.error('Unknown client_id \'{}\''.format(authentication_request['client_id'])) 35 | raise InvalidAuthenticationRequest('Unknown client_id', 36 | authentication_request, 37 | oauth_error='unauthorized_client') 38 | 39 | def redirect_uri_is_in_registered_redirect_uris(provider, authentication_request): 40 | """ 41 | Verifies the redirect uri is registered for the client making the request. 42 | :param provider: provider instance 43 | :param authentication_request: authentication request to verify 44 | :raise InvalidAuthenticationRequest: if the redirect uri is not registered 45 | """ 46 | try: 47 | allowed_redirect_uris = provider.clients[authentication_request['client_id']]['redirect_uris'] 48 | except KeyError as e: 49 | logger.error('client metadata is missing redirect_uris') 50 | raise InvalidRedirectURI( 51 | 'No redirect uri registered for this client', 52 | authentication_request, 53 | oauth_error="invalid_request", 54 | ) 55 | 56 | if authentication_request['redirect_uri'] not in allowed_redirect_uris: 57 | logger.error("Redirect uri \'{0}\' is not registered for this client".format(authentication_request['redirect_uri'])) 58 | raise InvalidRedirectURI( 59 | 'Redirect uri is not registered for this client', 60 | authentication_request, 61 | oauth_error="invalid_request", 62 | ) 63 | 64 | 65 | def response_type_is_in_registered_response_types(provider, authentication_request): 66 | """ 67 | Verifies that the requested response type is allowed for the client making the request. 68 | :param provider: provider instance 69 | :param authentication_request: authentication request to verify 70 | :raise InvalidAuthenticationRequest: if the response type is not allowed 71 | """ 72 | error = InvalidAuthenticationRequest('Response type is not registered', 73 | authentication_request, 74 | oauth_error='invalid_request') 75 | try: 76 | allowed_response_types = provider.clients[authentication_request['client_id']]['response_types'] 77 | except KeyError as e: 78 | logger.error('client metadata is missing response_types') 79 | raise error 80 | 81 | if not is_allowed_response_type(authentication_request['response_type'], allowed_response_types): 82 | logger.error('Response type \'{}\' is not registered'.format(' '.join(authentication_request['response_type']))) 83 | raise error 84 | 85 | 86 | def userinfo_claims_only_specified_when_access_token_is_issued(authentication_request): 87 | """ 88 | According to 89 | "OpenID Connect Core 1.0", Section 5.5: "When the userinfo member is used, the request MUST 90 | also use a response_type value that results in an Access Token being issued to the Client for 91 | use at the UserInfo Endpoint." 92 | :param authentication_request: the authentication request to verify 93 | :raise InvalidAuthenticationRequest: if the requested claims can not be returned according to the request 94 | """ 95 | will_issue_access_token = authentication_request['response_type'] != ['id_token'] 96 | contains_userinfo_claims_request = 'claims' in authentication_request and 'userinfo' in authentication_request[ 97 | 'claims'] 98 | if not will_issue_access_token and contains_userinfo_claims_request: 99 | raise InvalidAuthenticationRequest('Userinfo claims cannot be requested, when response_type=\'id_token\'', 100 | authentication_request, 101 | oauth_error='invalid_request') 102 | 103 | 104 | def requested_scope_is_supported(provider, authentication_request): 105 | requested_scopes = set(authentication_request['scope']) 106 | supported_scopes = set(provider.provider_configuration['scopes_supported']) 107 | requested_unsupported_scopes = requested_scopes - supported_scopes 108 | if requested_unsupported_scopes: 109 | logger.warning('Request contains unsupported/unknown scopes: {}' 110 | .format(', '.join(requested_unsupported_scopes))) 111 | 112 | 113 | def registration_request_verify(registration_request): 114 | """ 115 | Verifies that all required parameters and correct values are included in the client registration request. 116 | :param registration_request: the authentication request to verify 117 | :raise InvalidClientRegistrationRequest: if the registration is incorrect 118 | """ 119 | try: 120 | registration_request.verify() 121 | except MessageException as e: 122 | raise InvalidClientRegistrationRequest(str(e), registration_request, oauth_error='invalid_request') from e 123 | 124 | 125 | def client_preferences_match_provider_capabilities(provider, registration_request): 126 | """ 127 | Verifies that all requested preferences in the client metadata can be fulfilled by this provider. 128 | :param registration_request: the authentication request to verify 129 | :raise InvalidClientRegistrationRequest: if the registration is incorrect 130 | """ 131 | 132 | def match(client_preference, provider_capability): 133 | if isinstance(client_preference, list): 134 | # deal with comparing space separated values, e.g. 'response_types', without considering the order 135 | # at least one requested preference must be matched 136 | return len(find_common_values(client_preference, provider_capability)) > 0 137 | 138 | return client_preference in provider_capability 139 | 140 | for client_preference in registration_request.keys(): 141 | if client_preference not in PREFERENCE2PROVIDER: 142 | # metadata parameter that shouldn't be matched 143 | continue 144 | 145 | provider_capability = PREFERENCE2PROVIDER[client_preference] 146 | if not match(registration_request[client_preference], provider.configuration_information[provider_capability]): 147 | raise InvalidClientRegistrationRequest( 148 | 'Could not match client preference {}={} with provider capability {}={}'.format( 149 | client_preference, registration_request[client_preference], provider_capability, 150 | provider.configuration_information[provider_capability]), 151 | registration_request, 152 | oauth_error='invalid_request') 153 | -------------------------------------------------------------------------------- /src/pyop/storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from abc import ABC, abstractmethod 4 | import copy 5 | import json 6 | import logging 7 | from datetime import datetime 8 | from urllib.parse import urlparse 9 | from urllib.parse import parse_qs 10 | 11 | from .crypto import _AESCipher 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | try: 18 | import pymongo 19 | except ImportError: 20 | _has_pymongo = False 21 | else: 22 | _has_pymongo = True 23 | 24 | try: 25 | from redis.client import Redis 26 | except ImportError: 27 | _has_redis = False 28 | else: 29 | _has_redis = True 30 | 31 | 32 | class StorageBase(ABC): 33 | _ttl = None 34 | 35 | @abstractmethod 36 | def __setitem__(self, key, value): 37 | pass 38 | 39 | @abstractmethod 40 | def pack(self, value): 41 | pass 42 | 43 | @abstractmethod 44 | def __getitem__(self, key): 45 | pass 46 | 47 | @abstractmethod 48 | def __delitem__(self, key): 49 | pass 50 | 51 | @abstractmethod 52 | def __contains__(self, key): 53 | pass 54 | 55 | @abstractmethod 56 | def items(self): 57 | pass 58 | 59 | def pop(self, key, default=None): 60 | try: 61 | data = self[key] 62 | except KeyError: 63 | return default 64 | del self[key] 65 | return data 66 | 67 | @classmethod 68 | def from_uri(cls, db_uri, collection, db_name=None, ttl=None, **kwargs): 69 | url = urlparse(db_uri) 70 | 71 | if url.scheme == "mongodb": 72 | return MongoWrapper( 73 | db_uri=db_uri, 74 | db_name=db_name, 75 | collection=collection, 76 | ttl=ttl, 77 | extra_options=kwargs, 78 | ) 79 | elif url.scheme == "redis" or url.scheme == "unix": 80 | return RedisWrapper( 81 | db_uri=db_uri, 82 | db_name=db_name, 83 | collection=collection, 84 | ttl=ttl, 85 | extra_options=kwargs, 86 | ) 87 | elif url.scheme == "stateless": 88 | alg = parse_qs(url.query).get("alg") if url.query else None 89 | alg = alg[0] if alg else None 90 | return StatelessWrapper( 91 | collection=collection, 92 | encryption_key=url.password, 93 | alg=alg 94 | ) 95 | 96 | raise ValueError(f"Invalid DB URI: {db_uri}") 97 | 98 | @classmethod 99 | def type(cls, db_uri): 100 | url = urlparse(db_uri) 101 | if url.scheme == "mongodb": 102 | return "mongodb" 103 | elif url.scheme == "redis" or url.scheme == "unix": 104 | return "redis" 105 | elif url.scheme == "stateless": 106 | return "stateless" 107 | 108 | raise ValueError(f"Invalid DB URI: {db_uri}") 109 | 110 | @property 111 | def ttl(self): 112 | return self._ttl 113 | 114 | 115 | class MongoWrapper(StorageBase): 116 | def __init__(self, db_uri, db_name, collection, ttl=None, extra_options=None): 117 | if not _has_pymongo: 118 | raise ImportError("pymongo module is required but it is not available") 119 | 120 | if not extra_options: 121 | extra_options = {} 122 | 123 | mongo_options = extra_options.pop("mongo_kwargs", None) or {} 124 | 125 | self._db_uri = db_uri 126 | self._coll_name = collection 127 | self._db = MongoDB(db_uri, db_name=db_name, **mongo_options) 128 | self._coll = self._db.get_collection(collection) 129 | self._coll.create_index('lookup_key', unique=True) 130 | 131 | if ttl is None or (isinstance(ttl, int) and ttl >= 0): 132 | self._ttl = ttl 133 | else: 134 | raise ValueError("TTL must be a non-negative integer or None") 135 | if ttl is not None: 136 | self._coll.create_index( 137 | 'last_modified', 138 | expireAfterSeconds=ttl, 139 | name="expiry" 140 | ) 141 | 142 | def __setitem__(self, key, value): 143 | doc = { 144 | 'lookup_key': key, 145 | 'data': value, 146 | 'last_modified': datetime.utcnow() 147 | } 148 | self._coll.replace_one({'lookup_key': key}, doc, upsert=True) 149 | 150 | def pack(self, value): 151 | raise NotImplementedError 152 | 153 | def __getitem__(self, key): 154 | doc = self._coll.find_one({'lookup_key': key}) 155 | if not doc: 156 | raise KeyError(key) 157 | return doc['data'] 158 | 159 | def __delitem__(self, key): 160 | self._coll.delete_one({'lookup_key': key}) 161 | 162 | def __contains__(self, key): 163 | count = self._coll.count_documents({'lookup_key': key}) 164 | return bool(count) 165 | 166 | def items(self): 167 | for doc in self._coll.find(): 168 | yield (doc['lookup_key'], doc['data']) 169 | 170 | 171 | class RedisWrapper(StorageBase): 172 | """ 173 | Simple wrapper for a dict-like storage in Redis. 174 | Supports JSON-serializable data types. 175 | """ 176 | 177 | def __init__( 178 | self, db_uri, *, db_name=None, collection, ttl=None, extra_options=None 179 | ): 180 | if not _has_redis: 181 | raise ImportError("redis module is required but it is not available") 182 | 183 | if not extra_options: 184 | extra_options = {} 185 | 186 | redis_kwargs = extra_options.pop("redis_kwargs", None) or {} 187 | redis_options = { 188 | "decode_responses": True, "db": db_name, **redis_kwargs 189 | } 190 | 191 | self._db = Redis.from_url(db_uri, **redis_options) 192 | self._collection = collection 193 | if ttl is None or (isinstance(ttl, int) and ttl >= 0): 194 | self._ttl = ttl 195 | else: 196 | raise ValueError("TTL must be a non-negative integer or None") 197 | 198 | def _make_key(self, key): 199 | if not isinstance(key, str): 200 | raise TypeError(f"Keys must be strings, {type(key).__name__} given") 201 | 202 | return ":".join([self._collection, key]) 203 | 204 | def __setitem__(self, key, value): 205 | # Replacing the value of a key resets the ttl counter 206 | encoded = json.dumps({ "value": value }) 207 | self._db.set(self._make_key(key), encoded, ex=self.ttl) 208 | 209 | def pack(self, value): 210 | raise NotImplementedError 211 | 212 | def __getitem__(self, key): 213 | encoded = self._db.get(self._make_key(key)) 214 | if encoded is None: 215 | raise KeyError(key) 216 | return json.loads(encoded).get("value") 217 | 218 | def __delitem__(self, key): 219 | # Deleting a non-existent key is allowed 220 | self._db.delete(self._make_key(key)) 221 | 222 | def __contains__(self, key): 223 | return (self._db.get(self._make_key(key)) is not None) 224 | 225 | def items(self): 226 | for key in self._db.keys(self._collection + "*"): 227 | visible_key = key[len(self._collection) + 1 :] 228 | 229 | if isinstance(visible_key, bytes): 230 | visible_key = visible_key.decode() 231 | 232 | try: 233 | yield (visible_key, self[visible_key]) 234 | except KeyError: 235 | pass 236 | 237 | 238 | class StatelessWrapper(StorageBase): 239 | def __init__(self, collection, encryption_key, alg=None): 240 | self.collection = collection 241 | if not alg or alg.lower() == "aes256": 242 | self.cipher = _AESCipher(encryption_key) 243 | else: 244 | raise ValueError(f"Invalid encryption algorithm: {alg}") 245 | 246 | def __setitem__(self, key, value): 247 | pass 248 | 249 | def pack(self, value): 250 | key = None 251 | if value: 252 | if isinstance(value, dict): 253 | value = json.dumps(value) 254 | key = self.cipher.encrypt(value.encode("UTF-8")).decode("UTF-8") 255 | return key 256 | 257 | def __getitem__(self, key): 258 | return self._unpack(key) 259 | 260 | def __delitem__(self, key): 261 | raise NotImplementedError 262 | 263 | def __contains__(self, key): 264 | if self._unpack(key): 265 | return True 266 | return False 267 | 268 | def items(self): 269 | raise NotImplementedError 270 | 271 | def _unpack(self, value): 272 | unpacked_val = None 273 | try: 274 | if value: 275 | unpacked_val = self.cipher.decrypt(value.encode("UTF-8")).decode("UTF-8") 276 | unpacked_val = json.loads(unpacked_val) 277 | except ValueError: 278 | if unpacked_val: 279 | logger.debug("Value '%s' is not a dict", value) 280 | else: 281 | logger.warning("Value '%s' is invalid for %s", value, self.collection) 282 | return unpacked_val 283 | 284 | 285 | class MongoDB(object): 286 | """Simple wrapper to get pymongo real objects from the settings uri""" 287 | 288 | def __init__(self, db_uri, db_name=None, connection_factory=None, **kwargs): 289 | if db_uri is None: 290 | raise ValueError('db_uri not supplied') 291 | 292 | self._sanitized_uri = None 293 | self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri) 294 | 295 | db_name = self._parsed_uri.get('database') or db_name 296 | if db_name is None: 297 | raise ValueError( 298 | "Database name must be provided either in the URI or as an argument" 299 | ) 300 | self._database_name = self._parsed_uri['database'] = db_name 301 | 302 | if 'replicaSet' in kwargs and kwargs['replicaSet'] is None: 303 | del kwargs['replicaSet'] 304 | 305 | self._options = self._parsed_uri.get('options') 306 | if connection_factory is None: 307 | connection_factory = pymongo.MongoClient 308 | if 'replicaSet' in kwargs: 309 | connection_factory = pymongo.MongoReplicaSetClient 310 | if 'replicaSet' in self._options and self._options['replicaSet'] is not None: 311 | connection_factory = pymongo.MongoReplicaSetClient 312 | kwargs['replicaSet'] = self._options['replicaSet'] 313 | 314 | if 'replicaSet' in kwargs: 315 | if 'socketTimeoutMS' not in kwargs: 316 | kwargs['socketTimeoutMS'] = 5000 317 | if 'connectTimeoutMS' not in kwargs: 318 | kwargs['connectTimeoutMS'] = 5000 319 | 320 | self._db_uri = _format_mongodb_uri(self._parsed_uri) 321 | 322 | try: 323 | self._connection = connection_factory( 324 | host=self._db_uri, 325 | tz_aware=True, 326 | **kwargs) 327 | except pymongo.errors.ConnectionFailure as e: 328 | raise e 329 | 330 | def __repr__(self): 331 | return '<{!s}: {!s} {!s}>'.format(self.__class__.__name__, 332 | self._db_uri, 333 | self._database_name) 334 | 335 | @property 336 | def sanitized_uri(self): 337 | """ 338 | Return the database URI we're using in a format sensible for logging etc. 339 | 340 | :return: db_uri 341 | """ 342 | if self._sanitized_uri is None: 343 | _parsed = copy.copy(self._parsed_uri) 344 | if 'username' in _parsed: 345 | _parsed['password'] = 'secret' 346 | _parsed['nodelist'] = [_parsed['nodelist'][0]] 347 | self._sanitized_uri = _format_mongodb_uri(_parsed) 348 | return self._sanitized_uri 349 | 350 | def get_connection(self): 351 | """ 352 | Get the raw pymongo connection object. 353 | :return: Pymongo connection object 354 | """ 355 | return self._connection 356 | 357 | def get_database(self, database_name=None, username=None, password=None): 358 | """ 359 | Get a pymongo database handle, after authenticating. 360 | 361 | Authenticates using the username/password in the DB URI given to 362 | __init__() unless username/password is supplied as arguments. 363 | 364 | :param database_name: (optional) Name of database 365 | :param username: (optional) Username to login with 366 | :param password: (optional) Password to login with 367 | :return: Pymongo database object 368 | """ 369 | if database_name is None: 370 | database_name = self._database_name 371 | if database_name is None: 372 | raise ValueError('No database_name supplied, and no default provided to __init__') 373 | db = self._connection[database_name] 374 | if username and password: 375 | db.authenticate(username, password) 376 | elif self._parsed_uri.get("username", None): 377 | if 'authSource' in self._options and self._options['authSource'] is not None: 378 | db.authenticate( 379 | self._parsed_uri.get("username", None), 380 | self._parsed_uri.get("password", None), 381 | source=self._options['authSource'] 382 | ) 383 | else: 384 | db.authenticate( 385 | self._parsed_uri.get("username", None), 386 | self._parsed_uri.get("password", None) 387 | ) 388 | return db 389 | 390 | def get_collection(self, collection, database_name=None, username=None, password=None): 391 | """ 392 | Get a pymongo collection handle. 393 | 394 | :param collection: Name of collection 395 | :param database_name: (optional) Name of database 396 | :param username: (optional) Username to login with 397 | :param password: (optional) Password to login with 398 | :return: Pymongo collection object 399 | """ 400 | _db = self.get_database(database_name, username, password) 401 | return _db[collection] 402 | 403 | def close(self): 404 | self._connection.close() 405 | 406 | 407 | def _format_mongodb_uri(parsed_uri): 408 | """ 409 | Painstakingly reconstruct a MongoDB URI parsed using pymongo.uri_parser.parse_uri. 410 | 411 | :param parsed_uri: Result of pymongo.uri_parser.parse_uri 412 | :type parsed_uri: dict 413 | 414 | :return: New URI 415 | :rtype: str | unicode 416 | """ 417 | user_pass = '' 418 | if parsed_uri.get('username') and parsed_uri.get('password'): 419 | user_pass = '{username!s}:{password!s}@'.format(**parsed_uri) 420 | 421 | _nodes = [] 422 | for host, port in parsed_uri.get('nodelist'): 423 | if ':' in host and not host.endswith(']'): 424 | # IPv6 address without brackets 425 | host = '[{!s}]'.format(host) 426 | if port == 27017: 427 | _nodes.append(host) 428 | else: 429 | _nodes.append('{!s}:{!s}'.format(host, port)) 430 | nodelist = ','.join(_nodes) 431 | 432 | options = '' 433 | if parsed_uri.get('options'): 434 | _opt_list = [] 435 | for key, value in parsed_uri.get('options').items(): 436 | if isinstance(value, bool): 437 | value = str(value).lower() 438 | _opt_list.append('{!s}={!s}'.format(key, value)) 439 | options = '?' + '&'.join(_opt_list) 440 | 441 | db_name = parsed_uri.get('database') or '' 442 | 443 | res = "mongodb://{user_pass!s}{nodelist!s}/{db_name!s}{options!s}".format( 444 | user_pass=user_pass, 445 | nodelist=nodelist, 446 | db_name=db_name, 447 | # collection is ignored 448 | options=options) 449 | return res 450 | -------------------------------------------------------------------------------- /src/pyop/subject_identifier.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | 4 | class SubjectIdentifierFactory(object): 5 | """ 6 | Interface for implementation of an algorithm for creating pairwise subject identifiers, see 7 | 8 | "OpenID Connect Core 1.0", Section 8.1. 9 | """ 10 | 11 | def create_public_identifier(self, user_id): 12 | # type: (str) -> str 13 | raise NotImplementedError() 14 | 15 | def create_pairwise_identifier(self, user_id, sector_identifier): 16 | # type: (str, str) -> str 17 | raise NotImplementedError() 18 | 19 | 20 | class HashBasedSubjectIdentifierFactory(object): 21 | """ 22 | Implements a hash based algorithm for creating a pairwise subject identifier. 23 | """ 24 | 25 | def __init__(self, hash_salt): 26 | # type: (str) -> None 27 | self.hash_salt = hash_salt 28 | 29 | def create_public_identifier(self, user_id): 30 | return self._hash(user_id) 31 | 32 | def create_pairwise_identifier(self, user_id, sector_identifier): 33 | return self._hash(sector_identifier + user_id) 34 | 35 | def _hash(self, data): 36 | # type: (str) -> str 37 | hash_input = data + self.hash_salt 38 | return hashlib.sha256(hash_input.encode('utf-8')).hexdigest() 39 | -------------------------------------------------------------------------------- /src/pyop/userinfo.py: -------------------------------------------------------------------------------- 1 | class Userinfo(object): 2 | """ 3 | Wrapper providing a read-only interface for a database containing user info. 4 | 5 | The backing database must use a local identifier as key, and all userinfo that should be returned in OpenID 6 | Connect ID Tokens or Userinfo Responses must follow the format of OpenID Connect standard claims, see 7 | 8 | "OpenID Connect Core 1.0", Section 5.1 9 | """ 10 | 11 | def __init__(self, db): 12 | # type: (Mapping[str, Union[str, List[str]]]) -> None 13 | self._db = db 14 | 15 | def __getitem__(self, item): 16 | return self._db[item] 17 | 18 | def __contains__(self, item): 19 | return item in self._db 20 | 21 | def get_claims_for(self, user_id, requested_claims, userinfo=None): 22 | # type: (str, Mapping[str, Optional[Mapping[str, Union[str, List[str]]]]) -> Dict[str, Union[str, List[str]]] 23 | """ 24 | Filter the userinfo based on which claims where requested. 25 | :param user_id: user identifier 26 | :param requested_claims: see 27 | "OpenID Connect Core 1.0", Section 5.5 for structure 28 | :param userinfo: if user_info is specified the claims will be filtered from the user_info directly instead 29 | first querying the storage against the user_id 30 | :return: All requested claims available from the userinfo. 31 | """ 32 | 33 | if not userinfo: 34 | userinfo = self._db[user_id] if user_id else {} 35 | claims = {claim: userinfo[claim] for claim in requested_claims if claim in userinfo} 36 | return claims 37 | -------------------------------------------------------------------------------- /src/pyop/util.py: -------------------------------------------------------------------------------- 1 | def should_fragment_encode(authentication_request): 2 | if authentication_request['response_type'] == ['code']: 3 | # Authorization Code Flow -> query encode 4 | return False 5 | 6 | return True 7 | 8 | 9 | def is_allowed_response_type(response_type, supported_response_types): 10 | return frozenset(response_type) in [frozenset(rt.split()) for rt in supported_response_types] 11 | 12 | 13 | def find_common_values(preference_values, supported_values): 14 | unordered_preference_values = {frozenset(p.split()) for p in preference_values} 15 | unordered_supported_values = {frozenset(s.split()) for s in supported_values} 16 | return unordered_supported_values.intersection(unordered_preference_values) 17 | 18 | 19 | def requested_scope_is_allowed(requested_scope, allowed_scope): 20 | return frozenset(requested_scope).issubset(frozenset(allowed_scope.split())) 21 | -------------------------------------------------------------------------------- /tests/pyop/test_access_token.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pyop.access_token import extract_bearer_token_from_http_request, BearerTokenError 4 | 5 | ACCESS_TOKEN = 'abcdef' 6 | 7 | 8 | class TestExtractBearerTokenFromHttpRequest(object): 9 | def test_authorization_header(self): 10 | assert extract_bearer_token_from_http_request(authz_header='Bearer {}'.format(ACCESS_TOKEN)) == ACCESS_TOKEN 11 | 12 | def test_non_bearer_authorization_header(self): 13 | with pytest.raises(BearerTokenError): 14 | extract_bearer_token_from_http_request(authz_header='Basic {}'.format(ACCESS_TOKEN)) 15 | 16 | def test_access_token_in_request(self): 17 | data = { 18 | 'foo': 'bar', 19 | 'access_token': ACCESS_TOKEN 20 | } 21 | assert extract_bearer_token_from_http_request(data) == ACCESS_TOKEN 22 | 23 | def test_request_without_access_token(self): 24 | data = { 25 | 'foo': 'bar', 26 | } 27 | with pytest.raises(BearerTokenError): 28 | extract_bearer_token_from_http_request(data) 29 | -------------------------------------------------------------------------------- /tests/pyop/test_authz_state.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import functools 3 | import time 4 | from unittest.mock import patch, Mock 5 | 6 | import pytest 7 | 8 | from pyop.message import AuthorizationRequest 9 | from pyop.authz_state import AccessToken, InvalidScope 10 | from pyop.authz_state import AuthorizationState 11 | from pyop.exceptions import InvalidSubjectIdentifier, InvalidAccessToken, InvalidAuthorizationCode, InvalidRefreshToken 12 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory 13 | import pyop.storage 14 | 15 | MOCK_TIME = Mock(return_value=time.mktime(dt.datetime(2016, 6, 21).timetuple())) 16 | INVALID_INPUT = [None, '', 'noexist'] 17 | 18 | 19 | class TestAuthorizationState(object): 20 | TEST_TOKEN_LIFETIME = 60 * 50 # 50 minutes 21 | TEST_SUBJECT_IDENTIFIER = 'sub' 22 | 23 | def set_valid_subject_identifier(self, authorization_state): 24 | is_valid_sub_mock = Mock() 25 | is_valid_sub_mock.side_effect = lambda sub: sub == self.TEST_SUBJECT_IDENTIFIER 26 | authorization_state._is_valid_subject_identifier = is_valid_sub_mock 27 | 28 | @pytest.fixture 29 | def authorization_request(self): 30 | authn_req = AuthorizationRequest(**{'scope': 'openid', 'client_id': 'client1'}) 31 | return authn_req 32 | 33 | @pytest.fixture 34 | def authorization_state_factory(self): 35 | return functools.partial(AuthorizationState, HashBasedSubjectIdentifierFactory('salt')) 36 | 37 | @pytest.fixture 38 | def authorization_state(self, authorization_state_factory): 39 | return authorization_state_factory(refresh_token_lifetime=3600) 40 | 41 | @pytest.fixture 42 | def stateless_storage(self): 43 | return pyop.storage.StatelessWrapper("pyop", "abc123") 44 | 45 | def assert_access_token(self, authorization_request, access_token, access_token_db, iat): 46 | assert isinstance(access_token, AccessToken) 47 | assert access_token.expires_in == self.TEST_TOKEN_LIFETIME 48 | assert access_token.value 49 | assert access_token.BEARER_TOKEN_TYPE == 'Bearer' 50 | 51 | assert access_token.value in access_token_db 52 | self.assert_introspected_token(authorization_request, access_token_db[access_token.value], access_token, iat) 53 | assert access_token_db[access_token.value]['exp'] == iat + self.TEST_TOKEN_LIFETIME 54 | 55 | def assert_introspected_token(self, authorization_request, token_introspection, access_token, iat): 56 | auth_req = authorization_request.to_dict() 57 | 58 | assert token_introspection['scope'] == auth_req['scope'] 59 | assert token_introspection['client_id'] == auth_req['client_id'] 60 | assert token_introspection['token_type'] == access_token.type 61 | assert token_introspection['sub'] == self.TEST_SUBJECT_IDENTIFIER 62 | assert token_introspection['aud'] == [auth_req['client_id']] 63 | assert token_introspection['iat'] == iat 64 | 65 | @patch('time.time', MOCK_TIME) 66 | def test_create_authorization_code(self, authorization_state_factory, authorization_request): 67 | code_lifetime = 60 * 2 # two minutes 68 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime) 69 | self.set_valid_subject_identifier(authorization_state) 70 | 71 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 72 | assert authz_code in authorization_state.authorization_codes 73 | assert authorization_state.authorization_codes[authz_code]['exp'] == int(time.time()) + code_lifetime 74 | assert authorization_state.authorization_codes[authz_code]['used'] is False 75 | assert authorization_state.authorization_codes[authz_code][AuthorizationState.KEY_AUTHORIZATION_REQUEST] == \ 76 | authorization_request.to_dict() 77 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER 78 | 79 | @patch('time.time', MOCK_TIME) 80 | def test_create_authorization_code_with_stateless_storage(self, authorization_state_factory, authorization_request, 81 | stateless_storage): 82 | code_lifetime = 60 * 2 # two minutes 83 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime, 84 | authorization_code_db=stateless_storage) 85 | self.set_valid_subject_identifier(authorization_state) 86 | 87 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 88 | assert authz_code in authorization_state.authorization_codes 89 | assert authorization_state.authorization_codes[authz_code]['exp'] == int(time.time()) + code_lifetime 90 | assert authorization_state.authorization_codes[authz_code]['used'] is False 91 | assert authorization_state.authorization_codes[authz_code][AuthorizationState.KEY_AUTHORIZATION_REQUEST] == \ 92 | authorization_request.to_dict() 93 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER 94 | 95 | def test_create_authorization_code_with_scope_other_than_auth_req(self, authorization_state, authorization_request): 96 | scope = ['openid', 'extra'] 97 | self.set_valid_subject_identifier(authorization_state) 98 | 99 | authz_code = authorization_state.create_authorization_code(authorization_request, 100 | self.TEST_SUBJECT_IDENTIFIER, scope=scope) 101 | assert authorization_state.authorization_codes[authz_code]['granted_scope'] == ' '.join(scope) 102 | 103 | @pytest.mark.parametrize('sub', INVALID_INPUT) 104 | def test_create_authorization_code_with_invalid_subject_identifier(self, sub, authorization_state, 105 | authorization_request): 106 | with pytest.raises(InvalidSubjectIdentifier): 107 | authorization_state.create_authorization_code(authorization_request, sub) 108 | 109 | @patch('time.time', MOCK_TIME) 110 | def test_create_access_token(self, authorization_state_factory, authorization_request): 111 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME) 112 | self.set_valid_subject_identifier(authorization_state) 113 | 114 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 115 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens, 116 | MOCK_TIME.return_value) 117 | 118 | @patch('time.time', MOCK_TIME) 119 | def test_create_access_token_with_stateless_storage(self, authorization_state_factory, authorization_request, 120 | stateless_storage): 121 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME, 122 | access_token_db=stateless_storage) 123 | self.set_valid_subject_identifier(authorization_state) 124 | 125 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 126 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens, 127 | MOCK_TIME.return_value) 128 | 129 | def test_create_access_token_with_scope_other_than_auth_req(self, authorization_state, authorization_request): 130 | scope = ['openid', 'extra'] 131 | self.set_valid_subject_identifier(authorization_state) 132 | 133 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER, 134 | scope=scope) 135 | assert authorization_state.access_tokens[access_token.value]['scope'] == ' '.join(scope) 136 | 137 | @pytest.mark.parametrize('sub', INVALID_INPUT) 138 | def test_create_access_token_with_invalid_subject_identifier(self, sub, authorization_state, authorization_request): 139 | self.set_valid_subject_identifier(authorization_state) 140 | with pytest.raises(InvalidSubjectIdentifier): 141 | authorization_state.create_access_token(authorization_request, sub) 142 | 143 | @patch('time.time', MOCK_TIME) 144 | def test_introspect_access_token(self, authorization_state_factory, authorization_request): 145 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME) 146 | self.set_valid_subject_identifier(authorization_state) 147 | 148 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 149 | token_introspection = authorization_state.introspect_access_token(access_token.value) 150 | assert token_introspection['active'] is True 151 | self.assert_introspected_token(authorization_request, token_introspection, access_token, MOCK_TIME.return_value) 152 | 153 | def test_introspect_access_token_with_expired_token(self, authorization_state_factory, authorization_request): 154 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME) 155 | self.set_valid_subject_identifier(authorization_state) 156 | 157 | with patch('time.time', MOCK_TIME): 158 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 159 | 160 | mock_time2 = Mock() 161 | mock_time2.return_value = MOCK_TIME.return_value + self.TEST_TOKEN_LIFETIME + 1 # time after token expiration 162 | with patch('time.time', mock_time2): 163 | token_introspection = authorization_state.introspect_access_token(access_token.value) 164 | assert token_introspection['active'] is False 165 | self.assert_introspected_token(authorization_request, token_introspection, access_token, MOCK_TIME.return_value) 166 | 167 | @pytest.mark.parametrize('access_token', INVALID_INPUT) 168 | def test_introspect_access_token_with_invalid_access_token(self, access_token, authorization_state): 169 | with pytest.raises(InvalidAccessToken): 170 | authorization_state.introspect_access_token(access_token) 171 | 172 | @pytest.mark.parametrize('authz_code', INVALID_INPUT) 173 | def test_exchange_code_for_token_with_invalid_code(self, authz_code, authorization_state): 174 | with pytest.raises(InvalidAuthorizationCode): 175 | authorization_state.exchange_code_for_token(authz_code) 176 | 177 | @patch('time.time', MOCK_TIME) 178 | def test_exchange_code_for_token(self, authorization_state_factory, authorization_request): 179 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME) 180 | self.set_valid_subject_identifier(authorization_state) 181 | 182 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 183 | access_token = authorization_state.exchange_code_for_token(authz_code) 184 | 185 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens, MOCK_TIME.return_value) 186 | assert authorization_state.authorization_codes[authz_code]['used'] == True 187 | 188 | @patch('time.time', MOCK_TIME) 189 | def test_exchange_code_for_token_with_stateless_storage(self, authorization_state_factory, authorization_request, 190 | stateless_storage): 191 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME, 192 | authorization_code_db=stateless_storage, 193 | access_token_db=stateless_storage) 194 | self.set_valid_subject_identifier(authorization_state) 195 | 196 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 197 | access_token = authorization_state.exchange_code_for_token(authz_code) 198 | 199 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens, 200 | MOCK_TIME.return_value) 201 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER 202 | assert authorization_state.authorization_codes[authz_code]['used'] == False 203 | 204 | def test_exchange_code_for_token_with_scope_other_than_auth_req(self, authorization_state, 205 | authorization_request): 206 | scope = ['openid', 'extra'] 207 | self.set_valid_subject_identifier(authorization_state) 208 | 209 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER, 210 | scope=scope) 211 | access_token = authorization_state.exchange_code_for_token(authz_code) 212 | 213 | assert authorization_state.access_tokens[access_token.value]['scope'] == ' '.join(scope) 214 | 215 | def test_exchange_code_for_token_with_expired_token(self, authorization_state_factory, authorization_request): 216 | code_lifetime = 2 217 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime) 218 | self.set_valid_subject_identifier(authorization_state) 219 | 220 | with patch('time.time', MOCK_TIME): 221 | authz_code = authorization_state.create_authorization_code(authorization_request, 222 | self.TEST_SUBJECT_IDENTIFIER) 223 | 224 | time_mock = Mock() 225 | time_mock.return_value = MOCK_TIME.return_value + code_lifetime + 1 # time after code expiration 226 | with patch('time.time', time_mock), pytest.raises(InvalidAuthorizationCode): 227 | authorization_state.exchange_code_for_token(authz_code) 228 | 229 | def test_exchange_code_for_token_with_used_token(self, authorization_state_factory, authorization_request): 230 | code_lifetime = 2 231 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime) 232 | self.set_valid_subject_identifier(authorization_state) 233 | 234 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 235 | assert authorization_state.exchange_code_for_token(authz_code) # successful use once 236 | with pytest.raises(InvalidAuthorizationCode): 237 | authorization_state.exchange_code_for_token(authz_code) # fail on second use 238 | 239 | def test_create_refresh_token(self, authorization_state, authorization_request): 240 | self.set_valid_subject_identifier(authorization_state) 241 | 242 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 243 | refresh_token = authorization_state.create_refresh_token(access_token.value) 244 | 245 | assert refresh_token in authorization_state.refresh_tokens 246 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value 247 | assert 'exp' in authorization_state.refresh_tokens[refresh_token] 248 | 249 | def test_create_refresh_token_with_stateless_storage(self, authorization_state_factory, authorization_request, 250 | stateless_storage): 251 | authorization_state = authorization_state_factory(refresh_token_lifetime=3600, 252 | authorization_code_db=stateless_storage, 253 | access_token_db=stateless_storage, 254 | refresh_token_db=stateless_storage) 255 | self.set_valid_subject_identifier(authorization_state) 256 | 257 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 258 | refresh_token = authorization_state.create_refresh_token(access_token.value) 259 | 260 | assert refresh_token in authorization_state.refresh_tokens 261 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value 262 | assert 'exp' in authorization_state.refresh_tokens[refresh_token] 263 | 264 | def test_create_refresh_token_issues_no_refresh_token_if_no_lifetime_is_specified(self, authorization_state_factory, 265 | authorization_request): 266 | authorization_state = authorization_state_factory(refresh_token_lifetime=None) 267 | self.set_valid_subject_identifier(authorization_state) 268 | 269 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 270 | refresh_token = authorization_state.create_refresh_token(access_token.value) 271 | 272 | assert refresh_token is None 273 | 274 | @pytest.mark.parametrize('access_token', INVALID_INPUT) 275 | def test_create_refresh_token_with_invalid_access_token_value(self, access_token, authorization_state): 276 | with pytest.raises(InvalidAccessToken): 277 | authorization_state.create_refresh_token(access_token) 278 | 279 | @patch('time.time', MOCK_TIME) 280 | def test_create_refresh_token_with_expiration_time(self, authorization_state_factory, authorization_request): 281 | refresh_token_lifetime = 60 * 60 * 24 # 24 hours 282 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_token_lifetime) 283 | self.set_valid_subject_identifier(authorization_state) 284 | 285 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 286 | refresh_token = authorization_state.create_refresh_token(access_token.value) 287 | 288 | assert refresh_token in authorization_state.refresh_tokens 289 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value 290 | assert authorization_state.refresh_tokens[refresh_token]['exp'] == int(time.time()) + refresh_token_lifetime 291 | 292 | def test_use_refresh_token(self, authorization_state_factory, authorization_request): 293 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME, 294 | refresh_token_lifetime=300) 295 | self.set_valid_subject_identifier(authorization_state) 296 | 297 | with patch('time.time', MOCK_TIME): 298 | old_access_token = authorization_state.create_access_token(authorization_request, 299 | self.TEST_SUBJECT_IDENTIFIER) 300 | refresh_token = authorization_state.create_refresh_token(old_access_token.value) 301 | 302 | mock_time2 = Mock() 303 | mock_time2.return_value = MOCK_TIME.return_value + 100 304 | with patch('time.time', mock_time2): 305 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token) 306 | 307 | assert new_refresh_token is None 308 | assert new_access_token.value != old_access_token.value 309 | assert new_access_token.type == old_access_token.type 310 | 311 | assert authorization_state.access_tokens[new_access_token.value]['exp'] > \ 312 | authorization_state.access_tokens[old_access_token.value]['exp'] 313 | assert authorization_state.access_tokens[new_access_token.value]['iat'] > \ 314 | authorization_state.access_tokens[old_access_token.value]['iat'] 315 | self.assert_access_token(authorization_request, new_access_token, authorization_state.access_tokens, mock_time2.return_value) 316 | 317 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == new_access_token.value 318 | 319 | @pytest.mark.parametrize('refresh_token', INVALID_INPUT) 320 | def test_use_refresh_token_with_invalid_refresh_token(self, refresh_token, authorization_state): 321 | with pytest.raises(InvalidRefreshToken): 322 | authorization_state.use_refresh_token(refresh_token) 323 | 324 | def test_use_refresh_token_issues_new_refresh_token_if_the_old_is_close_to_expiration( 325 | self, authorization_state_factory, authorization_request): 326 | refresh_threshold = 3600 327 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_threshold * 2, 328 | refresh_token_threshold=refresh_threshold) 329 | self.set_valid_subject_identifier(authorization_state) 330 | 331 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 332 | refresh_token = authorization_state.create_refresh_token(old_access_token.value) 333 | 334 | close_to_expiration = int(time.time()) + authorization_state.refresh_token_lifetime - 50 335 | with patch('time.time', Mock(return_value=close_to_expiration)): 336 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token) 337 | 338 | assert new_refresh_token is not None 339 | assert new_refresh_token in authorization_state.refresh_tokens 340 | assert authorization_state.refresh_tokens[new_refresh_token]['access_token'] == new_access_token.value 341 | 342 | def test_use_refresh_token_doesnt_issue_new_refresh_token_if_the_old_is_far_from_expiration( 343 | self, authorization_state_factory, authorization_request): 344 | refresh_threshold = 3600 345 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_threshold * 2, 346 | refresh_token_threshold=refresh_threshold) 347 | 348 | self.set_valid_subject_identifier(authorization_state) 349 | 350 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 351 | refresh_token = authorization_state.create_refresh_token(old_access_token.value) 352 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token) 353 | 354 | assert new_refresh_token is None 355 | 356 | def test_use_refresh_token_doesnt_issue_new_refresh_token_if_no_refresh_token_threshold_is_set( 357 | self, authorization_state_factory, authorization_request): 358 | authorization_state = authorization_state_factory(refresh_token_lifetime=400) 359 | 360 | self.set_valid_subject_identifier(authorization_state) 361 | 362 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 363 | refresh_token = authorization_state.create_refresh_token(old_access_token.value) 364 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token) 365 | 366 | assert new_refresh_token is None 367 | 368 | def test_use_refresh_token_with_expired_refresh_token(self, authorization_state_factory, authorization_request): 369 | refresh_token_lifetime = 2 370 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_token_lifetime) 371 | self.set_valid_subject_identifier(authorization_state) 372 | 373 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 374 | with patch('time.time', MOCK_TIME): 375 | refresh_token = authorization_state.create_refresh_token(access_token.value) 376 | 377 | time_mock = Mock() 378 | time_mock.return_value = MOCK_TIME.return_value + refresh_token_lifetime + 1 # time after refresh_token expiration 379 | with patch('time.time', time_mock), pytest.raises(InvalidRefreshToken): 380 | authorization_state.use_refresh_token(refresh_token) 381 | 382 | def test_use_refresh_token_with_superset_scope(self, authorization_state, authorization_request): 383 | self.set_valid_subject_identifier(authorization_state) 384 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 385 | refresh_token = authorization_state.create_refresh_token(access_token.value) 386 | with pytest.raises(InvalidScope): 387 | authorization_state.use_refresh_token(refresh_token, scope=['openid', 'extra']) 388 | 389 | def test_use_refresh_token_with_subset_scope(self, authorization_state, authorization_request): 390 | self.set_valid_subject_identifier(authorization_state) 391 | authorization_request['scope'] = 'openid profile' 392 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 393 | refresh_token = authorization_state.create_refresh_token(access_token.value) 394 | access_token, _ = authorization_state.use_refresh_token(refresh_token, scope=['openid']) 395 | 396 | assert authorization_state.access_tokens[access_token.value]['scope'] == 'openid' 397 | 398 | def test_use_refresh_token_with_subset_scope_does_not_minimize_granted_scope(self, authorization_state, 399 | authorization_request): 400 | scope = 'openid profile' 401 | self.set_valid_subject_identifier(authorization_state) 402 | authorization_request['scope'] = scope 403 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 404 | refresh_token = authorization_state.create_refresh_token(access_token.value) 405 | 406 | # first time: issue access token with subset of granted scope 407 | access_token, _ = authorization_state.use_refresh_token(refresh_token, scope=['openid']) 408 | assert authorization_state.access_tokens[access_token.value]['scope'] == 'openid' 409 | 410 | # second time: issue access token with exactly granted scope 411 | access_token, _ = authorization_state.use_refresh_token(refresh_token) 412 | assert authorization_state.access_tokens[access_token.value]['scope'] == scope 413 | 414 | def test_use_refresh_token_without_scope(self, authorization_state, authorization_request): 415 | self.set_valid_subject_identifier(authorization_state) 416 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 417 | refresh_token = authorization_state.create_refresh_token(access_token.value) 418 | access_token, _ = authorization_state.use_refresh_token(refresh_token) 419 | 420 | assert authorization_state.access_tokens[access_token.value]['scope'] == \ 421 | ' '.join(authorization_request['scope']) 422 | 423 | def test_create_subject_identifier_public(self, authorization_state): 424 | user_id = 'test_user' 425 | sub1 = authorization_state.get_subject_identifier('public', user_id) 426 | sub2 = authorization_state.get_subject_identifier('public', user_id) 427 | assert sub1 == sub2 428 | assert authorization_state.subject_identifiers[user_id]['public'] == sub1 429 | 430 | def test_create_subject_identifier_pairwise_with_diffent_redirect_uris(self, authorization_state): 431 | user_id = 'test_user' 432 | sector_identifier1 = 'client1.example.com' 433 | sector_identifier2 = 'client2.example.com' 434 | sub1 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier1) 435 | sub2 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier2) 436 | assert sub1 != sub2 437 | assert all(s in authorization_state.subject_identifiers[user_id]['pairwise'] for s in [sub1, sub2]) 438 | 439 | def test_create_subject_identifier_pairwise_with_same_hostname(self, authorization_state): 440 | user_id = 'test_user' 441 | sector_identifier = 'client.example.com' 442 | sub1 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier) 443 | sub2 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier) 444 | assert sub1 == sub2 445 | assert sub1 in authorization_state.subject_identifiers[user_id]['pairwise'] 446 | 447 | def test_create_subject_identifier_pairwise_without_sector_identifier(self, authorization_state): 448 | with pytest.raises(ValueError): 449 | authorization_state.get_subject_identifier('pairwise', 'test_user', None) 450 | 451 | def test_create_subject_identifier_with_unknown_subject_type(self, authorization_state): 452 | with pytest.raises(ValueError): 453 | authorization_state.get_subject_identifier('unknown', 'test_user', None) 454 | 455 | @pytest.mark.parametrize('subject_type', [ 456 | 'public', 457 | 'pairwise' 458 | ]) 459 | def test_is_valid_subject_identifier(self, subject_type, authorization_state): 460 | sub = authorization_state.get_subject_identifier(subject_type, 'test_user', 'client.example.com') 461 | assert authorization_state._is_valid_subject_identifier(sub) is True 462 | 463 | def test_get_authentication_request_for_code(self, authorization_state, authorization_request): 464 | self.set_valid_subject_identifier(authorization_state) 465 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 466 | request = authorization_state.get_authorization_request_for_code(authz_code) 467 | assert request.to_dict() == authorization_request.to_dict() 468 | 469 | def test_get_authentication_request_for_access_token(self, authorization_state, authorization_request): 470 | self.set_valid_subject_identifier(authorization_state) 471 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 472 | request = authorization_state.get_authorization_request_for_access_token(access_token.value) 473 | assert request.to_dict() == authorization_request.to_dict() 474 | 475 | def test_get_subject_identifier_for_code(self, authorization_state, authorization_request): 476 | self.set_valid_subject_identifier(authorization_state) 477 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 478 | sub = authorization_state.get_subject_identifier_for_code(authz_code) 479 | assert sub == self.TEST_SUBJECT_IDENTIFIER 480 | 481 | def test_remove_state_for_subject_identifier(self, authorization_state, authorization_request): 482 | self.set_valid_subject_identifier(authorization_state) 483 | authz_code1 = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 484 | authz_code2 = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 485 | access_token1 = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 486 | access_token2 = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER) 487 | 488 | authorization_state.delete_state_for_subject_identifier(self.TEST_SUBJECT_IDENTIFIER) 489 | 490 | for ac in [authz_code1, authz_code2]: 491 | assert ac not in authorization_state.authorization_codes 492 | for at in [access_token1, access_token2]: 493 | assert at.value not in authorization_state.access_tokens 494 | 495 | def test_remove_state_for_unknown_subject_identifier(self, authorization_state): 496 | with pytest.raises(InvalidSubjectIdentifier): 497 | authorization_state.delete_state_for_subject_identifier('unknown') 498 | -------------------------------------------------------------------------------- /tests/pyop/test_client_authentication.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | import pytest 4 | 5 | from pyop.client_authentication import verify_client_authentication 6 | from pyop.exceptions import InvalidClientAuthentication 7 | 8 | TEST_CLIENT_ID = 'client1' 9 | TEST_CLIENT_SECRET = 'my_secret' 10 | 11 | 12 | class TestVerifyClientAuthentication(object): 13 | def create_basic_auth(self, client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET): 14 | credentials = client_id + ':' + client_secret 15 | auth = base64.urlsafe_b64encode(credentials.encode('utf-8')).decode('utf-8') 16 | return 'Basic {}'.format(auth) 17 | 18 | @pytest.fixture(autouse=True) 19 | def create_request_args(self): 20 | self.token_request_args = { 21 | 'grant_type': 'authorization_code', 22 | 'code': 'code', 23 | 'redirect_uri': 'https://client.example.com', 24 | 'client_id': TEST_CLIENT_ID, 25 | 'client_secret': TEST_CLIENT_SECRET 26 | } 27 | 28 | self.clients = { 29 | TEST_CLIENT_ID: { 30 | 'client_secret': TEST_CLIENT_SECRET, 31 | 'token_endpoint_auth_method': 'client_secret_post' 32 | } 33 | } 34 | 35 | def test_wrong_authentication_method(self): 36 | # do client_secret_basic, while client_secret_post is expected 37 | authz_header = self.create_basic_auth() 38 | with pytest.raises(InvalidClientAuthentication): 39 | verify_client_authentication(self.clients, None, authz_header) 40 | 41 | def test_authentication_method_defaults_to_client_secret_basic(self): 42 | del self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] 43 | authz_header = self.create_basic_auth() 44 | assert verify_client_authentication(self.clients, self.token_request_args, authz_header) == TEST_CLIENT_ID 45 | 46 | def test_client_secret_post(self): 47 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'client_secret_post' 48 | assert verify_client_authentication(self.clients, self.token_request_args) == TEST_CLIENT_ID 49 | 50 | def test_client_secret_basic(self): 51 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'client_secret_basic' 52 | authz_header = self.create_basic_auth() 53 | assert verify_client_authentication(self.clients, self.token_request_args, authz_header) == TEST_CLIENT_ID 54 | 55 | def test_unknown_client_id(self): 56 | self.token_request_args['client_id'] = 'unknown' 57 | with pytest.raises(InvalidClientAuthentication): 58 | verify_client_authentication(self.clients, self.token_request_args) == TEST_CLIENT_ID 59 | 60 | def test_wrong_client_secret(self): 61 | self.token_request_args['client_secret'] = 'foobar' 62 | with pytest.raises(InvalidClientAuthentication): 63 | verify_client_authentication(self.clients, self.token_request_args) 64 | 65 | def test_public_client_no_auth(self): 66 | del self.token_request_args['client_secret'] 67 | # public client 68 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'none' 69 | del self.clients[TEST_CLIENT_ID]['client_secret'] 70 | 71 | assert verify_client_authentication(self.clients, self.token_request_args, None) == TEST_CLIENT_ID 72 | 73 | def test_invalid_authorization_scheme(self): 74 | authz_header = self.create_basic_auth() 75 | with pytest.raises(InvalidClientAuthentication): 76 | verify_client_authentication(self.clients, self.token_request_args, 77 | authz_header.replace('Basic', 'invalid')) 78 | 79 | def test_invalid_userid_password(self): 80 | with pytest.raises(InvalidClientAuthentication): 81 | verify_client_authentication(self.clients, self.token_request_args, 'Basic invalid') 82 | -------------------------------------------------------------------------------- /tests/pyop/test_exceptions.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse, parse_qsl 2 | 3 | from pyop.message import AuthorizationRequest 4 | from pyop.exceptions import InvalidAuthenticationRequest 5 | 6 | 7 | class TestInvalidAuthenticationRequest: 8 | def test_error_url_should_contain_state_from_authentication_request(self): 9 | authn_params = {'redirect_uri': 'test_redirect_uri', 'response_type': 'code', 'state': 'test_state'} 10 | authn_req = AuthorizationRequest().from_dict(authn_params) 11 | error_url = InvalidAuthenticationRequest('test', authn_req, 'invalid_request').to_error_url() 12 | 13 | error = dict(parse_qsl(urlparse(error_url).query)) 14 | assert error['state'] == authn_params['state'] 15 | 16 | def test_error_url_should_handle_unknown_response_type(self): 17 | authn_params = {'redirect_uri': 'test_redirect_uri', 'state': 'test_state'} # no 'response_type' 18 | authn_req = AuthorizationRequest().from_dict(authn_params) 19 | 20 | error = InvalidAuthenticationRequest('test', authn_req, 'invalid_request') 21 | assert error.to_error_url() is None -------------------------------------------------------------------------------- /tests/pyop/test_provider.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import json 3 | import time 4 | from collections import Counter 5 | from unittest.mock import Mock, patch 6 | from urllib.parse import urlencode 7 | 8 | import pytest 9 | from Cryptodome.PublicKey import RSA 10 | from jwkest import jws 11 | from jwkest.jwk import RSAKey 12 | from oic import rndstr 13 | from oic.oauth2.message import MissingRequiredValue, MissingRequiredAttribute 14 | from oic.oic import PREFERENCE2PROVIDER 15 | from oic.oic.message import IdToken, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse 16 | 17 | from pyop.message import AuthorizationRequest 18 | from pyop.access_token import BearerTokenError 19 | from pyop.authz_state import AuthorizationState 20 | from pyop.client_authentication import InvalidClientAuthentication 21 | from pyop.exceptions import InvalidAuthenticationRequest, AuthorizationError, InvalidTokenRequest, \ 22 | InvalidClientRegistrationRequest, InvalidAccessToken, InvalidAuthorizationCode, InvalidSubjectIdentifier 23 | from pyop.provider import Provider, redirect_uri_is_in_registered_redirect_uris, \ 24 | response_type_is_in_registered_response_types 25 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory 26 | from pyop.userinfo import Userinfo 27 | 28 | TEST_CLIENT_ID = 'client1' 29 | TEST_CLIENT_SECRET = 'secret' 30 | INVALID_TEST_CLIENT_ID = 'invalid_client' 31 | INVALID_TEST_CLIENT_SECRET = 'invalid_secret' 32 | TEST_REDIRECT_URI = 'https://client.example.com' 33 | ISSUER = 'https://provider.example.com' 34 | TEST_USER_ID = 'user1' 35 | 36 | MOCK_TIME = Mock(return_value=time.mktime(dt.datetime(2016, 6, 21).timetuple())) 37 | 38 | 39 | def rsa_key(): 40 | return RSAKey(key=RSA.generate(1024), use="sig", alg="RS256", kid=rndstr(4)) 41 | 42 | 43 | def provide_configuration(): 44 | conf = { 45 | 'issuer': ISSUER, 46 | 'jwks_uri': '/jwks', 47 | 'authorization_endpoint': '/authorization', 48 | 'token_endpoint': '/token' 49 | } 50 | return conf 51 | 52 | 53 | def assert_id_token_base_claims(jws, verification_key, provider, auth_req): 54 | id_token = IdToken().from_jwt(jws, key=[verification_key]) 55 | assert id_token['nonce'] == auth_req['nonce'] 56 | assert id_token['iss'] == ISSUER 57 | assert provider.authz_state.get_user_id_for_subject_identifier(id_token['sub']) == TEST_USER_ID 58 | assert id_token['iat'] == MOCK_TIME.return_value 59 | assert id_token['exp'] == id_token['iat'] + provider.id_token_lifetime 60 | assert TEST_CLIENT_ID in id_token['aud'] 61 | 62 | return id_token 63 | 64 | 65 | @pytest.fixture 66 | def auth_req_args(request): 67 | request.instance.authn_request_args = { 68 | 'scope': 'openid', 69 | 'response_type': 'code', 70 | 'client_id': TEST_CLIENT_ID, 71 | 'redirect_uri': TEST_REDIRECT_URI, 72 | 'state': 'state', 73 | 'nonce': 'nonce' 74 | } 75 | 76 | 77 | @pytest.fixture 78 | def inject_provider(request): 79 | clients = { 80 | TEST_CLIENT_ID: { 81 | 'subject_type': 'pairwise', 82 | 'redirect_uris': [TEST_REDIRECT_URI], 83 | 'response_types': ['code'], 84 | 'client_secret': TEST_CLIENT_SECRET, 85 | 'token_endpoint_auth_method': 'client_secret_post', 86 | 'post_logout_redirect_uris': ['https://client.example.com/post_logout'] 87 | }, 88 | INVALID_TEST_CLIENT_ID: { 89 | 'subject_type': 'pairwise', 90 | 'redirect_uris': 'http://invalid.redirect.loc', 91 | 'response_types': ['code'], 92 | 'client_secret': INVALID_TEST_CLIENT_SECRET, 93 | 'token_endpoint_auth_method': 'client_secret_post', 94 | 'post_logout_redirect_uris': ['https://client.example.com/post_logout'] 95 | } 96 | } 97 | 98 | userinfo = Userinfo({ 99 | TEST_USER_ID: { 100 | 'name': 'The T. Tester', 101 | 'family_name': 'Tester', 102 | 'given_name': 'The', 103 | 'middle_name': 'Theodore', 104 | 'nickname': 'testster', 105 | 'email': 'testster@example.com', 106 | } 107 | }) 108 | request.instance.provider = Provider(rsa_key(), provide_configuration(), 109 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')), 110 | clients, userinfo) 111 | 112 | 113 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args') 114 | class TestProviderParseAuthenticationRequest(object): 115 | def test_parse_authentication_request(self): 116 | nonce = 'nonce' 117 | self.authn_request_args['nonce'] = nonce 118 | 119 | received_request = self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 120 | assert received_request.to_dict() == self.authn_request_args 121 | 122 | def test_reject_request_with_missing_required_parameter(self): 123 | del self.authn_request_args['redirect_uri'] 124 | 125 | with pytest.raises(InvalidAuthenticationRequest) as exc: 126 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 127 | assert isinstance(exc.value.__cause__, MissingRequiredAttribute) 128 | 129 | def test_reject_request_with_scope_without_openid(self): 130 | self.authn_request_args['scope'] = 'foobar' # does not contain 'openid' 131 | 132 | with pytest.raises(InvalidAuthenticationRequest) as exc: 133 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 134 | assert isinstance(exc.value.__cause__, MissingRequiredValue) 135 | 136 | def test_custom_validation_hook_reject(self): 137 | class TestException(Exception): 138 | pass 139 | 140 | def fail_all_requests(auth_req): 141 | raise InvalidAuthenticationRequest("Test exception", auth_req) from TestException() 142 | 143 | self.provider.authentication_request_validators.append(fail_all_requests) 144 | with pytest.raises(InvalidAuthenticationRequest) as exc: 145 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 146 | 147 | assert isinstance(exc.value.__cause__, TestException) 148 | 149 | def test_redirect_uri_not_matching_registered_redirect_uris(self): 150 | self.authn_request_args['redirect_uri'] = 'https://something.example.com' 151 | with pytest.raises(InvalidAuthenticationRequest): 152 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 153 | 154 | def test_response_type_not_matching_registered_response_types(self): 155 | self.authn_request_args['response_type'] = 'id_token' 156 | with pytest.raises(InvalidAuthenticationRequest): 157 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 158 | 159 | def test_unknown_client_id(self): 160 | self.authn_request_args['client_id'] = 'unknown' 161 | with pytest.raises(InvalidAuthenticationRequest): 162 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 163 | 164 | def test_include_userinfo_claims_request_with_response_type_id_token(self): 165 | self.authn_request_args['claims'] = ClaimsRequest(userinfo=Claims(nickname=None)).to_json() 166 | self.provider.clients[TEST_CLIENT_ID]['response_types'] = ['id_token'] 167 | self.authn_request_args['response_type'] = 'id_token' 168 | with pytest.raises(InvalidAuthenticationRequest): 169 | self.provider.parse_authentication_request(urlencode(self.authn_request_args)) 170 | 171 | 172 | @pytest.mark.usefixtures('auth_req_args') 173 | class TestAuthenticationRequestValidators(object): 174 | @pytest.fixture 175 | def provider_mock(self): 176 | provider = Mock() 177 | provider.clients = Mock() 178 | provider.clients.__getitem__ = Mock(return_value={}) 179 | return provider 180 | 181 | def test_redirect_uri_is_in_registered_redirect_uris_with_no_redirect_uris(self, provider_mock): 182 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 183 | with pytest.raises(InvalidAuthenticationRequest): 184 | redirect_uri_is_in_registered_redirect_uris(provider_mock, auth_req) 185 | 186 | def test_response_type_is_in_registered_response_types_with_no_response_types(self, provider_mock): 187 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 188 | with pytest.raises(InvalidAuthenticationRequest): 189 | response_type_is_in_registered_response_types(provider_mock, auth_req) 190 | 191 | 192 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args') 193 | class TestProviderAuthorize(object): 194 | def test_authorize(self): 195 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 196 | resp = self.provider.authorize(auth_req, TEST_USER_ID) 197 | assert resp['code'] in self.provider.authz_state.authorization_codes 198 | assert resp['state'] == self.authn_request_args['state'] 199 | 200 | def test_authorize_with_custom_sub(self, monkeypatch): 201 | sub = 'test_sub1' 202 | monkeypatch.setitem(self.provider.userinfo._db[TEST_USER_ID], 'sub', sub) 203 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 204 | resp = self.provider.authorize(auth_req, TEST_USER_ID) 205 | assert resp['code'] in self.provider.authz_state.authorization_codes 206 | assert resp['state'] == self.authn_request_args['state'] 207 | assert self.provider.authz_state.authorization_codes[resp['code']]['sub'] == sub 208 | 209 | @patch('time.time', MOCK_TIME) 210 | @pytest.mark.parametrize('extra_claims', [ 211 | {'foo': 'bar'}, 212 | lambda user_id, client_id: {'foo': 'bar'} 213 | ]) 214 | def test_authorize_with_extra_id_token_claims(self, extra_claims): 215 | self.authn_request_args['response_type'] = ['id_token'] # make sure ID Token is produced 216 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 217 | resp = self.provider.authorize(auth_req, TEST_USER_ID, extra_claims) 218 | id_token = assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, auth_req) 219 | assert id_token['foo'] == 'bar' 220 | 221 | def test_authorize_include_user_claims_from_scope_in_id_token_if_no_userinfo_req_can_be_made(self): 222 | self.authn_request_args['response_type'] = 'id_token' 223 | self.authn_request_args['scope'] = 'openid profile' 224 | self.authn_request_args['claims'] = ClaimsRequest(id_token=Claims(email={'essential': True})) 225 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 226 | resp = self.provider.authorize(auth_req, TEST_USER_ID) 227 | 228 | id_token = IdToken().from_jwt(resp['id_token'], key=[self.provider.signing_key]) 229 | # verify all claims are part of the ID Token 230 | assert all(id_token[claim] == value for claim, value in self.provider.userinfo[TEST_USER_ID].items()) 231 | 232 | @patch('time.time', MOCK_TIME) 233 | def test_authorize_includes_requested_id_token_claims_even_if_token_request_can_be_made(self): 234 | self.authn_request_args['response_type'] = ['id_token', 'token'] 235 | self.authn_request_args['claims'] = ClaimsRequest(id_token=Claims(email=None)) 236 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 237 | resp = self.provider.authorize(auth_req, TEST_USER_ID) 238 | id_token = assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, auth_req) 239 | assert id_token['email'] == self.provider.userinfo[TEST_USER_ID]['email'] 240 | 241 | @patch('time.time', MOCK_TIME) 242 | def test_hybrid_flow(self): 243 | self.authn_request_args['response_type'] = 'code id_token token' 244 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 245 | resp = self.provider.authorize(auth_req, TEST_USER_ID, extra_id_token_claims={'foo': 'bar'}) 246 | 247 | assert resp['state'] == self.authn_request_args['state'] 248 | assert resp['code'] in self.provider.authz_state.authorization_codes 249 | 250 | assert resp['access_token'] in self.provider.authz_state.access_tokens 251 | assert resp['expires_in'] == self.provider.authz_state.access_token_lifetime 252 | assert resp['token_type'] == 'Bearer' 253 | 254 | id_token = IdToken().from_jwt(resp['id_token'], key=[self.provider.signing_key]) 255 | assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, self.authn_request_args) 256 | assert id_token["c_hash"] == jws.left_hash(resp['code'].encode('utf-8'), 'HS256') 257 | assert id_token["at_hash"] == jws.left_hash(resp['access_token'].encode('utf-8'), 'HS256') 258 | assert id_token['foo'] == 'bar' 259 | 260 | @pytest.mark.parametrize('claims_location', [ 261 | 'id_token', 262 | 'userinfo' 263 | ]) 264 | def test_with_requested_sub_not_matching(self, claims_location): 265 | self.authn_request_args['claims'] = ClaimsRequest(**{claims_location: Claims(sub={'value': 'nomatch'})}) 266 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 267 | with pytest.raises(AuthorizationError): 268 | self.provider.authorize(auth_req, TEST_USER_ID) 269 | 270 | def test_with_multiple_requested_sub(self): 271 | self.authn_request_args['claims'] = ClaimsRequest(userinfo=Claims(sub={'value': 'nomatch1'}), 272 | id_token=Claims(sub={'value': 'nomatch2'})) 273 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 274 | with pytest.raises(AuthorizationError) as exc: 275 | self.provider.authorize(auth_req, TEST_USER_ID) 276 | 277 | assert 'different' in str(exc.value) 278 | 279 | 280 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args') 281 | class TestProviderHandleTokenRequest(object): 282 | def create_authz_code(self, extra_auth_req_params=None): 283 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com') 284 | 285 | if extra_auth_req_params: 286 | self.authn_request_args.update(extra_auth_req_params) 287 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 288 | return self.provider.authz_state.create_authorization_code(auth_req, sub) 289 | 290 | def create_refresh_token(self): 291 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com') 292 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 293 | access_token = self.provider.authz_state.create_access_token(auth_req, sub) 294 | return self.provider.authz_state.create_refresh_token(access_token.value) 295 | 296 | @pytest.fixture(autouse=True) 297 | def create_token_request_args(self): 298 | self.authorization_code_exchange_request_args = { 299 | 'grant_type': 'authorization_code', 300 | 'code': None, 301 | 'redirect_uri': 'https://client.example.com', 302 | 'client_id': TEST_CLIENT_ID, 303 | 'client_secret': TEST_CLIENT_SECRET 304 | } 305 | 306 | self.refresh_token_request_args = { 307 | 'grant_type': 'refresh_token', 308 | 'refresh_token': None, 309 | 'scope': 'openid', 310 | 'client_id': TEST_CLIENT_ID, 311 | 'client_secret': TEST_CLIENT_SECRET 312 | } 313 | 314 | @patch('time.time', MOCK_TIME) 315 | def test_code_exchange_request(self): 316 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 317 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None) 318 | assert response['access_token'] in self.provider.authz_state.access_tokens 319 | assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, 320 | self.authn_request_args) 321 | 322 | @patch('time.time', MOCK_TIME) 323 | def test_pkce_code_exchange_request(self): 324 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code( 325 | { 326 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw", 327 | "code_challenge_method": "S256" 328 | } 329 | ) 330 | self.authorization_code_exchange_request_args['code_verifier'] = "SoOEDN-mZKNhw7Mc52VXxyiqTvFB3mod36MwPru253c" 331 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None) 332 | assert response['access_token'] in self.provider.authz_state.access_tokens 333 | assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, 334 | self.authn_request_args) 335 | 336 | @patch('time.time', MOCK_TIME) 337 | def test_code_exchange_request_with_claims_requested_in_id_token(self): 338 | claims_req = {'claims': ClaimsRequest(id_token=Claims(email=None))} 339 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(extra_auth_req_params=claims_req) 340 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None) 341 | assert response['access_token'] in self.provider.authz_state.access_tokens 342 | id_token = assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, 343 | self.authn_request_args) 344 | assert id_token['email'] == self.provider.userinfo[TEST_USER_ID]['email'] 345 | 346 | @patch('time.time', MOCK_TIME) 347 | @pytest.mark.parametrize('extra_claims', [ 348 | {'foo': 'bar'}, 349 | lambda user_id, client_id: {'foo': 'bar'} 350 | ]) 351 | def test_handle_token_request_with_extra_id_token_claims(self, extra_claims): 352 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 353 | response = self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args), 354 | extra_id_token_claims=extra_claims) 355 | assert response['access_token'] in self.provider.authz_state.access_tokens 356 | id_token = assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider, 357 | self.authn_request_args) 358 | assert id_token['foo'] == 'bar' 359 | 360 | def test_handle_token_request_reject_invalid_client_authentication(self): 361 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 362 | self.authorization_code_exchange_request_args['client_secret'] = 'invalid' 363 | with pytest.raises(InvalidClientAuthentication): 364 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args), 365 | extra_id_token_claims={'foo': 'bar'}) 366 | 367 | def test_handle_token_request_reject_code_not_issued_to_client(self): 368 | self.authorization_code_exchange_request_args['client_id'] = INVALID_TEST_CLIENT_ID 369 | self.authorization_code_exchange_request_args['client_secret'] = INVALID_TEST_CLIENT_SECRET 370 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 371 | with pytest.raises(InvalidAuthorizationCode): 372 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 373 | 374 | def test_handle_token_request_reject_invalid_redirect_uri_in_exchange_request(self): 375 | self.authorization_code_exchange_request_args['redirect_uri'] = 'https://invalid.com' 376 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 377 | with pytest.raises(InvalidTokenRequest): 378 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 379 | 380 | def test_handle_token_request_reject_invalid_grant_type(self): 381 | self.authorization_code_exchange_request_args['grant_type'] = 'invalid' 382 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 383 | with pytest.raises(InvalidTokenRequest): 384 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 385 | 386 | def test_handle_token_request_reject_missing_grant_type(self): 387 | del self.authorization_code_exchange_request_args['grant_type'] 388 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code() 389 | with pytest.raises(InvalidTokenRequest): 390 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 391 | 392 | def test_handle_token_request_reject_invalid_code_verifier(self): 393 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code( 394 | { 395 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", 396 | "code_challenge_method": "S256" 397 | } 398 | ) 399 | self.authorization_code_exchange_request_args['code_verifier'] = "ThiS Cer_tainly Ain't Valid" 400 | with pytest.raises(InvalidTokenRequest): 401 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 402 | 403 | def test_handle_token_request_reject_unsynced_requests(self): 404 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code( 405 | { 406 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", 407 | "code_challenge_method": "S256" 408 | } 409 | ) 410 | with pytest.raises(InvalidTokenRequest): 411 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 412 | 413 | def test_handle_token_request_reject_missing_code_challenge_method(self): 414 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code( 415 | { 416 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=", 417 | } 418 | ) 419 | with pytest.raises(InvalidTokenRequest): 420 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args)) 421 | 422 | def test_refresh_request(self): 423 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'), 424 | refresh_token_lifetime=600) 425 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token() 426 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args)) 427 | assert response['access_token'] in self.provider.authz_state.access_tokens 428 | assert 'refresh_token' not in response 429 | 430 | def test_refresh_request_with_refresh_token_close_to_expiry_issues_new_refresh_token(self): 431 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'), 432 | refresh_token_lifetime=10, 433 | refresh_token_threshold=2) 434 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token() 435 | 436 | close_to_expiration = int(time.time()) + self.provider.authz_state.refresh_token_lifetime - 1 437 | with patch('time.time', Mock(return_value=close_to_expiration)): 438 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args)) 439 | assert response['access_token'] in self.provider.authz_state.access_tokens 440 | assert response['refresh_token'] in self.provider.authz_state.refresh_tokens 441 | 442 | def test_refresh_request_without_scope_parameter_defaults_to_scope_from_authentication_request(self): 443 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'), 444 | refresh_token_lifetime=600) 445 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token() 446 | del self.refresh_token_request_args['scope'] 447 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args)) 448 | assert response['access_token'] in self.provider.authz_state.access_tokens 449 | assert self.provider.authz_state.access_tokens[response['access_token']]['scope'] == self.authn_request_args[ 450 | 'scope'] 451 | 452 | @pytest.mark.parametrize('missing_parameter', [ 453 | 'grant_type', 454 | 'code', 455 | 'redirect_uri' 456 | ]) 457 | def test_code_exchange_request_with_missing_parameter(self, missing_parameter): 458 | request_args = { 459 | 'grant_type': 'authorization_code', 460 | 'code': None, 461 | 'redirect_uri': TEST_REDIRECT_URI, 462 | } 463 | del request_args[missing_parameter] 464 | with pytest.raises(InvalidTokenRequest): 465 | self.provider._do_code_exchange(request_args) 466 | 467 | @pytest.mark.parametrize('missing_parameter', [ 468 | 'grant_type', 469 | 'refresh_token', 470 | ]) 471 | def test_refresh_token_request_with_missing_parameter(self, missing_parameter): 472 | request_args = { 473 | 'grant_type': 'refresh_token', 474 | 'refresh_token': None, 475 | } 476 | del request_args[missing_parameter] 477 | with pytest.raises(InvalidTokenRequest): 478 | self.provider._do_token_refresh(request_args) 479 | 480 | 481 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args') 482 | class TestProviderHandleUserinfoRequest(object): 483 | def create_access_token(self, extra_auth_req_params=None): 484 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com') 485 | 486 | if extra_auth_req_params: 487 | self.authn_request_args.update(extra_auth_req_params) 488 | 489 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args) 490 | access_token = self.provider.authz_state.create_access_token(auth_req, sub) 491 | return access_token.value 492 | 493 | def test_handle_userinfo(self): 494 | claims_request = ClaimsRequest(userinfo=Claims(email=None)) 495 | access_token = self.create_access_token({'scope': 'openid profile', 'claims': claims_request}) 496 | response = self.provider.handle_userinfo_request(urlencode({'access_token': access_token})) 497 | 498 | response_sub = response['sub'] 499 | del response['sub'] 500 | assert response.to_dict() == self.provider.userinfo[TEST_USER_ID] 501 | assert self.provider.authz_state.get_user_id_for_subject_identifier(response_sub) == TEST_USER_ID 502 | 503 | def test_handle_userinfo_with_custom_sub(self, monkeypatch): 504 | sub = 'test_sub1' 505 | monkeypatch.setitem(self.provider.userinfo._db[TEST_USER_ID], 'sub', sub) 506 | claims_request = ClaimsRequest(userinfo=Claims(email=None)) 507 | access_token = self.create_access_token({'scope': 'openid profile', 'claims': claims_request}) 508 | response = self.provider.handle_userinfo_request(urlencode({'access_token': access_token})) 509 | 510 | assert response['sub'] == sub 511 | 512 | def test_handle_userinfo_rejects_request_missing_access_token(self): 513 | with pytest.raises(BearerTokenError) as exc: 514 | self.provider.handle_userinfo_request() 515 | 516 | def test_handle_userinfo_rejects_invalid_access_token(self): 517 | access_token = self.create_access_token() 518 | self.provider.authz_state.access_tokens[access_token]['exp'] = 0 519 | with pytest.raises(InvalidAccessToken): 520 | self.provider.handle_userinfo_request(urlencode({'access_token': access_token})) 521 | 522 | 523 | @pytest.mark.usefixtures('inject_provider') 524 | class TestProviderHandleRegistrationRequest(object): 525 | def test_handle_registration_request(self): 526 | request = {'redirect_uris': ['https://client.example.com/redirect']} 527 | response = self.provider.handle_client_registration_request(json.dumps(request)) 528 | assert 'client_id' in response 529 | assert 'client_id_issued_at' in response 530 | assert 'client_secret' in response 531 | assert response['client_secret_expires_at'] == 0 532 | assert all(k in response.items() for k in request.items()) 533 | 534 | assert response['client_id'] in self.provider.clients 535 | 536 | def test_rejects_invalid_request(self): 537 | request = {'application_type': 'web', 'client_name': 'test client'} # missing 'redirect_uris' 538 | with pytest.raises(InvalidClientRegistrationRequest): 539 | self.provider.handle_client_registration_request(json.dumps(request)) 540 | 541 | @pytest.mark.parametrize('client_preference, provider_capability, client_value, provider_value', [ 542 | ('request_object_signing_alg', 'request_object_signing_alg_values_supported', 543 | 'HS256', ['none', 'RS256']), 544 | ('request_object_encryption_alg', 'request_object_encryption_alg_values_supported', 545 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']), 546 | ('request_object_encryption_enc', 'request_object_encryption_enc_values_supported', 547 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']), 548 | ('userinfo_signed_response_alg', 'userinfo_signing_alg_values_supported', 549 | 'HS256', ['none', 'RS256']), 550 | ('userinfo_encrypted_response_alg', 'userinfo_encryption_alg_values_supported', 551 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']), 552 | ('userinfo_encrypted_response_enc', 'userinfo_encryption_enc_values_supported', 553 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']), 554 | ('id_token_signed_response_alg', 'id_token_signing_alg_values_supported', 555 | 'HS256', ['none', 'RS256']), 556 | ('id_token_encrypted_response_alg', 'id_token_encryption_alg_values_supported', 557 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']), 558 | ('id_token_encrypted_response_enc', 'id_token_encryption_enc_values_supported', 559 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']), 560 | ('default_acr_values', 'acr_values_supported', 561 | ['1', '2'], ['3', '4']), 562 | ('subject_type', 'subject_types_supported', 563 | 'public', ['pairwise']), 564 | ('token_endpoint_auth_method', 'token_endpoint_auth_methods_supported', 565 | 'private_key_jwt', ['client_secret_post', 'client_secret_basic']), 566 | ('token_endpoint_auth_signing_alg', 'token_endpoint_auth_signing_alg_values_supported', 567 | 'HS256', ['none', 'RS256']), 568 | ('response_types', 'response_types_supported', 569 | ['id_token token'], ['code', 'code token']), 570 | ('grant_types', 'grant_types_supported', 571 | 'implicit', ['authorization_code']) 572 | ]) 573 | def test_rejects_mismatching_request(self, client_preference, provider_capability, client_value, provider_value): 574 | request = {'redirect_uris': ['https://client.example.com/redirect']} 575 | provider_capabilities = provide_configuration() 576 | 577 | if client_preference.startswith(('request_object_encryption', 'id_token_encrypted', 'userinfo_encrypted')): 578 | # provide default value for the metadata params that come in pairs 579 | param = client_preference[:-4] 580 | alg_param = param + '_alg' 581 | request[alg_param] = 'RSA-OAEP' 582 | provider_capabilities[PREFERENCE2PROVIDER[alg_param]] = ['RSA-OAEP'] 583 | enc_param = param + '_enc' 584 | request[enc_param] = 'A192CBC-HS256' 585 | provider_capabilities[PREFERENCE2PROVIDER[enc_param]] = ['A192CBC-HS256'] 586 | 587 | request[client_preference] = client_value 588 | provider_capabilities[provider_capability] = provider_value 589 | provider = Provider(rsa_key(), provider_capabilities, 590 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')), {}, None) 591 | with pytest.raises(InvalidClientRegistrationRequest): 592 | provider.handle_client_registration_request(json.dumps(request)) 593 | 594 | @pytest.mark.parametrize('client_preference, provider_capability, client_value, provider_value', [ 595 | ('response_types', 'response_types_supported', 596 | ['code id_token token', 'code', 'id_token', 'id_token token'], 597 | ['code id_token token', 'id_token token']), 598 | ('default_acr_values', 'acr_values_supported', 599 | ['1', '2', '3'], ['2', '3', '4']), 600 | ]) 601 | def test_matches_common_set_of_metadata_values(self, client_preference, provider_capability, 602 | client_value, provider_value): 603 | provider_capabilities = provide_configuration() 604 | provider_capabilities.update({provider_capability: provider_value}) 605 | provider = Provider(rsa_key(), provider_capabilities, 606 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')), {}, None) 607 | request = {'redirect_uris': ['https://client.example.com/redirect'], client_preference: client_value} 608 | response = provider.handle_client_registration_request(json.dumps(request)) 609 | expected_values = set(client_value).intersection(provider_value) 610 | assert Counter(frozenset(v.split()) for v in response[client_preference]) == \ 611 | Counter(frozenset(v.split()) for v in expected_values) 612 | 613 | def test_match_space_separated_response_type_without_order(self): 614 | registration_request = {'redirect_uris': ['https://client.example.com/redirect'], 615 | 'response_types': ['id_token token']} 616 | # should not raise an exception 617 | assert self.provider.handle_client_registration_request(json.dumps(registration_request)) 618 | 619 | def test_client_can_use_registered_space_separated_response_type_in_authentication_request(self): 620 | response_type = 'id_token token' 621 | registration_request = {'redirect_uris': ['https://client.example.com/redirect'], 622 | 'response_types': [response_type]} 623 | registration_response = self.provider.handle_client_registration_request(json.dumps(registration_request)) 624 | 625 | authentication_request = {'client_id': registration_response['client_id'], 626 | 'redirect_uri': 'https://client.example.com/redirect', 627 | 'scope': 'openid', 628 | 'nonce': 'nonce', 629 | 'response_type': response_type} 630 | # should not raise an exception 631 | assert self.provider.parse_authentication_request(urlencode(authentication_request)) 632 | 633 | 634 | class TestProviderProviderConfiguration(object): 635 | def test_provider_configuration(self): 636 | config = provide_configuration() 637 | config.update({'foo': 'bar', 'abc': 'xyz'}) 638 | provider = Provider(None, config, None, None, None) 639 | provider_config = provider.provider_configuration 640 | assert all(k in provider_config for k in config) 641 | 642 | 643 | class TestProviderJWKS(object): 644 | def test_jwks(self): 645 | provider = Provider(rsa_key(), provide_configuration(), None, None, None) 646 | assert provider.jwks == {'keys': [provider.signing_key.serialize()]} 647 | 648 | 649 | @pytest.mark.usefixtures('inject_provider') 650 | class TestRPInitiatedLogout(object): 651 | def test_logout_user_with_subject_identifier(self): 652 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1', 653 | redirect_uri='https://client.example.com/redirect') 654 | auth_resp = self.provider.authorize(auth_req, 'user1') 655 | 656 | id_token = IdToken().from_jwt(auth_resp['id_token'], key=[self.provider.signing_key]) 657 | self.provider.logout_user(subject_identifier=id_token['sub']) 658 | with pytest.raises(InvalidAccessToken): 659 | self.provider.authz_state.introspect_access_token(auth_resp['access_token']) 660 | with pytest.raises(InvalidAuthorizationCode): 661 | self.provider.authz_state.exchange_code_for_token(auth_resp['code']) 662 | 663 | def test_logout_user_with_id_token_hint(self): 664 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1', 665 | redirect_uri='https://client.example.com/redirect') 666 | auth_resp = self.provider.authorize(auth_req, 'user1') 667 | 668 | self.provider.logout_user(end_session_request=EndSessionRequest(id_token_hint=auth_resp['id_token'])) 669 | with pytest.raises(InvalidAccessToken): 670 | self.provider.authz_state.introspect_access_token(auth_resp['access_token']) 671 | with pytest.raises(InvalidAuthorizationCode): 672 | self.provider.authz_state.exchange_code_for_token(auth_resp['code']) 673 | 674 | def test_logout_user_with_unknown_subject_identifier(self): 675 | with pytest.raises(InvalidSubjectIdentifier): 676 | self.provider.logout_user(subject_identifier='unknown') 677 | 678 | def test_post_logout_redirect(self): 679 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1', 680 | redirect_uri='https://client.example.com/redirect') 681 | auth_resp = self.provider.authorize(auth_req, 'user1') 682 | end_session_request = EndSessionRequest(id_token_hint=auth_resp['id_token'], 683 | post_logout_redirect_uri='https://client.example.com/post_logout', 684 | state='state') 685 | redirect_url = self.provider.do_post_logout_redirect(end_session_request) 686 | assert redirect_url == EndSessionResponse(state='state').request('https://client.example.com/post_logout') 687 | 688 | def test_post_logout_redirect_without_post_logout_redirect_uri(self): 689 | assert self.provider.do_post_logout_redirect(EndSessionRequest()) is None 690 | 691 | def test_post_logout_redirect_with_unknown_client_for_post_logout_redirect_uri(self): 692 | end_session_request = EndSessionRequest(post_logout_redirect_uri='https://client.example.com/post_logout') 693 | assert self.provider.do_post_logout_redirect(end_session_request) is None 694 | 695 | def test_post_logout_redirect_with_unknown_post_logout_redirect_uri(self): 696 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1', 697 | redirect_uri='https://client.example.com/redirect') 698 | auth_resp = self.provider.authorize(auth_req, 'user1') 699 | end_session_request = EndSessionRequest(id_token_hint=auth_resp['id_token'], 700 | post_logout_redirect_uri='https://client.example.com/unknown') 701 | assert self.provider.do_post_logout_redirect(end_session_request) is None 702 | -------------------------------------------------------------------------------- /tests/pyop/test_storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | 5 | from abc import ABC, abstractmethod 6 | from contextlib import contextmanager 7 | from redis.client import Redis 8 | import datetime 9 | import fakeredis 10 | import mongomock 11 | import pymongo 12 | import time 13 | 14 | import pyop.storage 15 | 16 | __author__ = 'lundberg' 17 | 18 | 19 | db_specs_list = [ 20 | {"uri": "mongodb://localhost:1234/pyop", "name": "pyop"}, 21 | {"uri": "redis://localhost/0", "name": 0}, 22 | ] 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def mock_redis(monkeypatch): 27 | def mockreturn(*args, **kwargs): 28 | return fakeredis.FakeStrictRedis(*args, **kwargs) 29 | monkeypatch.setattr(Redis, "from_url", mockreturn) 30 | 31 | @pytest.fixture(autouse=True) 32 | def mock_mongo(): 33 | pymongo.MongoClient = mongomock.MongoClient 34 | 35 | 36 | class TestStorage(object): 37 | @pytest.fixture(params=db_specs_list) 38 | def db(self, request): 39 | return pyop.storage.StorageBase.from_uri( 40 | request.param["uri"], db_name=request.param["name"], collection="test" 41 | ) 42 | 43 | def test_write(self, db): 44 | db['foo'] = 'bar' 45 | assert db['foo'] == 'bar' 46 | 47 | def test_multilevel_dict(self, db): 48 | db['foo'] = {} 49 | assert db['foo'] == {} 50 | db['foo'] = {'bar': 'baz'} 51 | assert db['foo']['bar'] == 'baz' 52 | 53 | def test_contains(self, db): 54 | db['foo'] = 'bar' 55 | assert 'foo' in db 56 | 57 | def test_pop(self, db): 58 | db['foo'] = 'bar' 59 | assert db.pop('foo') == 'bar' 60 | try: 61 | db['foo'] 62 | except Exception as e: 63 | assert isinstance(e, KeyError) 64 | 65 | def test_items(self, db): 66 | db['foo'] = 'foorbar' 67 | db['bar'] = True 68 | db['baz'] = {'foo': 'bar'} 69 | for key, item in db.items(): 70 | assert key 71 | assert item 72 | 73 | @pytest.mark.parametrize( 74 | "args,kwargs", 75 | [ 76 | (["mongodb://localhost/pyop"], {"collection": "test", "ttl": 3}), 77 | (["mongodb://localhost"], {"db_name": "pyop", "collection": "test"}), 78 | (["mongodb://localhost", "test", "pyop"], {}), 79 | (["mongodb://localhost/pyop", "test"], {}), 80 | (["mongodb://localhost/pyop"], {"db_name": "other", "collection": "test"}), 81 | (["redis://localhost"], {"collection": "test"}), 82 | (["redis://localhost", "test"], {}), 83 | (["redis://localhost"], {"db_name": 2, "collection": "test"}), 84 | (["unix://localhost/0"], {"collection": "test", "ttl": 3}), 85 | ], 86 | ) 87 | def test_from_uri(self, args, kwargs): 88 | store = pyop.storage.StorageBase.from_uri(*args, **kwargs) 89 | store["test"] = "value" 90 | assert store["test"] == "value" 91 | 92 | @pytest.mark.parametrize( 93 | "error,args,kwargs", 94 | [ 95 | (ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}), 96 | ( 97 | TypeError, 98 | ["mongodb://localhost", "ouch"], 99 | {"db_name": 3, "collection": "test", "ttl": None}, 100 | ), 101 | ( 102 | TypeError, 103 | ["mongodb://localhost", "ouch"], 104 | {"db_name": "pyop", "collection": "test", "ttl": None}, 105 | ), 106 | ( 107 | TypeError, 108 | ["mongodb://localhost", "pyop"], 109 | {"collection": "test", "ttl": None}, 110 | ), 111 | ( 112 | TypeError, 113 | ["redis://localhost", "ouch"], 114 | {"db_name": 3, "collection": "test", "ttl": None}, 115 | ), 116 | (TypeError, ["redis://localhost/0"], {}), 117 | (TypeError, ["redis://localhost/0"], {"db_name": "pyop"}), 118 | ], 119 | ) 120 | def test_from_uri_invalid_parameters(self, error, args, kwargs): 121 | with pytest.raises(error): 122 | pyop.storage.StorageBase.from_uri(*args, **kwargs) 123 | 124 | 125 | class StorageTTLTest(ABC): 126 | def prepare_db(self, uri, ttl): 127 | self.db = pyop.storage.StorageBase.from_uri( 128 | uri, 129 | collection="test", 130 | ttl=ttl, 131 | ) 132 | self.db["foo"] = {"bar": "baz"} 133 | 134 | @abstractmethod 135 | def set_time(self, offset, monkey): 136 | pass 137 | 138 | @contextmanager 139 | def adjust_time(self, offset): 140 | mp = pytest.MonkeyPatch() 141 | try: 142 | yield self.set_time(offset, mp) 143 | finally: 144 | mp.undo() 145 | 146 | def execute_ttl_test(self, uri, ttl): 147 | self.prepare_db(uri, ttl) 148 | assert self.db["foo"] 149 | with self.adjust_time(offset=int(ttl / 2)): 150 | assert self.db["foo"] 151 | with self.adjust_time(offset=int(ttl * 2)): 152 | with pytest.raises(KeyError): 153 | self.db["foo"] 154 | 155 | @pytest.mark.parametrize("spec", db_specs_list) 156 | @pytest.mark.parametrize("ttl", ["invalid", -1, 2.3, {}]) 157 | def test_invalid_ttl(self, spec, ttl): 158 | with pytest.raises(ValueError): 159 | self.prepare_db(spec["uri"], ttl) 160 | 161 | 162 | class TestRedisTTL(StorageTTLTest): 163 | def set_time(self, offset, monkeypatch): 164 | now = time.time() 165 | def new_time(): 166 | return now + offset 167 | 168 | monkeypatch.setattr(time, "time", new_time) 169 | 170 | def test_ttl(self): 171 | self.execute_ttl_test("redis://localhost/0", 3600) 172 | 173 | def test_missing_module(self): 174 | pyop.storage._has_redis = False 175 | self.prepare_db("mongodb://localhost/0", None) 176 | with pytest.raises(ImportError): 177 | self.prepare_db("redis://localhost/0", None) 178 | pyop.storage._has_redis = True 179 | 180 | 181 | class TestMongoTTL(StorageTTLTest): 182 | def set_time(self, offset, monkeypatch): 183 | now = datetime.datetime.utcnow() 184 | def new_time(): 185 | return now + datetime.timedelta(seconds=offset) 186 | 187 | monkeypatch.setattr(mongomock, "utcnow", new_time) 188 | 189 | def test_ttl(self): 190 | self.execute_ttl_test("mongodb://localhost/pyop", 3600) 191 | 192 | def test_missing_module(self): 193 | pyop.storage._has_pymongo = False 194 | self.prepare_db("redis://localhost/0", None) 195 | with pytest.raises(ImportError): 196 | self.prepare_db("mongodb://localhost/0", None) 197 | pyop.storage._has_pymongo = True 198 | 199 | 200 | class TestStatelessWrapper(object): 201 | @pytest.fixture 202 | def db(self): 203 | return pyop.storage.StatelessWrapper("pyop", "abc123") 204 | 205 | def test_write(self, db): 206 | db['foo'] = 'bar' 207 | assert db['foo'] is None 208 | 209 | def test_pack_and_unpack(self, db): 210 | val_1 = {'foo': 'bar'} 211 | key = db.pack(val_1) 212 | val_2 = db[key] 213 | assert val_1 == val_2 214 | 215 | def test_pack_with_non_dict_val(self, db): 216 | val_1 = 'this is not a dict' 217 | key = db.pack(val_1) 218 | val_2 = db[key] 219 | assert val_1 == val_2 220 | 221 | def test_contains(self, db): 222 | val_1 = {'foo': 'bar'} 223 | key = db.pack(val_1) 224 | assert key in db 225 | 226 | def test_items(self, db): 227 | with pytest.raises(NotImplementedError): 228 | db['foo'] = 'bar' 229 | db.items() 230 | 231 | def test_delitem(self, db): 232 | with pytest.raises(NotImplementedError): 233 | db['foo'] = 'bar' 234 | del db['foo'] -------------------------------------------------------------------------------- /tests/pyop/test_util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pyop.message import AuthorizationRequest 4 | from pyop.util import should_fragment_encode 5 | 6 | 7 | class TestShouldFragmentEncode(object): 8 | @pytest.mark.parametrize('response_type, expected', [ 9 | ('code', False), 10 | ('id_token', True), 11 | ('id_token token', True), 12 | ('code id_token', True), 13 | ('code token', True), 14 | ('code id_token token', True), 15 | ]) 16 | def test_by_response_type(self, response_type, expected): 17 | auth_req = {'response_type': response_type} 18 | assert should_fragment_encode(AuthorizationRequest(**auth_req)) is expected 19 | -------------------------------------------------------------------------------- /tests/test_requirements.txt: -------------------------------------------------------------------------------- 1 | pytest >= 6.2 2 | pip >= 19.0 3 | responses 4 | pycryptodomex 5 | fakeredis 6 | mongomock 7 | pymongo 8 | redis 9 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py34, py35 3 | 4 | [testenv] 5 | commands = py.test -rs -vvv tests/ 6 | deps = -rtests/test_requirements.txt 7 | --------------------------------------------------------------------------------