├── .bumpversion.cfg
├── .gitignore
├── .travis.yml
├── Dockerfile
├── LICENSE
├── NOTICE
├── README.md
├── application.cfg
├── docker
├── setup.sh
└── start.sh
├── example
├── README.md
├── __init__.py
├── app.py
├── application.cfg
├── https.crt
├── https.key
├── requirements.txt
├── signing_key.pem
├── templates
│ └── logout.jinja2
├── views.py
└── wsgi.py
├── setup.py
├── signing_key.pem
├── src
└── pyop
│ ├── __init__.py
│ ├── access_token.py
│ ├── authz_state.py
│ ├── client_authentication.py
│ ├── crypto.py
│ ├── exceptions.py
│ ├── message.py
│ ├── provider.py
│ ├── request_validator.py
│ ├── storage.py
│ ├── subject_identifier.py
│ ├── userinfo.py
│ └── util.py
├── tests
├── pyop
│ ├── test_access_token.py
│ ├── test_authz_state.py
│ ├── test_client_authentication.py
│ ├── test_exceptions.py
│ ├── test_provider.py
│ ├── test_stateless_provider.py
│ ├── test_storage.py
│ └── test_util.py
└── test_requirements.txt
└── tox.ini
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 2.0.8
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 |
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .cache/
3 | .tox/
4 | __pycache__
5 | *.egg-info
6 | build/
7 | dist/
8 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | os: linux
2 | dist: bionic
3 | language: python
4 |
5 | services:
6 | - docker
7 | - mongodb
8 |
9 | install:
10 | - pip install tox
11 | - pip install tox-travis
12 |
13 | script:
14 | - tox
15 |
16 | jobs:
17 | allow_failures:
18 | - python: 3.9-dev
19 | include:
20 | - python: 3.6
21 | - python: 3.7
22 | - python: 3.8
23 | - python: pypy3
24 |
25 | - stage: Deploy new release
26 | script: skip
27 | deploy:
28 | - provider: pypi
29 | distributions: sdist bdist_wheel
30 | skip_existing: true
31 | user: Lundberg
32 | password:
33 | secure: H5d+Its9YTMSvVddRWX2qgChMb8Eur5zI+qRy3NAPdwRNs1RNyIk1a2z9/EPFmRIu6OsBBDcHsCiq4VXcwvpigdAqMu4iAoZ/Xe0xf88k21GCggfaAPbINRVL6031RFUQkfGZ4abT2cXnerDylMv2DporPZkfCEUJonq+we0GmtJHoCSemXewMxt28TSu0aPKRL4aBfbuRoAPx50jUns9ekxgc0sqpSLvE5qyxWxXIePK0/+8tX3OrdCcKMg/IshgoK7Yondu+DhN+qhf+AkQuPDXUQTx/TKdg/YDVqj8SHT6hIFFi6dCakuhkYIKlkggnSguLhZ2zhVUjYFt1f0NOv2j7dHuKxyUFR9Qm/49rdY/E3ir3CU5YgUEprcgo/jj5K3B1/jY2uXNez1JD97RC6IAPg4o+PwenVQ9a3pLwqnImSaJKPTQf9IyFfrV/xru3ZyQiftmUmCYtCPybDATOq5iqNAQa9Ec0Mg54OGcabPQkNp9CrNFkcO0sM3VNRnGTmuqdYIkjNxPwNCzjbAQKlwcXVNg48kHjH6vb9D+mxjt9CYwCJdfkGm2F2pekr5S8tDdLAkxE9VW+r7SrsRaJRFCHU+6AaejnWvOLCy2S+KJ0JhQesJm0k1iT2fsC8v92MKIzghrY/sqKck33pxB57cFqxLIQIVYgCaBFWbwfc=
34 | on:
35 | tags: true
36 | repo: IdentityPython/pyop
37 | if: tag IS present
38 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | ENV DEBIAN_FRONTEND noninteractive
4 |
5 | RUN /bin/echo -e "deb http://se.archive.ubuntu.com/ubuntu xenial main restricted universe\ndeb http://archive.ubuntu.com/ubuntu xenial-updates main restricted universe\ndeb http://security.ubuntu.com/ubuntu xenial-security main restricted universe" > /etc/apt/sources.list
6 |
7 | RUN apt-get update && \
8 | apt-get -y dist-upgrade && \
9 | apt-get -y install \
10 | python3-pip \
11 | python-virtualenv \
12 | libpython3-dev \
13 | python-setuptools \
14 | build-essential \
15 | libffi-dev \
16 | libssl-dev \
17 | iputils-ping \
18 | && apt-get clean
19 |
20 | RUN rm -rf /var/lib/apt/lists/*
21 |
22 | RUN adduser --system --no-create-home --shell /bin/false --group pyop
23 |
24 | COPY . /opt/pyop/src/
25 | COPY docker/setup.sh /opt/pyop/setup.sh
26 | COPY docker/start.sh /start.sh
27 | RUN /opt/pyop/setup.sh
28 |
29 | # Add Dockerfile to the container as documentation
30 | COPY Dockerfile /Dockerfile
31 |
32 | WORKDIR /
33 |
34 | EXPOSE 9090
35 |
36 | CMD ["bash", "/start.sh"]
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownershUpdaip of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2016 Umeå universitet
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pyOP
2 | [](https://travis-ci.org/IdentityPython/pyop)
3 | [](https://pypi.python.org/pypi/pyop)
4 |
5 |
6 | OpenID Connect Provider (OP) library in Python.
7 | Uses [pyoidc](https://github.com/rohe/pyoidc/) and
8 | [pyjwkest](https://github.com/rohe/pyjwkest).
9 |
10 | # Provider implementations using pyOP
11 | * [se-leg-op](https://github.com/SUNET/se-leg-op)
12 | * [SATOSA OIDC frontend](https://github.com/its-dirg/SATOSA/blob/master/src/satosa/frontends/openid_connect.py)
13 | * [local example](example/views.py)
14 |
15 | # Introduction
16 |
17 | pyOP is a high-level library intended to be usable in any web server application.
18 | By only providing the core functionality for OpenID Connect the application can freely choose to implement any kind of
19 | authentication mechanisms, while pyOP provides a simple interface for the OpenID Connect messages to send back to
20 | clients.
21 |
22 | ## OpenID Connect support
23 | * [Dynamic Provider Discovery](https://openid.net/specs/openid-connect-discovery-1_0.html)
24 | * [Dynamic Client Registration](https://openid.net/specs/openid-connect-registration-1_0.html)
25 | * [Core](http://openid.net/specs/openid-connect-core-1_0.html)
26 | * [Authorization Code Flow](http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps)
27 | * [Implicit Flow](http://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth)
28 | * [Hybrid Flow](http://openid.net/specs/openid-connect-core-1_0.html#HybridFlowAuth)
29 | * Claims
30 | * [Requesting Claims using Scope Values](http://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims)
31 | * [Claims Request Parameter](http://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter)
32 | * Crypto support
33 | * Currently only supports issuing signed ID Tokens
34 |
35 | # Configuration
36 | The provider instance can be configured through the provider configuration information. In the following example, a
37 | provider instance is initiated to use a MongoDB instance as its backend storage:
38 |
39 | ```python
40 | from jwkest.jwk import rsa_load, RSAKey
41 |
42 | from pyop.authz_state import AuthorizationState
43 | from pyop.provider import Provider
44 | from pyop.storage import MongoWrapper
45 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory
46 | from pyop.userinfo import Userinfo
47 |
48 | signing_key = RSAKey(key=rsa_load('signing_key.pem'), use='sig', alg='RS256')
49 | configuration_information = {
50 | 'issuer': 'https://example.com',
51 | 'authorization_endpoint': 'https://example.com/authorization',
52 | 'token_endpoint': 'https://example.com/token',
53 | 'userinfo_endpoint': 'https://example.com/userinfo',
54 | 'registration_endpoint': 'https://example.com/registration',
55 | 'response_types_supported': ['code', 'id_token token'],
56 | 'id_token_signing_alg_values_supported': [signing_key.alg],
57 | 'response_modes_supported': ['fragment', 'query'],
58 | 'subject_types_supported': ['public', 'pairwise'],
59 | 'grant_types_supported': ['authorization_code', 'implicit'],
60 | 'claim_types_supported': ['normal'],
61 | 'claims_parameter_supported': True,
62 | 'claims_supported': ['sub', 'name', 'given_name', 'family_name'],
63 | 'request_parameter_supported': False,
64 | 'request_uri_parameter_supported': False,
65 | 'scopes_supported': ['openid', 'profile']
66 | }
67 |
68 | subject_id_factory = HashBasedSubjectIdentifierFactory(sub_hash_salt)
69 | authz_state = AuthorizationState(subject_id_factory,
70 | MongoWrapper(db_uri, 'provider', 'authz_codes'),
71 | MongoWrapper(db_uri, 'provider', 'access_tokens'),
72 | MongoWrapper(db_uri, 'provider', 'refresh_tokens'),
73 | MongoWrapper(db_uri, 'provider', 'subject_identifiers'))
74 | client_db = MongoWrapper(db_uri, 'provider', 'clients')
75 | user_db = MongoWrapper(db_uri, 'provider', 'users')
76 | provider = Provider(signing_key, configuration_information, authz_state, client_db, Userinfo(user_db))
77 | ```
78 |
79 | where `db_uri` is the [MongoDB connection URI](https://docs.mongodb.com/manual/reference/connection-string/) and
80 | `sub_hash_salt` is a secret string to use as a salt when creating hash based subject identifiers.
81 |
82 | ## Token lifetimes
83 | The ID token lifetime (in seconds) can be supplied to the `Provider` constructor with `id_token_lifetime`, e.g.:
84 |
85 | ```python
86 | Provider(..., id_token_lifetime=600)
87 | ```
88 | If not specified it will default to 1 hour.
89 |
90 | The lifetime of authorization codes, access tokens, and refresh tokens is configured in the `AuthorizationState`, e.g.:
91 |
92 | ```python
93 | AuthorizationState(..., authorization_code_lifetime=300, access_token_lifetime=60*60*24,
94 | refresh_token_lifetime=60*60*24*365, refresh_token_threshold=None)
95 | ```
96 |
97 | If not specified the lifetimes defaults to the following values:
98 | * Authorization codes are valid for 10 minutes.
99 | * Access tokens are valid for 1 hour.
100 | * Refresh tokens are not issued.
101 |
102 | To make sure refresh tokens are issued in response to code exchange token requests, specify a
103 | `refresh_token_lifetime` > 0.
104 | To make sure refresh tokens are renewed if they are close to expiry in response to refresh token requests,
105 | specify a `refresh_token_threshold` > 0.
106 |
107 | # Dynamic discovery: Provider Configuration Information
108 | To publish the provider configuration information at an endpoint, use `Provider.provider_configuration`.
109 |
110 | The following example illustrates the high-level idea:
111 |
112 | ```python
113 | @app.route('/.well-known/openid-configuration')
114 | def provider_config():
115 | return HTTPResponse(provider.provider_configuration.to_json(), content_type="application/json")
116 | ```
117 |
118 | # Authorization endpoint
119 | An incoming authentication request can be validated by the provider using `Provider.parse_authentication_request`.
120 | If the request is valid, it should be stored and associated with the current user session to be able to retrieve it
121 | when the end-user authentication is completed.
122 |
123 | ```python
124 | from pyop.exceptions import InvalidAuthenticationRequest
125 |
126 | @app.route('/authorization')
127 | def authorization_endpoints(request):
128 | try:
129 | authn_req = provider.parse_authentication_request(request)
130 | except InvalidAuthenticationRequest as e:
131 | error_url = e.to_error_url()
132 |
133 | if error_url:
134 | return HTTPResponse(error_url, status=303)
135 | else:
136 | return HTTPResponse("Something went wrong: {}".format(str(e)), status=400)
137 |
138 | session['authn_req'] = authn_req.to_dict()
139 | // TODO initiate end-user authentication
140 | ```
141 |
142 | When the authentication is completed by the user, the provider must be notified to make an authentication response
143 | to the client's 'redirect_uri'. This is done with `Provider.authorize`, where the local user id supplied must exist
144 | in the user database supplied on initialization. When using the included `MongoWrapper`, no mapping is done between
145 | user data and OpenID Connect claim names. Hence the underlying data source must contain the user information under the
146 | same names as the [standard claims of OpenID Connect](http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims).
147 |
148 | ```python
149 | from pyop.message import AuthorizationRequest
150 |
151 | from pyop.util import should_fragment_encode
152 |
153 | authn_req = session['authn_req']
154 | authn_response = provider.authorize(AuthorizationRequest().from_dict(authn_req), user_id)
155 | return_url = authn_response.request(authn_req['redirect_uri'], should_fragment_encode(authn_req))
156 |
157 | return HTTPResponse(return_url, status=303)
158 | ```
159 |
160 | ## Authentication request validation
161 | The provider instance is by default configured to validate authentication requests according to the OpenID Connect
162 | Core specification. If you need to add additional custom validation of authentication requests, that's possible by
163 | adding such validation functions to the list of authentication request validators.
164 |
165 | In this example an additional validator that checks that the 'nonce' parameter is included in all requests is added:
166 |
167 | ```python
168 | from pyop.exceptions import InvalidAuthenticationRequest
169 |
170 | def request_contains_nonce(authentication_request):
171 | if 'nonce' not in authentication_request:
172 | raise InvalidAuthenticationRequest('The request does not contain a nonce', authentication_request,
173 | oauth_error='invalid_request')
174 |
175 | provider.authentication_request_validators.append(request_contains_nonce)
176 | ```
177 |
178 | # Token endpoint
179 | An incoming token request is processed by `Provider.handle_token_request`. It will validate the request and issue all
180 | necessary tokens (access token and possibly refresh token)
181 |
182 | ```python
183 | from oic.oic.message import TokenErrorResponse
184 |
185 | from pyop.exceptions import InvalidClientAuthentication
186 | from pyop.exceptions import OAuthError
187 |
188 | @app.route('/token', methods=['POST', 'GET'])
189 | def token_endpoint(request):
190 | try:
191 | token_response = provider.handle_token_request(request.get_data().decode('utf-8'),
192 | request.headers)
193 | return HTTPResponse(token_response.to_json(), content_type='application/json')
194 | except InvalidClientAuthentication as e:
195 | error_resp = TokenErrorResponse(error='invalid_client', error_description=str(e))
196 | http_response = HTTPResponse(error_resp.to_json(), status=401, content_type='application/json')
197 | http_response.headers['WWW-Authenticate'] = 'Basic'
198 | return http_response
199 | except OAuthError as e:
200 | error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e))
201 | return HTTPResponse(error_resp.to_json(), status=400, content_type='application/json')
202 | ```
203 |
204 |
205 | # Userinfo endpoint
206 | An incoming userinfo request is processed by `Provider.handle_userinfo_request`. It will validate the request and return
207 | all requested userinfo.
208 |
209 | ```python
210 | from oic.oic.message import UserInfoErrorResponse
211 |
212 | from pyop.access_token import AccessToken
213 | from pyop.exceptions import BearerTokenError
214 | from pyop.exceptions import InvalidAccessToken
215 |
216 | @app.route('/userinfo', methods=['GET', 'POST'])
217 | def userinfo_endpoint(request):
218 | try:
219 | response = provider.handle_userinfo_request(request.get_data().decode('utf-8'),
220 | request.headers)
221 | return HTTPResponse(response.to_json(), content_type='application/json')
222 | except (BearerTokenError, InvalidAccessToken) as e:
223 | error_resp = UserInfoErrorResponse(error='invalid_token', error_description=str(e))
224 | http_response = HTTPResponse(error_resp.to_json(), status=401, content_type='application/json')
225 | http_response.headers['WWW-Authenticate'] = AccessToken.BEARER_TOKEN_TYPE
226 | return http_response
227 | ```
228 |
229 |
230 | # Dynamic client registration
231 |
232 | An incoming client registration request is process by `Provider.handle_client_registration_request`. It will validate the request,
233 | store the registered metadata and issue new client credentials.
234 |
235 | ```python
236 | from pyop.exceptions import InvalidClientRegistrationRequest
237 |
238 | @app.route('/registration', methods=['POST'])
239 | def registration_endpoint(request):
240 | try:
241 | response = provider.handle_client_registration_request(request.get_data().decode('utf-8'))
242 | return HTTPResponse(response.to_json(), status=201, content_type='application/json')
243 | except InvalidClientRegistrationRequest as e:
244 | return HTTPResponse(e.to_json(), status=400, content_type='application/json')
245 | ```
246 |
247 | ## Registration request validation
248 | The provider instance is by default configured to validate registration requests according to the OpenID Connect
249 | Dynamic Registration specification. If you need to add additional custom validation of registration requests, that's
250 | possible by adding such validation functions to the list of registration request validators.
251 |
252 | In this example an additional validator that checks that the 'software_statement' parameter is included in all requests
253 | is added:
254 |
255 | ```python
256 | def request_contains_software_statement(registration_request):
257 | if 'software_statement' not in registration_request:
258 | raise InvalidClientRegistrationRequest('The request does not contain a software_statement', registration_request,
259 | oauth_error='invalid_request')
260 |
261 | provider.registration_request_validators.append(request_contains_software_statement)
262 | ```
263 |
264 | # User logout
265 |
266 | RP-initiated logout, as described in [Section 5 of OpenID Connect Session Management](http://openid.net/specs/openid-connect-session-1_0.html#RPLogout)
267 | is supported. The parsed request should be passed to `Provider.logout_user` together with any known subject identifier
268 | for the user, and then `Provider.do_post_logout_redirect` should be called do obey any valid `post_logout_redirect_uri`
269 | included in the request.
270 |
271 | ```python
272 | from oic.oic.message import EndSessionRequest
273 |
274 | from pyop.exceptions import InvalidSubjectIdentifier
275 |
276 | @app.route('/logout')
277 | def end_session_endpoint(request):
278 | end_session_request = EndSessionRequest().deserialize(request.get_data().decode('utf-8'))
279 |
280 | try:
281 | provider.logout_user(session.get('sub'), end_session_request)
282 | except InvalidSubjectIdentifier as e:
283 | return HTTPResponse('Logout unsuccessful!', content_type='text/html', status=400)
284 |
285 | # TODO automagic logout, should ask user first!
286 | redirect_url = provider.do_post_logout_redirect(end_session_request)
287 | if redirect_url:
288 | return HTTPResponse(redirect_url, status=303)
289 |
290 | return HTTPResponse('Logout successful!', content_type='text/html')
291 | ```
292 |
293 | # Exceptions
294 | All exceptions, except `AuthorizationError`, inherits from `ValueError`. However it might be necessary to distinguish
295 | between them to send the correct error message back to the client according to the OpenID Connect specifications.
296 |
297 | All OAuth errors contain the OAuth error code in `OAuthError.oauth_error`, together with the error description as the
298 | message of the exception (accessed by `str(exception)`).
299 |
--------------------------------------------------------------------------------
/application.cfg:
--------------------------------------------------------------------------------
1 | example/application.cfg
--------------------------------------------------------------------------------
/docker/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Install all requirements
4 | #
5 |
6 | set -e
7 | set -x
8 |
9 | PYPI="https://pypi.nordu.net/simple/"
10 | ping -c 1 -q pypiserver.docker && PYPI="http://pypiserver.docker:8080/simple/"
11 |
12 | echo "#############################################################"
13 | echo "$0: Using PyPi URL ${PYPI}"
14 | echo "#############################################################"
15 |
16 | virtualenv -p python3 /opt/pyop
17 | /opt/pyop/bin/pip install -U pip
18 |
19 | # setup.py points to current directory
20 | # so we need to change to the right one.
21 | cd /opt/pyop/src
22 |
23 | /opt/pyop/bin/python3 setup.py install
24 | /opt/pyop/bin/pip install Flask
25 | /opt/pyop/bin/pip install gunicorn
26 |
--------------------------------------------------------------------------------
/docker/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 | set -x
5 |
6 | . /opt/pyop/bin/activate
7 |
8 | # nice to have in docker run output, to check what
9 | # version of something is actually running.
10 | /opt/pyop/bin/pip freeze
11 |
12 | export PYTHONPATH=/opt/pyop/src
13 |
14 | start-stop-daemon --start \
15 | -c pyop:pyop \
16 | --exec /opt/pyop/bin/gunicorn \
17 | --pidfile /var/run/pyop.pid \
18 | --chdir /opt/pyop/src \
19 | -- \
20 | example.wsgi:app \
21 | -b :9090 \
22 | --certfile example/https.crt \
23 | --keyfile example/https.key
24 |
25 |
--------------------------------------------------------------------------------
/example/README.md:
--------------------------------------------------------------------------------
1 | # pyOP example application
2 | To run the example application, execute the following commands:
3 |
4 | ```bash
5 | cd example/
6 | pip install -r requirements.txt # install the dependencies
7 | gunicorn wsgi:app -b :9090 --certfile https.crt --keyfile https.key # run the application
8 | ```
--------------------------------------------------------------------------------
/example/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IdentityPython/pyop/fab87f9f6193079171fdad0c223c810fe9532dd2/example/__init__.py
--------------------------------------------------------------------------------
/example/app.py:
--------------------------------------------------------------------------------
1 | from flask.app import Flask
2 | from flask.helpers import url_for
3 | from jwkest.jwk import RSAKey, rsa_load
4 |
5 | from pyop.authz_state import AuthorizationState
6 | from pyop.provider import Provider
7 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory
8 | from pyop.userinfo import Userinfo
9 |
10 |
11 | def init_oidc_provider(app):
12 | with app.app_context():
13 | issuer = url_for('oidc_provider.index')[:-1]
14 | authentication_endpoint = url_for('oidc_provider.authentication_endpoint')
15 | jwks_uri = url_for('oidc_provider.jwks_uri')
16 | token_endpoint = url_for('oidc_provider.token_endpoint')
17 | userinfo_endpoint = url_for('oidc_provider.userinfo_endpoint')
18 | registration_endpoint = url_for('oidc_provider.registration_endpoint')
19 | end_session_endpoint = url_for('oidc_provider.end_session_endpoint')
20 |
21 | configuration_information = {
22 | 'issuer': issuer,
23 | 'authorization_endpoint': authentication_endpoint,
24 | 'jwks_uri': jwks_uri,
25 | 'token_endpoint': token_endpoint,
26 | 'userinfo_endpoint': userinfo_endpoint,
27 | 'registration_endpoint': registration_endpoint,
28 | 'end_session_endpoint': end_session_endpoint,
29 | 'scopes_supported': ['openid', 'profile'],
30 | 'response_types_supported': ['code', 'code id_token', 'code token', 'code id_token token'], # code and hybrid
31 | 'response_modes_supported': ['query', 'fragment'],
32 | 'grant_types_supported': ['authorization_code', 'implicit'],
33 | 'subject_types_supported': ['pairwise'],
34 | 'token_endpoint_auth_methods_supported': ['client_secret_basic'],
35 | 'claims_parameter_supported': True
36 | }
37 |
38 | userinfo_db = Userinfo(app.users)
39 | signing_key = RSAKey(key=rsa_load('signing_key.pem'), alg='RS256')
40 | provider = Provider(signing_key, configuration_information,
41 | AuthorizationState(HashBasedSubjectIdentifierFactory(app.config['SUBJECT_ID_HASH_SALT'])),
42 | {}, userinfo_db)
43 |
44 | return provider
45 |
46 |
47 | def oidc_provider_init_app(name=None):
48 | name = name or __name__
49 | app = Flask(name)
50 | app.config.from_pyfile('application.cfg')
51 |
52 | app.users = {'test_user': {'name': 'Testing Name'}}
53 |
54 | from .views import oidc_provider_views
55 | app.register_blueprint(oidc_provider_views)
56 |
57 | # Initialize the oidc_provider after views to be able to set correct urls
58 | app.provider = init_oidc_provider(app)
59 |
60 | return app
61 |
--------------------------------------------------------------------------------
/example/application.cfg:
--------------------------------------------------------------------------------
1 | SERVER_NAME = 'localhost:9090'
2 | SECRET_KEY = 'secret_key'
3 | SESSION_COOKIE_NAME='pyop_session'
4 | SUBJECT_ID_HASH_SALT = 'salt'
5 | PREFERRED_URL_SCHEME = 'https'
--------------------------------------------------------------------------------
/example/https.crt:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIEBjCCAu6gAwIBAgIJAIybVu7kfIK0MA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV
3 | BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
4 | aWRnaXRzIFB0eSBMdGQxGDAWBgNVBAMTD2xva2kuaXRzLnVtdS5zZTAeFw0xNTEy
5 | MTAxNDQyMDFaFw0yNTEyMDcxNDQyMDFaMF8xCzAJBgNVBAYTAkFVMRMwEQYDVQQI
6 | EwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQx
7 | GDAWBgNVBAMTD2xva2kuaXRzLnVtdS5zZTCCASIwDQYJKoZIhvcNAQEBBQADggEP
8 | ADCCAQoCggEBALiLDBwIteIobC+7JHoNeQRrTIbws9BghN4UUzyLo7+xeP9YwHaS
9 | tq6HqYK4cVLyx8k06Siw/4PwqMPNj9/B4f/ZXhEkXgbBP5TP36UgKrUIk4zInRFb
10 | Rjy+DcqjSZdgW1CKBKWJstXjSYen5rPm+voM/0msi164NPcfDMQIZmcQWh0MmEfG
11 | qlvdwTvjdaAQt8p7CGsxIdu4gPfhubknbTQKu+BVq5/RCVP7VU830PSr1RYhthX8
12 | Gt8ir32jEdDdjIrfA/zFx6PChyLkQFXtg/9WymnIM1j2ngNreL2nppwnqMYRnI9i
13 | C/y7MY4al3WeL9IETrtgh1jzXUNpgpJ03B0CAwEAAaOBxDCBwTAdBgNVHQ4EFgQU
14 | cTNphzIIRpQBQ2VT0Vx9xQYzNN0wgZEGA1UdIwSBiTCBhoAUcTNphzIIRpQBQ2VT
15 | 0Vx9xQYzNN2hY6RhMF8xCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRl
16 | MSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxGDAWBgNVBAMTD2xv
17 | a2kuaXRzLnVtdS5zZYIJAIybVu7kfIK0MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcN
18 | AQEFBQADggEBAJYJfUqOPTyZ+tflKoN4l+scIXpBxqtQbjX+MYli6VHpl+M8y163
19 | KsCglXPddL7Z58KBrUDx1m6f7dFQ3PMYn/S2dUcRrNOdfaDKZ5QgyYj/iVr8HSOh
20 | 6i1OtMFaBqW5WyqA5YgvUz63hZ2kDOBHZcEfSn2+roylBUiueV9gFNKDWneNMLo2
21 | PMZxcGWZ3wIQbu9ahakbJUvTigFStKeLoY1A2ZSTH7W4elB5DDxYOKZSzd/KZpfn
22 | /o/Pc7YbbEUYgIyf3QNusdH+t2pw9ZkrlKMhiv9ZAmjAWMDY7O/i3r7u7AkODu7z
23 | OHbH3rJqkbaiS8/q0cMZG6AMUzPRglzsTc4=
24 | -----END CERTIFICATE-----
25 |
--------------------------------------------------------------------------------
/example/https.key:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEowIBAAKCAQEAuIsMHAi14ihsL7skeg15BGtMhvCz0GCE3hRTPIujv7F4/1jA
3 | dpK2roepgrhxUvLHyTTpKLD/g/Cow82P38Hh/9leESReBsE/lM/fpSAqtQiTjMid
4 | EVtGPL4NyqNJl2BbUIoEpYmy1eNJh6fms+b6+gz/SayLXrg09x8MxAhmZxBaHQyY
5 | R8aqW93BO+N1oBC3ynsIazEh27iA9+G5uSdtNAq74FWrn9EJU/tVTzfQ9KvVFiG2
6 | Ffwa3yKvfaMR0N2Mit8D/MXHo8KHIuRAVe2D/1bKacgzWPaeA2t4vaemnCeoxhGc
7 | j2IL/LsxjhqXdZ4v0gROu2CHWPNdQ2mCknTcHQIDAQABAoIBAQCooAWUqDDqUl1o
8 | z+voyt7FtvXaZ58mzMsb0h6suDwMMTKKwKI8tprOp4+wrrB+RvFfXUWftPwFp6XO
9 | JMtOfm7vxcM6jqyMJ5DdfYSx8c6UVR3eCoHbFjf70P3xJ3tbIuTNlw/f4w7Sejj6
10 | B+W6hVjXm4C55TwEdPWQyYJ0rehESxITn7DDjDNXxxwDqwAv8yPTYf8m7mb7qh3V
11 | EBnvZPIWpgEIVqV2crQSfHJwg29KhS7cnExJBYEPppBQ4aoUGyMqJN0EyBL4DrRo
12 | Ds6hPttLaXEmB+ACm/OQzhEeFKduob5OIKRSyp9Z8t6/B9uHiIKCENRU4O4zux57
13 | jZ8MIypRAoGBAO0anHYHlniyP5f/8u4kvTnW3wbDtRJ3L3zVSKN8shQmQnjNWmsI
14 | LhLWLU2OTRsXlJB7oYqNBojUBdeGimYot9kmjx4XkxELB6XAaYDG6pDvM7axC5qH
15 | iuC4jVHdEsIfy9dP6wZ/b+A+JOWWpS1vdAizfvgWI32JLGDPkjjUlzNvAoGBAMdA
16 | F5KZZLZFYsZM470/bb20qFMURDRI+5yz0VUHTNUQEr/xhMcYFske6ox14A7gXzwd
17 | SHAvTDkV8DsGu/FzSWzZmVhNc1EtdM3Cbe3Y7WJoIQoupuxBDEvrBE9riyOjW61q
18 | dYIO2ymfJfc2Vx6d7LAK9itXdqa3RIo7ZPK7EbMzAoGAAR1C5vsaJe8QhXJafewG
19 | R6NO4QVCcJfGzVtjQAFyBM45OcAdUKt1K/l9tQOaMSpnNFagZ7pJ8ZKthFnJhLlk
20 | Q8z+lzGdK1NV8d15oXVN3OiC4bTrTQqeCHhVkbDsSaVEm/pwLFOk/vTLz5hpplED
21 | xpaxXhEckZZ3cu0GzuWQ4FkCgYAb34tspqjAFtTKiNcTElx3vV4OwTcJWWxZb45J
22 | JsxIwgbdcxvv/h6x4/FL1PGTIzAvaKlJiFRRaBBDMZ35GPeckpQxFiSbppBAeIKI
23 | U2Bh888rbXtMcY0W0bm4ooLEaYXZrJrjptBh8jGNc7ycO9twhRgK2CFxERI1hDmK
24 | +0BuoQKBgHLRFdJYFkNKB+j56vTlB5AFTXcjs6x0dZ4j0SDAqVYgIHErpd2bhgOz
25 | s698rfdadwbYN4UEqvV/2NhDLx/jag4BvCermK71HaXvYnvqKw3UHGAG8K/dCCFR
26 | eSuZnxTSoWAPjMLi64hRzbWVT+oeZuy83lG6qDygbI1z1nlYWdo3
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/example/requirements.txt:
--------------------------------------------------------------------------------
1 | pyop
2 | Flask
3 | gunicorn
4 | oic>=1.2.1
5 | setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability
6 | werkzeug>=3.0.1 # not directly required, pinned by Snyk to avoid a vulnerability
7 | requests>=2.32.2 # not directly required, pinned by Snyk to avoid a vulnerability
8 | urllib3>=2.2.2 # not directly required, pinned by Snyk to avoid a vulnerability
9 |
--------------------------------------------------------------------------------
/example/signing_key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIICXAIBAAKBgQCi7nye2Ye1MrUD/sZAplfkpMkXHYduydvfvv/+Ihx1ClxKS/KG
3 | /1EhqyBVVvhHLRs9pimKMyLm2pBE51rGOt//XKhsoFAa37VID2iz7DQuV6DGgyBS
4 | FaKgaYBpinEQy2WcjU4eABnABV2r+K2UmGkqVJheqHqOqHUKasT4gy/6kQIDAQAB
5 | AoGAf7u+YX6ioNCvDwHHBVojn/H8YK3axmVkhiYkZWTysGM99VVTPridL2sMfzse
6 | jBZ1u8Av4tOyMg/5eLtz8+KmRjljpeAEFfsA1htWE8vESXnvDFwKldXD9Vi/kppb
7 | CYqASGCBUX3i1LPYffvjUxIgD+Tjx4k56c5EN5G331flDV0CQQDP8fWraegLJ+K1
8 | iXGNQzpaqG3EI3vf35Yb2bJpmD39QIXFIcJJ5MZHVW+1TyvgiavM4hS2+LGA8kGh
9 | OvMWfbYTAkEAyJWGBUmAW9mooo1Vw4tJjEWAHjHvzcQ4dqIju+WN8Xy5JTWkDD6Z
10 | VgKGtgLt2HfpSsgej14+Rh5mrjo4SbYxSwJAKG0syq9jOk/9xjc7STBJtvhJprkT
11 | SxnHsBBpnBfJ7WNO3l1KzVzZo2Kbvg7vQ87gBIvrZQsCT0RJuBOi0LuN2wJAAInm
12 | Qj1gSt7axRT8FfpZyDankW0w56yPOkJVNjv3lZ5wINl0B1RjtQdstTBs0xf/WGQR
13 | MPFf2XBbdjxRymDi4QJBAM3MUYPOlUk1UVCSQKyCkBwL3zMaPjBjD5LXkhGJzxsb
14 | T74NznwmCib/r0Rl/KmD7/bAq7R4aheOS/OMaZyhbkk=
15 | -----END RSA PRIVATE KEY-----
16 |
--------------------------------------------------------------------------------
/example/templates/logout.jinja2:
--------------------------------------------------------------------------------
1 |
2 |
Logout
3 |
4 | Do you really want to logout?
5 |
--------------------------------------------------------------------------------
/example/views.py:
--------------------------------------------------------------------------------
1 | from urllib.parse import urlencode, parse_qs
2 |
3 | import flask
4 | from flask import Blueprint, redirect
5 | from flask import current_app
6 | from flask import jsonify
7 | from flask.helpers import make_response
8 | from flask.templating import render_template
9 | from oic.oic.message import TokenErrorResponse, UserInfoErrorResponse, EndSessionRequest
10 |
11 | from pyop.access_token import AccessToken, BearerTokenError
12 | from pyop.exceptions import InvalidAuthenticationRequest, InvalidAccessToken, InvalidClientAuthentication, OAuthError, \
13 | InvalidSubjectIdentifier, InvalidClientRegistrationRequest
14 | from pyop.util import should_fragment_encode
15 |
16 | oidc_provider_views = Blueprint('oidc_provider', __name__, url_prefix='')
17 |
18 |
19 | @oidc_provider_views.route('/')
20 | def index():
21 | return 'Hello world!'
22 |
23 |
24 | @oidc_provider_views.route('/registration', methods=['POST'])
25 | def registration_endpoint():
26 | try:
27 | response = current_app.provider.handle_client_registration_request(flask.request.get_data().decode('utf-8'))
28 | return make_response(jsonify(response.to_dict()), 201)
29 | except InvalidClientRegistrationRequest as e:
30 | return make_response(jsonify(e.to_dict()), status=400)
31 |
32 |
33 | @oidc_provider_views.route('/authentication', methods=['GET'])
34 | def authentication_endpoint():
35 | # parse authentication request
36 | try:
37 | auth_req = current_app.provider.parse_authentication_request(urlencode(flask.request.args),
38 | flask.request.headers)
39 | except InvalidAuthenticationRequest as e:
40 | current_app.logger.debug('received invalid authn request', exc_info=True)
41 | error_url = e.to_error_url()
42 | if error_url:
43 | return redirect(error_url, 303)
44 | else:
45 | # show error to user
46 | return make_response('Something went wrong: {}'.format(str(e)), 400)
47 |
48 | # automagic authentication
49 | authn_response = current_app.provider.authorize(auth_req, 'test_user')
50 | response_url = authn_response.request(auth_req['redirect_uri'], should_fragment_encode(auth_req))
51 | return redirect(response_url, 303)
52 |
53 |
54 | @oidc_provider_views.route('/.well-known/openid-configuration')
55 | def provider_configuration():
56 | return jsonify(current_app.provider.provider_configuration.to_dict())
57 |
58 |
59 | @oidc_provider_views.route('/jwks')
60 | def jwks_uri():
61 | return jsonify(current_app.provider.jwks)
62 |
63 |
64 | @oidc_provider_views.route('/token', methods=['POST'])
65 | def token_endpoint():
66 | try:
67 | token_response = current_app.provider.handle_token_request(flask.request.get_data().decode('utf-8'),
68 | flask.request.headers)
69 | return jsonify(token_response.to_dict())
70 | except InvalidClientAuthentication as e:
71 | current_app.logger.debug('invalid client authentication at token endpoint', exc_info=True)
72 | error_resp = TokenErrorResponse(error='invalid_client', error_description=str(e))
73 | response = make_response(error_resp.to_json(), 401)
74 | response.headers['Content-Type'] = 'application/json'
75 | response.headers['WWW-Authenticate'] = 'Basic'
76 | return response
77 | except OAuthError as e:
78 | current_app.logger.debug('invalid request: %s', str(e), exc_info=True)
79 | error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e))
80 | response = make_response(error_resp.to_json(), 400)
81 | response.headers['Content-Type'] = 'application/json'
82 | return response
83 |
84 |
85 | @oidc_provider_views.route('/userinfo', methods=['GET', 'POST'])
86 | def userinfo_endpoint():
87 | try:
88 | response = current_app.provider.handle_userinfo_request(flask.request.get_data().decode('utf-8'),
89 | flask.request.headers)
90 | return jsonify(response.to_dict())
91 | except (BearerTokenError, InvalidAccessToken) as e:
92 | error_resp = UserInfoErrorResponse(error='invalid_token', error_description=str(e))
93 | response = make_response(error_resp.to_json(), 401)
94 | response.headers['WWW-Authenticate'] = AccessToken.BEARER_TOKEN_TYPE
95 | response.headers['Content-Type'] = 'application/json'
96 | return response
97 |
98 |
99 | def do_logout(end_session_request):
100 | try:
101 | current_app.provider.logout_user(end_session_request=end_session_request)
102 | except InvalidSubjectIdentifier as e:
103 | return make_response('Logout unsuccessful!', 400)
104 |
105 | redirect_url = current_app.provider.do_post_logout_redirect(end_session_request)
106 | if redirect_url:
107 | return redirect(redirect_url, 303)
108 |
109 | return make_response('Logout successful!')
110 |
111 |
112 | @oidc_provider_views.route('/logout', methods=['GET', 'POST'])
113 | def end_session_endpoint():
114 | if flask.request.method == 'GET':
115 | # redirect from RP
116 | end_session_request = EndSessionRequest().deserialize(urlencode(flask.request.args))
117 | flask.session['end_session_request'] = end_session_request.to_dict()
118 | return render_template('logout.jinja2')
119 | else:
120 | form = parse_qs(flask.request.get_data().decode('utf-8'))
121 | if 'logout' in form:
122 | return do_logout(EndSessionRequest().from_dict(flask.session['end_session_request']))
123 | else:
124 | return make_response('You chose not to logout')
125 |
--------------------------------------------------------------------------------
/example/wsgi.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from example.app import oidc_provider_init_app
4 |
5 | name = 'oidc_provider'
6 | app = oidc_provider_init_app(name)
7 | logging.basicConfig(level=logging.DEBUG)
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='pyop',
5 | version='3.4.2',
6 | packages=find_packages('src'),
7 | package_dir={'': 'src'},
8 | url='https://github.com/IdentityPython/pyop',
9 | license='Apache 2.0',
10 | author='Rebecka Gulliksson',
11 | author_email='satosa-dev@lists.sunet.se',
12 | description='OpenID Connect Provider (OP) library in Python.',
13 | install_requires=[
14 | 'oic >= 1.2.1',
15 | 'pycryptodomex',
16 | ],
17 | extras_require={
18 | 'mongo': 'pymongo >= 3.12, < 4.0',
19 | 'redis': 'redis',
20 | },
21 | )
22 |
--------------------------------------------------------------------------------
/signing_key.pem:
--------------------------------------------------------------------------------
1 | example/signing_key.pem
--------------------------------------------------------------------------------
/src/pyop/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IdentityPython/pyop/fab87f9f6193079171fdad0c223c810fe9532dd2/src/pyop/__init__.py
--------------------------------------------------------------------------------
/src/pyop/access_token.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from urllib.parse import parse_qsl
3 |
4 | from .exceptions import BearerTokenError
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class AccessToken(object):
10 | """
11 | Representation of an access token.
12 | """
13 | BEARER_TOKEN_TYPE = 'Bearer'
14 |
15 | def __init__(self, value, expires_in, typ=BEARER_TOKEN_TYPE):
16 | self.value = value
17 | self.expires_in = expires_in
18 | self.type = typ
19 |
20 |
21 | def extract_bearer_token_from_http_request(parsed_request=None, authz_header=None):
22 | # type (Optional[Mapping[str, str]], Optional[str] -> str
23 | """
24 | Extracts a Bearer token from an http request
25 | :param parsed_request: parsed request (URL query part of request body)
26 | :param authz_header: HTTP Authorization header
27 | :return: Bearer access token, if found
28 | :raise BearerTokenError: if no Bearer token could be extracted from the request
29 | """
30 | if authz_header:
31 | # Authorization Request Header Field: https://tools.ietf.org/html/rfc6750#section-2.1
32 | if authz_header.startswith(AccessToken.BEARER_TOKEN_TYPE):
33 | access_token = authz_header[len(AccessToken.BEARER_TOKEN_TYPE) + 1:]
34 | logger.debug('found access token %s in authz header', access_token)
35 | return access_token
36 | elif parsed_request:
37 | if 'access_token' in parsed_request:
38 | """
39 | Form-Encoded Body Parameter: https://tools.ietf.org/html/rfc6750#section-2.2, and
40 | URI Query Parameter: https://tools.ietf.org/html/rfc6750#section-2.3
41 | """
42 | access_token = parsed_request['access_token']
43 | logger.debug('found access token %s in request', access_token)
44 | return access_token
45 |
46 | raise BearerTokenError('Bearer Token could not be found in the request')
47 |
--------------------------------------------------------------------------------
/src/pyop/authz_state.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import uuid
4 |
5 | from oic.extension.message import TokenIntrospectionResponse
6 |
7 | from .message import AuthorizationRequest
8 | from .access_token import AccessToken
9 | from .exceptions import InvalidAccessToken
10 | from .exceptions import InvalidAuthorizationCode
11 | from .exceptions import InvalidRefreshToken
12 | from .exceptions import InvalidScope
13 | from .exceptions import InvalidSubjectIdentifier
14 | from .storage import StatelessWrapper
15 | from .util import requested_scope_is_allowed
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def rand_str():
21 | return uuid.uuid4().hex
22 |
23 |
24 | # TODO remove expired/invalid authorization codes/tokens from db's
25 |
26 | class AuthorizationState(object):
27 | KEY_AUTHORIZATION_REQUEST = 'auth_req'
28 | KEY_USER_INFO = 'user_info'
29 | KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'
30 |
31 | def __init__(self, subject_identifier_factory, authorization_code_db=None, access_token_db=None,
32 | refresh_token_db=None, subject_identifier_db=None, *,
33 | authorization_code_lifetime=600, access_token_lifetime=3600, refresh_token_lifetime=None,
34 | refresh_token_threshold=None):
35 | # type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
36 | # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
37 | """
38 | :param subject_identifier_factory: callable to use when construction subject identifiers
39 | :param authorization_code_db: database for storing authorization codes, defaults to in-memory
40 | dict if not specified
41 | :param access_token_db: database for storing access tokens, defaults to in-memory
42 | dict if not specified
43 | :param refresh_token_db: database for storing refresh tokens, defaults to in-memory
44 | dict if not specified
45 | :param subject_identifier_db: database for storing subject identifiers, defaults to in-memory
46 | dict if not specified
47 | :param authorization_code_lifetime: how long before authorization codes should expire (in seconds),
48 | defaults to 10 minutes
49 | :param access_token_lifetime: how long before access tokens should expire (in seconds),
50 | defaults to 1 hour
51 | :param refresh_token_lifetime: how long before refresh tokens should expire (in seconds),
52 | defaults to never issuing a refresh token if not defined
53 | :param refresh_token_threshold: how long before refresh token expiry time a new one should be issued (in
54 | seconds) in a token refresh request, defaults to never issuing a new refresh token
55 | """
56 |
57 | if not subject_identifier_factory:
58 | raise ValueError('subject_identifier_factory can\'t be None')
59 | self._subject_identifier_factory = subject_identifier_factory
60 |
61 | self.authorization_code_lifetime = authorization_code_lifetime
62 | """
63 | Mapping of authorization codes to the subject identifier and auth request.
64 | """
65 | self.authorization_codes = authorization_code_db if authorization_code_db is not None else {}
66 |
67 | self.access_token_lifetime = access_token_lifetime
68 | """
69 | Mapping of access tokens to the scope, token type, client id and subject identifier.
70 | """
71 | self.access_tokens = access_token_db if access_token_db is not None else {}
72 |
73 | self.refresh_token_lifetime = refresh_token_lifetime
74 | self.refresh_token_threshold = refresh_token_threshold
75 | """
76 | Mapping of refresh tokens to access tokens.
77 | """
78 | self.refresh_tokens = refresh_token_db if refresh_token_db is not None else {}
79 |
80 | """
81 | Mapping of user id's to subject identifiers.
82 | """
83 | self.stateless = (
84 | isinstance(self.authorization_codes, StatelessWrapper)
85 | or isinstance(self.access_tokens, StatelessWrapper)
86 | or isinstance(self.refresh_tokens, StatelessWrapper)
87 | )
88 | self.subject_identifiers = (
89 | {}
90 | if self.stateless
91 | else subject_identifier_db
92 | if subject_identifier_db is not None
93 | else {}
94 | )
95 |
96 | def create_authorization_code(
97 | self,
98 | authorization_request,
99 | subject_identifier,
100 | scope=None,
101 | user_info=None,
102 | extra_id_token_claims=None,
103 | ):
104 | # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
105 | """
106 | Creates an authorization code bound to the authorization request and the authenticated user identified
107 | by the subject identifier.
108 | """
109 |
110 | if not self._is_valid_subject_identifier(subject_identifier):
111 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))
112 |
113 | scope = ' '.join(scope or authorization_request['scope'])
114 | logger.debug('creating authz code for scope=%s', scope)
115 |
116 | authz_info = {
117 | 'used': False,
118 | 'exp': int(time.time()) + self.authorization_code_lifetime,
119 | 'sub': subject_identifier,
120 | 'granted_scope': scope,
121 | self.KEY_AUTHORIZATION_REQUEST: authorization_request.to_dict()
122 | }
123 |
124 | if self.stateless:
125 | if user_info:
126 | authz_info[self.KEY_USER_INFO] = user_info
127 | authz_info[self.KEY_EXTRA_ID_TOKEN_CLAIMS] = extra_id_token_claims or {}
128 | authorization_code = self.authorization_codes.pack(authz_info)
129 | else:
130 | authorization_code = rand_str()
131 | self.authorization_codes[authorization_code] = authz_info
132 |
133 | logger.debug('new authz_code=%s to client_id=%s for sub=%s valid_until=%s', authorization_code,
134 | authorization_request['client_id'], subject_identifier, authz_info['exp'])
135 | return authorization_code
136 |
137 | def create_access_token(self, authorization_request, subject_identifier, scope=None, user_info=None):
138 | # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict]) -> se_leg_op.access_token.AccessToken
139 | """
140 | Creates an access token bound to the authentication request and the authenticated user identified by the
141 | subject identifier.
142 | """
143 | if not self._is_valid_subject_identifier(subject_identifier):
144 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))
145 |
146 | scope = scope or authorization_request['scope']
147 |
148 | return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope),
149 | user_info=user_info)
150 |
151 | def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None,
152 | user_info=None):
153 | # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
154 | """
155 | Creates an access token bound to the subject identifier, client id and requested scope.
156 | """
157 | scope = current_scope or granted_scope
158 | logger.debug('creating access token for scope=%s', scope)
159 |
160 | authz_info = {
161 | 'iat': int(time.time()),
162 | 'exp': int(time.time()) + self.access_token_lifetime,
163 | 'sub': subject_identifier,
164 | 'client_id': auth_req['client_id'],
165 | 'aud': [auth_req['client_id']],
166 | 'scope': scope,
167 | 'granted_scope': granted_scope,
168 | 'token_type': AccessToken.BEARER_TOKEN_TYPE,
169 | self.KEY_AUTHORIZATION_REQUEST: auth_req
170 | }
171 |
172 | if self.stateless:
173 | if user_info:
174 | authz_info[self.KEY_USER_INFO] = user_info
175 | access_token_val = self.access_tokens.pack(authz_info)
176 | else:
177 | access_token_val = rand_str()
178 | self.access_tokens[access_token_val] = authz_info
179 |
180 | logger.debug('new access_token=%s to client_id=%s for sub=%s valid_until=%s',
181 | access_token_val, auth_req['client_id'], subject_identifier, authz_info['exp'])
182 | access_token = AccessToken(access_token_val, self.access_token_lifetime)
183 | return access_token
184 |
185 | def exchange_code_for_token(self, authorization_code):
186 | # type: (str) -> se_leg_op.access_token.AccessToken
187 | """
188 | Exchanges an authorization code for an access token.
189 | """
190 | if authorization_code not in self.authorization_codes:
191 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
192 |
193 | authz_info = self.authorization_codes[authorization_code]
194 | if authz_info['used']:
195 | logger.debug('detected already used authz_code=%s', authorization_code)
196 | raise InvalidAuthorizationCode('{} has already been used'.format(authorization_code))
197 | elif authz_info['exp'] < int(time.time()):
198 | logger.debug('detected expired authz_code=%s, now=%s > exp=%s ',
199 | authorization_code, int(time.time()), authz_info['exp'])
200 | raise InvalidAuthorizationCode('{} has expired'.format(authorization_code))
201 |
202 | authz_info['used'] = True
203 | self.authorization_codes[authorization_code] = authz_info
204 |
205 | access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
206 | authz_info['granted_scope'],
207 | user_info=authz_info.get(self.KEY_USER_INFO))
208 |
209 | logger.debug('authz_code=%s exchanged to access_token=%s', authorization_code, access_token.value)
210 | return access_token
211 |
212 | def introspect_access_token(self, access_token_value):
213 | # type: (str) -> Dict[str, Union[str, List[str]]]
214 | """
215 | Returns authorization data associated with the access token.
216 | See "Token Introspection", Section 2.2.
217 | """
218 | if access_token_value not in self.access_tokens:
219 | raise InvalidAccessToken('{} unknown'.format(access_token_value))
220 |
221 | authz_info = self.access_tokens[access_token_value]
222 |
223 | introspection = {'active': authz_info['exp'] >= int(time.time())}
224 |
225 | introspection_params = {k: v for k, v in authz_info.items() if k in TokenIntrospectionResponse.c_param}
226 | introspection.update(introspection_params)
227 | return introspection
228 |
229 | def create_refresh_token(self, access_token_value):
230 | # type: (str) -> str
231 | """
232 | Creates an refresh token bound to the specified access token.
233 | """
234 | if access_token_value not in self.access_tokens:
235 | raise InvalidAccessToken('{} unknown'.format(access_token_value))
236 |
237 | if not self.refresh_token_lifetime:
238 | logger.debug('no refresh token issued for for access_token=%s', access_token_value)
239 | return None
240 |
241 | authz_info = {'access_token': access_token_value, 'exp': int(time.time()) + self.refresh_token_lifetime}
242 |
243 | if self.stateless:
244 | refresh_token = self.refresh_tokens.pack(authz_info)
245 | else:
246 | refresh_token = rand_str()
247 | self.refresh_tokens[refresh_token] = authz_info
248 |
249 | logger.debug('issued refresh_token=%s expiring=%d for access_token=%s', refresh_token, authz_info['exp'],
250 | access_token_value)
251 | return refresh_token
252 |
253 | def use_refresh_token(self, refresh_token, scope=None):
254 | # type (str, Optional[List[str]]) -> Tuple[se_leg_op.access_token.AccessToken, Optional[str]]
255 | """
256 | Creates a new access token, and refresh token, based on the supplied refresh token.
257 | :return: new access token and new refresh token if the old one had an expiration time
258 | """
259 |
260 | if refresh_token not in self.refresh_tokens:
261 | raise InvalidRefreshToken('{} unknown'.format(refresh_token))
262 |
263 | refresh_token_info = self.refresh_tokens[refresh_token]
264 | if 'exp' in refresh_token_info and refresh_token_info['exp'] < int(time.time()):
265 | raise InvalidRefreshToken('{} has expired'.format(refresh_token))
266 |
267 | authz_info = self.access_tokens[refresh_token_info['access_token']]
268 |
269 | if scope:
270 | if not requested_scope_is_allowed(scope, authz_info['granted_scope']):
271 | logger.debug('trying to refresh token with superset scope, requested_scope=%s, granted_scope=%s',
272 | scope, authz_info['granted_scope'])
273 | raise InvalidScope('Requested scope includes non-granted value')
274 | scope = ' '.join(scope)
275 | logger.debug('refreshing token with new scope, old_scope=%s -> new_scope=%s', authz_info['scope'], scope)
276 | else:
277 | # OAuth 2.0: scope: "[...] if omitted is treated as equal to the scope originally granted by the resource owner"
278 | scope = authz_info['granted_scope']
279 |
280 | new_access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
281 | authz_info['granted_scope'], scope,
282 | user_info=authz_info.get(self.KEY_USER_INFO))
283 |
284 | new_refresh_token = None
285 | if self.refresh_token_threshold \
286 | and 'exp' in refresh_token_info \
287 | and refresh_token_info['exp'] - int(time.time()) < self.refresh_token_threshold:
288 | # refresh token is close to expiry, issue a new one
289 | new_refresh_token = self.create_refresh_token(new_access_token.value)
290 | else:
291 | self.refresh_tokens[refresh_token]['access_token'] = new_access_token.value
292 |
293 | logger.debug('refreshed tokens, new_access_token=%s new_refresh_token=%s old_refresh_token=%s',
294 | new_access_token, new_refresh_token, refresh_token)
295 | return new_access_token, new_refresh_token
296 |
297 | def get_subject_identifier(self, subject_type, user_id, sector_identifier=None):
298 | # type: (str, str, str) -> str
299 | """
300 | Returns a subject identifier for the local user identifier.
301 | :param subject_type: 'pairwise' or 'public', see
302 |
303 | "OpenID Connect Core 1.0", Section 8.
304 | :param user_id: local user identifier
305 | :param sector_identifier: the client's sector identifier,
306 | see
307 | "OpenID Connect Core 1.0", Section 1.2
308 | """
309 |
310 | if user_id not in self.subject_identifiers:
311 | self.subject_identifiers[user_id] = {}
312 |
313 | if subject_type == 'public':
314 | if 'public' not in self.subject_identifiers[user_id]:
315 | new_sub = self._subject_identifier_factory.create_public_identifier(user_id)
316 | self.subject_identifiers[user_id] = {'public': new_sub}
317 |
318 | logger.debug('created new public sub=%s for user_id=%s',
319 | self.subject_identifiers[user_id]['public'], user_id)
320 | sub = self.subject_identifiers[user_id]['public']
321 | logger.debug('returning public sub=%s', sub)
322 | return sub
323 | elif subject_type == 'pairwise':
324 | if not sector_identifier:
325 | raise ValueError('sector_identifier cannot be None or empty')
326 |
327 | subject_id = self._subject_identifier_factory.create_pairwise_identifier(user_id, sector_identifier)
328 | logger.debug('returning pairwise sub=%s for user_id=%s and sector_identifier=%s',
329 | subject_id, user_id, sector_identifier)
330 | sub = self.subject_identifiers[user_id]
331 | pairwise_set = set(sub.get('pairwise', []))
332 | pairwise_set.add(subject_id)
333 | sub['pairwise'] = list(pairwise_set)
334 | self.subject_identifiers[user_id] = sub
335 | return subject_id
336 |
337 | raise ValueError('Unknown subject_type={}'.format(subject_type))
338 |
339 | def _is_valid_subject_identifier(self, sub):
340 | # type: (str) -> bool
341 | """
342 | Determines whether the subject identifier is known.
343 | """
344 |
345 | try:
346 | self.get_user_id_for_subject_identifier(sub)
347 | return True
348 | except InvalidSubjectIdentifier:
349 | return False
350 |
351 | def get_user_id_for_subject_identifier(self, subject_identifier):
352 | for user_id, subject_identifiers in self.subject_identifiers.items():
353 | is_public_sub = 'public' in subject_identifiers and subject_identifier == subject_identifiers['public']
354 | is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers['pairwise']
355 | if is_public_sub or is_pairwise_sub:
356 | return user_id
357 |
358 | raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))
359 |
360 | def get_user_info_for_code(self, authorization_code):
361 | # type: (str) -> dict
362 | if authorization_code not in self.authorization_codes:
363 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
364 |
365 | return self.authorization_codes[authorization_code].get(self.KEY_USER_INFO)
366 |
367 | def get_extra_id_token_claims_for_code(self, authorization_code):
368 | # type: (str) -> dict
369 | if authorization_code not in self.authorization_codes:
370 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
371 |
372 | return self.authorization_codes[authorization_code].get(self.KEY_EXTRA_ID_TOKEN_CLAIMS)
373 |
374 | def get_user_info_for_access_token(self, access_token):
375 | # type: (str) -> dict
376 | if access_token not in self.access_tokens:
377 | raise InvalidAccessToken('{} unknown'.format(access_token))
378 |
379 | return self.access_tokens[access_token].get(self.KEY_USER_INFO)
380 |
381 | def get_authorization_request_for_code(self, authorization_code):
382 | # type: (str) -> AuthorizationRequest
383 | if authorization_code not in self.authorization_codes:
384 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
385 |
386 | return AuthorizationRequest().from_dict(
387 | self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST])
388 |
389 | def get_authorization_request_for_access_token(self, access_token_value):
390 | # type: (str) ->
391 | if access_token_value not in self.access_tokens:
392 | raise InvalidAccessToken('{} unknown'.format(access_token_value))
393 |
394 | return AuthorizationRequest().from_dict(self.access_tokens[access_token_value][self.KEY_AUTHORIZATION_REQUEST])
395 |
396 | def get_subject_identifier_for_code(self, authorization_code):
397 | # type: (str) -> AuthorizationRequest
398 | if authorization_code not in self.authorization_codes:
399 | raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
400 |
401 | return self.authorization_codes[authorization_code]['sub']
402 |
403 | def delete_state_for_subject_identifier(self, subject_identifier):
404 | # type (str) -> None
405 | if not self._is_valid_subject_identifier(subject_identifier):
406 | raise InvalidSubjectIdentifier('Trying to delete state for unknown subject identifier')
407 |
408 | for tokens in [self.authorization_codes, self.access_tokens]:
409 | tokens_to_remove = [k for k, v in tokens.items() if v['sub'] == subject_identifier]
410 | for ac in tokens_to_remove:
411 | tokens.pop(ac, None)
412 |
--------------------------------------------------------------------------------
/src/pyop/client_authentication.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import logging
3 | from urllib.parse import unquote
4 |
5 | from .exceptions import InvalidClientAuthentication
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def verify_client_authentication(clients, parsed_request, authz_header=None):
11 | # type: (Mapping[str, str], Mapping[str, Mapping[str, Any]], Optional[str]) -> bool
12 | """
13 | Verifies client authentication at the token endpoint, see
14 | "The OAuth 2.0 Authorization Framework",
15 | Section 2.3.1
16 | :param parsed_request: key-value pairs from parsed urlencoded request
17 | :param clients: clients db
18 | :param authz_header: the HTTP Authorization header value
19 | :return: the unmodified parsed request
20 | :raise InvalidClientAuthentication: if the client authentication was incorrect
21 | """
22 | client_id = None
23 | client_secret = None
24 | authn_method = None
25 | if authz_header:
26 | logger.debug('client authentication in Authorization header %s', authz_header)
27 |
28 | authz_scheme = authz_header.split(maxsplit=1)[0]
29 | if authz_scheme == 'Basic':
30 | authn_method = 'client_secret_basic'
31 | credentials = authz_header[len('Basic '):]
32 | missing_padding = 4 - len(credentials) % 4
33 | if missing_padding:
34 | credentials += '=' * missing_padding
35 | try:
36 | auth = base64.urlsafe_b64decode(credentials.encode('utf-8')).decode('utf-8')
37 | except UnicodeDecodeError as e:
38 | raise InvalidClientAuthentication('Could not userid/password from authorization header'.format(authz_scheme))
39 | client_id, client_secret = [unquote(part) for part in auth.split(':')]
40 | else:
41 | raise InvalidClientAuthentication('Unknown scheme in authorization header, {} != Basic'.format(authz_scheme))
42 | elif 'client_id' in parsed_request:
43 | logger.debug('client authentication in request body %s', parsed_request)
44 |
45 | client_id = parsed_request['client_id']
46 | if 'client_secret' in parsed_request:
47 | authn_method = 'client_secret_post'
48 | client_secret = parsed_request['client_secret']
49 | else:
50 | authn_method = 'none'
51 | client_secret = None
52 |
53 | if client_id not in clients:
54 | raise InvalidClientAuthentication('client_id \'{}\' unknown'.format(client_id))
55 |
56 | client_info = clients[client_id]
57 | if client_secret != client_info.get('client_secret', None):
58 | raise InvalidClientAuthentication('Incorrect client_secret')
59 |
60 | expected_authn_method = client_info.get('token_endpoint_auth_method', 'client_secret_basic')
61 | if authn_method != expected_authn_method:
62 | raise InvalidClientAuthentication(
63 | 'Wrong authentication method used, MUST use \'{}\''.format(expected_authn_method))
64 |
65 | return client_id
66 |
--------------------------------------------------------------------------------
/src/pyop/crypto.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 |
4 | from Cryptodome import Random
5 | from Cryptodome.Cipher import AES
6 |
7 |
8 | class _AESCipher(object):
9 | """
10 | This class will perform AES encryption/decryption with a keylength of 256.
11 |
12 | @see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
13 | """
14 |
15 | def __init__(self, key):
16 | """
17 | Constructor
18 |
19 | :type key: str
20 |
21 | :param key: The key used for encryption and decryption. The longer key the better.
22 | """
23 | self.bs = 32
24 | self.key = hashlib.sha256(key.encode()).digest()
25 |
26 | def encrypt(self, raw):
27 | """
28 | Encryptes the parameter raw.
29 |
30 | :type raw: bytes
31 | :rtype: str
32 |
33 | :param: bytes to be encrypted.
34 |
35 | :return: A base 64 encoded string.
36 | """
37 | raw = self._pad(raw)
38 | iv = Random.new().read(AES.block_size)
39 | cipher = AES.new(self.key, AES.MODE_CBC, iv)
40 | return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))
41 |
42 | def decrypt(self, enc):
43 | """
44 | Decryptes the parameter enc.
45 |
46 | :type enc: bytes
47 | :rtype: bytes
48 |
49 | :param: The value to be decrypted.
50 | :return: The decrypted value.
51 | """
52 | enc = base64.urlsafe_b64decode(enc)
53 | iv = enc[:AES.block_size]
54 | cipher = AES.new(self.key, AES.MODE_CBC, iv)
55 | return self._unpad(cipher.decrypt(enc[AES.block_size:]))
56 |
57 | def _pad(self, b):
58 | """
59 | Will padd the param to be of the correct length for the encryption alg.
60 |
61 | :type b: bytes
62 | :rtype: bytes
63 | """
64 | return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")
65 |
66 | @staticmethod
67 | def _unpad(b):
68 | """
69 | Removes the padding performed by the method _pad.
70 |
71 | :type b: bytes
72 | :rtype: bytes
73 | """
74 | return b[:-ord(b[len(b) - 1:])]
75 |
--------------------------------------------------------------------------------
/src/pyop/exceptions.py:
--------------------------------------------------------------------------------
1 | from oic.oic.message import AuthorizationErrorResponse, ClientRegistrationErrorResponse
2 |
3 | from .util import should_fragment_encode
4 |
5 |
6 | class OAuthError(ValueError):
7 | def __init__(self, message, oauth_error):
8 | super().__init__(message)
9 | self.oauth_error = oauth_error
10 |
11 |
12 | class InvalidAuthorizationCode(OAuthError):
13 | def __init__(self, message):
14 | super().__init__(message, 'invalid_grant')
15 |
16 |
17 | class InvalidRefreshToken(OAuthError):
18 | def __init__(self, message):
19 | super().__init__(message, 'invalid_grant')
20 |
21 |
22 | class InvalidAccessToken(OAuthError):
23 | def __init__(self, message):
24 | super().__init__(message, 'invalid_token')
25 |
26 |
27 | class InvalidScope(OAuthError):
28 | def __init__(self, message):
29 | super().__init__(message, 'invalid_scope')
30 |
31 |
32 | class InvalidClientAuthentication(OAuthError):
33 | def __init__(self, message):
34 | super().__init__(message, 'invalid_client')
35 |
36 |
37 | class InvalidSubjectIdentifier(ValueError):
38 | pass
39 |
40 |
41 | class InvalidRequestError(OAuthError):
42 | def __init__(self, message, parsed_request, oauth_error):
43 | super().__init__(message, oauth_error)
44 | self.request = parsed_request
45 |
46 |
47 | class InvalidAuthenticationRequest(InvalidRequestError):
48 | def __init__(self, message, parsed_request, oauth_error=None):
49 | super().__init__(message, parsed_request, oauth_error)
50 |
51 | def to_error_url(self):
52 | redirect_uri = self.request.get('redirect_uri')
53 | response_type = self.request.get('response_type')
54 | if redirect_uri and response_type and self.oauth_error:
55 | error_resp = AuthorizationErrorResponse(error=self.oauth_error, error_message=str(self),
56 | state=self.request.get('state'))
57 | return error_resp.request(redirect_uri, should_fragment_encode(self.request))
58 |
59 | return None
60 |
61 |
62 | class InvalidRedirectURI(InvalidAuthenticationRequest):
63 | def to_error_url(self):
64 | return None
65 |
66 |
67 | class InvalidTokenRequest(InvalidRequestError):
68 | def __init__(self, message, parsed_request, oauth_error='invalid_request'):
69 | super().__init__(message, parsed_request, oauth_error)
70 |
71 |
72 | class InvalidClientRegistrationRequest(InvalidRequestError):
73 | def __init__(self, message, parsed_request, oauth_error='invalid_request'):
74 | super().__init__(message, parsed_request, oauth_error)
75 |
76 | def to_json(self):
77 | error = ClientRegistrationErrorResponse(error=self.oauth_error, error_description=str(self))
78 | return error.to_json()
79 |
80 |
81 | class BearerTokenError(ValueError):
82 | pass
83 |
84 |
85 | class AuthorizationError(Exception):
86 | pass
87 |
--------------------------------------------------------------------------------
/src/pyop/message.py:
--------------------------------------------------------------------------------
1 | from oic.oauth2.message import SINGLE_OPTIONAL_STRING
2 | from oic.oic import message
3 |
4 | class AccessTokenRequest(message.AccessTokenRequest):
5 | c_param = message.AccessTokenRequest.c_param.copy()
6 | c_param.update(
7 | {
8 | 'code_verifier': SINGLE_OPTIONAL_STRING
9 | }
10 | )
11 |
12 | class AuthorizationRequest(message.AuthorizationRequest):
13 | c_param = message.AuthorizationRequest.c_param.copy()
14 | c_param.update(
15 | {
16 | 'code_challenge': SINGLE_OPTIONAL_STRING,
17 | 'code_challenge_method': SINGLE_OPTIONAL_STRING
18 | }
19 | )
20 |
21 | c_allowed_values = message.AuthorizationRequest.c_allowed_values.copy()
22 | c_allowed_values.update(
23 | {
24 | "code_challenge_method": [
25 | "plain",
26 | "S256"
27 | ]
28 | }
29 | )
30 |
--------------------------------------------------------------------------------
/src/pyop/provider.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import logging
4 | import time
5 | import uuid
6 | from urllib.parse import parse_qsl
7 | from urllib.parse import urlparse
8 |
9 | from jwkest import jws
10 | from oic import rndstr
11 | from oic.exception import MessageException
12 | from oic.oic import PREFERENCE2PROVIDER
13 | from oic.oic import scope2claims
14 | from oic.oic.message import AccessTokenResponse
15 | from oic.oic.message import AuthorizationResponse
16 | from oic.oic.message import EndSessionRequest
17 | from oic.oic.message import EndSessionResponse
18 | from oic.oic.message import IdToken
19 | from oic.oic.message import OpenIDSchema
20 | from oic.oic.message import ProviderConfigurationResponse
21 | from oic.oic.message import RefreshAccessTokenRequest
22 | from oic.oic.message import RegistrationRequest
23 | from oic.oic.message import RegistrationResponse
24 | from oic.extension.provider import Provider as OICProviderExtensions
25 |
26 | from .message import AuthorizationRequest
27 | from .message import AccessTokenRequest
28 | from .access_token import extract_bearer_token_from_http_request
29 | from .client_authentication import verify_client_authentication
30 | from .exceptions import AuthorizationError
31 | from .exceptions import InvalidAccessToken
32 | from .exceptions import InvalidTokenRequest
33 | from .exceptions import InvalidAuthorizationCode
34 | from .exceptions import OAuthError
35 | from .request_validator import authorization_request_verify
36 | from .request_validator import client_id_is_known
37 | from .request_validator import client_preferences_match_provider_capabilities
38 | from .request_validator import redirect_uri_is_in_registered_redirect_uris
39 | from .request_validator import registration_request_verify
40 | from .request_validator import requested_scope_is_supported
41 | from .request_validator import response_type_is_in_registered_response_types
42 | from .request_validator import userinfo_claims_only_specified_when_access_token_is_issued
43 | from .util import find_common_values
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | class Provider(object):
49 | def __init__(self, signing_key, configuration_information, authz_state, clients, userinfo, *,
50 | id_token_lifetime=3600, extra_scopes=None):
51 | # type: (jwkest.jwk.Key, Dict[str, Union[str, Sequence[str]]], se_leg_op.authz_state.AuthorizationState,
52 | # Mapping[str, Mapping[str, Any]], se_leg_op.userinfo.Userinfo, int) -> None
53 | """
54 | Creates a new provider instance.
55 | :param configuration_information: see
56 |
57 | "OpenID Connect Discovery 1.0", Section 3
58 | :param clients: see
59 | "OpenID Connect Dynamic Client Registration 1.0", Section 2
60 | :param userinfo: read-only interface for user info
61 | :param id_token_lifetime: how long the signed ID Tokens should be valid (in seconds), defaults to 1 hour
62 | """
63 | self.signing_key = signing_key
64 | self.configuration_information = ProviderConfigurationResponse(**configuration_information)
65 | if 'subject_types_supported' not in configuration_information:
66 | self.configuration_information['subject_types_supported'] = ['pairwise']
67 | if 'id_token_signing_alg_values_supported' not in configuration_information:
68 | self.configuration_information['id_token_signing_alg_values_supported'] = ['RS256']
69 | if 'scopes_supported' not in configuration_information:
70 | self.configuration_information['scopes_supported'] = ['openid']
71 | if 'response_types_supported' not in configuration_information:
72 | self.configuration_information['response_types_supported'] = ['code', 'id_token', 'token id_token']
73 |
74 | self.extra_scopes = {} if extra_scopes is None else extra_scopes
75 | _scopes = self.configuration_information['scopes_supported']
76 | _scopes.extend(self.extra_scopes.keys())
77 | self.configuration_information['scopes_supported'] = list(set(_scopes))
78 |
79 | self.configuration_information.verify()
80 |
81 | self.authz_state = authz_state
82 | self.stateless = self.authz_state and self.authz_state.stateless
83 |
84 | self.clients = clients
85 | self.userinfo = userinfo
86 | self.id_token_lifetime = id_token_lifetime
87 |
88 | self.authentication_request_validators = [] # type: List[Callable[[AuthorizationRequest], Boolean]]
89 | self.authentication_request_validators.append(authorization_request_verify)
90 | self.authentication_request_validators.append(
91 | functools.partial(client_id_is_known, self))
92 | self.authentication_request_validators.append(
93 | functools.partial(redirect_uri_is_in_registered_redirect_uris, self))
94 | self.authentication_request_validators.append(
95 | functools.partial(response_type_is_in_registered_response_types, self))
96 | self.authentication_request_validators.append(userinfo_claims_only_specified_when_access_token_is_issued)
97 | self.authentication_request_validators.append(functools.partial(requested_scope_is_supported, self))
98 |
99 | self.registration_request_validators = [] # type: List[Callable[[oic.oic.message.RegistrationRequest], Boolean]]
100 | self.registration_request_validators.append(registration_request_verify)
101 | self.registration_request_validators.append(
102 | functools.partial(client_preferences_match_provider_capabilities, self))
103 |
104 | @property
105 | def provider_configuration(self):
106 | """
107 | The provider configuration information.
108 | """
109 | return copy.deepcopy(self.configuration_information)
110 |
111 | @property
112 | def jwks(self):
113 | """
114 | All keys published by the provider as JSON Web Key Set.
115 | """
116 |
117 | keys = [self.signing_key.serialize()]
118 | return {'keys': keys}
119 |
120 | def parse_authentication_request(self, request_body, http_headers=None):
121 | # type: (str, Optional[Mapping[str, str]]) -> AuthorizationRequest
122 | """
123 | Parses and verifies an authentication request.
124 |
125 | :param request_body: urlencoded authentication request
126 | :param http_headers: http headers
127 | """
128 |
129 | auth_req = AuthorizationRequest().deserialize(request_body)
130 |
131 | for validator in self.authentication_request_validators:
132 | validator(auth_req)
133 |
134 | logger.debug('parsed authentication_request: %s', auth_req)
135 | return auth_req
136 |
137 | def authorize(
138 | self,
139 | authentication_request, # type: AuthorizationRequest
140 | user_id, # type: str
141 | extra_id_token_claims=None, # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]]
142 | ):
143 | # type: (...) -> oic.oic.message.AuthorizationResponse
144 | """
145 | Creates an Authentication Response for the specified authentication request and local identifier of the
146 | authenticated user.
147 | """
148 | custom_sub = self.userinfo[user_id].get('sub')
149 | if custom_sub:
150 | self.authz_state.subject_identifiers[user_id] = {'public': custom_sub}
151 | sub = custom_sub
152 | else:
153 | sub = self._create_subject_identifier(
154 | user_id,
155 | authentication_request['client_id'],
156 | authentication_request['redirect_uri'],
157 | )
158 |
159 | self._check_subject_identifier_matches_requested(authentication_request, sub)
160 |
161 | if extra_id_token_claims is None:
162 | extra_id_token_claims = {}
163 | elif callable(extra_id_token_claims):
164 | extra_id_token_claims = extra_id_token_claims(
165 | user_id, authentication_request['client_id']
166 | )
167 |
168 | response = AuthorizationResponse()
169 |
170 | authz_code = None
171 | if 'code' in authentication_request['response_type']:
172 | authz_code = self.authz_state.create_authorization_code(
173 | authentication_request,
174 | sub,
175 | user_info=self.userinfo[user_id],
176 | extra_id_token_claims=extra_id_token_claims,
177 | )
178 | response['code'] = authz_code
179 |
180 | access_token_value = None
181 | if 'token' in authentication_request['response_type']:
182 | access_token = self.authz_state.create_access_token(
183 | authentication_request, sub, user_info=self.userinfo[user_id]
184 | )
185 | access_token_value = access_token.value
186 | self._add_access_token_to_response(response, access_token)
187 |
188 | if 'id_token' in authentication_request['response_type']:
189 | requested_claims = self._get_requested_claims_in(authentication_request, 'id_token')
190 | if len(authentication_request['response_type']) == 1:
191 | # only id token is issued -> no way of doing userinfo request, so include all claims in ID Token,
192 | # even those requested by the scope parameter
193 | requested_claims.update(
194 | scope2claims(
195 | authentication_request['scope'], extra_scope_dict=self.extra_scopes
196 | )
197 | )
198 |
199 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims)
200 | response['id_token'] = self._create_signed_id_token(
201 | authentication_request['client_id'],
202 | sub,
203 | user_claims,
204 | authentication_request.get('nonce'),
205 | authz_code,
206 | access_token_value,
207 | extra_id_token_claims,
208 | )
209 | logger.debug(
210 | 'issued id_token=%s from requested_claims=%s userinfo=%s extra_claims=%s',
211 | response['id_token'],
212 | requested_claims,
213 | user_claims,
214 | extra_id_token_claims,
215 | )
216 |
217 | if 'state' in authentication_request:
218 | response['state'] = authentication_request['state']
219 | return response
220 |
221 | def _add_access_token_to_response(self, response, access_token):
222 | # type: (oic.message.AccessTokenResponse, se_leg_op.access_token.AccessToken) -> None
223 | """
224 | Adds the Access Token and the associated parameters to the Token Response.
225 | """
226 | response['access_token'] = access_token.value
227 | response['token_type'] = access_token.type
228 | response['expires_in'] = access_token.expires_in
229 |
230 | def _create_subject_identifier(self, user_id, client_id, redirect_uri):
231 | # type (str, str, str) -> str
232 | """
233 | Creates a subject identifier for the specified client and user
234 | see
235 | "OpenID Connect Core 1.0", Section 1.2.
236 | :param user_id: local user identifier
237 | :param client_id: which client to generate a subject identifier for
238 | :param redirect_uri: the clients' redirect_uri
239 | :return: a subject identifier for the user intended for client who made the authentication request
240 | """
241 | supported_subject_types = self.configuration_information['subject_types_supported'][0]
242 | subject_type = self.clients[client_id].get('subject_type', supported_subject_types)
243 | sector_identifier_uri = self.clients[client_id].get('sector_identifier_uri', redirect_uri)
244 | sector_identifier = urlparse(sector_identifier_uri).netloc
245 | return self.authz_state.get_subject_identifier(subject_type, user_id, sector_identifier)
246 |
247 | def _get_requested_claims_in(self, authentication_request, response_method):
248 | # type (AuthorizationRequest, str) -> Mapping[str, Optional[Mapping[str, Union[str, List[str]]]]
249 | """
250 | Parses any claims requested using the 'claims' request parameter, see
251 |
252 | "OpenID Connect Core 1.0", Section 5.5.
253 | :param authentication_request: the authentication request
254 | :param response_method: 'id_token' or 'userinfo'
255 | """
256 | if response_method != 'id_token' and response_method != 'userinfo':
257 | raise ValueError('response_method must be \'id_token\' or \'userinfo\'')
258 |
259 | requested_claims = {}
260 |
261 | if 'claims' in authentication_request and response_method in authentication_request['claims']:
262 | requested_claims.update(authentication_request['claims'][response_method])
263 | return requested_claims
264 |
265 | def _create_signed_id_token(self,
266 | client_id, # type: str
267 | sub, # type: str
268 | user_claims=None, # type: Optional[Mapping[str, Union[str, List[str]]]]
269 | nonce=None, # type: Optional[str]
270 | authorization_code=None, # type: Optional[str]
271 | access_token_value=None, # type: Optional[str]
272 | extra_id_token_claims=None): # type: Optional[Mappings[str, Union[str, List[str]]]]
273 | # type: (...) -> str
274 | """
275 | Creates a signed ID Token.
276 | :param client_id: who the ID Token is intended for
277 | :param sub: who the ID Token is regarding
278 | :param user_claims: any claims about the user to be included
279 | :param nonce: nonce from the authentication request
280 | :param authorization_code: the authorization code issued together with this ID Token
281 | :param access_token_value: the access token issued together with this ID Token
282 | :param extra_id_token_claims: any extra claims that should be included in the ID Token
283 | :return: a JWS, containing the ID Token as payload
284 | """
285 |
286 | alg = self.clients[client_id].get('id_token_signed_response_alg',
287 | self.configuration_information['id_token_signing_alg_values_supported'][0])
288 | args = {}
289 |
290 | hash_alg = 'HS{}'.format(alg[-3:])
291 | if authorization_code:
292 | args['c_hash'] = jws.left_hash(authorization_code.encode('utf-8'), hash_alg)
293 | if access_token_value:
294 | args['at_hash'] = jws.left_hash(access_token_value.encode('utf-8'), hash_alg)
295 |
296 | if user_claims:
297 | args.update(user_claims)
298 |
299 | if extra_id_token_claims:
300 | args.update(extra_id_token_claims)
301 |
302 | id_token = IdToken(iss=self.configuration_information['issuer'],
303 | sub=sub,
304 | aud=client_id,
305 | iat=int(time.time()),
306 | exp=int(time.time()) + self.id_token_lifetime,
307 | **args)
308 |
309 | if nonce:
310 | id_token['nonce'] = nonce
311 |
312 | logger.debug('signed id_token with kid=%s using alg=%s', self.signing_key, alg)
313 | return id_token.to_jwt([self.signing_key], alg)
314 |
315 | def _check_subject_identifier_matches_requested(self, authentication_request, sub):
316 | # type (AuthorizationRequest, str) -> None
317 | """
318 | Verifies the subject identifier against any requested subject identifier using the claims request parameter.
319 | :param authentication_request: authentication request
320 | :param sub: subject identifier
321 | :raise AuthorizationError: if the subject identifier does not match the requested one
322 | """
323 | if 'claims' in authentication_request:
324 | requested_id_token_sub = authentication_request['claims'].get('id_token', {}).get('sub')
325 | requested_userinfo_sub = authentication_request['claims'].get('userinfo', {}).get('sub')
326 | if requested_id_token_sub and requested_userinfo_sub and requested_id_token_sub != requested_userinfo_sub:
327 | raise AuthorizationError('Requested different subject identifier for IDToken and userinfo: {} != {}'
328 | .format(requested_id_token_sub, requested_userinfo_sub))
329 |
330 | requested_sub = requested_id_token_sub or requested_userinfo_sub
331 | if requested_sub and sub != requested_sub:
332 | raise AuthorizationError('Requested subject identifier \'{}\' could not be matched'
333 | .format(requested_sub))
334 |
335 | def handle_token_request(self, request_body, # type: str
336 | http_headers=None, # type: Optional[Mapping[str, str]]
337 | extra_id_token_claims=None
338 | # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]]
339 | ):
340 | # type: (...) -> oic.oic.message.AccessTokenResponse
341 | """
342 | Handles a token request, either for exchanging an authorization code or using a refresh token.
343 | :param request_body: urlencoded token request
344 | :param http_headers: http headers
345 | :param extra_id_token_claims: extra claims to include in the signed ID Token
346 | """
347 |
348 | token_request = self._verify_client_authentication(request_body, http_headers)
349 |
350 | if 'grant_type' not in token_request:
351 | raise InvalidTokenRequest('grant_type missing', token_request)
352 | elif token_request['grant_type'] == 'authorization_code':
353 | return self._do_code_exchange(token_request, extra_id_token_claims)
354 | elif token_request['grant_type'] == 'refresh_token':
355 | return self._do_token_refresh(token_request)
356 |
357 | raise InvalidTokenRequest('grant_type \'{}\' unknown'.format(token_request['grant_type']), token_request,
358 | oauth_error='unsupported_grant_type')
359 |
360 | def _PKCE_verify(self,
361 | token_request, # type: AccessTokenRequest
362 | authentication_request # type: AuthorizationRequest
363 | ):
364 | # type: (...) -> bool
365 | """
366 | Verify that the given code_verifier complies with the initially supplied code_challenge.
367 |
368 | Only supports the SHA256 code challenge method, plaintext is regarded as unsafe.
369 |
370 | :param token_request: the token request containing the initially supplied code challenge and code_challenge method.
371 | :param authentication_request: the code_verfier to check against the code challenge.
372 | :returns: whether the code_verifier is what was expected given the cc_cm
373 | """
374 | if not 'code_verifier' in token_request:
375 | return False
376 |
377 | if not 'code_challenge_method' in authentication_request:
378 | raise InvalidTokenRequest("A code_challenge and code_verifier have been supplied"
379 | "but missing code_challenge_method in authentication_request", token_request)
380 |
381 | # OIC Provider extension returns either a boolean or Response object containing an error. To support
382 | # stricter typing guidelines, return if True. Error handling support should be in encapsulating function.
383 | return OICProviderExtensions.verify_code_challenge(token_request['code_verifier'],
384 | authentication_request['code_challenge'], authentication_request['code_challenge_method']) == True
385 |
386 | def _verify_code_exchange_req(self,
387 | token_request, # type: AccessTokenRequest
388 | authentication_request # type: AuthorizationRequest
389 | ):
390 | # type: (...) -> None
391 | """
392 | Verify that the code exchange request is valid. In order to be valid we validate
393 | the expected client and redirect_uri. Finally, if requested by the client, perform a
394 | PKCE check.
395 |
396 | :param token_request: The request asking for a token given a code, and optionally a code_verifier
397 | :param authentication_request: The authentication request belonging to the provided code.
398 | :raises InvalidTokenRequest, InvalidAuthorizationCode: If request is invalid, throw a representing exception.
399 | """
400 | if token_request['client_id'] != authentication_request['client_id']:
401 | logger.info('Authorization code \'%s\' belonging to \'%s\' was used by \'%s\'',
402 | token_request['code'], authentication_request['client_id'], token_request['client_id'])
403 | raise InvalidAuthorizationCode('{} unknown'.format(token_request['code']))
404 | if token_request['redirect_uri'] != authentication_request['redirect_uri']:
405 | raise InvalidTokenRequest('Invalid redirect_uri: {} != {}'.format(token_request['redirect_uri'],
406 | authentication_request['redirect_uri']),
407 | token_request)
408 | if 'code_challenge' in authentication_request and not self._PKCE_verify(token_request, authentication_request):
409 | raise InvalidTokenRequest('Unexpected Code Verifier: {}'.format(authentication_request['code_challenge']),
410 | token_request)
411 |
412 | def _do_code_exchange(self, request, # type: Dict[str, str]
413 | extra_id_token_claims=None
414 | # type: Optional[Union[Mapping[str, Union[str, List[str]]], Callable[[str, str], Mapping[str, Union[str, List[str]]]]]
415 | ):
416 | # type: (...) -> oic.message.AccessTokenResponse
417 | """
418 | Handles a token request for exchanging an authorization code for an access token
419 | (grant_type=authorization_code).
420 | :param request: parsed http request parameters
421 | :param extra_id_token_claims: any extra parameters to include in the signed ID Token, either as a dict-like
422 | object or as a callable object accepting the local user identifier and client identifier which returns
423 | any extra claims which might depend on the user id and/or client id.
424 | :return: a token response containing a signed ID Token, an Access Token, and a Refresh Token
425 | :raise InvalidTokenRequest: if the token request is invalid
426 | """
427 | token_request = AccessTokenRequest().from_dict(request)
428 | try:
429 | token_request.verify()
430 | except MessageException as e:
431 | raise InvalidTokenRequest(str(e), token_request) from e
432 |
433 | authentication_request = self.authz_state.get_authorization_request_for_code(token_request['code'])
434 |
435 | self._verify_code_exchange_req(token_request, authentication_request)
436 |
437 | sub = self.authz_state.get_subject_identifier_for_code(token_request['code'])
438 | if not self.stateless:
439 | user_id = self.authz_state.get_user_id_for_subject_identifier(sub)
440 |
441 | response = AccessTokenResponse()
442 |
443 | access_token = self.authz_state.exchange_code_for_token(token_request['code'])
444 | self._add_access_token_to_response(response, access_token)
445 | refresh_token = self.authz_state.create_refresh_token(access_token.value)
446 | if refresh_token is not None:
447 | response['refresh_token'] = refresh_token
448 |
449 | extra_id_token_claims = {}
450 | if self.stateless:
451 | extra_id_token_claims_in_code = self.authz_state.get_extra_id_token_claims_for_code(token_request['code'])
452 | extra_id_token_claims.update(extra_id_token_claims_in_code)
453 | elif callable(extra_id_token_claims):
454 | extra_id_token_claims = extra_id_token_claims(user_id, authentication_request['client_id'])
455 |
456 | requested_claims = self._get_requested_claims_in(authentication_request, 'id_token')
457 | if self.stateless:
458 | user_info = self.authz_state.get_user_info_for_code(token_request['code'])
459 | user_claims = self.userinfo.get_claims_for(None, requested_claims, user_info)
460 | else:
461 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims)
462 | response['id_token'] = self._create_signed_id_token(authentication_request['client_id'], sub,
463 | user_claims,
464 | authentication_request.get('nonce'),
465 | None, access_token.value,
466 | extra_id_token_claims)
467 | logger.debug('issued id_token=%s from requested_claims=%s userinfo=%s extra_claims=%s',
468 | response['id_token'], requested_claims, user_claims, extra_id_token_claims)
469 |
470 | return response
471 |
472 | def _do_token_refresh(self, request):
473 | # type: (Mapping[str, str]) -> oic.oic.message.AccessTokenResponse
474 | """
475 | Handles a token request for refreshing an access token (grant_type=refresh_token).
476 | :param request: parsed http request parameters
477 | :return: a token response containing a new Access Token and possibly a new Refresh Token
478 | :raise InvalidTokenRequest: if the token request is invalid
479 | """
480 | token_request = RefreshAccessTokenRequest().from_dict(request)
481 | try:
482 | token_request.verify()
483 | except MessageException as e:
484 | raise InvalidTokenRequest(str(e), token_request) from e
485 |
486 | response = AccessTokenResponse()
487 |
488 | access_token, refresh_token = self.authz_state.use_refresh_token(token_request['refresh_token'],
489 | scope=token_request.get('scope'))
490 | self._add_access_token_to_response(response, access_token)
491 | if refresh_token:
492 | response['refresh_token'] = refresh_token
493 |
494 | return response
495 |
496 | def _verify_client_authentication(self, request_body, http_headers=None):
497 | # type (str, Optional[Mapping[str, str]] -> Mapping[str, str]
498 | """
499 | Verifies the client authentication.
500 | :param request_body: urlencoded token request
501 | :param http_headers:
502 | :return: The parsed request body.
503 | """
504 | if http_headers is None:
505 | http_headers = {}
506 | token_request = dict(parse_qsl(request_body))
507 | token_request['client_id'] = verify_client_authentication(self.clients, token_request, http_headers.get('Authorization'))
508 | return token_request
509 |
510 | def handle_userinfo_request(self, request=None, http_headers=None):
511 | # type: (Optional[str], Optional[Mapping[str, str]]) -> oic.oic.message.OpenIDSchema
512 | """
513 | Handles a userinfo request.
514 | :param request: urlencoded request (either query string or POST body)
515 | :param http_headers: http headers
516 | """
517 | if http_headers is None:
518 | http_headers = {}
519 | userinfo_request = dict(parse_qsl(request))
520 | bearer_token = extract_bearer_token_from_http_request(userinfo_request, http_headers.get('Authorization'))
521 |
522 | introspection = self.authz_state.introspect_access_token(bearer_token)
523 | if not introspection['active']:
524 | raise InvalidAccessToken('The access token has expired')
525 | scopes = introspection['scope'].split()
526 |
527 | if not self.stateless:
528 | user_id = self.authz_state.get_user_id_for_subject_identifier(introspection['sub'])
529 |
530 | requested_claims = scope2claims(scopes, extra_scope_dict=self.extra_scopes)
531 | authentication_request = self.authz_state.get_authorization_request_for_access_token(bearer_token)
532 | requested_claims.update(self._get_requested_claims_in(authentication_request, 'userinfo'))
533 |
534 | if self.stateless:
535 | user_info = self.authz_state.get_user_info_for_access_token(bearer_token)
536 | user_claims = self.userinfo.get_claims_for(None, requested_claims, user_info)
537 | else:
538 | user_claims = self.userinfo.get_claims_for(user_id, requested_claims)
539 |
540 | user_claims.setdefault('sub', introspection['sub'])
541 | response = OpenIDSchema(**user_claims)
542 | logger.debug('userinfo=%s from requested_claims=%s userinfo=%s',
543 | response, requested_claims, user_claims)
544 | return response
545 |
546 | def _issue_new_client(self):
547 | # create unique client id
548 | client_id = rndstr(12)
549 | while client_id in self.clients:
550 | client_id = rndstr(12)
551 | # create random secret
552 | client_secret = uuid.uuid4().hex
553 |
554 | return client_id, client_secret
555 |
556 | def match_client_preferences_with_provider_capabilities(self, client_preferences):
557 | # type: (oic.message.RegistrationRequest) -> Mapping[str, Union[str, List[str]]]
558 | """
559 | Match as many as of the client preferences as possible.
560 | :param client_preferences: requested preferences from client registration request
561 | :return: the matched preferences selected by the provider
562 | """
563 | matched_prefs = client_preferences.to_dict()
564 | for pref in ['response_types', 'default_acr_values']:
565 | if pref not in client_preferences:
566 | continue
567 |
568 | capability = PREFERENCE2PROVIDER[pref]
569 | # only preserve the common values
570 | matched_values = find_common_values(client_preferences[pref], self.configuration_information[capability])
571 | # deal with space separated values
572 | matched_prefs[pref] = [' '.join(v) for v in matched_values]
573 |
574 | return matched_prefs
575 |
576 | def handle_client_registration_request(self, request, http_headers=None):
577 | # type: (Optional[str], Optional[Mapping[str, str]]) -> oic.oic.message.RegistrationResponse
578 | """
579 | Handles a client registration request.
580 | :param request: JSON request from POST body
581 | :param http_headers: http headers
582 | """
583 | registration_req = RegistrationRequest().deserialize(request, 'json')
584 |
585 | for validator in self.registration_request_validators:
586 | validator(registration_req)
587 | logger.debug('parsed authentication_request: %s', registration_req)
588 |
589 | client_id, client_secret = self._issue_new_client()
590 | credentials = {
591 | 'client_id': client_id,
592 | 'client_id_issued_at': int(time.time()),
593 | 'client_secret': client_secret,
594 | 'client_secret_expires_at': 0 # never expires
595 | }
596 |
597 | response_params = self.match_client_preferences_with_provider_capabilities(registration_req)
598 | response_params.update(credentials)
599 | self.clients[client_id] = copy.deepcopy(response_params)
600 |
601 | registration_resp = RegistrationResponse(**response_params)
602 | logger.debug('registration_resp=%s from registration_req=%s', registration_resp, registration_req)
603 | return registration_resp
604 |
605 | def logout_user(self, subject_identifier=None, end_session_request=None):
606 | # type: (Optional[str], Optional[oic.oic.message.EndSessionRequest]) -> None
607 | if self.stateless:
608 | raise OAuthError("Logout is not supported with stateless storage provider", "invalid_request")
609 | if not end_session_request:
610 | end_session_request = EndSessionRequest()
611 | if 'id_token_hint' in end_session_request:
612 | id_token = IdToken().from_jwt(end_session_request['id_token_hint'], key=[self.signing_key])
613 | subject_identifier = id_token['sub']
614 |
615 | self.authz_state.delete_state_for_subject_identifier(subject_identifier)
616 |
617 | def do_post_logout_redirect(self, end_session_request):
618 | # type: (oic.oic.message.EndSessionRequest) -> oic.oic.message.EndSessionResponse
619 | if 'post_logout_redirect_uri' not in end_session_request:
620 | return None
621 |
622 | client_id = None
623 | if 'id_token_hint' in end_session_request:
624 | id_token = IdToken().from_jwt(end_session_request['id_token_hint'], key=[self.signing_key])
625 | client_id = id_token['aud'][0]
626 |
627 | if 'post_logout_redirect_uri' in end_session_request:
628 | if not client_id:
629 | return None
630 | if not end_session_request['post_logout_redirect_uri'] in self.clients[client_id].get(
631 | 'post_logout_redirect_uris', []):
632 | return None
633 |
634 | end_session_response = EndSessionResponse()
635 | if 'state' in end_session_request:
636 | end_session_response['state'] = end_session_request['state']
637 |
638 | return end_session_response.request(end_session_request['post_logout_redirect_uri'])
639 |
--------------------------------------------------------------------------------
/src/pyop/request_validator.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from oic.exception import MessageException
4 | from oic.oic import PREFERENCE2PROVIDER
5 |
6 | from .exceptions import InvalidClientRegistrationRequest
7 | from .exceptions import InvalidAuthenticationRequest
8 | from .exceptions import InvalidRedirectURI
9 | from .util import is_allowed_response_type, find_common_values
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def authorization_request_verify(authentication_request):
15 | """
16 | Verifies that all required parameters and correct values are included in the authentication request.
17 | :param authentication_request: the authentication request to verify
18 | :raise InvalidAuthenticationRequest: if the authentication is incorrect
19 | """
20 | try:
21 | authentication_request.verify()
22 | except MessageException as e:
23 | raise InvalidAuthenticationRequest(str(e), authentication_request, oauth_error='invalid_request') from e
24 |
25 |
26 | def client_id_is_known(provider, authentication_request):
27 | """
28 | Verifies the client identifier is known.
29 | :param provider: provider instance
30 | :param authentication_request: the authentication request to verify
31 | :raise InvalidAuthenticationRequest: if the client_id is unknown
32 | """
33 | if authentication_request['client_id'] not in provider.clients:
34 | logger.error('Unknown client_id \'{}\''.format(authentication_request['client_id']))
35 | raise InvalidAuthenticationRequest('Unknown client_id',
36 | authentication_request,
37 | oauth_error='unauthorized_client')
38 |
39 | def redirect_uri_is_in_registered_redirect_uris(provider, authentication_request):
40 | """
41 | Verifies the redirect uri is registered for the client making the request.
42 | :param provider: provider instance
43 | :param authentication_request: authentication request to verify
44 | :raise InvalidAuthenticationRequest: if the redirect uri is not registered
45 | """
46 | try:
47 | allowed_redirect_uris = provider.clients[authentication_request['client_id']]['redirect_uris']
48 | except KeyError as e:
49 | logger.error('client metadata is missing redirect_uris')
50 | raise InvalidRedirectURI(
51 | 'No redirect uri registered for this client',
52 | authentication_request,
53 | oauth_error="invalid_request",
54 | )
55 |
56 | if authentication_request['redirect_uri'] not in allowed_redirect_uris:
57 | logger.error("Redirect uri \'{0}\' is not registered for this client".format(authentication_request['redirect_uri']))
58 | raise InvalidRedirectURI(
59 | 'Redirect uri is not registered for this client',
60 | authentication_request,
61 | oauth_error="invalid_request",
62 | )
63 |
64 |
65 | def response_type_is_in_registered_response_types(provider, authentication_request):
66 | """
67 | Verifies that the requested response type is allowed for the client making the request.
68 | :param provider: provider instance
69 | :param authentication_request: authentication request to verify
70 | :raise InvalidAuthenticationRequest: if the response type is not allowed
71 | """
72 | error = InvalidAuthenticationRequest('Response type is not registered',
73 | authentication_request,
74 | oauth_error='invalid_request')
75 | try:
76 | allowed_response_types = provider.clients[authentication_request['client_id']]['response_types']
77 | except KeyError as e:
78 | logger.error('client metadata is missing response_types')
79 | raise error
80 |
81 | if not is_allowed_response_type(authentication_request['response_type'], allowed_response_types):
82 | logger.error('Response type \'{}\' is not registered'.format(' '.join(authentication_request['response_type'])))
83 | raise error
84 |
85 |
86 | def userinfo_claims_only_specified_when_access_token_is_issued(authentication_request):
87 | """
88 | According to
89 | "OpenID Connect Core 1.0", Section 5.5: "When the userinfo member is used, the request MUST
90 | also use a response_type value that results in an Access Token being issued to the Client for
91 | use at the UserInfo Endpoint."
92 | :param authentication_request: the authentication request to verify
93 | :raise InvalidAuthenticationRequest: if the requested claims can not be returned according to the request
94 | """
95 | will_issue_access_token = authentication_request['response_type'] != ['id_token']
96 | contains_userinfo_claims_request = 'claims' in authentication_request and 'userinfo' in authentication_request[
97 | 'claims']
98 | if not will_issue_access_token and contains_userinfo_claims_request:
99 | raise InvalidAuthenticationRequest('Userinfo claims cannot be requested, when response_type=\'id_token\'',
100 | authentication_request,
101 | oauth_error='invalid_request')
102 |
103 |
104 | def requested_scope_is_supported(provider, authentication_request):
105 | requested_scopes = set(authentication_request['scope'])
106 | supported_scopes = set(provider.provider_configuration['scopes_supported'])
107 | requested_unsupported_scopes = requested_scopes - supported_scopes
108 | if requested_unsupported_scopes:
109 | logger.warning('Request contains unsupported/unknown scopes: {}'
110 | .format(', '.join(requested_unsupported_scopes)))
111 |
112 |
113 | def registration_request_verify(registration_request):
114 | """
115 | Verifies that all required parameters and correct values are included in the client registration request.
116 | :param registration_request: the authentication request to verify
117 | :raise InvalidClientRegistrationRequest: if the registration is incorrect
118 | """
119 | try:
120 | registration_request.verify()
121 | except MessageException as e:
122 | raise InvalidClientRegistrationRequest(str(e), registration_request, oauth_error='invalid_request') from e
123 |
124 |
125 | def client_preferences_match_provider_capabilities(provider, registration_request):
126 | """
127 | Verifies that all requested preferences in the client metadata can be fulfilled by this provider.
128 | :param registration_request: the authentication request to verify
129 | :raise InvalidClientRegistrationRequest: if the registration is incorrect
130 | """
131 |
132 | def match(client_preference, provider_capability):
133 | if isinstance(client_preference, list):
134 | # deal with comparing space separated values, e.g. 'response_types', without considering the order
135 | # at least one requested preference must be matched
136 | return len(find_common_values(client_preference, provider_capability)) > 0
137 |
138 | return client_preference in provider_capability
139 |
140 | for client_preference in registration_request.keys():
141 | if client_preference not in PREFERENCE2PROVIDER:
142 | # metadata parameter that shouldn't be matched
143 | continue
144 |
145 | provider_capability = PREFERENCE2PROVIDER[client_preference]
146 | if not match(registration_request[client_preference], provider.configuration_information[provider_capability]):
147 | raise InvalidClientRegistrationRequest(
148 | 'Could not match client preference {}={} with provider capability {}={}'.format(
149 | client_preference, registration_request[client_preference], provider_capability,
150 | provider.configuration_information[provider_capability]),
151 | registration_request,
152 | oauth_error='invalid_request')
153 |
--------------------------------------------------------------------------------
/src/pyop/storage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from abc import ABC, abstractmethod
4 | import copy
5 | import json
6 | import logging
7 | from datetime import datetime
8 | from urllib.parse import urlparse
9 | from urllib.parse import parse_qs
10 |
11 | from .crypto import _AESCipher
12 |
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | try:
18 | import pymongo
19 | except ImportError:
20 | _has_pymongo = False
21 | else:
22 | _has_pymongo = True
23 |
24 | try:
25 | from redis.client import Redis
26 | except ImportError:
27 | _has_redis = False
28 | else:
29 | _has_redis = True
30 |
31 |
32 | class StorageBase(ABC):
33 | _ttl = None
34 |
35 | @abstractmethod
36 | def __setitem__(self, key, value):
37 | pass
38 |
39 | @abstractmethod
40 | def pack(self, value):
41 | pass
42 |
43 | @abstractmethod
44 | def __getitem__(self, key):
45 | pass
46 |
47 | @abstractmethod
48 | def __delitem__(self, key):
49 | pass
50 |
51 | @abstractmethod
52 | def __contains__(self, key):
53 | pass
54 |
55 | @abstractmethod
56 | def items(self):
57 | pass
58 |
59 | def pop(self, key, default=None):
60 | try:
61 | data = self[key]
62 | except KeyError:
63 | return default
64 | del self[key]
65 | return data
66 |
67 | @classmethod
68 | def from_uri(cls, db_uri, collection, db_name=None, ttl=None, **kwargs):
69 | url = urlparse(db_uri)
70 |
71 | if url.scheme == "mongodb":
72 | return MongoWrapper(
73 | db_uri=db_uri,
74 | db_name=db_name,
75 | collection=collection,
76 | ttl=ttl,
77 | extra_options=kwargs,
78 | )
79 | elif url.scheme == "redis" or url.scheme == "unix":
80 | return RedisWrapper(
81 | db_uri=db_uri,
82 | db_name=db_name,
83 | collection=collection,
84 | ttl=ttl,
85 | extra_options=kwargs,
86 | )
87 | elif url.scheme == "stateless":
88 | alg = parse_qs(url.query).get("alg") if url.query else None
89 | alg = alg[0] if alg else None
90 | return StatelessWrapper(
91 | collection=collection,
92 | encryption_key=url.password,
93 | alg=alg
94 | )
95 |
96 | raise ValueError(f"Invalid DB URI: {db_uri}")
97 |
98 | @classmethod
99 | def type(cls, db_uri):
100 | url = urlparse(db_uri)
101 | if url.scheme == "mongodb":
102 | return "mongodb"
103 | elif url.scheme == "redis" or url.scheme == "unix":
104 | return "redis"
105 | elif url.scheme == "stateless":
106 | return "stateless"
107 |
108 | raise ValueError(f"Invalid DB URI: {db_uri}")
109 |
110 | @property
111 | def ttl(self):
112 | return self._ttl
113 |
114 |
115 | class MongoWrapper(StorageBase):
116 | def __init__(self, db_uri, db_name, collection, ttl=None, extra_options=None):
117 | if not _has_pymongo:
118 | raise ImportError("pymongo module is required but it is not available")
119 |
120 | if not extra_options:
121 | extra_options = {}
122 |
123 | mongo_options = extra_options.pop("mongo_kwargs", None) or {}
124 |
125 | self._db_uri = db_uri
126 | self._coll_name = collection
127 | self._db = MongoDB(db_uri, db_name=db_name, **mongo_options)
128 | self._coll = self._db.get_collection(collection)
129 | self._coll.create_index('lookup_key', unique=True)
130 |
131 | if ttl is None or (isinstance(ttl, int) and ttl >= 0):
132 | self._ttl = ttl
133 | else:
134 | raise ValueError("TTL must be a non-negative integer or None")
135 | if ttl is not None:
136 | self._coll.create_index(
137 | 'last_modified',
138 | expireAfterSeconds=ttl,
139 | name="expiry"
140 | )
141 |
142 | def __setitem__(self, key, value):
143 | doc = {
144 | 'lookup_key': key,
145 | 'data': value,
146 | 'last_modified': datetime.utcnow()
147 | }
148 | self._coll.replace_one({'lookup_key': key}, doc, upsert=True)
149 |
150 | def pack(self, value):
151 | raise NotImplementedError
152 |
153 | def __getitem__(self, key):
154 | doc = self._coll.find_one({'lookup_key': key})
155 | if not doc:
156 | raise KeyError(key)
157 | return doc['data']
158 |
159 | def __delitem__(self, key):
160 | self._coll.delete_one({'lookup_key': key})
161 |
162 | def __contains__(self, key):
163 | count = self._coll.count_documents({'lookup_key': key})
164 | return bool(count)
165 |
166 | def items(self):
167 | for doc in self._coll.find():
168 | yield (doc['lookup_key'], doc['data'])
169 |
170 |
171 | class RedisWrapper(StorageBase):
172 | """
173 | Simple wrapper for a dict-like storage in Redis.
174 | Supports JSON-serializable data types.
175 | """
176 |
177 | def __init__(
178 | self, db_uri, *, db_name=None, collection, ttl=None, extra_options=None
179 | ):
180 | if not _has_redis:
181 | raise ImportError("redis module is required but it is not available")
182 |
183 | if not extra_options:
184 | extra_options = {}
185 |
186 | redis_kwargs = extra_options.pop("redis_kwargs", None) or {}
187 | redis_options = {
188 | "decode_responses": True, "db": db_name, **redis_kwargs
189 | }
190 |
191 | self._db = Redis.from_url(db_uri, **redis_options)
192 | self._collection = collection
193 | if ttl is None or (isinstance(ttl, int) and ttl >= 0):
194 | self._ttl = ttl
195 | else:
196 | raise ValueError("TTL must be a non-negative integer or None")
197 |
198 | def _make_key(self, key):
199 | if not isinstance(key, str):
200 | raise TypeError(f"Keys must be strings, {type(key).__name__} given")
201 |
202 | return ":".join([self._collection, key])
203 |
204 | def __setitem__(self, key, value):
205 | # Replacing the value of a key resets the ttl counter
206 | encoded = json.dumps({ "value": value })
207 | self._db.set(self._make_key(key), encoded, ex=self.ttl)
208 |
209 | def pack(self, value):
210 | raise NotImplementedError
211 |
212 | def __getitem__(self, key):
213 | encoded = self._db.get(self._make_key(key))
214 | if encoded is None:
215 | raise KeyError(key)
216 | return json.loads(encoded).get("value")
217 |
218 | def __delitem__(self, key):
219 | # Deleting a non-existent key is allowed
220 | self._db.delete(self._make_key(key))
221 |
222 | def __contains__(self, key):
223 | return (self._db.get(self._make_key(key)) is not None)
224 |
225 | def items(self):
226 | for key in self._db.keys(self._collection + "*"):
227 | visible_key = key[len(self._collection) + 1 :]
228 |
229 | if isinstance(visible_key, bytes):
230 | visible_key = visible_key.decode()
231 |
232 | try:
233 | yield (visible_key, self[visible_key])
234 | except KeyError:
235 | pass
236 |
237 |
238 | class StatelessWrapper(StorageBase):
239 | def __init__(self, collection, encryption_key, alg=None):
240 | self.collection = collection
241 | if not alg or alg.lower() == "aes256":
242 | self.cipher = _AESCipher(encryption_key)
243 | else:
244 | raise ValueError(f"Invalid encryption algorithm: {alg}")
245 |
246 | def __setitem__(self, key, value):
247 | pass
248 |
249 | def pack(self, value):
250 | key = None
251 | if value:
252 | if isinstance(value, dict):
253 | value = json.dumps(value)
254 | key = self.cipher.encrypt(value.encode("UTF-8")).decode("UTF-8")
255 | return key
256 |
257 | def __getitem__(self, key):
258 | return self._unpack(key)
259 |
260 | def __delitem__(self, key):
261 | raise NotImplementedError
262 |
263 | def __contains__(self, key):
264 | if self._unpack(key):
265 | return True
266 | return False
267 |
268 | def items(self):
269 | raise NotImplementedError
270 |
271 | def _unpack(self, value):
272 | unpacked_val = None
273 | try:
274 | if value:
275 | unpacked_val = self.cipher.decrypt(value.encode("UTF-8")).decode("UTF-8")
276 | unpacked_val = json.loads(unpacked_val)
277 | except ValueError:
278 | if unpacked_val:
279 | logger.debug("Value '%s' is not a dict", value)
280 | else:
281 | logger.warning("Value '%s' is invalid for %s", value, self.collection)
282 | return unpacked_val
283 |
284 |
285 | class MongoDB(object):
286 | """Simple wrapper to get pymongo real objects from the settings uri"""
287 |
288 | def __init__(self, db_uri, db_name=None, connection_factory=None, **kwargs):
289 | if db_uri is None:
290 | raise ValueError('db_uri not supplied')
291 |
292 | self._sanitized_uri = None
293 | self._parsed_uri = pymongo.uri_parser.parse_uri(db_uri)
294 |
295 | db_name = self._parsed_uri.get('database') or db_name
296 | if db_name is None:
297 | raise ValueError(
298 | "Database name must be provided either in the URI or as an argument"
299 | )
300 | self._database_name = self._parsed_uri['database'] = db_name
301 |
302 | if 'replicaSet' in kwargs and kwargs['replicaSet'] is None:
303 | del kwargs['replicaSet']
304 |
305 | self._options = self._parsed_uri.get('options')
306 | if connection_factory is None:
307 | connection_factory = pymongo.MongoClient
308 | if 'replicaSet' in kwargs:
309 | connection_factory = pymongo.MongoReplicaSetClient
310 | if 'replicaSet' in self._options and self._options['replicaSet'] is not None:
311 | connection_factory = pymongo.MongoReplicaSetClient
312 | kwargs['replicaSet'] = self._options['replicaSet']
313 |
314 | if 'replicaSet' in kwargs:
315 | if 'socketTimeoutMS' not in kwargs:
316 | kwargs['socketTimeoutMS'] = 5000
317 | if 'connectTimeoutMS' not in kwargs:
318 | kwargs['connectTimeoutMS'] = 5000
319 |
320 | self._db_uri = _format_mongodb_uri(self._parsed_uri)
321 |
322 | try:
323 | self._connection = connection_factory(
324 | host=self._db_uri,
325 | tz_aware=True,
326 | **kwargs)
327 | except pymongo.errors.ConnectionFailure as e:
328 | raise e
329 |
330 | def __repr__(self):
331 | return '<{!s}: {!s} {!s}>'.format(self.__class__.__name__,
332 | self._db_uri,
333 | self._database_name)
334 |
335 | @property
336 | def sanitized_uri(self):
337 | """
338 | Return the database URI we're using in a format sensible for logging etc.
339 |
340 | :return: db_uri
341 | """
342 | if self._sanitized_uri is None:
343 | _parsed = copy.copy(self._parsed_uri)
344 | if 'username' in _parsed:
345 | _parsed['password'] = 'secret'
346 | _parsed['nodelist'] = [_parsed['nodelist'][0]]
347 | self._sanitized_uri = _format_mongodb_uri(_parsed)
348 | return self._sanitized_uri
349 |
350 | def get_connection(self):
351 | """
352 | Get the raw pymongo connection object.
353 | :return: Pymongo connection object
354 | """
355 | return self._connection
356 |
357 | def get_database(self, database_name=None, username=None, password=None):
358 | """
359 | Get a pymongo database handle, after authenticating.
360 |
361 | Authenticates using the username/password in the DB URI given to
362 | __init__() unless username/password is supplied as arguments.
363 |
364 | :param database_name: (optional) Name of database
365 | :param username: (optional) Username to login with
366 | :param password: (optional) Password to login with
367 | :return: Pymongo database object
368 | """
369 | if database_name is None:
370 | database_name = self._database_name
371 | if database_name is None:
372 | raise ValueError('No database_name supplied, and no default provided to __init__')
373 | db = self._connection[database_name]
374 | if username and password:
375 | db.authenticate(username, password)
376 | elif self._parsed_uri.get("username", None):
377 | if 'authSource' in self._options and self._options['authSource'] is not None:
378 | db.authenticate(
379 | self._parsed_uri.get("username", None),
380 | self._parsed_uri.get("password", None),
381 | source=self._options['authSource']
382 | )
383 | else:
384 | db.authenticate(
385 | self._parsed_uri.get("username", None),
386 | self._parsed_uri.get("password", None)
387 | )
388 | return db
389 |
390 | def get_collection(self, collection, database_name=None, username=None, password=None):
391 | """
392 | Get a pymongo collection handle.
393 |
394 | :param collection: Name of collection
395 | :param database_name: (optional) Name of database
396 | :param username: (optional) Username to login with
397 | :param password: (optional) Password to login with
398 | :return: Pymongo collection object
399 | """
400 | _db = self.get_database(database_name, username, password)
401 | return _db[collection]
402 |
403 | def close(self):
404 | self._connection.close()
405 |
406 |
407 | def _format_mongodb_uri(parsed_uri):
408 | """
409 | Painstakingly reconstruct a MongoDB URI parsed using pymongo.uri_parser.parse_uri.
410 |
411 | :param parsed_uri: Result of pymongo.uri_parser.parse_uri
412 | :type parsed_uri: dict
413 |
414 | :return: New URI
415 | :rtype: str | unicode
416 | """
417 | user_pass = ''
418 | if parsed_uri.get('username') and parsed_uri.get('password'):
419 | user_pass = '{username!s}:{password!s}@'.format(**parsed_uri)
420 |
421 | _nodes = []
422 | for host, port in parsed_uri.get('nodelist'):
423 | if ':' in host and not host.endswith(']'):
424 | # IPv6 address without brackets
425 | host = '[{!s}]'.format(host)
426 | if port == 27017:
427 | _nodes.append(host)
428 | else:
429 | _nodes.append('{!s}:{!s}'.format(host, port))
430 | nodelist = ','.join(_nodes)
431 |
432 | options = ''
433 | if parsed_uri.get('options'):
434 | _opt_list = []
435 | for key, value in parsed_uri.get('options').items():
436 | if isinstance(value, bool):
437 | value = str(value).lower()
438 | _opt_list.append('{!s}={!s}'.format(key, value))
439 | options = '?' + '&'.join(_opt_list)
440 |
441 | db_name = parsed_uri.get('database') or ''
442 |
443 | res = "mongodb://{user_pass!s}{nodelist!s}/{db_name!s}{options!s}".format(
444 | user_pass=user_pass,
445 | nodelist=nodelist,
446 | db_name=db_name,
447 | # collection is ignored
448 | options=options)
449 | return res
450 |
--------------------------------------------------------------------------------
/src/pyop/subject_identifier.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 |
3 |
4 | class SubjectIdentifierFactory(object):
5 | """
6 | Interface for implementation of an algorithm for creating pairwise subject identifiers, see
7 |
8 | "OpenID Connect Core 1.0", Section 8.1.
9 | """
10 |
11 | def create_public_identifier(self, user_id):
12 | # type: (str) -> str
13 | raise NotImplementedError()
14 |
15 | def create_pairwise_identifier(self, user_id, sector_identifier):
16 | # type: (str, str) -> str
17 | raise NotImplementedError()
18 |
19 |
20 | class HashBasedSubjectIdentifierFactory(object):
21 | """
22 | Implements a hash based algorithm for creating a pairwise subject identifier.
23 | """
24 |
25 | def __init__(self, hash_salt):
26 | # type: (str) -> None
27 | self.hash_salt = hash_salt
28 |
29 | def create_public_identifier(self, user_id):
30 | return self._hash(user_id)
31 |
32 | def create_pairwise_identifier(self, user_id, sector_identifier):
33 | return self._hash(sector_identifier + user_id)
34 |
35 | def _hash(self, data):
36 | # type: (str) -> str
37 | hash_input = data + self.hash_salt
38 | return hashlib.sha256(hash_input.encode('utf-8')).hexdigest()
39 |
--------------------------------------------------------------------------------
/src/pyop/userinfo.py:
--------------------------------------------------------------------------------
1 | class Userinfo(object):
2 | """
3 | Wrapper providing a read-only interface for a database containing user info.
4 |
5 | The backing database must use a local identifier as key, and all userinfo that should be returned in OpenID
6 | Connect ID Tokens or Userinfo Responses must follow the format of OpenID Connect standard claims, see
7 |
8 | "OpenID Connect Core 1.0", Section 5.1
9 | """
10 |
11 | def __init__(self, db):
12 | # type: (Mapping[str, Union[str, List[str]]]) -> None
13 | self._db = db
14 |
15 | def __getitem__(self, item):
16 | return self._db[item]
17 |
18 | def __contains__(self, item):
19 | return item in self._db
20 |
21 | def get_claims_for(self, user_id, requested_claims, userinfo=None):
22 | # type: (str, Mapping[str, Optional[Mapping[str, Union[str, List[str]]]]) -> Dict[str, Union[str, List[str]]]
23 | """
24 | Filter the userinfo based on which claims where requested.
25 | :param user_id: user identifier
26 | :param requested_claims: see
27 | "OpenID Connect Core 1.0", Section 5.5 for structure
28 | :param userinfo: if user_info is specified the claims will be filtered from the user_info directly instead
29 | first querying the storage against the user_id
30 | :return: All requested claims available from the userinfo.
31 | """
32 |
33 | if not userinfo:
34 | userinfo = self._db[user_id] if user_id else {}
35 | claims = {claim: userinfo[claim] for claim in requested_claims if claim in userinfo}
36 | return claims
37 |
--------------------------------------------------------------------------------
/src/pyop/util.py:
--------------------------------------------------------------------------------
1 | def should_fragment_encode(authentication_request):
2 | if authentication_request['response_type'] == ['code']:
3 | # Authorization Code Flow -> query encode
4 | return False
5 |
6 | return True
7 |
8 |
9 | def is_allowed_response_type(response_type, supported_response_types):
10 | return frozenset(response_type) in [frozenset(rt.split()) for rt in supported_response_types]
11 |
12 |
13 | def find_common_values(preference_values, supported_values):
14 | unordered_preference_values = {frozenset(p.split()) for p in preference_values}
15 | unordered_supported_values = {frozenset(s.split()) for s in supported_values}
16 | return unordered_supported_values.intersection(unordered_preference_values)
17 |
18 |
19 | def requested_scope_is_allowed(requested_scope, allowed_scope):
20 | return frozenset(requested_scope).issubset(frozenset(allowed_scope.split()))
21 |
--------------------------------------------------------------------------------
/tests/pyop/test_access_token.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pyop.access_token import extract_bearer_token_from_http_request, BearerTokenError
4 |
5 | ACCESS_TOKEN = 'abcdef'
6 |
7 |
8 | class TestExtractBearerTokenFromHttpRequest(object):
9 | def test_authorization_header(self):
10 | assert extract_bearer_token_from_http_request(authz_header='Bearer {}'.format(ACCESS_TOKEN)) == ACCESS_TOKEN
11 |
12 | def test_non_bearer_authorization_header(self):
13 | with pytest.raises(BearerTokenError):
14 | extract_bearer_token_from_http_request(authz_header='Basic {}'.format(ACCESS_TOKEN))
15 |
16 | def test_access_token_in_request(self):
17 | data = {
18 | 'foo': 'bar',
19 | 'access_token': ACCESS_TOKEN
20 | }
21 | assert extract_bearer_token_from_http_request(data) == ACCESS_TOKEN
22 |
23 | def test_request_without_access_token(self):
24 | data = {
25 | 'foo': 'bar',
26 | }
27 | with pytest.raises(BearerTokenError):
28 | extract_bearer_token_from_http_request(data)
29 |
--------------------------------------------------------------------------------
/tests/pyop/test_authz_state.py:
--------------------------------------------------------------------------------
1 | import datetime as dt
2 | import functools
3 | import time
4 | from unittest.mock import patch, Mock
5 |
6 | import pytest
7 |
8 | from pyop.message import AuthorizationRequest
9 | from pyop.authz_state import AccessToken, InvalidScope
10 | from pyop.authz_state import AuthorizationState
11 | from pyop.exceptions import InvalidSubjectIdentifier, InvalidAccessToken, InvalidAuthorizationCode, InvalidRefreshToken
12 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory
13 | import pyop.storage
14 |
15 | MOCK_TIME = Mock(return_value=time.mktime(dt.datetime(2016, 6, 21).timetuple()))
16 | INVALID_INPUT = [None, '', 'noexist']
17 |
18 |
19 | class TestAuthorizationState(object):
20 | TEST_TOKEN_LIFETIME = 60 * 50 # 50 minutes
21 | TEST_SUBJECT_IDENTIFIER = 'sub'
22 |
23 | def set_valid_subject_identifier(self, authorization_state):
24 | is_valid_sub_mock = Mock()
25 | is_valid_sub_mock.side_effect = lambda sub: sub == self.TEST_SUBJECT_IDENTIFIER
26 | authorization_state._is_valid_subject_identifier = is_valid_sub_mock
27 |
28 | @pytest.fixture
29 | def authorization_request(self):
30 | authn_req = AuthorizationRequest(**{'scope': 'openid', 'client_id': 'client1'})
31 | return authn_req
32 |
33 | @pytest.fixture
34 | def authorization_state_factory(self):
35 | return functools.partial(AuthorizationState, HashBasedSubjectIdentifierFactory('salt'))
36 |
37 | @pytest.fixture
38 | def authorization_state(self, authorization_state_factory):
39 | return authorization_state_factory(refresh_token_lifetime=3600)
40 |
41 | @pytest.fixture
42 | def stateless_storage(self):
43 | return pyop.storage.StatelessWrapper("pyop", "abc123")
44 |
45 | def assert_access_token(self, authorization_request, access_token, access_token_db, iat):
46 | assert isinstance(access_token, AccessToken)
47 | assert access_token.expires_in == self.TEST_TOKEN_LIFETIME
48 | assert access_token.value
49 | assert access_token.BEARER_TOKEN_TYPE == 'Bearer'
50 |
51 | assert access_token.value in access_token_db
52 | self.assert_introspected_token(authorization_request, access_token_db[access_token.value], access_token, iat)
53 | assert access_token_db[access_token.value]['exp'] == iat + self.TEST_TOKEN_LIFETIME
54 |
55 | def assert_introspected_token(self, authorization_request, token_introspection, access_token, iat):
56 | auth_req = authorization_request.to_dict()
57 |
58 | assert token_introspection['scope'] == auth_req['scope']
59 | assert token_introspection['client_id'] == auth_req['client_id']
60 | assert token_introspection['token_type'] == access_token.type
61 | assert token_introspection['sub'] == self.TEST_SUBJECT_IDENTIFIER
62 | assert token_introspection['aud'] == [auth_req['client_id']]
63 | assert token_introspection['iat'] == iat
64 |
65 | @patch('time.time', MOCK_TIME)
66 | def test_create_authorization_code(self, authorization_state_factory, authorization_request):
67 | code_lifetime = 60 * 2 # two minutes
68 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime)
69 | self.set_valid_subject_identifier(authorization_state)
70 |
71 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
72 | assert authz_code in authorization_state.authorization_codes
73 | assert authorization_state.authorization_codes[authz_code]['exp'] == int(time.time()) + code_lifetime
74 | assert authorization_state.authorization_codes[authz_code]['used'] is False
75 | assert authorization_state.authorization_codes[authz_code][AuthorizationState.KEY_AUTHORIZATION_REQUEST] == \
76 | authorization_request.to_dict()
77 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER
78 |
79 | @patch('time.time', MOCK_TIME)
80 | def test_create_authorization_code_with_stateless_storage(self, authorization_state_factory, authorization_request,
81 | stateless_storage):
82 | code_lifetime = 60 * 2 # two minutes
83 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime,
84 | authorization_code_db=stateless_storage)
85 | self.set_valid_subject_identifier(authorization_state)
86 |
87 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
88 | assert authz_code in authorization_state.authorization_codes
89 | assert authorization_state.authorization_codes[authz_code]['exp'] == int(time.time()) + code_lifetime
90 | assert authorization_state.authorization_codes[authz_code]['used'] is False
91 | assert authorization_state.authorization_codes[authz_code][AuthorizationState.KEY_AUTHORIZATION_REQUEST] == \
92 | authorization_request.to_dict()
93 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER
94 |
95 | def test_create_authorization_code_with_scope_other_than_auth_req(self, authorization_state, authorization_request):
96 | scope = ['openid', 'extra']
97 | self.set_valid_subject_identifier(authorization_state)
98 |
99 | authz_code = authorization_state.create_authorization_code(authorization_request,
100 | self.TEST_SUBJECT_IDENTIFIER, scope=scope)
101 | assert authorization_state.authorization_codes[authz_code]['granted_scope'] == ' '.join(scope)
102 |
103 | @pytest.mark.parametrize('sub', INVALID_INPUT)
104 | def test_create_authorization_code_with_invalid_subject_identifier(self, sub, authorization_state,
105 | authorization_request):
106 | with pytest.raises(InvalidSubjectIdentifier):
107 | authorization_state.create_authorization_code(authorization_request, sub)
108 |
109 | @patch('time.time', MOCK_TIME)
110 | def test_create_access_token(self, authorization_state_factory, authorization_request):
111 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME)
112 | self.set_valid_subject_identifier(authorization_state)
113 |
114 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
115 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens,
116 | MOCK_TIME.return_value)
117 |
118 | @patch('time.time', MOCK_TIME)
119 | def test_create_access_token_with_stateless_storage(self, authorization_state_factory, authorization_request,
120 | stateless_storage):
121 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME,
122 | access_token_db=stateless_storage)
123 | self.set_valid_subject_identifier(authorization_state)
124 |
125 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
126 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens,
127 | MOCK_TIME.return_value)
128 |
129 | def test_create_access_token_with_scope_other_than_auth_req(self, authorization_state, authorization_request):
130 | scope = ['openid', 'extra']
131 | self.set_valid_subject_identifier(authorization_state)
132 |
133 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER,
134 | scope=scope)
135 | assert authorization_state.access_tokens[access_token.value]['scope'] == ' '.join(scope)
136 |
137 | @pytest.mark.parametrize('sub', INVALID_INPUT)
138 | def test_create_access_token_with_invalid_subject_identifier(self, sub, authorization_state, authorization_request):
139 | self.set_valid_subject_identifier(authorization_state)
140 | with pytest.raises(InvalidSubjectIdentifier):
141 | authorization_state.create_access_token(authorization_request, sub)
142 |
143 | @patch('time.time', MOCK_TIME)
144 | def test_introspect_access_token(self, authorization_state_factory, authorization_request):
145 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME)
146 | self.set_valid_subject_identifier(authorization_state)
147 |
148 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
149 | token_introspection = authorization_state.introspect_access_token(access_token.value)
150 | assert token_introspection['active'] is True
151 | self.assert_introspected_token(authorization_request, token_introspection, access_token, MOCK_TIME.return_value)
152 |
153 | def test_introspect_access_token_with_expired_token(self, authorization_state_factory, authorization_request):
154 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME)
155 | self.set_valid_subject_identifier(authorization_state)
156 |
157 | with patch('time.time', MOCK_TIME):
158 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
159 |
160 | mock_time2 = Mock()
161 | mock_time2.return_value = MOCK_TIME.return_value + self.TEST_TOKEN_LIFETIME + 1 # time after token expiration
162 | with patch('time.time', mock_time2):
163 | token_introspection = authorization_state.introspect_access_token(access_token.value)
164 | assert token_introspection['active'] is False
165 | self.assert_introspected_token(authorization_request, token_introspection, access_token, MOCK_TIME.return_value)
166 |
167 | @pytest.mark.parametrize('access_token', INVALID_INPUT)
168 | def test_introspect_access_token_with_invalid_access_token(self, access_token, authorization_state):
169 | with pytest.raises(InvalidAccessToken):
170 | authorization_state.introspect_access_token(access_token)
171 |
172 | @pytest.mark.parametrize('authz_code', INVALID_INPUT)
173 | def test_exchange_code_for_token_with_invalid_code(self, authz_code, authorization_state):
174 | with pytest.raises(InvalidAuthorizationCode):
175 | authorization_state.exchange_code_for_token(authz_code)
176 |
177 | @patch('time.time', MOCK_TIME)
178 | def test_exchange_code_for_token(self, authorization_state_factory, authorization_request):
179 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME)
180 | self.set_valid_subject_identifier(authorization_state)
181 |
182 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
183 | access_token = authorization_state.exchange_code_for_token(authz_code)
184 |
185 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens, MOCK_TIME.return_value)
186 | assert authorization_state.authorization_codes[authz_code]['used'] == True
187 |
188 | @patch('time.time', MOCK_TIME)
189 | def test_exchange_code_for_token_with_stateless_storage(self, authorization_state_factory, authorization_request,
190 | stateless_storage):
191 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME,
192 | authorization_code_db=stateless_storage,
193 | access_token_db=stateless_storage)
194 | self.set_valid_subject_identifier(authorization_state)
195 |
196 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
197 | access_token = authorization_state.exchange_code_for_token(authz_code)
198 |
199 | self.assert_access_token(authorization_request, access_token, authorization_state.access_tokens,
200 | MOCK_TIME.return_value)
201 | assert authorization_state.authorization_codes[authz_code]['sub'] == self.TEST_SUBJECT_IDENTIFIER
202 | assert authorization_state.authorization_codes[authz_code]['used'] == False
203 |
204 | def test_exchange_code_for_token_with_scope_other_than_auth_req(self, authorization_state,
205 | authorization_request):
206 | scope = ['openid', 'extra']
207 | self.set_valid_subject_identifier(authorization_state)
208 |
209 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER,
210 | scope=scope)
211 | access_token = authorization_state.exchange_code_for_token(authz_code)
212 |
213 | assert authorization_state.access_tokens[access_token.value]['scope'] == ' '.join(scope)
214 |
215 | def test_exchange_code_for_token_with_expired_token(self, authorization_state_factory, authorization_request):
216 | code_lifetime = 2
217 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime)
218 | self.set_valid_subject_identifier(authorization_state)
219 |
220 | with patch('time.time', MOCK_TIME):
221 | authz_code = authorization_state.create_authorization_code(authorization_request,
222 | self.TEST_SUBJECT_IDENTIFIER)
223 |
224 | time_mock = Mock()
225 | time_mock.return_value = MOCK_TIME.return_value + code_lifetime + 1 # time after code expiration
226 | with patch('time.time', time_mock), pytest.raises(InvalidAuthorizationCode):
227 | authorization_state.exchange_code_for_token(authz_code)
228 |
229 | def test_exchange_code_for_token_with_used_token(self, authorization_state_factory, authorization_request):
230 | code_lifetime = 2
231 | authorization_state = authorization_state_factory(authorization_code_lifetime=code_lifetime)
232 | self.set_valid_subject_identifier(authorization_state)
233 |
234 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
235 | assert authorization_state.exchange_code_for_token(authz_code) # successful use once
236 | with pytest.raises(InvalidAuthorizationCode):
237 | authorization_state.exchange_code_for_token(authz_code) # fail on second use
238 |
239 | def test_create_refresh_token(self, authorization_state, authorization_request):
240 | self.set_valid_subject_identifier(authorization_state)
241 |
242 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
243 | refresh_token = authorization_state.create_refresh_token(access_token.value)
244 |
245 | assert refresh_token in authorization_state.refresh_tokens
246 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value
247 | assert 'exp' in authorization_state.refresh_tokens[refresh_token]
248 |
249 | def test_create_refresh_token_with_stateless_storage(self, authorization_state_factory, authorization_request,
250 | stateless_storage):
251 | authorization_state = authorization_state_factory(refresh_token_lifetime=3600,
252 | authorization_code_db=stateless_storage,
253 | access_token_db=stateless_storage,
254 | refresh_token_db=stateless_storage)
255 | self.set_valid_subject_identifier(authorization_state)
256 |
257 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
258 | refresh_token = authorization_state.create_refresh_token(access_token.value)
259 |
260 | assert refresh_token in authorization_state.refresh_tokens
261 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value
262 | assert 'exp' in authorization_state.refresh_tokens[refresh_token]
263 |
264 | def test_create_refresh_token_issues_no_refresh_token_if_no_lifetime_is_specified(self, authorization_state_factory,
265 | authorization_request):
266 | authorization_state = authorization_state_factory(refresh_token_lifetime=None)
267 | self.set_valid_subject_identifier(authorization_state)
268 |
269 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
270 | refresh_token = authorization_state.create_refresh_token(access_token.value)
271 |
272 | assert refresh_token is None
273 |
274 | @pytest.mark.parametrize('access_token', INVALID_INPUT)
275 | def test_create_refresh_token_with_invalid_access_token_value(self, access_token, authorization_state):
276 | with pytest.raises(InvalidAccessToken):
277 | authorization_state.create_refresh_token(access_token)
278 |
279 | @patch('time.time', MOCK_TIME)
280 | def test_create_refresh_token_with_expiration_time(self, authorization_state_factory, authorization_request):
281 | refresh_token_lifetime = 60 * 60 * 24 # 24 hours
282 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_token_lifetime)
283 | self.set_valid_subject_identifier(authorization_state)
284 |
285 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
286 | refresh_token = authorization_state.create_refresh_token(access_token.value)
287 |
288 | assert refresh_token in authorization_state.refresh_tokens
289 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == access_token.value
290 | assert authorization_state.refresh_tokens[refresh_token]['exp'] == int(time.time()) + refresh_token_lifetime
291 |
292 | def test_use_refresh_token(self, authorization_state_factory, authorization_request):
293 | authorization_state = authorization_state_factory(access_token_lifetime=self.TEST_TOKEN_LIFETIME,
294 | refresh_token_lifetime=300)
295 | self.set_valid_subject_identifier(authorization_state)
296 |
297 | with patch('time.time', MOCK_TIME):
298 | old_access_token = authorization_state.create_access_token(authorization_request,
299 | self.TEST_SUBJECT_IDENTIFIER)
300 | refresh_token = authorization_state.create_refresh_token(old_access_token.value)
301 |
302 | mock_time2 = Mock()
303 | mock_time2.return_value = MOCK_TIME.return_value + 100
304 | with patch('time.time', mock_time2):
305 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token)
306 |
307 | assert new_refresh_token is None
308 | assert new_access_token.value != old_access_token.value
309 | assert new_access_token.type == old_access_token.type
310 |
311 | assert authorization_state.access_tokens[new_access_token.value]['exp'] > \
312 | authorization_state.access_tokens[old_access_token.value]['exp']
313 | assert authorization_state.access_tokens[new_access_token.value]['iat'] > \
314 | authorization_state.access_tokens[old_access_token.value]['iat']
315 | self.assert_access_token(authorization_request, new_access_token, authorization_state.access_tokens, mock_time2.return_value)
316 |
317 | assert authorization_state.refresh_tokens[refresh_token]['access_token'] == new_access_token.value
318 |
319 | @pytest.mark.parametrize('refresh_token', INVALID_INPUT)
320 | def test_use_refresh_token_with_invalid_refresh_token(self, refresh_token, authorization_state):
321 | with pytest.raises(InvalidRefreshToken):
322 | authorization_state.use_refresh_token(refresh_token)
323 |
324 | def test_use_refresh_token_issues_new_refresh_token_if_the_old_is_close_to_expiration(
325 | self, authorization_state_factory, authorization_request):
326 | refresh_threshold = 3600
327 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_threshold * 2,
328 | refresh_token_threshold=refresh_threshold)
329 | self.set_valid_subject_identifier(authorization_state)
330 |
331 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
332 | refresh_token = authorization_state.create_refresh_token(old_access_token.value)
333 |
334 | close_to_expiration = int(time.time()) + authorization_state.refresh_token_lifetime - 50
335 | with patch('time.time', Mock(return_value=close_to_expiration)):
336 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token)
337 |
338 | assert new_refresh_token is not None
339 | assert new_refresh_token in authorization_state.refresh_tokens
340 | assert authorization_state.refresh_tokens[new_refresh_token]['access_token'] == new_access_token.value
341 |
342 | def test_use_refresh_token_doesnt_issue_new_refresh_token_if_the_old_is_far_from_expiration(
343 | self, authorization_state_factory, authorization_request):
344 | refresh_threshold = 3600
345 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_threshold * 2,
346 | refresh_token_threshold=refresh_threshold)
347 |
348 | self.set_valid_subject_identifier(authorization_state)
349 |
350 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
351 | refresh_token = authorization_state.create_refresh_token(old_access_token.value)
352 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token)
353 |
354 | assert new_refresh_token is None
355 |
356 | def test_use_refresh_token_doesnt_issue_new_refresh_token_if_no_refresh_token_threshold_is_set(
357 | self, authorization_state_factory, authorization_request):
358 | authorization_state = authorization_state_factory(refresh_token_lifetime=400)
359 |
360 | self.set_valid_subject_identifier(authorization_state)
361 |
362 | old_access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
363 | refresh_token = authorization_state.create_refresh_token(old_access_token.value)
364 | new_access_token, new_refresh_token = authorization_state.use_refresh_token(refresh_token)
365 |
366 | assert new_refresh_token is None
367 |
368 | def test_use_refresh_token_with_expired_refresh_token(self, authorization_state_factory, authorization_request):
369 | refresh_token_lifetime = 2
370 | authorization_state = authorization_state_factory(refresh_token_lifetime=refresh_token_lifetime)
371 | self.set_valid_subject_identifier(authorization_state)
372 |
373 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
374 | with patch('time.time', MOCK_TIME):
375 | refresh_token = authorization_state.create_refresh_token(access_token.value)
376 |
377 | time_mock = Mock()
378 | time_mock.return_value = MOCK_TIME.return_value + refresh_token_lifetime + 1 # time after refresh_token expiration
379 | with patch('time.time', time_mock), pytest.raises(InvalidRefreshToken):
380 | authorization_state.use_refresh_token(refresh_token)
381 |
382 | def test_use_refresh_token_with_superset_scope(self, authorization_state, authorization_request):
383 | self.set_valid_subject_identifier(authorization_state)
384 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
385 | refresh_token = authorization_state.create_refresh_token(access_token.value)
386 | with pytest.raises(InvalidScope):
387 | authorization_state.use_refresh_token(refresh_token, scope=['openid', 'extra'])
388 |
389 | def test_use_refresh_token_with_subset_scope(self, authorization_state, authorization_request):
390 | self.set_valid_subject_identifier(authorization_state)
391 | authorization_request['scope'] = 'openid profile'
392 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
393 | refresh_token = authorization_state.create_refresh_token(access_token.value)
394 | access_token, _ = authorization_state.use_refresh_token(refresh_token, scope=['openid'])
395 |
396 | assert authorization_state.access_tokens[access_token.value]['scope'] == 'openid'
397 |
398 | def test_use_refresh_token_with_subset_scope_does_not_minimize_granted_scope(self, authorization_state,
399 | authorization_request):
400 | scope = 'openid profile'
401 | self.set_valid_subject_identifier(authorization_state)
402 | authorization_request['scope'] = scope
403 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
404 | refresh_token = authorization_state.create_refresh_token(access_token.value)
405 |
406 | # first time: issue access token with subset of granted scope
407 | access_token, _ = authorization_state.use_refresh_token(refresh_token, scope=['openid'])
408 | assert authorization_state.access_tokens[access_token.value]['scope'] == 'openid'
409 |
410 | # second time: issue access token with exactly granted scope
411 | access_token, _ = authorization_state.use_refresh_token(refresh_token)
412 | assert authorization_state.access_tokens[access_token.value]['scope'] == scope
413 |
414 | def test_use_refresh_token_without_scope(self, authorization_state, authorization_request):
415 | self.set_valid_subject_identifier(authorization_state)
416 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
417 | refresh_token = authorization_state.create_refresh_token(access_token.value)
418 | access_token, _ = authorization_state.use_refresh_token(refresh_token)
419 |
420 | assert authorization_state.access_tokens[access_token.value]['scope'] == \
421 | ' '.join(authorization_request['scope'])
422 |
423 | def test_create_subject_identifier_public(self, authorization_state):
424 | user_id = 'test_user'
425 | sub1 = authorization_state.get_subject_identifier('public', user_id)
426 | sub2 = authorization_state.get_subject_identifier('public', user_id)
427 | assert sub1 == sub2
428 | assert authorization_state.subject_identifiers[user_id]['public'] == sub1
429 |
430 | def test_create_subject_identifier_pairwise_with_diffent_redirect_uris(self, authorization_state):
431 | user_id = 'test_user'
432 | sector_identifier1 = 'client1.example.com'
433 | sector_identifier2 = 'client2.example.com'
434 | sub1 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier1)
435 | sub2 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier2)
436 | assert sub1 != sub2
437 | assert all(s in authorization_state.subject_identifiers[user_id]['pairwise'] for s in [sub1, sub2])
438 |
439 | def test_create_subject_identifier_pairwise_with_same_hostname(self, authorization_state):
440 | user_id = 'test_user'
441 | sector_identifier = 'client.example.com'
442 | sub1 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier)
443 | sub2 = authorization_state.get_subject_identifier('pairwise', user_id, sector_identifier)
444 | assert sub1 == sub2
445 | assert sub1 in authorization_state.subject_identifiers[user_id]['pairwise']
446 |
447 | def test_create_subject_identifier_pairwise_without_sector_identifier(self, authorization_state):
448 | with pytest.raises(ValueError):
449 | authorization_state.get_subject_identifier('pairwise', 'test_user', None)
450 |
451 | def test_create_subject_identifier_with_unknown_subject_type(self, authorization_state):
452 | with pytest.raises(ValueError):
453 | authorization_state.get_subject_identifier('unknown', 'test_user', None)
454 |
455 | @pytest.mark.parametrize('subject_type', [
456 | 'public',
457 | 'pairwise'
458 | ])
459 | def test_is_valid_subject_identifier(self, subject_type, authorization_state):
460 | sub = authorization_state.get_subject_identifier(subject_type, 'test_user', 'client.example.com')
461 | assert authorization_state._is_valid_subject_identifier(sub) is True
462 |
463 | def test_get_authentication_request_for_code(self, authorization_state, authorization_request):
464 | self.set_valid_subject_identifier(authorization_state)
465 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
466 | request = authorization_state.get_authorization_request_for_code(authz_code)
467 | assert request.to_dict() == authorization_request.to_dict()
468 |
469 | def test_get_authentication_request_for_access_token(self, authorization_state, authorization_request):
470 | self.set_valid_subject_identifier(authorization_state)
471 | access_token = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
472 | request = authorization_state.get_authorization_request_for_access_token(access_token.value)
473 | assert request.to_dict() == authorization_request.to_dict()
474 |
475 | def test_get_subject_identifier_for_code(self, authorization_state, authorization_request):
476 | self.set_valid_subject_identifier(authorization_state)
477 | authz_code = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
478 | sub = authorization_state.get_subject_identifier_for_code(authz_code)
479 | assert sub == self.TEST_SUBJECT_IDENTIFIER
480 |
481 | def test_remove_state_for_subject_identifier(self, authorization_state, authorization_request):
482 | self.set_valid_subject_identifier(authorization_state)
483 | authz_code1 = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
484 | authz_code2 = authorization_state.create_authorization_code(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
485 | access_token1 = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
486 | access_token2 = authorization_state.create_access_token(authorization_request, self.TEST_SUBJECT_IDENTIFIER)
487 |
488 | authorization_state.delete_state_for_subject_identifier(self.TEST_SUBJECT_IDENTIFIER)
489 |
490 | for ac in [authz_code1, authz_code2]:
491 | assert ac not in authorization_state.authorization_codes
492 | for at in [access_token1, access_token2]:
493 | assert at.value not in authorization_state.access_tokens
494 |
495 | def test_remove_state_for_unknown_subject_identifier(self, authorization_state):
496 | with pytest.raises(InvalidSubjectIdentifier):
497 | authorization_state.delete_state_for_subject_identifier('unknown')
498 |
--------------------------------------------------------------------------------
/tests/pyop/test_client_authentication.py:
--------------------------------------------------------------------------------
1 | import base64
2 |
3 | import pytest
4 |
5 | from pyop.client_authentication import verify_client_authentication
6 | from pyop.exceptions import InvalidClientAuthentication
7 |
8 | TEST_CLIENT_ID = 'client1'
9 | TEST_CLIENT_SECRET = 'my_secret'
10 |
11 |
12 | class TestVerifyClientAuthentication(object):
13 | def create_basic_auth(self, client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET):
14 | credentials = client_id + ':' + client_secret
15 | auth = base64.urlsafe_b64encode(credentials.encode('utf-8')).decode('utf-8')
16 | return 'Basic {}'.format(auth)
17 |
18 | @pytest.fixture(autouse=True)
19 | def create_request_args(self):
20 | self.token_request_args = {
21 | 'grant_type': 'authorization_code',
22 | 'code': 'code',
23 | 'redirect_uri': 'https://client.example.com',
24 | 'client_id': TEST_CLIENT_ID,
25 | 'client_secret': TEST_CLIENT_SECRET
26 | }
27 |
28 | self.clients = {
29 | TEST_CLIENT_ID: {
30 | 'client_secret': TEST_CLIENT_SECRET,
31 | 'token_endpoint_auth_method': 'client_secret_post'
32 | }
33 | }
34 |
35 | def test_wrong_authentication_method(self):
36 | # do client_secret_basic, while client_secret_post is expected
37 | authz_header = self.create_basic_auth()
38 | with pytest.raises(InvalidClientAuthentication):
39 | verify_client_authentication(self.clients, None, authz_header)
40 |
41 | def test_authentication_method_defaults_to_client_secret_basic(self):
42 | del self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method']
43 | authz_header = self.create_basic_auth()
44 | assert verify_client_authentication(self.clients, self.token_request_args, authz_header) == TEST_CLIENT_ID
45 |
46 | def test_client_secret_post(self):
47 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'client_secret_post'
48 | assert verify_client_authentication(self.clients, self.token_request_args) == TEST_CLIENT_ID
49 |
50 | def test_client_secret_basic(self):
51 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'client_secret_basic'
52 | authz_header = self.create_basic_auth()
53 | assert verify_client_authentication(self.clients, self.token_request_args, authz_header) == TEST_CLIENT_ID
54 |
55 | def test_unknown_client_id(self):
56 | self.token_request_args['client_id'] = 'unknown'
57 | with pytest.raises(InvalidClientAuthentication):
58 | verify_client_authentication(self.clients, self.token_request_args) == TEST_CLIENT_ID
59 |
60 | def test_wrong_client_secret(self):
61 | self.token_request_args['client_secret'] = 'foobar'
62 | with pytest.raises(InvalidClientAuthentication):
63 | verify_client_authentication(self.clients, self.token_request_args)
64 |
65 | def test_public_client_no_auth(self):
66 | del self.token_request_args['client_secret']
67 | # public client
68 | self.clients[TEST_CLIENT_ID]['token_endpoint_auth_method'] = 'none'
69 | del self.clients[TEST_CLIENT_ID]['client_secret']
70 |
71 | assert verify_client_authentication(self.clients, self.token_request_args, None) == TEST_CLIENT_ID
72 |
73 | def test_invalid_authorization_scheme(self):
74 | authz_header = self.create_basic_auth()
75 | with pytest.raises(InvalidClientAuthentication):
76 | verify_client_authentication(self.clients, self.token_request_args,
77 | authz_header.replace('Basic', 'invalid'))
78 |
79 | def test_invalid_userid_password(self):
80 | with pytest.raises(InvalidClientAuthentication):
81 | verify_client_authentication(self.clients, self.token_request_args, 'Basic invalid')
82 |
--------------------------------------------------------------------------------
/tests/pyop/test_exceptions.py:
--------------------------------------------------------------------------------
1 | from urllib.parse import urlparse, parse_qsl
2 |
3 | from pyop.message import AuthorizationRequest
4 | from pyop.exceptions import InvalidAuthenticationRequest
5 |
6 |
7 | class TestInvalidAuthenticationRequest:
8 | def test_error_url_should_contain_state_from_authentication_request(self):
9 | authn_params = {'redirect_uri': 'test_redirect_uri', 'response_type': 'code', 'state': 'test_state'}
10 | authn_req = AuthorizationRequest().from_dict(authn_params)
11 | error_url = InvalidAuthenticationRequest('test', authn_req, 'invalid_request').to_error_url()
12 |
13 | error = dict(parse_qsl(urlparse(error_url).query))
14 | assert error['state'] == authn_params['state']
15 |
16 | def test_error_url_should_handle_unknown_response_type(self):
17 | authn_params = {'redirect_uri': 'test_redirect_uri', 'state': 'test_state'} # no 'response_type'
18 | authn_req = AuthorizationRequest().from_dict(authn_params)
19 |
20 | error = InvalidAuthenticationRequest('test', authn_req, 'invalid_request')
21 | assert error.to_error_url() is None
--------------------------------------------------------------------------------
/tests/pyop/test_provider.py:
--------------------------------------------------------------------------------
1 | import datetime as dt
2 | import json
3 | import time
4 | from collections import Counter
5 | from unittest.mock import Mock, patch
6 | from urllib.parse import urlencode
7 |
8 | import pytest
9 | from Cryptodome.PublicKey import RSA
10 | from jwkest import jws
11 | from jwkest.jwk import RSAKey
12 | from oic import rndstr
13 | from oic.oauth2.message import MissingRequiredValue, MissingRequiredAttribute
14 | from oic.oic import PREFERENCE2PROVIDER
15 | from oic.oic.message import IdToken, ClaimsRequest, Claims, EndSessionRequest, EndSessionResponse
16 |
17 | from pyop.message import AuthorizationRequest
18 | from pyop.access_token import BearerTokenError
19 | from pyop.authz_state import AuthorizationState
20 | from pyop.client_authentication import InvalidClientAuthentication
21 | from pyop.exceptions import InvalidAuthenticationRequest, AuthorizationError, InvalidTokenRequest, \
22 | InvalidClientRegistrationRequest, InvalidAccessToken, InvalidAuthorizationCode, InvalidSubjectIdentifier
23 | from pyop.provider import Provider, redirect_uri_is_in_registered_redirect_uris, \
24 | response_type_is_in_registered_response_types
25 | from pyop.subject_identifier import HashBasedSubjectIdentifierFactory
26 | from pyop.userinfo import Userinfo
27 |
28 | TEST_CLIENT_ID = 'client1'
29 | TEST_CLIENT_SECRET = 'secret'
30 | INVALID_TEST_CLIENT_ID = 'invalid_client'
31 | INVALID_TEST_CLIENT_SECRET = 'invalid_secret'
32 | TEST_REDIRECT_URI = 'https://client.example.com'
33 | ISSUER = 'https://provider.example.com'
34 | TEST_USER_ID = 'user1'
35 |
36 | MOCK_TIME = Mock(return_value=time.mktime(dt.datetime(2016, 6, 21).timetuple()))
37 |
38 |
39 | def rsa_key():
40 | return RSAKey(key=RSA.generate(1024), use="sig", alg="RS256", kid=rndstr(4))
41 |
42 |
43 | def provide_configuration():
44 | conf = {
45 | 'issuer': ISSUER,
46 | 'jwks_uri': '/jwks',
47 | 'authorization_endpoint': '/authorization',
48 | 'token_endpoint': '/token'
49 | }
50 | return conf
51 |
52 |
53 | def assert_id_token_base_claims(jws, verification_key, provider, auth_req):
54 | id_token = IdToken().from_jwt(jws, key=[verification_key])
55 | assert id_token['nonce'] == auth_req['nonce']
56 | assert id_token['iss'] == ISSUER
57 | assert provider.authz_state.get_user_id_for_subject_identifier(id_token['sub']) == TEST_USER_ID
58 | assert id_token['iat'] == MOCK_TIME.return_value
59 | assert id_token['exp'] == id_token['iat'] + provider.id_token_lifetime
60 | assert TEST_CLIENT_ID in id_token['aud']
61 |
62 | return id_token
63 |
64 |
65 | @pytest.fixture
66 | def auth_req_args(request):
67 | request.instance.authn_request_args = {
68 | 'scope': 'openid',
69 | 'response_type': 'code',
70 | 'client_id': TEST_CLIENT_ID,
71 | 'redirect_uri': TEST_REDIRECT_URI,
72 | 'state': 'state',
73 | 'nonce': 'nonce'
74 | }
75 |
76 |
77 | @pytest.fixture
78 | def inject_provider(request):
79 | clients = {
80 | TEST_CLIENT_ID: {
81 | 'subject_type': 'pairwise',
82 | 'redirect_uris': [TEST_REDIRECT_URI],
83 | 'response_types': ['code'],
84 | 'client_secret': TEST_CLIENT_SECRET,
85 | 'token_endpoint_auth_method': 'client_secret_post',
86 | 'post_logout_redirect_uris': ['https://client.example.com/post_logout']
87 | },
88 | INVALID_TEST_CLIENT_ID: {
89 | 'subject_type': 'pairwise',
90 | 'redirect_uris': 'http://invalid.redirect.loc',
91 | 'response_types': ['code'],
92 | 'client_secret': INVALID_TEST_CLIENT_SECRET,
93 | 'token_endpoint_auth_method': 'client_secret_post',
94 | 'post_logout_redirect_uris': ['https://client.example.com/post_logout']
95 | }
96 | }
97 |
98 | userinfo = Userinfo({
99 | TEST_USER_ID: {
100 | 'name': 'The T. Tester',
101 | 'family_name': 'Tester',
102 | 'given_name': 'The',
103 | 'middle_name': 'Theodore',
104 | 'nickname': 'testster',
105 | 'email': 'testster@example.com',
106 | }
107 | })
108 | request.instance.provider = Provider(rsa_key(), provide_configuration(),
109 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')),
110 | clients, userinfo)
111 |
112 |
113 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args')
114 | class TestProviderParseAuthenticationRequest(object):
115 | def test_parse_authentication_request(self):
116 | nonce = 'nonce'
117 | self.authn_request_args['nonce'] = nonce
118 |
119 | received_request = self.provider.parse_authentication_request(urlencode(self.authn_request_args))
120 | assert received_request.to_dict() == self.authn_request_args
121 |
122 | def test_reject_request_with_missing_required_parameter(self):
123 | del self.authn_request_args['redirect_uri']
124 |
125 | with pytest.raises(InvalidAuthenticationRequest) as exc:
126 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
127 | assert isinstance(exc.value.__cause__, MissingRequiredAttribute)
128 |
129 | def test_reject_request_with_scope_without_openid(self):
130 | self.authn_request_args['scope'] = 'foobar' # does not contain 'openid'
131 |
132 | with pytest.raises(InvalidAuthenticationRequest) as exc:
133 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
134 | assert isinstance(exc.value.__cause__, MissingRequiredValue)
135 |
136 | def test_custom_validation_hook_reject(self):
137 | class TestException(Exception):
138 | pass
139 |
140 | def fail_all_requests(auth_req):
141 | raise InvalidAuthenticationRequest("Test exception", auth_req) from TestException()
142 |
143 | self.provider.authentication_request_validators.append(fail_all_requests)
144 | with pytest.raises(InvalidAuthenticationRequest) as exc:
145 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
146 |
147 | assert isinstance(exc.value.__cause__, TestException)
148 |
149 | def test_redirect_uri_not_matching_registered_redirect_uris(self):
150 | self.authn_request_args['redirect_uri'] = 'https://something.example.com'
151 | with pytest.raises(InvalidAuthenticationRequest):
152 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
153 |
154 | def test_response_type_not_matching_registered_response_types(self):
155 | self.authn_request_args['response_type'] = 'id_token'
156 | with pytest.raises(InvalidAuthenticationRequest):
157 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
158 |
159 | def test_unknown_client_id(self):
160 | self.authn_request_args['client_id'] = 'unknown'
161 | with pytest.raises(InvalidAuthenticationRequest):
162 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
163 |
164 | def test_include_userinfo_claims_request_with_response_type_id_token(self):
165 | self.authn_request_args['claims'] = ClaimsRequest(userinfo=Claims(nickname=None)).to_json()
166 | self.provider.clients[TEST_CLIENT_ID]['response_types'] = ['id_token']
167 | self.authn_request_args['response_type'] = 'id_token'
168 | with pytest.raises(InvalidAuthenticationRequest):
169 | self.provider.parse_authentication_request(urlencode(self.authn_request_args))
170 |
171 |
172 | @pytest.mark.usefixtures('auth_req_args')
173 | class TestAuthenticationRequestValidators(object):
174 | @pytest.fixture
175 | def provider_mock(self):
176 | provider = Mock()
177 | provider.clients = Mock()
178 | provider.clients.__getitem__ = Mock(return_value={})
179 | return provider
180 |
181 | def test_redirect_uri_is_in_registered_redirect_uris_with_no_redirect_uris(self, provider_mock):
182 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
183 | with pytest.raises(InvalidAuthenticationRequest):
184 | redirect_uri_is_in_registered_redirect_uris(provider_mock, auth_req)
185 |
186 | def test_response_type_is_in_registered_response_types_with_no_response_types(self, provider_mock):
187 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
188 | with pytest.raises(InvalidAuthenticationRequest):
189 | response_type_is_in_registered_response_types(provider_mock, auth_req)
190 |
191 |
192 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args')
193 | class TestProviderAuthorize(object):
194 | def test_authorize(self):
195 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
196 | resp = self.provider.authorize(auth_req, TEST_USER_ID)
197 | assert resp['code'] in self.provider.authz_state.authorization_codes
198 | assert resp['state'] == self.authn_request_args['state']
199 |
200 | def test_authorize_with_custom_sub(self, monkeypatch):
201 | sub = 'test_sub1'
202 | monkeypatch.setitem(self.provider.userinfo._db[TEST_USER_ID], 'sub', sub)
203 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
204 | resp = self.provider.authorize(auth_req, TEST_USER_ID)
205 | assert resp['code'] in self.provider.authz_state.authorization_codes
206 | assert resp['state'] == self.authn_request_args['state']
207 | assert self.provider.authz_state.authorization_codes[resp['code']]['sub'] == sub
208 |
209 | @patch('time.time', MOCK_TIME)
210 | @pytest.mark.parametrize('extra_claims', [
211 | {'foo': 'bar'},
212 | lambda user_id, client_id: {'foo': 'bar'}
213 | ])
214 | def test_authorize_with_extra_id_token_claims(self, extra_claims):
215 | self.authn_request_args['response_type'] = ['id_token'] # make sure ID Token is produced
216 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
217 | resp = self.provider.authorize(auth_req, TEST_USER_ID, extra_claims)
218 | id_token = assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, auth_req)
219 | assert id_token['foo'] == 'bar'
220 |
221 | def test_authorize_include_user_claims_from_scope_in_id_token_if_no_userinfo_req_can_be_made(self):
222 | self.authn_request_args['response_type'] = 'id_token'
223 | self.authn_request_args['scope'] = 'openid profile'
224 | self.authn_request_args['claims'] = ClaimsRequest(id_token=Claims(email={'essential': True}))
225 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
226 | resp = self.provider.authorize(auth_req, TEST_USER_ID)
227 |
228 | id_token = IdToken().from_jwt(resp['id_token'], key=[self.provider.signing_key])
229 | # verify all claims are part of the ID Token
230 | assert all(id_token[claim] == value for claim, value in self.provider.userinfo[TEST_USER_ID].items())
231 |
232 | @patch('time.time', MOCK_TIME)
233 | def test_authorize_includes_requested_id_token_claims_even_if_token_request_can_be_made(self):
234 | self.authn_request_args['response_type'] = ['id_token', 'token']
235 | self.authn_request_args['claims'] = ClaimsRequest(id_token=Claims(email=None))
236 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
237 | resp = self.provider.authorize(auth_req, TEST_USER_ID)
238 | id_token = assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, auth_req)
239 | assert id_token['email'] == self.provider.userinfo[TEST_USER_ID]['email']
240 |
241 | @patch('time.time', MOCK_TIME)
242 | def test_hybrid_flow(self):
243 | self.authn_request_args['response_type'] = 'code id_token token'
244 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
245 | resp = self.provider.authorize(auth_req, TEST_USER_ID, extra_id_token_claims={'foo': 'bar'})
246 |
247 | assert resp['state'] == self.authn_request_args['state']
248 | assert resp['code'] in self.provider.authz_state.authorization_codes
249 |
250 | assert resp['access_token'] in self.provider.authz_state.access_tokens
251 | assert resp['expires_in'] == self.provider.authz_state.access_token_lifetime
252 | assert resp['token_type'] == 'Bearer'
253 |
254 | id_token = IdToken().from_jwt(resp['id_token'], key=[self.provider.signing_key])
255 | assert_id_token_base_claims(resp['id_token'], self.provider.signing_key, self.provider, self.authn_request_args)
256 | assert id_token["c_hash"] == jws.left_hash(resp['code'].encode('utf-8'), 'HS256')
257 | assert id_token["at_hash"] == jws.left_hash(resp['access_token'].encode('utf-8'), 'HS256')
258 | assert id_token['foo'] == 'bar'
259 |
260 | @pytest.mark.parametrize('claims_location', [
261 | 'id_token',
262 | 'userinfo'
263 | ])
264 | def test_with_requested_sub_not_matching(self, claims_location):
265 | self.authn_request_args['claims'] = ClaimsRequest(**{claims_location: Claims(sub={'value': 'nomatch'})})
266 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
267 | with pytest.raises(AuthorizationError):
268 | self.provider.authorize(auth_req, TEST_USER_ID)
269 |
270 | def test_with_multiple_requested_sub(self):
271 | self.authn_request_args['claims'] = ClaimsRequest(userinfo=Claims(sub={'value': 'nomatch1'}),
272 | id_token=Claims(sub={'value': 'nomatch2'}))
273 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
274 | with pytest.raises(AuthorizationError) as exc:
275 | self.provider.authorize(auth_req, TEST_USER_ID)
276 |
277 | assert 'different' in str(exc.value)
278 |
279 |
280 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args')
281 | class TestProviderHandleTokenRequest(object):
282 | def create_authz_code(self, extra_auth_req_params=None):
283 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com')
284 |
285 | if extra_auth_req_params:
286 | self.authn_request_args.update(extra_auth_req_params)
287 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
288 | return self.provider.authz_state.create_authorization_code(auth_req, sub)
289 |
290 | def create_refresh_token(self):
291 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com')
292 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
293 | access_token = self.provider.authz_state.create_access_token(auth_req, sub)
294 | return self.provider.authz_state.create_refresh_token(access_token.value)
295 |
296 | @pytest.fixture(autouse=True)
297 | def create_token_request_args(self):
298 | self.authorization_code_exchange_request_args = {
299 | 'grant_type': 'authorization_code',
300 | 'code': None,
301 | 'redirect_uri': 'https://client.example.com',
302 | 'client_id': TEST_CLIENT_ID,
303 | 'client_secret': TEST_CLIENT_SECRET
304 | }
305 |
306 | self.refresh_token_request_args = {
307 | 'grant_type': 'refresh_token',
308 | 'refresh_token': None,
309 | 'scope': 'openid',
310 | 'client_id': TEST_CLIENT_ID,
311 | 'client_secret': TEST_CLIENT_SECRET
312 | }
313 |
314 | @patch('time.time', MOCK_TIME)
315 | def test_code_exchange_request(self):
316 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
317 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None)
318 | assert response['access_token'] in self.provider.authz_state.access_tokens
319 | assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
320 | self.authn_request_args)
321 |
322 | @patch('time.time', MOCK_TIME)
323 | def test_pkce_code_exchange_request(self):
324 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
325 | {
326 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw",
327 | "code_challenge_method": "S256"
328 | }
329 | )
330 | self.authorization_code_exchange_request_args['code_verifier'] = "SoOEDN-mZKNhw7Mc52VXxyiqTvFB3mod36MwPru253c"
331 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None)
332 | assert response['access_token'] in self.provider.authz_state.access_tokens
333 | assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
334 | self.authn_request_args)
335 |
336 | @patch('time.time', MOCK_TIME)
337 | def test_code_exchange_request_with_claims_requested_in_id_token(self):
338 | claims_req = {'claims': ClaimsRequest(id_token=Claims(email=None))}
339 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(extra_auth_req_params=claims_req)
340 | response = self.provider._do_code_exchange(self.authorization_code_exchange_request_args, None)
341 | assert response['access_token'] in self.provider.authz_state.access_tokens
342 | id_token = assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
343 | self.authn_request_args)
344 | assert id_token['email'] == self.provider.userinfo[TEST_USER_ID]['email']
345 |
346 | @patch('time.time', MOCK_TIME)
347 | @pytest.mark.parametrize('extra_claims', [
348 | {'foo': 'bar'},
349 | lambda user_id, client_id: {'foo': 'bar'}
350 | ])
351 | def test_handle_token_request_with_extra_id_token_claims(self, extra_claims):
352 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
353 | response = self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args),
354 | extra_id_token_claims=extra_claims)
355 | assert response['access_token'] in self.provider.authz_state.access_tokens
356 | id_token = assert_id_token_base_claims(response['id_token'], self.provider.signing_key, self.provider,
357 | self.authn_request_args)
358 | assert id_token['foo'] == 'bar'
359 |
360 | def test_handle_token_request_reject_invalid_client_authentication(self):
361 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
362 | self.authorization_code_exchange_request_args['client_secret'] = 'invalid'
363 | with pytest.raises(InvalidClientAuthentication):
364 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args),
365 | extra_id_token_claims={'foo': 'bar'})
366 |
367 | def test_handle_token_request_reject_code_not_issued_to_client(self):
368 | self.authorization_code_exchange_request_args['client_id'] = INVALID_TEST_CLIENT_ID
369 | self.authorization_code_exchange_request_args['client_secret'] = INVALID_TEST_CLIENT_SECRET
370 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
371 | with pytest.raises(InvalidAuthorizationCode):
372 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
373 |
374 | def test_handle_token_request_reject_invalid_redirect_uri_in_exchange_request(self):
375 | self.authorization_code_exchange_request_args['redirect_uri'] = 'https://invalid.com'
376 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
377 | with pytest.raises(InvalidTokenRequest):
378 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
379 |
380 | def test_handle_token_request_reject_invalid_grant_type(self):
381 | self.authorization_code_exchange_request_args['grant_type'] = 'invalid'
382 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
383 | with pytest.raises(InvalidTokenRequest):
384 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
385 |
386 | def test_handle_token_request_reject_missing_grant_type(self):
387 | del self.authorization_code_exchange_request_args['grant_type']
388 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code()
389 | with pytest.raises(InvalidTokenRequest):
390 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
391 |
392 | def test_handle_token_request_reject_invalid_code_verifier(self):
393 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
394 | {
395 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
396 | "code_challenge_method": "S256"
397 | }
398 | )
399 | self.authorization_code_exchange_request_args['code_verifier'] = "ThiS Cer_tainly Ain't Valid"
400 | with pytest.raises(InvalidTokenRequest):
401 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
402 |
403 | def test_handle_token_request_reject_unsynced_requests(self):
404 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
405 | {
406 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
407 | "code_challenge_method": "S256"
408 | }
409 | )
410 | with pytest.raises(InvalidTokenRequest):
411 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
412 |
413 | def test_handle_token_request_reject_missing_code_challenge_method(self):
414 | self.authorization_code_exchange_request_args['code'] = self.create_authz_code(
415 | {
416 | "code_challenge": "_1f8tFjAtu6D1Df-GOyDPoMjCJdEvaSWsnqR6SLpzsw=",
417 | }
418 | )
419 | with pytest.raises(InvalidTokenRequest):
420 | self.provider.handle_token_request(urlencode(self.authorization_code_exchange_request_args))
421 |
422 | def test_refresh_request(self):
423 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'),
424 | refresh_token_lifetime=600)
425 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token()
426 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args))
427 | assert response['access_token'] in self.provider.authz_state.access_tokens
428 | assert 'refresh_token' not in response
429 |
430 | def test_refresh_request_with_refresh_token_close_to_expiry_issues_new_refresh_token(self):
431 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'),
432 | refresh_token_lifetime=10,
433 | refresh_token_threshold=2)
434 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token()
435 |
436 | close_to_expiration = int(time.time()) + self.provider.authz_state.refresh_token_lifetime - 1
437 | with patch('time.time', Mock(return_value=close_to_expiration)):
438 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args))
439 | assert response['access_token'] in self.provider.authz_state.access_tokens
440 | assert response['refresh_token'] in self.provider.authz_state.refresh_tokens
441 |
442 | def test_refresh_request_without_scope_parameter_defaults_to_scope_from_authentication_request(self):
443 | self.provider.authz_state = AuthorizationState(HashBasedSubjectIdentifierFactory('salt'),
444 | refresh_token_lifetime=600)
445 | self.refresh_token_request_args['refresh_token'] = self.create_refresh_token()
446 | del self.refresh_token_request_args['scope']
447 | response = self.provider.handle_token_request(urlencode(self.refresh_token_request_args))
448 | assert response['access_token'] in self.provider.authz_state.access_tokens
449 | assert self.provider.authz_state.access_tokens[response['access_token']]['scope'] == self.authn_request_args[
450 | 'scope']
451 |
452 | @pytest.mark.parametrize('missing_parameter', [
453 | 'grant_type',
454 | 'code',
455 | 'redirect_uri'
456 | ])
457 | def test_code_exchange_request_with_missing_parameter(self, missing_parameter):
458 | request_args = {
459 | 'grant_type': 'authorization_code',
460 | 'code': None,
461 | 'redirect_uri': TEST_REDIRECT_URI,
462 | }
463 | del request_args[missing_parameter]
464 | with pytest.raises(InvalidTokenRequest):
465 | self.provider._do_code_exchange(request_args)
466 |
467 | @pytest.mark.parametrize('missing_parameter', [
468 | 'grant_type',
469 | 'refresh_token',
470 | ])
471 | def test_refresh_token_request_with_missing_parameter(self, missing_parameter):
472 | request_args = {
473 | 'grant_type': 'refresh_token',
474 | 'refresh_token': None,
475 | }
476 | del request_args[missing_parameter]
477 | with pytest.raises(InvalidTokenRequest):
478 | self.provider._do_token_refresh(request_args)
479 |
480 |
481 | @pytest.mark.usefixtures('inject_provider', 'auth_req_args')
482 | class TestProviderHandleUserinfoRequest(object):
483 | def create_access_token(self, extra_auth_req_params=None):
484 | sub = self.provider.authz_state.get_subject_identifier('pairwise', TEST_USER_ID, 'client1.example.com')
485 |
486 | if extra_auth_req_params:
487 | self.authn_request_args.update(extra_auth_req_params)
488 |
489 | auth_req = AuthorizationRequest().from_dict(self.authn_request_args)
490 | access_token = self.provider.authz_state.create_access_token(auth_req, sub)
491 | return access_token.value
492 |
493 | def test_handle_userinfo(self):
494 | claims_request = ClaimsRequest(userinfo=Claims(email=None))
495 | access_token = self.create_access_token({'scope': 'openid profile', 'claims': claims_request})
496 | response = self.provider.handle_userinfo_request(urlencode({'access_token': access_token}))
497 |
498 | response_sub = response['sub']
499 | del response['sub']
500 | assert response.to_dict() == self.provider.userinfo[TEST_USER_ID]
501 | assert self.provider.authz_state.get_user_id_for_subject_identifier(response_sub) == TEST_USER_ID
502 |
503 | def test_handle_userinfo_with_custom_sub(self, monkeypatch):
504 | sub = 'test_sub1'
505 | monkeypatch.setitem(self.provider.userinfo._db[TEST_USER_ID], 'sub', sub)
506 | claims_request = ClaimsRequest(userinfo=Claims(email=None))
507 | access_token = self.create_access_token({'scope': 'openid profile', 'claims': claims_request})
508 | response = self.provider.handle_userinfo_request(urlencode({'access_token': access_token}))
509 |
510 | assert response['sub'] == sub
511 |
512 | def test_handle_userinfo_rejects_request_missing_access_token(self):
513 | with pytest.raises(BearerTokenError) as exc:
514 | self.provider.handle_userinfo_request()
515 |
516 | def test_handle_userinfo_rejects_invalid_access_token(self):
517 | access_token = self.create_access_token()
518 | self.provider.authz_state.access_tokens[access_token]['exp'] = 0
519 | with pytest.raises(InvalidAccessToken):
520 | self.provider.handle_userinfo_request(urlencode({'access_token': access_token}))
521 |
522 |
523 | @pytest.mark.usefixtures('inject_provider')
524 | class TestProviderHandleRegistrationRequest(object):
525 | def test_handle_registration_request(self):
526 | request = {'redirect_uris': ['https://client.example.com/redirect']}
527 | response = self.provider.handle_client_registration_request(json.dumps(request))
528 | assert 'client_id' in response
529 | assert 'client_id_issued_at' in response
530 | assert 'client_secret' in response
531 | assert response['client_secret_expires_at'] == 0
532 | assert all(k in response.items() for k in request.items())
533 |
534 | assert response['client_id'] in self.provider.clients
535 |
536 | def test_rejects_invalid_request(self):
537 | request = {'application_type': 'web', 'client_name': 'test client'} # missing 'redirect_uris'
538 | with pytest.raises(InvalidClientRegistrationRequest):
539 | self.provider.handle_client_registration_request(json.dumps(request))
540 |
541 | @pytest.mark.parametrize('client_preference, provider_capability, client_value, provider_value', [
542 | ('request_object_signing_alg', 'request_object_signing_alg_values_supported',
543 | 'HS256', ['none', 'RS256']),
544 | ('request_object_encryption_alg', 'request_object_encryption_alg_values_supported',
545 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']),
546 | ('request_object_encryption_enc', 'request_object_encryption_enc_values_supported',
547 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']),
548 | ('userinfo_signed_response_alg', 'userinfo_signing_alg_values_supported',
549 | 'HS256', ['none', 'RS256']),
550 | ('userinfo_encrypted_response_alg', 'userinfo_encryption_alg_values_supported',
551 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']),
552 | ('userinfo_encrypted_response_enc', 'userinfo_encryption_enc_values_supported',
553 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']),
554 | ('id_token_signed_response_alg', 'id_token_signing_alg_values_supported',
555 | 'HS256', ['none', 'RS256']),
556 | ('id_token_encrypted_response_alg', 'id_token_encryption_alg_values_supported',
557 | 'RSA-OAEP-256', ['RSA-OAEP', 'ECDH-ES']),
558 | ('id_token_encrypted_response_enc', 'id_token_encryption_enc_values_supported',
559 | 'A192CBC-HS384', ['A128CBC-HS256', 'A256CBC-HS512 ']),
560 | ('default_acr_values', 'acr_values_supported',
561 | ['1', '2'], ['3', '4']),
562 | ('subject_type', 'subject_types_supported',
563 | 'public', ['pairwise']),
564 | ('token_endpoint_auth_method', 'token_endpoint_auth_methods_supported',
565 | 'private_key_jwt', ['client_secret_post', 'client_secret_basic']),
566 | ('token_endpoint_auth_signing_alg', 'token_endpoint_auth_signing_alg_values_supported',
567 | 'HS256', ['none', 'RS256']),
568 | ('response_types', 'response_types_supported',
569 | ['id_token token'], ['code', 'code token']),
570 | ('grant_types', 'grant_types_supported',
571 | 'implicit', ['authorization_code'])
572 | ])
573 | def test_rejects_mismatching_request(self, client_preference, provider_capability, client_value, provider_value):
574 | request = {'redirect_uris': ['https://client.example.com/redirect']}
575 | provider_capabilities = provide_configuration()
576 |
577 | if client_preference.startswith(('request_object_encryption', 'id_token_encrypted', 'userinfo_encrypted')):
578 | # provide default value for the metadata params that come in pairs
579 | param = client_preference[:-4]
580 | alg_param = param + '_alg'
581 | request[alg_param] = 'RSA-OAEP'
582 | provider_capabilities[PREFERENCE2PROVIDER[alg_param]] = ['RSA-OAEP']
583 | enc_param = param + '_enc'
584 | request[enc_param] = 'A192CBC-HS256'
585 | provider_capabilities[PREFERENCE2PROVIDER[enc_param]] = ['A192CBC-HS256']
586 |
587 | request[client_preference] = client_value
588 | provider_capabilities[provider_capability] = provider_value
589 | provider = Provider(rsa_key(), provider_capabilities,
590 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')), {}, None)
591 | with pytest.raises(InvalidClientRegistrationRequest):
592 | provider.handle_client_registration_request(json.dumps(request))
593 |
594 | @pytest.mark.parametrize('client_preference, provider_capability, client_value, provider_value', [
595 | ('response_types', 'response_types_supported',
596 | ['code id_token token', 'code', 'id_token', 'id_token token'],
597 | ['code id_token token', 'id_token token']),
598 | ('default_acr_values', 'acr_values_supported',
599 | ['1', '2', '3'], ['2', '3', '4']),
600 | ])
601 | def test_matches_common_set_of_metadata_values(self, client_preference, provider_capability,
602 | client_value, provider_value):
603 | provider_capabilities = provide_configuration()
604 | provider_capabilities.update({provider_capability: provider_value})
605 | provider = Provider(rsa_key(), provider_capabilities,
606 | AuthorizationState(HashBasedSubjectIdentifierFactory('salt')), {}, None)
607 | request = {'redirect_uris': ['https://client.example.com/redirect'], client_preference: client_value}
608 | response = provider.handle_client_registration_request(json.dumps(request))
609 | expected_values = set(client_value).intersection(provider_value)
610 | assert Counter(frozenset(v.split()) for v in response[client_preference]) == \
611 | Counter(frozenset(v.split()) for v in expected_values)
612 |
613 | def test_match_space_separated_response_type_without_order(self):
614 | registration_request = {'redirect_uris': ['https://client.example.com/redirect'],
615 | 'response_types': ['id_token token']}
616 | # should not raise an exception
617 | assert self.provider.handle_client_registration_request(json.dumps(registration_request))
618 |
619 | def test_client_can_use_registered_space_separated_response_type_in_authentication_request(self):
620 | response_type = 'id_token token'
621 | registration_request = {'redirect_uris': ['https://client.example.com/redirect'],
622 | 'response_types': [response_type]}
623 | registration_response = self.provider.handle_client_registration_request(json.dumps(registration_request))
624 |
625 | authentication_request = {'client_id': registration_response['client_id'],
626 | 'redirect_uri': 'https://client.example.com/redirect',
627 | 'scope': 'openid',
628 | 'nonce': 'nonce',
629 | 'response_type': response_type}
630 | # should not raise an exception
631 | assert self.provider.parse_authentication_request(urlencode(authentication_request))
632 |
633 |
634 | class TestProviderProviderConfiguration(object):
635 | def test_provider_configuration(self):
636 | config = provide_configuration()
637 | config.update({'foo': 'bar', 'abc': 'xyz'})
638 | provider = Provider(None, config, None, None, None)
639 | provider_config = provider.provider_configuration
640 | assert all(k in provider_config for k in config)
641 |
642 |
643 | class TestProviderJWKS(object):
644 | def test_jwks(self):
645 | provider = Provider(rsa_key(), provide_configuration(), None, None, None)
646 | assert provider.jwks == {'keys': [provider.signing_key.serialize()]}
647 |
648 |
649 | @pytest.mark.usefixtures('inject_provider')
650 | class TestRPInitiatedLogout(object):
651 | def test_logout_user_with_subject_identifier(self):
652 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1',
653 | redirect_uri='https://client.example.com/redirect')
654 | auth_resp = self.provider.authorize(auth_req, 'user1')
655 |
656 | id_token = IdToken().from_jwt(auth_resp['id_token'], key=[self.provider.signing_key])
657 | self.provider.logout_user(subject_identifier=id_token['sub'])
658 | with pytest.raises(InvalidAccessToken):
659 | self.provider.authz_state.introspect_access_token(auth_resp['access_token'])
660 | with pytest.raises(InvalidAuthorizationCode):
661 | self.provider.authz_state.exchange_code_for_token(auth_resp['code'])
662 |
663 | def test_logout_user_with_id_token_hint(self):
664 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1',
665 | redirect_uri='https://client.example.com/redirect')
666 | auth_resp = self.provider.authorize(auth_req, 'user1')
667 |
668 | self.provider.logout_user(end_session_request=EndSessionRequest(id_token_hint=auth_resp['id_token']))
669 | with pytest.raises(InvalidAccessToken):
670 | self.provider.authz_state.introspect_access_token(auth_resp['access_token'])
671 | with pytest.raises(InvalidAuthorizationCode):
672 | self.provider.authz_state.exchange_code_for_token(auth_resp['code'])
673 |
674 | def test_logout_user_with_unknown_subject_identifier(self):
675 | with pytest.raises(InvalidSubjectIdentifier):
676 | self.provider.logout_user(subject_identifier='unknown')
677 |
678 | def test_post_logout_redirect(self):
679 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1',
680 | redirect_uri='https://client.example.com/redirect')
681 | auth_resp = self.provider.authorize(auth_req, 'user1')
682 | end_session_request = EndSessionRequest(id_token_hint=auth_resp['id_token'],
683 | post_logout_redirect_uri='https://client.example.com/post_logout',
684 | state='state')
685 | redirect_url = self.provider.do_post_logout_redirect(end_session_request)
686 | assert redirect_url == EndSessionResponse(state='state').request('https://client.example.com/post_logout')
687 |
688 | def test_post_logout_redirect_without_post_logout_redirect_uri(self):
689 | assert self.provider.do_post_logout_redirect(EndSessionRequest()) is None
690 |
691 | def test_post_logout_redirect_with_unknown_client_for_post_logout_redirect_uri(self):
692 | end_session_request = EndSessionRequest(post_logout_redirect_uri='https://client.example.com/post_logout')
693 | assert self.provider.do_post_logout_redirect(end_session_request) is None
694 |
695 | def test_post_logout_redirect_with_unknown_post_logout_redirect_uri(self):
696 | auth_req = AuthorizationRequest(response_type='code id_token token', scope='openid', client_id='client1',
697 | redirect_uri='https://client.example.com/redirect')
698 | auth_resp = self.provider.authorize(auth_req, 'user1')
699 | end_session_request = EndSessionRequest(id_token_hint=auth_resp['id_token'],
700 | post_logout_redirect_uri='https://client.example.com/unknown')
701 | assert self.provider.do_post_logout_redirect(end_session_request) is None
702 |
--------------------------------------------------------------------------------
/tests/pyop/test_storage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import pytest
4 |
5 | from abc import ABC, abstractmethod
6 | from contextlib import contextmanager
7 | from redis.client import Redis
8 | import datetime
9 | import fakeredis
10 | import mongomock
11 | import pymongo
12 | import time
13 |
14 | import pyop.storage
15 |
16 | __author__ = 'lundberg'
17 |
18 |
19 | db_specs_list = [
20 | {"uri": "mongodb://localhost:1234/pyop", "name": "pyop"},
21 | {"uri": "redis://localhost/0", "name": 0},
22 | ]
23 |
24 |
25 | @pytest.fixture(autouse=True)
26 | def mock_redis(monkeypatch):
27 | def mockreturn(*args, **kwargs):
28 | return fakeredis.FakeStrictRedis(*args, **kwargs)
29 | monkeypatch.setattr(Redis, "from_url", mockreturn)
30 |
31 | @pytest.fixture(autouse=True)
32 | def mock_mongo():
33 | pymongo.MongoClient = mongomock.MongoClient
34 |
35 |
36 | class TestStorage(object):
37 | @pytest.fixture(params=db_specs_list)
38 | def db(self, request):
39 | return pyop.storage.StorageBase.from_uri(
40 | request.param["uri"], db_name=request.param["name"], collection="test"
41 | )
42 |
43 | def test_write(self, db):
44 | db['foo'] = 'bar'
45 | assert db['foo'] == 'bar'
46 |
47 | def test_multilevel_dict(self, db):
48 | db['foo'] = {}
49 | assert db['foo'] == {}
50 | db['foo'] = {'bar': 'baz'}
51 | assert db['foo']['bar'] == 'baz'
52 |
53 | def test_contains(self, db):
54 | db['foo'] = 'bar'
55 | assert 'foo' in db
56 |
57 | def test_pop(self, db):
58 | db['foo'] = 'bar'
59 | assert db.pop('foo') == 'bar'
60 | try:
61 | db['foo']
62 | except Exception as e:
63 | assert isinstance(e, KeyError)
64 |
65 | def test_items(self, db):
66 | db['foo'] = 'foorbar'
67 | db['bar'] = True
68 | db['baz'] = {'foo': 'bar'}
69 | for key, item in db.items():
70 | assert key
71 | assert item
72 |
73 | @pytest.mark.parametrize(
74 | "args,kwargs",
75 | [
76 | (["mongodb://localhost/pyop"], {"collection": "test", "ttl": 3}),
77 | (["mongodb://localhost"], {"db_name": "pyop", "collection": "test"}),
78 | (["mongodb://localhost", "test", "pyop"], {}),
79 | (["mongodb://localhost/pyop", "test"], {}),
80 | (["mongodb://localhost/pyop"], {"db_name": "other", "collection": "test"}),
81 | (["redis://localhost"], {"collection": "test"}),
82 | (["redis://localhost", "test"], {}),
83 | (["redis://localhost"], {"db_name": 2, "collection": "test"}),
84 | (["unix://localhost/0"], {"collection": "test", "ttl": 3}),
85 | ],
86 | )
87 | def test_from_uri(self, args, kwargs):
88 | store = pyop.storage.StorageBase.from_uri(*args, **kwargs)
89 | store["test"] = "value"
90 | assert store["test"] == "value"
91 |
92 | @pytest.mark.parametrize(
93 | "error,args,kwargs",
94 | [
95 | (ValueError, ["mongodb://localhost"], {"collection": "test", "ttl": None}),
96 | (
97 | TypeError,
98 | ["mongodb://localhost", "ouch"],
99 | {"db_name": 3, "collection": "test", "ttl": None},
100 | ),
101 | (
102 | TypeError,
103 | ["mongodb://localhost", "ouch"],
104 | {"db_name": "pyop", "collection": "test", "ttl": None},
105 | ),
106 | (
107 | TypeError,
108 | ["mongodb://localhost", "pyop"],
109 | {"collection": "test", "ttl": None},
110 | ),
111 | (
112 | TypeError,
113 | ["redis://localhost", "ouch"],
114 | {"db_name": 3, "collection": "test", "ttl": None},
115 | ),
116 | (TypeError, ["redis://localhost/0"], {}),
117 | (TypeError, ["redis://localhost/0"], {"db_name": "pyop"}),
118 | ],
119 | )
120 | def test_from_uri_invalid_parameters(self, error, args, kwargs):
121 | with pytest.raises(error):
122 | pyop.storage.StorageBase.from_uri(*args, **kwargs)
123 |
124 |
125 | class StorageTTLTest(ABC):
126 | def prepare_db(self, uri, ttl):
127 | self.db = pyop.storage.StorageBase.from_uri(
128 | uri,
129 | collection="test",
130 | ttl=ttl,
131 | )
132 | self.db["foo"] = {"bar": "baz"}
133 |
134 | @abstractmethod
135 | def set_time(self, offset, monkey):
136 | pass
137 |
138 | @contextmanager
139 | def adjust_time(self, offset):
140 | mp = pytest.MonkeyPatch()
141 | try:
142 | yield self.set_time(offset, mp)
143 | finally:
144 | mp.undo()
145 |
146 | def execute_ttl_test(self, uri, ttl):
147 | self.prepare_db(uri, ttl)
148 | assert self.db["foo"]
149 | with self.adjust_time(offset=int(ttl / 2)):
150 | assert self.db["foo"]
151 | with self.adjust_time(offset=int(ttl * 2)):
152 | with pytest.raises(KeyError):
153 | self.db["foo"]
154 |
155 | @pytest.mark.parametrize("spec", db_specs_list)
156 | @pytest.mark.parametrize("ttl", ["invalid", -1, 2.3, {}])
157 | def test_invalid_ttl(self, spec, ttl):
158 | with pytest.raises(ValueError):
159 | self.prepare_db(spec["uri"], ttl)
160 |
161 |
162 | class TestRedisTTL(StorageTTLTest):
163 | def set_time(self, offset, monkeypatch):
164 | now = time.time()
165 | def new_time():
166 | return now + offset
167 |
168 | monkeypatch.setattr(time, "time", new_time)
169 |
170 | def test_ttl(self):
171 | self.execute_ttl_test("redis://localhost/0", 3600)
172 |
173 | def test_missing_module(self):
174 | pyop.storage._has_redis = False
175 | self.prepare_db("mongodb://localhost/0", None)
176 | with pytest.raises(ImportError):
177 | self.prepare_db("redis://localhost/0", None)
178 | pyop.storage._has_redis = True
179 |
180 |
181 | class TestMongoTTL(StorageTTLTest):
182 | def set_time(self, offset, monkeypatch):
183 | now = datetime.datetime.utcnow()
184 | def new_time():
185 | return now + datetime.timedelta(seconds=offset)
186 |
187 | monkeypatch.setattr(mongomock, "utcnow", new_time)
188 |
189 | def test_ttl(self):
190 | self.execute_ttl_test("mongodb://localhost/pyop", 3600)
191 |
192 | def test_missing_module(self):
193 | pyop.storage._has_pymongo = False
194 | self.prepare_db("redis://localhost/0", None)
195 | with pytest.raises(ImportError):
196 | self.prepare_db("mongodb://localhost/0", None)
197 | pyop.storage._has_pymongo = True
198 |
199 |
200 | class TestStatelessWrapper(object):
201 | @pytest.fixture
202 | def db(self):
203 | return pyop.storage.StatelessWrapper("pyop", "abc123")
204 |
205 | def test_write(self, db):
206 | db['foo'] = 'bar'
207 | assert db['foo'] is None
208 |
209 | def test_pack_and_unpack(self, db):
210 | val_1 = {'foo': 'bar'}
211 | key = db.pack(val_1)
212 | val_2 = db[key]
213 | assert val_1 == val_2
214 |
215 | def test_pack_with_non_dict_val(self, db):
216 | val_1 = 'this is not a dict'
217 | key = db.pack(val_1)
218 | val_2 = db[key]
219 | assert val_1 == val_2
220 |
221 | def test_contains(self, db):
222 | val_1 = {'foo': 'bar'}
223 | key = db.pack(val_1)
224 | assert key in db
225 |
226 | def test_items(self, db):
227 | with pytest.raises(NotImplementedError):
228 | db['foo'] = 'bar'
229 | db.items()
230 |
231 | def test_delitem(self, db):
232 | with pytest.raises(NotImplementedError):
233 | db['foo'] = 'bar'
234 | del db['foo']
--------------------------------------------------------------------------------
/tests/pyop/test_util.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pyop.message import AuthorizationRequest
4 | from pyop.util import should_fragment_encode
5 |
6 |
7 | class TestShouldFragmentEncode(object):
8 | @pytest.mark.parametrize('response_type, expected', [
9 | ('code', False),
10 | ('id_token', True),
11 | ('id_token token', True),
12 | ('code id_token', True),
13 | ('code token', True),
14 | ('code id_token token', True),
15 | ])
16 | def test_by_response_type(self, response_type, expected):
17 | auth_req = {'response_type': response_type}
18 | assert should_fragment_encode(AuthorizationRequest(**auth_req)) is expected
19 |
--------------------------------------------------------------------------------
/tests/test_requirements.txt:
--------------------------------------------------------------------------------
1 | pytest >= 6.2
2 | pip >= 19.0
3 | responses
4 | pycryptodomex
5 | fakeredis
6 | mongomock
7 | pymongo
8 | redis
9 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py34, py35
3 |
4 | [testenv]
5 | commands = py.test -rs -vvv tests/
6 | deps = -rtests/test_requirements.txt
7 |
--------------------------------------------------------------------------------