├── .flake8 ├── .github ├── CODEOWNERS └── workflows │ ├── pypi-upload.yml │ └── test.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── server │ ├── auth0.py │ ├── cognito.py │ └── firebase.py └── src │ └── authorize_in_doc.jpg ├── fastapi_cloudauth ├── __init__.py ├── auth0.py ├── base.py ├── cognito.py ├── firebase.py ├── messages.py └── verification.py ├── mypy.ini ├── pyproject.toml ├── pytest.ini ├── scripts ├── dep.sh ├── develop.sh ├── format-imports.sh ├── format.sh ├── lint.sh ├── test.sh └── test_local.sh └── tests ├── __init__.py ├── conftest.py ├── helpers.py ├── test_auth0.py ├── test_base.py ├── test_cloudauth.py ├── test_cognito.py ├── test_firebase.py └── test_verification.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = C,E,F,W,B,B9 4 | ignore = E203, E501, W503 5 | exclude = __init__.py -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is a comment. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # In this example, @tokusumi owns any files in the /.github/ and /scripts/test.sh 5 | # directory at the root of the repository and any of its 6 | # subdirectories. 7 | /.github/ @tokusumi 8 | /scripts/test.sh @tokusumi 9 | -------------------------------------------------------------------------------- /.github/workflows/pypi-upload.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPi 2 | on: 3 | release: 4 | types: [ created ] 5 | 6 | jobs: 7 | upload: 8 | name: upload 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Setup Python 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: '3.7' 17 | - uses: Gr1N/setup-poetry@v4 18 | with: 19 | poetry-version: 1.1.12 20 | - run: poetry run python -m pip install --upgrade pip 21 | - run: poetry install 22 | - run: poetry publish --build 23 | env: 24 | POETRY_PYPI_TOKEN_PYPI: ${{secrets.PYPI_TOKEN}} 25 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | branches: [ master ] 5 | pull_request: 6 | branches: [ master ] 7 | 8 | jobs: 9 | pytest: 10 | name: pytest 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.6, 3.7, 3.8, 3.9] 15 | poetry-version: [1.1.12] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Setup Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - uses: Gr1N/setup-poetry@v4 24 | with: 25 | poetry-version: ${{ matrix.poetry-version }} 26 | - run: poetry install 27 | - run: poetry run bash scripts/test.sh 28 | env: 29 | AUTH0_DOMAIN: ${{ secrets.AUTH0_DOMAIN }} 30 | AUTH0_CLIENTID: ${{ secrets.AUTH0_CLIENTID }} 31 | AUTH0_CLIENT_SECRET: ${{ secrets.AUTH0_CLIENT_SECRET }} 32 | AUTH0_MGMT_CLIENTID: ${{ secrets.AUTH0_MGMT_CLIENTID }} 33 | AUTH0_MGMT_CLIENT_SECRET: ${{ secrets.AUTH0_MGMT_CLIENT_SECRET }} 34 | AUTH0_AUDIENCE: ${{ secrets.AUTH0_AUDIENCE }} 35 | AUTH0_CONNECTION: ${{ secrets.AUTH0_CONNECTION }} 36 | COGNITO_REGION: ${{ secrets.COGNITO_REGION }} 37 | COGNITO_USERPOOLID: ${{ secrets.COGNITO_USERPOOLID }} 38 | COGNITO_APP_CLIENT_ID: ${{ secrets.COGNITO_APP_CLIENT_ID }} 39 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 40 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 41 | FIREBASE_PROJECTID: ${{ secrets.FIREBASE_PROJECTID }} 42 | FIREBASE_APIKEY: ${{ secrets.FIREBASE_APIKEY }} 43 | FIREBASE_BASE64_CREDENCIALS: ${{ secrets.FIREBASE_BASE64_CREDENCIALS }} 44 | 45 | - name: Upload coverage to Codecov 46 | uses: codecov/codecov-action@v2 47 | with: 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | file: ./coverage.xml 50 | flags: unittests 51 | name: codecov-umbrella 52 | fail_ci_if_error: true 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .venv 3 | env 4 | .vscode 5 | .mypy_cache 6 | .pytest_cache 7 | dist 8 | local 9 | poetry.lock 10 | *.egg-info 11 | .coverage 12 | scripts/load_env.sh 13 | credentials/* 14 | base64-credential -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Development Guideline 2 | 3 | ## Setup the development environment 4 | 5 | At first, you need to clone "FastAPI-CloudAuth" repository. 6 | 7 | Notice the working directory at the following descriptions is a root of cloned directory. 8 | 9 | ### Poetry 10 | 11 | "FastAPI-CloudAuth" uses [Poetry](https://github.com/python-poetry/poetry) to create and activate virtual environment for developments (used also for build and publish to Pypi, else). 12 | 13 | If "Poetry" is not available in your development environment, you can see at [official document](https://python-poetry.org/docs/) to install `poetry`. 14 | 15 | You can confirm to install "Poetry" successfully as follows: 16 | 17 | ``` 18 | $ poetry --version 19 | Poetry version 1.0.8 20 | ``` 21 | 22 | And now use `poetry` to create virtual environment and install the development dependencies: 23 | 24 | ``` 25 | $ poetry install 26 | ``` 27 | 28 | ## Testing 29 | 30 | You can use shortcut script for unit-test as follows: 31 | 32 | ``` 33 | $ poetry run bash scripts/test_local.sh -m unittest 34 | ``` 35 | 36 | But the most of "FastAPI-CloudAuth" testing is integration testing with cloud authentication services, which additional setup is required. 37 | 38 | NOTE: Additional setup requires sensitive data like credential. we strongly recommend that make sure to restrict permissions and understand what our test code actually do. 39 | 40 | Please create `load_env.sh` to help to load your environment variables for additional integration testing: 41 | 42 | ``` 43 | $ touch scripts/load_env.sh 44 | $ echo '#!/usr/bin/bash' > scripts/load_env.sh 45 | ``` 46 | 47 | ### AWS Cognito 48 | 49 | the following values are required: 50 | 51 | * Region: Region code. ex) `us-east-1`, `ap-northeast-1`, ... 52 | * Pool Id: Unique ID of your AWS Cognito User pool. ex) `region_9digit-hash` 53 | * App client id: Unique ID of App client to access your user pool. ex) `26digit-hash` 54 | * Access key ID: Required to programmatic calls to AWS from the AWS SDK, AWS CLI, etc. ex) `20digit-hash` 55 | * Secret access key: Use it with Access key ID. ex) `40digit-hash` 56 | 57 | above values are used to: 58 | 59 | * Create test user tempolarily in AWS Cognito user pool. 60 | * Get access/id token of created test user by AWS Cognito. 61 | * Delete test user after testing. 62 | 63 | #### AWS Access Key 64 | 65 | Notice that "FastAPI-CloudAuth" uses AWS SDK for Python3 ("[boto3](https://aws.amazon.com/sdk-for-python/)") for managements of test user. This requires AWS Access Key. 66 | 67 | Please read [How do I an create AWS Access key](https://aws.amazon.com/premiumsupport/knowledge-center/create-access-key/) and acquire valid "Access key ID" and "Secret access key". 68 | 69 | #### Create your user pool 70 | 71 | Go to [AWS Cognito](https://console.aws.amazon.com/cognito/users/). 72 | 73 | Button "Create a user pool", and setup with: 74 | 75 | * Required attributes: none 76 | * Username attributes: email 77 | * Email Delivery through Amazon SES: No 78 | * App clients: create new one for testing. when you create it, : 79 | * Turn off "Generate client secret". 80 | * Turn on "ALLOW_ADMIN_USER_PASSWORD_AUTH" in "Auth Flows Configuration". 81 | * and others are default 82 | 83 | Then, if your user pool is created successfully, it shows "Pool Id". 84 | 85 | Click "App clients" in General settings on the left side, it shows "App client id". 86 | 87 | #### Testing for AWS Cognito 88 | 89 | Add these values in `load_env.sh` as follows (replace \ with your value acquired above): 90 | 91 | ``` 92 | export COGNITO_REGION= 93 | export COGNITO_USERPOOLID= 94 | export COGNITO_APP_CLIENT_ID= 95 | export AWS_ACCESS_KEY_ID= 96 | export AWS_SECRET_ACCESS_KEY= 97 | ``` 98 | 99 | Finally, you can run testing for only "AWS Cognito" as follows: 100 | 101 | ``` 102 | $ poetry run bash scripts/test_local.sh -m cognito 103 | ``` 104 | 105 | ### Auth0 106 | 107 | the following values are required: 108 | 109 | * Domain: domain of `Default App` 110 | * Client ID: client id of `Default App` 111 | * Client Secret: client secret of `Default App` 112 | * Management Client ID: client ID of custom application authorized with management API 113 | * Management Client Secret: client secret of custom application authorized with management API 114 | * Identifier: The identifier (audience) of any custom dummy API 115 | 116 | above values are used to: 117 | 118 | * Create test user tempolarily in Auth0. 119 | * Get id token of created test user by Auth0. 120 | * Delete test user after testing. 121 | 122 | #### Setup Auth0 123 | 124 | At first, you need to sign-up/log-in [Auth0](https://auth0.com/) 125 | 126 | Button user icon at the top of right side, and type "Username-Password-Authentication" into `Tenant Settings`>`General`>`API Authorization Settings`>`Default Directory` and save it. 127 | 128 | Next, goes to `Default App` settings from Applications at side bar, click "Show Advanced Settings" and turn on `Grant Types`>`password` and saved changes. It shows "Domain", "Client ID" and "Client Secret" there. 129 | 130 | Next, create new application from application page at side bar with: 131 | 132 | * Enter any name (noted as "Management APP" here) 133 | * Choose `Machine to Machine Applications` as application's type 134 | * Select "Auth0 Management API" as authorized API with "read:users", "update:users", "delete:users", "create:users" scopes 135 | 136 | Created application settings page shows "Client ID" and "Client Secret", they are used as "Management Client ID" and "Management Client Secret". 137 | 138 | At last, goes to APIs at side bar to create new API with: 139 | 140 | * Enter any name (ex: "Dummy API") 141 | * Type any identifier, URL is recommended (ex: "https://dummy-api/") 142 | 143 | After successfully created, it shows "Identifier" at the just bottom of API name (same as identifier you typed). 144 | 145 | And changes/add as follows: 146 | 147 | * In `Settings`>`RBAC Settings`, turn on `Enable RBAC` and `Add Permissions in the Access Token` and save it. 148 | * In `Permissions`, add new scope "read:test" and "write:test" (add something in descrition). 149 | 150 | #### Testing for Auth0 151 | 152 | Add these values in `load_env.sh` as follows (replace \ with your value acquired above): 153 | 154 | ``` 155 | export AUTH0_DOMAIN= 156 | export AUTH0_CLIENTID= 157 | export AUTH0_CLIENT_SECRET= 158 | export AUTH0_MGMT_CLIENTID= 159 | export AUTH0_MGMT_CLIENT_SECRET= 160 | export AUTH0_AUDIENCE= 161 | ``` 162 | 163 | Then, you can run testing for only "Auth0" as follows: 164 | 165 | ``` 166 | $ poetry run bash scripts/test_local.sh -m auth0 167 | ``` 168 | 169 | ### Firebase 170 | 171 | the following values are required: 172 | 173 | * Firebase project ID: the unique identifier for your Firebase project, which can be found in the URL of that project's console. 174 | * Web API key: Required for login your Firebase Authentication service with http request (for getting id token). ex) `39digit-hash` 175 | * base64 encoding credential: Required for Firebase Admin SDK. ex) `base64-encoding-large-string` 176 | 177 | above values are used to: 178 | 179 | * Create test user tempolarily in Firebase Authentication. 180 | * Get id token of created test user by Firebase Authentication. 181 | * Delete test user after testing. 182 | 183 | #### Create your project 184 | 185 | Go to [Firebase](https://console.firebase.google.com/) and create new project. 186 | 187 | Then you can go to manage "Authentication" (from side bar). 188 | 189 | Button "Sign-in method" tab, and make mail/password provider able. (Notice that our testing create test user by admin permission and doesn't send verification email.) 190 | 191 | Then, go to "General" tab in project settings page and "Web API key" is listed. 192 | 193 | Notice that "FastAPI-CloudAuth" uses Firebase Admin SDK for Python3 ("[Firebase Admin Python SDK](https://firebase.google.com/docs/reference/admin/python)") for managements of test user. This requires "Google services account". 194 | 195 | Click "Service accounts" tab in project settings, and "Generate new private key" for Firebase Admin SDK, then downloading json file starts. 196 | 197 | Make sure to download credential json file (here noted filename as `service-cred.json`) and base64 encoding it as following command: 198 | 199 | ``` 200 | $ cat service-cred.json | base64 -w 0 > base64-credential 201 | ``` 202 | 203 | The string in `base64-credential` is "base64 encoding credential" 204 | 205 | #### Testing for Firebase Authentication 206 | 207 | Add these values in `load_env.sh` as follows (replace \ with your value acquired above): 208 | 209 | ``` 210 | export FIREBASE_PROJECTID= 211 | export FIREBASE_APIKEY= 212 | export FIREBASE_BASE64_CREDENCIALS= 213 | ``` 214 | 215 | Then, you can run testing for only "Firebase Authentication" as follows: 216 | 217 | ``` 218 | $ poetry run bash scripts/test_local.sh -m firebase 219 | ``` 220 | 221 | ### Tesing all at once 222 | 223 | Here you can run all testing at one line as follows: 224 | 225 | ``` 226 | $ poetry run bash scripts/test_local.sh 227 | ``` 228 | 229 | ## GitHub Actions 230 | 231 | If you follow above setup, you would be able to run GitHub Action in your fork repository. 232 | 233 | At first, fork "FastAPI-CloueAuth". add all values into GitHub Secrets with same key-value pairs. 234 | 235 | When you commit at your forked "master" branch or pull request into "master" branch, workflows runs. 236 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tomoro Tokusumi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI Cloud Auth 2 | 3 | ![Tests](https://github.com/tokusumi/fastapi-cloudauth/workflows/Tests/badge.svg) 4 | [![codecov](https://codecov.io/gh/tokusumi/fastapi-cloudauth/branch/master/graph/badge.svg)](https://codecov.io/gh/tokusumi/fastapi-cloudauth) 5 | [![PyPI version](https://badge.fury.io/py/fastapi-cloudauth.svg)](https://badge.fury.io/py/fastapi-cloudauth) 6 | 7 | fastapi-cloudauth standardizes and simplifies the integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication). 8 | 9 | ## Features 10 | 11 | * [X] Verify access/id token: standard JWT validation (signature, expiration), token audience claims, etc. 12 | * [X] Verify permissions based on scope (or groups) within access token and extract user info 13 | * [X] Get the detail of login user info (name, email, etc.) within ID token 14 | * [X] Dependency injection for verification/getting user, powered by [FastAPI](https://github.com/tiangolo/fastapi) 15 | * [X] Support for: 16 | * [X] [AWS Cognito](https://aws.amazon.com/jp/cognito/) 17 | * [X] [Auth0](https://auth0.com/jp/) 18 | * [x] [Firebase Auth](https://firebase.google.com/docs/auth) (Only ID token) 19 | 20 | ## Requirements 21 | 22 | Python 3.6+ 23 | 24 | ## Install 25 | 26 | ```console 27 | $ pip install fastapi-cloudauth 28 | ``` 29 | 30 | ## Example (AWS Cognito) 31 | 32 | ### Pre-requirements 33 | 34 | * Check `region`, `userPoolID` and `AppClientID` of AWS Cognito that you manage to 35 | * Create a user's assigned `read:users` permission in AWS Cognito 36 | * Get Access/ID token for the created user 37 | 38 | NOTE: access token is valid for verification, scope-based authentication, and getting user info (optional). ID token is valid for verification and getting full user info from claims. 39 | 40 | ### Create it 41 | 42 | Create a *main.py* file with the following content: 43 | 44 | ```python3 45 | import os 46 | from pydantic import BaseModel 47 | from fastapi import FastAPI, Depends 48 | from fastapi_cloudauth.cognito import Cognito, CognitoCurrentUser, CognitoClaims 49 | 50 | app = FastAPI() 51 | auth = Cognito( 52 | region=os.environ["REGION"], 53 | userPoolId=os.environ["USERPOOLID"], 54 | client_id=os.environ["APPCLIENTID"] 55 | ) 56 | 57 | @app.get("/", dependencies=[Depends(auth.scope(["read:users"]))]) 58 | def secure(): 59 | # access token is valid 60 | return "Hello" 61 | 62 | 63 | class AccessUser(BaseModel): 64 | sub: str 65 | 66 | 67 | @app.get("/access/") 68 | def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): 69 | # access token is valid and getting user info from access token 70 | return f"Hello", {current_user.sub} 71 | 72 | 73 | get_current_user = CognitoCurrentUser( 74 | region=os.environ["REGION"], 75 | userPoolId=os.environ["USERPOOLID"], 76 | client_id=os.environ["APPCLIENTID"] 77 | ) 78 | 79 | 80 | @app.get("/user/") 81 | def secure_user(current_user: CognitoClaims = Depends(get_current_user)): 82 | # ID token is valid and getting user info from ID token 83 | return f"Hello, {current_user.username}" 84 | ``` 85 | 86 | Run the server with: 87 | 88 | ```console 89 | $ uvicorn main:app 90 | 91 | INFO: Started server process [15332] 92 | INFO: Waiting for application startup. 93 | INFO: Application startup complete. 94 | INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) 95 | ``` 96 | 97 | ### Interactive API Doc 98 | 99 | Go to http://127.0.0.1:8000/docs. 100 | 101 | You will see the automatic interactive API documentation (provided by Swagger UI). 102 | 103 | `Authorize` :unlock: button can be available at the endpoint's injected dependency. 104 | 105 | You can supply a token and try the endpoint interactively. 106 | 107 | ![Swagger UI](https://raw.githubusercontent.com/tokusumi/fastapi-cloudauth/master/docs/src/authorize_in_doc.jpg) 108 | 109 | 110 | ## Example (Auth0) 111 | 112 | ### Pre-requirement 113 | 114 | * Check `domain`, `customAPI` (Audience) and `ClientID` of Auth0 that you manage to 115 | * Create a user assigned `read:users` permission in Auth0 116 | * Get Access/ID token for the created user 117 | 118 | ### Create it 119 | 120 | Create a file main.py with: 121 | 122 | ```python3 123 | import os 124 | from pydantic import BaseModel 125 | from fastapi import FastAPI, Depends 126 | from fastapi_cloudauth.auth0 import Auth0, Auth0CurrentUser, Auth0Claims 127 | 128 | app = FastAPI() 129 | 130 | auth = Auth0(domain=os.environ["DOMAIN"], customAPI=os.environ["CUSTOMAPI"]) 131 | 132 | 133 | @app.get("/", dependencies=[Depends(auth.scope(["read:users"]))]) 134 | def secure(): 135 | # access token is valid 136 | return "Hello" 137 | 138 | 139 | class AccessUser(BaseModel): 140 | sub: str 141 | 142 | 143 | @app.get("/access/") 144 | def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): 145 | # access token is valid and getting user info from access token 146 | return f"Hello", {current_user.sub} 147 | 148 | 149 | get_current_user = Auth0CurrentUser( 150 | domain=os.environ["DOMAIN"], 151 | client_id=os.environ["CLIENTID"] 152 | ) 153 | 154 | 155 | @app.get("/user/") 156 | def secure_user(current_user: Auth0Claims = Depends(get_current_user)): 157 | # ID token is valid and getting user info from ID token 158 | return f"Hello, {current_user.username}" 159 | ``` 160 | 161 | Try to run the server and see interactive UI in the same way. 162 | 163 | 164 | ## Example (Firebase Authentication) 165 | 166 | ### Pre-requirement 167 | 168 | * Create a user in Firebase Authentication and get `project ID` 169 | * Get ID token for the created user 170 | 171 | ### Create it 172 | 173 | Create a file main.py with: 174 | 175 | ```python3 176 | from fastapi import FastAPI, Depends 177 | from fastapi_cloudauth.firebase import FirebaseCurrentUser, FirebaseClaims 178 | 179 | app = FastAPI() 180 | 181 | get_current_user = FirebaseCurrentUser( 182 | project_id=os.environ["PROJECT_ID"] 183 | ) 184 | 185 | 186 | @app.get("/user/") 187 | def secure_user(current_user: FirebaseClaims = Depends(get_current_user)): 188 | # ID token is valid and getting user info from ID token 189 | return f"Hello, {current_user.user_id}" 190 | ``` 191 | 192 | Try to run the server and see the interactive UI in the same way. 193 | 194 | ## Additional User Information 195 | 196 | We can get values for the current user from access/ID token by writing a few lines. 197 | 198 | ### Custom Claims 199 | 200 | For Auth0, the ID token contains the following extra values (Ref at [Auth0 official doc](https://auth0.com/docs/tokens)): 201 | 202 | ```json 203 | { 204 | "iss": "http://YOUR_DOMAIN/", 205 | "sub": "auth0|123456", 206 | "aud": "YOUR_CLIENT_ID", 207 | "exp": 1311281970, 208 | "iat": 1311280970, 209 | "name": "Jane Doe", 210 | "given_name": "Jane", 211 | "family_name": "Doe", 212 | "gender": "female", 213 | "birthdate": "0000-10-31", 214 | "email": "janedoe@example.com", 215 | "picture": "http://example.com/janedoe/me.jpg" 216 | } 217 | ``` 218 | 219 | By default, `Auth0CurrentUser` gives `pydantic.BaseModel` object, which has `username` (name) and `email` fields. 220 | 221 | Here is sample code for extracting extra user information (adding `user_id`) from ID token: 222 | 223 | ```python3 224 | from pydantic import Field 225 | from fastapi_cloudauth.auth0 import Auth0Claims # base current user info model (inheriting `pydantic`). 226 | 227 | # extend current user info model by `pydantic`. 228 | class CustomAuth0Claims(Auth0Claims): 229 | user_id: str = Field(alias="sub") 230 | 231 | get_current_user = Auth0CurrentUser(domain=DOMAIN, client_id=CLIENTID) 232 | get_current_user.user_info = CustomAuth0Claims # override user info model with a custom one. 233 | ``` 234 | 235 | Or, we can set new custom claims as follows: 236 | 237 | ```python3 238 | get_user_detail = get_current_user.claim(CustomAuth0Claims) 239 | 240 | @app.get("/new/") 241 | async def detail(user: CustomAuth0Claims = Depends(get_user_detail)): 242 | return f"Hello, {user.user_id}" 243 | ``` 244 | 245 | ### Raw payload 246 | 247 | If you don't require `pydantic` data serialization (validation), `FastAPI-CloudAuth` has an option to extract the raw payload. 248 | 249 | All you need is: 250 | 251 | ```python3 252 | get_raw_info = get_current_user.claim(None) 253 | 254 | @app.get("/new/") 255 | async def raw_detail(user = Depends(get_raw_info)): 256 | # user has all items (ex. iss, sub, aud, exp, ... it depends on passed token) 257 | return f"Hello, {user.get('sub')}" 258 | ``` 259 | 260 | ## Additional scopes 261 | 262 | Advanced user-SCOPE verification to protect your API. 263 | 264 | Supports: 265 | 266 | - all (default): required all scopes you set 267 | - any: At least one of the configured scopes is required 268 | 269 | Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): 270 | 271 | ```python3 272 | from fastapi import Depends 273 | from fastapi_cloudauth import Operator 274 | 275 | @app.get("/", dependencies=[Depends(auth.scope(["allowned", "scopes"]))]) 276 | def api_all_scope(): 277 | return "user has 'allowned' and 'scopes' scopes" 278 | 279 | @app.get("/", dependencies=[Depends(auth.scope(["allowned", "scopes"], op=Operator._any))]) 280 | def api_any_scope(): 281 | return "user has at least one of scopes (allowned, scopes)" 282 | ``` 283 | 284 | ## Development - Contributing 285 | 286 | Please read [CONTRIBUTING](./CONTRIBUTING.md) for how to set up the development environment and testing. 287 | -------------------------------------------------------------------------------- /docs/server/auth0.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydantic import BaseModel 3 | from fastapi import FastAPI, Depends 4 | from fastapi_cloudauth.auth0 import Auth0, Auth0CurrentUser, Auth0Claims 5 | 6 | tags_metadata = [ 7 | { 8 | "name": "Auth0", 9 | "description": "Operations with access/ID token, provided by Auth0.", 10 | } 11 | ] 12 | 13 | app = FastAPI( 14 | title="FastAPI CloudAuth Project", 15 | description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", 16 | openapi_tags=tags_metadata, 17 | ) 18 | 19 | auth = Auth0(domain=os.environ["AUTH0_DOMAIN"]) 20 | 21 | 22 | @app.get("/", dependencies=[Depends(auth.scope("read:users"))], tags=["Auth0"]) 23 | def secure(): 24 | # access token is valid 25 | return "Hello" 26 | 27 | 28 | class AccessUser(BaseModel): 29 | sub: str 30 | 31 | 32 | @app.get("/access/", tags=["Auth0"]) 33 | def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): 34 | # access token is valid and getting user info from access token 35 | return f"Hello", {current_user.sub} 36 | 37 | 38 | get_current_user = Auth0CurrentUser(domain=os.environ["AUTH0_DOMAIN"]) 39 | 40 | 41 | @app.get("/user/", tags=["Auth0"]) 42 | def secure_user(current_user: Auth0Claims = Depends(get_current_user)): 43 | # ID token is valid and getting user info from ID token 44 | return f"Hello, {current_user.username}" 45 | -------------------------------------------------------------------------------- /docs/server/cognito.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydantic import BaseModel 3 | from fastapi import FastAPI, Depends 4 | from fastapi_cloudauth.cognito import Cognito, CognitoCurrentUser, CognitoClaims 5 | 6 | tags_metadata = [ 7 | { 8 | "name": "Cognito", 9 | "description": "Operations with access/ID token, provided by AWS Cognito.", 10 | } 11 | ] 12 | 13 | app = FastAPI( 14 | title="FastAPI CloudAuth Project", 15 | description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", 16 | openapi_tags=tags_metadata, 17 | ) 18 | 19 | auth = Cognito( 20 | region=os.environ["COGNITO_REGION"], userPoolId=os.environ["COGNITO_USERPOOLID"] 21 | ) 22 | 23 | 24 | @app.get("/", dependencies=[Depends(auth.scope("read:users"))], tags=["Cognito"]) 25 | def secure(): 26 | # access token is valid 27 | return "Hello" 28 | 29 | 30 | class AccessUser(BaseModel): 31 | sub: str 32 | 33 | 34 | @app.get("/access/", tags=["Cognito"]) 35 | def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): 36 | # access token is valid and getting user info from access token 37 | return f"Hello", {current_user.sub} 38 | 39 | 40 | get_current_user = CognitoCurrentUser( 41 | region=os.environ["COGNITO_REGION"], userPoolId=os.environ["COGNITO_USERPOOLID"] 42 | ) 43 | 44 | 45 | @app.get("/user/", tags=["Cognito"]) 46 | def secure_user(current_user: CognitoClaims = Depends(get_current_user)): 47 | # ID token is valid and getting user info from ID token 48 | return f"Hello, {current_user.username}" 49 | -------------------------------------------------------------------------------- /docs/server/firebase.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Depends 2 | from fastapi_cloudauth.firebase import FirebaseCurrentUser, FirebaseClaims 3 | 4 | tags_metadata = [ 5 | { 6 | "name": "Firebase", 7 | "description": "Operations with access/ID token, provided by Firebase Authentication.", 8 | } 9 | ] 10 | 11 | app = FastAPI( 12 | title="FastAPI CloudAuth Project", 13 | description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", 14 | openapi_tags=tags_metadata, 15 | ) 16 | 17 | get_current_user = FirebaseCurrentUser() 18 | 19 | 20 | @app.get("/user/", tags=["Firebase"]) 21 | def secure_user(current_user: FirebaseClaims = Depends(get_current_user)): 22 | # ID token is valid and getting user info from ID token 23 | return f"Hello, {current_user.user_id}" 24 | -------------------------------------------------------------------------------- /docs/src/authorize_in_doc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tokusumi/fastapi-cloudauth/594153706391258d80590a31e31666f260519a83/docs/src/authorize_in_doc.jpg -------------------------------------------------------------------------------- /fastapi_cloudauth/__init__.py: -------------------------------------------------------------------------------- 1 | from .auth0 import Auth0, Auth0CurrentUser 2 | from .cognito import Cognito, CognitoCurrentUser 3 | from .firebase import FirebaseCurrentUser 4 | -------------------------------------------------------------------------------- /fastapi_cloudauth/auth0.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import requests 4 | from fastapi.exceptions import HTTPException 5 | from jose import jwk 6 | from jose.backends.base import Key 7 | from pydantic import BaseModel, Field 8 | from starlette import status 9 | 10 | from .base import ScopedAuth, UserInfoAuth 11 | from .messages import NOT_VERIFIED 12 | from .verification import JWKS as BaseJWKS 13 | from .verification import ExtraVerifier 14 | 15 | 16 | def get_issuer(domain: str) -> str: 17 | url = f"https://{domain}/.well-known/openid-configuration" 18 | openid_config = requests.get(url).json() 19 | return str(openid_config.get("issuer", "")) 20 | 21 | 22 | class JWKS(BaseJWKS): 23 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 24 | return {_jwk["kid"]: jwk.construct(_jwk) for _jwk in jwks.get("keys", [])} 25 | 26 | 27 | class Auth0(ScopedAuth): 28 | """ 29 | Verify access token of auth0 30 | """ 31 | 32 | user_info = None 33 | 34 | def __init__( 35 | self, 36 | domain: str, 37 | customAPI: str, 38 | issuer: Optional[str] = None, 39 | scope_key: Optional[str] = "permissions", 40 | auto_error: bool = True, 41 | ): 42 | url = f"https://{domain}/.well-known/jwks.json" 43 | jwks = JWKS(url=url) 44 | if issuer is None: 45 | issuer = get_issuer(domain) 46 | super().__init__( 47 | jwks, 48 | audience=customAPI, 49 | issuer=issuer, 50 | scope_key=scope_key, 51 | auto_error=auto_error, 52 | extra=Auth0ExtraVerifier(), 53 | ) 54 | 55 | 56 | class Auth0Claims(BaseModel): 57 | username: str = Field(alias="name") 58 | email: str = Field(None, alias="email") 59 | 60 | 61 | class Auth0CurrentUser(UserInfoAuth): 62 | """ 63 | Verify ID token and get user info of Auth0 64 | """ 65 | 66 | user_info = Auth0Claims 67 | 68 | def __init__( 69 | self, 70 | domain: str, 71 | client_id: str, 72 | nonce: Optional[str] = None, 73 | issuer: Optional[str] = None, 74 | *args: Any, 75 | **kwargs: Any, 76 | ): 77 | url = f"https://{domain}/.well-known/jwks.json" 78 | jwks = JWKS(url=url) 79 | if issuer is None: 80 | issuer = get_issuer(domain) 81 | super().__init__( 82 | jwks, 83 | *args, 84 | user_info=self.user_info, 85 | audience=client_id, 86 | issuer=issuer, 87 | extra=Auth0ExtraVerifier(nonce=nonce), 88 | **kwargs, 89 | ) 90 | 91 | 92 | class Auth0ExtraVerifier(ExtraVerifier): 93 | def __init__(self, nonce: Optional[str] = None): 94 | self._nonce = nonce 95 | 96 | def __call__(self, claims: Dict[str, str], auto_error: bool = True) -> bool: 97 | # TODO: check the aud more 98 | 99 | # check the nonce 100 | try: 101 | nonce = claims["nonce"] 102 | if nonce != self._nonce: 103 | if auto_error: 104 | raise HTTPException( 105 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 106 | ) 107 | return False 108 | except KeyError: 109 | pass 110 | return True 111 | -------------------------------------------------------------------------------- /fastapi_cloudauth/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from copy import deepcopy 3 | from typing import Any, Dict, List, Optional, Type, Union 4 | 5 | from fastapi import Depends, HTTPException 6 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 7 | from jose import jwt 8 | from pydantic import BaseModel 9 | from pydantic.error_wrappers import ValidationError 10 | from starlette import status 11 | 12 | from fastapi_cloudauth.messages import NOT_AUTHENTICATED, NOT_VALIDATED_CLAIMS 13 | from fastapi_cloudauth.verification import ( 14 | JWKS, 15 | ExtraVerifier, 16 | JWKsVerifier, 17 | Operator, 18 | ScopedJWKsVerifier, 19 | Verifier, 20 | ) 21 | 22 | 23 | class CloudAuth(ABC): 24 | @property 25 | @abstractmethod 26 | def verifier(self) -> Verifier: 27 | """Composite Verifier class to verify jwt in HTTPAuthorizationCredentials""" 28 | ... # pragma: no cover 29 | 30 | @verifier.setter 31 | def verifier(self, instance: Verifier) -> None: 32 | ... # pragma: no cover 33 | 34 | @abstractmethod 35 | async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: 36 | """Define postprocess for verified token""" 37 | ... # pragma: no cover 38 | 39 | def clone(self, instance: "CloudAuth") -> "CloudAuth": 40 | """Create clone of this instance""" 41 | # In some case, Verifier can't pickle (deepcopy). 42 | # Tempolary put it aside to deepcopy. Then, undo it at the last line. 43 | if not isinstance(instance, CloudAuth): 44 | raise TypeError( 45 | "Only subclass of CloudAuth can be cloned" 46 | ) # pragma: no cover 47 | 48 | _verifier = instance.verifier 49 | instance.verifier = None # type: ignore 50 | clone = deepcopy(instance) 51 | clone.verifier = _verifier.clone(_verifier) 52 | instance.verifier = _verifier 53 | return clone 54 | 55 | async def __call__( 56 | self, 57 | http_auth: Optional[HTTPAuthorizationCredentials] = Depends( 58 | HTTPBearer(auto_error=False) 59 | ), 60 | ) -> Any: 61 | """User access/ID-token verification Shortcut to pass it into dependencies. 62 | Use as (where `auth` is this instance and `app` is fastapi.FastAPI instance): 63 | ``` 64 | from fastapi import Depends 65 | 66 | @app.get("/", dependencies=[Depends(auth)]) 67 | def api(): 68 | return "hello" 69 | ``` 70 | """ 71 | if http_auth is None: 72 | if self.verifier.auto_error: 73 | raise HTTPException( 74 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_AUTHENTICATED 75 | ) 76 | else: 77 | return None 78 | 79 | is_verified = await self.verifier.verify_token(http_auth) 80 | if not is_verified: 81 | return None 82 | 83 | return await self.call(http_auth) 84 | 85 | 86 | class UserInfoAuth(CloudAuth): 87 | """ 88 | Verify `ID token` and extract user information 89 | """ 90 | 91 | user_info: Optional[Type[BaseModel]] = None 92 | 93 | def __init__( 94 | self, 95 | jwks: JWKS, 96 | *, 97 | user_info: Optional[Type[BaseModel]] = None, 98 | audience: Optional[Union[str, List[str]]] = None, 99 | issuer: Optional[str] = None, 100 | auto_error: bool = True, 101 | extra: Optional[ExtraVerifier] = None, 102 | **kwargs: Any 103 | ) -> None: 104 | 105 | self.user_info = user_info 106 | self.auto_error = auto_error 107 | self._verifier = JWKsVerifier( 108 | jwks, 109 | audience=audience, 110 | issuer=issuer, 111 | auto_error=self.auto_error, 112 | extra=extra, 113 | ) 114 | 115 | @property 116 | def verifier(self) -> JWKsVerifier: 117 | return self._verifier 118 | 119 | @verifier.setter 120 | def verifier(self, verifier: JWKsVerifier) -> None: 121 | self._verifier = verifier 122 | 123 | def _clone(self) -> "UserInfoAuth": 124 | cloned = super().clone(self) 125 | if isinstance(cloned, UserInfoAuth): 126 | return cloned 127 | raise NotImplementedError # pragma: no cover 128 | 129 | def claim(self, schema: Optional[Type[BaseModel]] = None) -> "UserInfoAuth": 130 | """User verification and validation shortcut to pass it into app arguments. 131 | Use as (where `auth` is this instance and `app` is fastapi.FastAPI instance): 132 | ``` 133 | from fastapi import Depends 134 | from pydantic import BaseModel 135 | 136 | class CustomClaim(BaseModel): 137 | sub: str 138 | 139 | @app.get("/") 140 | def api(user: CustomClaim = Depends(auth.claim(CustomClaim))): 141 | return CustomClaim 142 | ``` 143 | """ 144 | clone = self._clone() 145 | clone.user_info = schema 146 | return clone 147 | 148 | async def call( 149 | self, http_auth: HTTPAuthorizationCredentials 150 | ) -> Optional[Union[BaseModel, Dict[str, Any]]]: 151 | """Get current user and verification with ID-token Shortcut. 152 | Use as (`Auth` is this subclass, `auth` is `Auth` instanse and `app` is fastapi.FastAPI instanse): 153 | ``` 154 | from fastapi import Depends 155 | 156 | @app.get("/") 157 | def api(current_user: Auth = Depends(auth)): 158 | return current_user 159 | ``` 160 | """ 161 | claims: Dict[str, Any] = jwt.get_unverified_claims(http_auth.credentials) 162 | 163 | if not self.user_info: 164 | return claims 165 | 166 | try: 167 | current_user = self.user_info.parse_obj(claims) 168 | return current_user 169 | except ValidationError: 170 | if self.auto_error: 171 | raise HTTPException( 172 | status_code=status.HTTP_401_UNAUTHORIZED, 173 | detail=NOT_VALIDATED_CLAIMS, 174 | ) 175 | else: 176 | return None 177 | 178 | 179 | class ScopedAuth(CloudAuth): 180 | """ 181 | Verify `Access token` and authorize it based on scope (or groups) 182 | """ 183 | 184 | _scope_key: Optional[str] = None 185 | user_info: Optional[Type[BaseModel]] = None 186 | 187 | def __init__( 188 | self, 189 | jwks: JWKS, 190 | audience: Optional[Union[str, List[str]]] = None, 191 | issuer: Optional[str] = None, 192 | user_info: Optional[Type[BaseModel]] = None, 193 | scope_name: Optional[List[str]] = None, 194 | scope_key: Optional[str] = None, 195 | auto_error: bool = True, 196 | op: Operator = Operator._all, 197 | extra: Optional[ExtraVerifier] = None, 198 | ): 199 | self.user_info = user_info 200 | self.auto_error = auto_error 201 | self._scope_name = scope_name 202 | if scope_key: 203 | self._scope_key = scope_key 204 | 205 | self._verifier = ScopedJWKsVerifier( 206 | jwks, 207 | audience=audience, 208 | issuer=issuer, 209 | scope_name=self._scope_name, 210 | op=op, 211 | scope_key=self._scope_key, 212 | auto_error=self.auto_error, 213 | extra=extra, 214 | ) 215 | 216 | @property 217 | def verifier(self) -> ScopedJWKsVerifier: 218 | return self._verifier 219 | 220 | @verifier.setter 221 | def verifier(self, verifier: ScopedJWKsVerifier) -> None: 222 | self._verifier = verifier 223 | 224 | @property 225 | def scope_key(self) -> Optional[str]: 226 | return self._scope_key 227 | 228 | @scope_key.setter 229 | def scope_key(self, key: Optional[str]) -> None: 230 | self._scope_key = key 231 | self._verifier.scope_key = key 232 | 233 | @property 234 | def scope_name(self) -> Optional[List[str]]: 235 | return self._scope_name 236 | 237 | @scope_name.setter 238 | def scope_name(self, name: Optional[List[str]]) -> None: 239 | self._scope_name = name 240 | self._verifier.scope_name = None if name is None else set(name) 241 | 242 | def _clone(self) -> "ScopedAuth": 243 | cloned = super().clone(self) 244 | if isinstance(cloned, ScopedAuth): 245 | return cloned 246 | raise NotImplementedError # pragma: no cover 247 | 248 | def scope( 249 | self, scope_name: Optional[Union[str, List[str]]], op: Operator = Operator._all 250 | ) -> "ScopedAuth": 251 | """User-SCOPE verification Shortcut to pass it into dependencies. 252 | Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): 253 | ``` 254 | from fastapi import Depends 255 | from fastapi_cloudauth import Operator 256 | 257 | @app.get("/", dependencies=[Depends(auth.scope(["allowned", "scopes"]))]) 258 | def api_all_scope(): 259 | return "user has all scopes" 260 | 261 | @app.get("/", dependencies=[Depends(auth.scope(["allowned", "scopes"], op=Operator._any))]) 262 | def api_any_scope(): 263 | return "user has any scopes" 264 | ``` 265 | """ 266 | clone = self._clone() 267 | if isinstance(scope_name, str): 268 | scope_name = [scope_name] 269 | clone.scope_name = scope_name 270 | clone._verifier.op = op 271 | if not clone.scope_key: 272 | raise AttributeError("declaire scope_key to set scope") 273 | return clone 274 | 275 | def claim(self, schema: Optional[Type[BaseModel]] = None) -> "ScopedAuth": 276 | """User verification and validation shortcut to pass it into app arguments. 277 | Use as (`auth` is this instance and `app` is fastapi.FastAPI instanse): 278 | ``` 279 | from fastapi import Depends 280 | from pydantic import BaseModel 281 | 282 | class CustomClaim(BaseModel): 283 | sub: str 284 | 285 | @app.get("/") 286 | def api(user: CustomClaim = Depends(auth.claim(CustomClaim))): 287 | return CustomClaim 288 | ``` 289 | """ 290 | clone = self._clone() 291 | clone.user_info = schema 292 | return clone 293 | 294 | async def call( 295 | self, http_auth: HTTPAuthorizationCredentials 296 | ) -> Optional[Union[Dict[str, Any], BaseModel, bool]]: 297 | """User access-token verification Shortcut to pass it into dependencies. 298 | Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): 299 | ``` 300 | from fastapi import Depends 301 | 302 | @app.get("/", dependencies=[Depends(auth)]) 303 | def api(): 304 | return "hello" 305 | ``` 306 | """ 307 | 308 | claims: Dict[str, Any] = jwt.get_unverified_claims(http_auth.credentials) 309 | 310 | if not self.user_info: 311 | return claims 312 | 313 | try: 314 | current_user = self.user_info.parse_obj(claims) 315 | return current_user 316 | except ValidationError: 317 | if self.auto_error: 318 | raise HTTPException( 319 | status_code=status.HTTP_401_UNAUTHORIZED, 320 | detail=NOT_VALIDATED_CLAIMS, 321 | ) 322 | else: 323 | return None 324 | -------------------------------------------------------------------------------- /fastapi_cloudauth/cognito.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Set 2 | 3 | from fastapi.exceptions import HTTPException 4 | from jose import jwk 5 | from jose.backends.base import Key 6 | from pydantic import BaseModel, Field 7 | from starlette import status 8 | 9 | from .base import ScopedAuth, UserInfoAuth 10 | from .messages import NOT_VERIFIED 11 | from .verification import JWKS as BaseJWKS 12 | from .verification import ExtraVerifier 13 | 14 | 15 | class JWKS(BaseJWKS): 16 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 17 | return {_jwk["kid"]: jwk.construct(_jwk) for _jwk in jwks.get("keys", [])} 18 | 19 | 20 | class Cognito(ScopedAuth): 21 | """ 22 | Verify access token of AWS Cognito 23 | """ 24 | 25 | user_info = None 26 | 27 | def __init__( 28 | self, 29 | region: str, 30 | userPoolId: str, 31 | client_id: str, 32 | scope_key: Optional[str] = "cognito:groups", 33 | auto_error: bool = True, 34 | ): 35 | url = f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json" 36 | jwks = JWKS(url=url) 37 | super().__init__( 38 | jwks, 39 | audience=client_id, 40 | issuer=f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 41 | scope_key=scope_key, 42 | auto_error=auto_error, 43 | extra=CognitoExtraVerifier( 44 | client_id=client_id, 45 | issuer=f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 46 | token_use={"access"}, 47 | ), 48 | ) 49 | 50 | 51 | class CognitoClaims(BaseModel): 52 | username: str = Field(alias="cognito:username") 53 | email: str = Field(None, alias="email") 54 | 55 | 56 | class CognitoCurrentUser(UserInfoAuth): 57 | """ 58 | Verify ID token and get user info of AWS Cognito 59 | """ 60 | 61 | user_info = CognitoClaims 62 | 63 | def __init__( 64 | self, 65 | region: str, 66 | userPoolId: str, 67 | client_id: str, 68 | *args: Any, 69 | **kwargs: Any, 70 | ): 71 | url = f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json" 72 | jwks = JWKS(url=url) 73 | super().__init__( 74 | jwks, 75 | user_info=self.user_info, 76 | audience=client_id, 77 | issuer=f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 78 | extra=CognitoExtraVerifier( 79 | client_id=client_id, 80 | issuer=f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 81 | token_use={"id"}, 82 | ), 83 | *args, 84 | **kwargs, 85 | ) 86 | 87 | 88 | class CognitoExtraVerifier(ExtraVerifier): 89 | def __init__(self, client_id: str, issuer: str, token_use: Set[str]): 90 | self._aud = client_id 91 | self._iss = issuer 92 | self._tu = token_use 93 | 94 | def __call__(self, claims: Dict[str, str], auto_error: bool = True) -> bool: 95 | # check token_use 96 | if claims.get("token_use"): 97 | if claims["token_use"] not in self._tu: 98 | if auto_error: 99 | raise HTTPException( 100 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 101 | ) 102 | return False 103 | return True 104 | -------------------------------------------------------------------------------- /fastapi_cloudauth/firebase.py: -------------------------------------------------------------------------------- 1 | from calendar import timegm 2 | from datetime import datetime 3 | from email.utils import parsedate_to_datetime 4 | from typing import Any, Dict, Optional 5 | 6 | import requests 7 | from fastapi import HTTPException 8 | from jose import jwk 9 | from jose.backends.base import Key 10 | from pydantic import BaseModel, Field 11 | from starlette import status 12 | 13 | from .base import UserInfoAuth 14 | from .messages import NOT_VERIFIED 15 | from .verification import JWKS as BaseJWKS 16 | from .verification import ExtraVerifier 17 | 18 | 19 | class JWKS(BaseJWKS): 20 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 21 | return { 22 | kid: jwk.construct(publickey, algorithm="RS256") 23 | for kid, publickey in jwks.items() 24 | } 25 | 26 | def _set_expiration(self, resp: requests.Response) -> Optional[datetime]: 27 | expires_header = resp.headers.get("expires") 28 | if expires_header: 29 | try: 30 | return parsedate_to_datetime(expires_header) 31 | except ValueError: 32 | # Guard against an invalid header value and do not set an expiry. 33 | # This won't happen unless Firebase messes up... 34 | return None 35 | else: 36 | return None 37 | 38 | 39 | class FirebaseClaims(BaseModel): 40 | user_id: str = Field(alias="user_id") 41 | email: str = Field(None, alias="email") 42 | 43 | 44 | class FirebaseCurrentUser(UserInfoAuth): 45 | """ 46 | Verify ID token and get user info of Firebase 47 | """ 48 | 49 | user_info = FirebaseClaims 50 | 51 | def __init__(self, project_id: str, *args: Any, **kwargs: Any): 52 | url = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com" 53 | jwks = JWKS(url=url) 54 | super().__init__( 55 | jwks, 56 | *args, 57 | user_info=self.user_info, 58 | audience=project_id, 59 | issuer=f"https://securetoken.google.com/{project_id}", 60 | extra=FirebaseExtraVerifier(project_id=project_id), 61 | **kwargs, 62 | ) 63 | 64 | 65 | class FirebaseExtraVerifier(ExtraVerifier): 66 | def __init__(self, project_id: str): 67 | self._pjt_id = project_id 68 | 69 | def __call__(self, claims: Dict[str, str], auto_error: bool = True) -> bool: 70 | # auth_time must be past time 71 | if claims.get("auth_time"): 72 | auth_time = int(claims["auth_time"]) 73 | now = timegm(datetime.utcnow().utctimetuple()) 74 | if now < auth_time: 75 | if auto_error: 76 | raise HTTPException( 77 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 78 | ) 79 | return False 80 | return True 81 | -------------------------------------------------------------------------------- /fastapi_cloudauth/messages.py: -------------------------------------------------------------------------------- 1 | NOT_AUTHENTICATED = "Not authenticated" 2 | NO_PUBLICKEY = "JWK public Attribute for authorization token not found" 3 | NOT_VERIFIED = "Not verified" 4 | SCOPE_NOT_MATCHED = "Scope not matched" 5 | NOT_VALIDATED_CLAIMS = "Validation Error for Claims" 6 | -------------------------------------------------------------------------------- /fastapi_cloudauth/verification.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from asyncio import Event 3 | from calendar import timegm 4 | from copy import deepcopy 5 | from datetime import datetime 6 | from enum import Enum 7 | from typing import Any, Dict, List, Optional, Type, Union 8 | 9 | import requests 10 | from fastapi import HTTPException 11 | from fastapi.security import HTTPAuthorizationCredentials 12 | from jose import jwt 13 | from jose.backends.base import Key 14 | from jose.exceptions import JWTError 15 | from jose.utils import base64url_decode 16 | from starlette import status 17 | 18 | from fastapi_cloudauth.messages import ( 19 | NO_PUBLICKEY, 20 | NOT_AUTHENTICATED, 21 | NOT_VERIFIED, 22 | SCOPE_NOT_MATCHED, 23 | ) 24 | 25 | 26 | class Verifier(ABC): 27 | @property 28 | @abstractmethod 29 | def auto_error(self) -> bool: 30 | ... # pragma: no cover 31 | 32 | @abstractmethod 33 | async def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: 34 | ... # pragma: no cover 35 | 36 | @abstractmethod 37 | def clone(self, instance: "Verifier") -> "Verifier": 38 | """create clone instanse""" 39 | ... # pragma: no cover 40 | 41 | 42 | class ExtraVerifier(ABC): 43 | @abstractmethod 44 | def __call__(self, claims: Dict[str, str], auto_error: bool = True) -> bool: 45 | ... # pragma: no cover 46 | 47 | 48 | class JWKS: 49 | def __init__( 50 | self, 51 | url: str = "", 52 | fixed_keys: Optional[Dict[str, Key]] = None, 53 | ): 54 | """Handle the JSON Web Key Set (JWKS), query and refresh ... 55 | Args: 56 | url: Provider JWKS URL. See official doc for what you want to connect. 57 | fixed_keys: (For Test) Set fixed jwks. if passed not None, make it invalid connection between social provider 58 | """ 59 | self.__url = url 60 | self.__fixed_keys = fixed_keys 61 | self.__keys: Dict[str, Key] = {} 62 | self.__expires: Optional[datetime] = None 63 | self.__refreshing = Event() 64 | self.__refreshing.set() 65 | if self.__fixed_keys is None: 66 | # query jwks from provider without mutex 67 | self._refresh_keys() 68 | 69 | async def get_publickey(self, kid: str) -> Optional[Key]: 70 | if self.__fixed_keys is not None: 71 | return self.__fixed_keys.get(kid) 72 | 73 | if self.__expires is not None: 74 | # Check expiration 75 | current_time = datetime.now(tz=self.__expires.tzinfo) 76 | if current_time >= self.__expires: 77 | await self.refresh_keys() 78 | 79 | return self.__keys.get(kid) 80 | 81 | async def refresh_keys(self) -> bool: 82 | """refresh jwks process""" 83 | if self.__refreshing.is_set(): 84 | # refresh jwks 85 | # Ensure only one key refresh can happen at once. 86 | # This prevents a dogpile of requests the second the keys expire 87 | # from causing a bunch of refreshes (each one is an http request). 88 | self.__refreshing.clear() 89 | 90 | # Re-query the keys from provider 91 | self._refresh_keys() 92 | 93 | # Remove the lock. 94 | self.__refreshing.set() 95 | else: 96 | # Other task for refresh is still working. 97 | # Only wait for that to pick publickey from the latest JWKS. 98 | # (Now, this line is not reachable because current re-quering is not awaitable) 99 | await self.__refreshing.wait() 100 | 101 | return True 102 | 103 | def _refresh_keys(self) -> None: 104 | """Core refresh jwks process 105 | NOTE: Call this directly if you does not require mutex on refresh process 106 | """ 107 | # Re-query the keys from provider. 108 | # NOTE (For Firebase Auth): The expires comes from an http header which is supposed to 109 | # be set to a time long before the keys are no longer in use. 110 | # This allows gradual roll-out of the keys and should prevent any 111 | # request from failing. 112 | # The only scenario which will result in failing requests is if 113 | # there are zero requests for the entire duration of the roll-out 114 | # (observed to be around 1 week), followed by a burst of multiple 115 | # requests at once. 116 | jwks_resp = requests.get(self.__url) 117 | 118 | # Reset the keys and the expiry date. 119 | self.__keys = self._construct(jwks_resp.json()) 120 | self.__expires = self._set_expiration(jwks_resp) 121 | 122 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 123 | raise NotImplementedError # pragma: no cover 124 | 125 | def _set_expiration(self, resp: requests.Response) -> Optional[datetime]: 126 | return None 127 | 128 | @classmethod 129 | def null(cls: Type["JWKS"]) -> "JWKS": 130 | return cls(url="", fixed_keys={}) 131 | 132 | @property 133 | def expires(self) -> Optional[datetime]: 134 | return self.__expires 135 | 136 | 137 | class JWKsVerifier(Verifier): 138 | def __init__( 139 | self, 140 | jwks: JWKS, 141 | audience: Optional[Union[str, List[str]]] = None, 142 | issuer: Optional[str] = None, 143 | auto_error: bool = True, 144 | *args: Any, 145 | extra: Optional[ExtraVerifier] = None, 146 | **kwargs: Any 147 | ) -> None: 148 | """ 149 | auto-error: if False, return payload as b'null' for invalid token. 150 | """ 151 | self._jwks = jwks 152 | self._auto_error = auto_error 153 | self._extra_verifier = extra 154 | self._aud = audience 155 | self._iss = issuer 156 | 157 | @property 158 | def auto_error(self) -> bool: 159 | return self._auto_error 160 | 161 | @auto_error.setter 162 | def auto_error(self, auto_error: bool) -> None: 163 | self._auto_error = auto_error 164 | 165 | async def _get_publickey( 166 | self, http_auth: HTTPAuthorizationCredentials 167 | ) -> Optional[Key]: 168 | token = http_auth.credentials 169 | 170 | try: 171 | header = jwt.get_unverified_header(token) 172 | except JWTError as e: 173 | if self.auto_error: 174 | raise HTTPException( 175 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_AUTHENTICATED 176 | ) from e 177 | else: 178 | return None 179 | 180 | kid = header.get("kid") 181 | if not kid: 182 | if self.auto_error: 183 | raise HTTPException( 184 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_AUTHENTICATED 185 | ) 186 | else: 187 | return None 188 | publickey = await self._jwks.get_publickey(kid) 189 | if not publickey: 190 | if self.auto_error: 191 | raise HTTPException( 192 | status_code=status.HTTP_401_UNAUTHORIZED, 193 | detail=NO_PUBLICKEY, 194 | ) 195 | else: 196 | return None 197 | return publickey 198 | 199 | def _verify_claims(self, http_auth: HTTPAuthorizationCredentials) -> bool: 200 | is_verified = False 201 | try: 202 | # check the expiration, issuer 203 | is_verified = jwt.decode( 204 | http_auth.credentials, 205 | "", 206 | audience=self._aud, 207 | issuer=self._iss, 208 | options={ 209 | "verify_signature": False, 210 | "verify_sub": False, 211 | "verify_at_hash": False, 212 | }, # done 213 | ) 214 | except jwt.ExpiredSignatureError as e: 215 | if self.auto_error: 216 | raise HTTPException( 217 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 218 | ) from e 219 | return False 220 | except jwt.JWTClaimsError as e: 221 | if self.auto_error: 222 | raise HTTPException( 223 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 224 | ) from e 225 | return False 226 | except JWTError as e: 227 | if self.auto_error: 228 | raise HTTPException( 229 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_AUTHENTICATED 230 | ) from e 231 | else: 232 | return False 233 | 234 | claims = jwt.get_unverified_claims(http_auth.credentials) 235 | 236 | # iat validation 237 | if claims.get("iat"): 238 | iat = int(claims["iat"]) 239 | now = timegm(datetime.utcnow().utctimetuple()) 240 | if now < iat: 241 | if self.auto_error: 242 | raise HTTPException( 243 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 244 | ) 245 | return False 246 | 247 | if self._extra_verifier: 248 | # check extra claims validation 249 | is_verified = self._extra_verifier( 250 | claims=claims, auto_error=self.auto_error 251 | ) 252 | 253 | return is_verified 254 | 255 | async def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: 256 | # check the signature 257 | public_key = await self._get_publickey(http_auth) 258 | if not public_key: 259 | # error handling is included in self.get_publickey 260 | return False 261 | 262 | message, encoded_sig = http_auth.credentials.rsplit(".", 1) 263 | decoded_sig = base64url_decode(encoded_sig.encode()) 264 | is_verified: bool = public_key.verify(message.encode(), decoded_sig) 265 | if not is_verified: 266 | if self.auto_error: 267 | raise HTTPException( 268 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_VERIFIED 269 | ) 270 | return False 271 | # check the standard claims 272 | is_verified = self._verify_claims(http_auth) 273 | 274 | return is_verified 275 | 276 | def clone(self, instance: "JWKsVerifier") -> "JWKsVerifier": # type: ignore[override] 277 | _jwks = instance._jwks 278 | instance._jwks = None # type: ignore 279 | clone = deepcopy(instance) 280 | clone._jwks = _jwks 281 | instance._jwks = _jwks 282 | return clone 283 | 284 | 285 | class Operator(Enum): 286 | _all = "all" 287 | _any = "any" 288 | 289 | 290 | class ScopedJWKsVerifier(JWKsVerifier): 291 | def __init__( 292 | self, 293 | jwks: JWKS, 294 | audience: Optional[Union[str, List[str]]] = None, 295 | issuer: Optional[str] = None, 296 | scope_key: Optional[str] = None, 297 | scope_name: Optional[List[str]] = None, 298 | op: Operator = Operator._all, 299 | auto_error: bool = True, 300 | extra: Optional[ExtraVerifier] = None, 301 | *args: Any, 302 | **kwargs: Any 303 | ) -> None: 304 | """ 305 | auto-error: if False, return payload as b'null' for invalid token. 306 | """ 307 | super().__init__( 308 | jwks, auto_error=auto_error, extra=extra, audience=audience, issuer=issuer 309 | ) 310 | self.scope_name = None if not scope_name else set(scope_name) 311 | self.scope_key = scope_key 312 | self.op = op 313 | 314 | def clone(self, instance: "ScopedJWKsVerifier") -> "ScopedJWKsVerifier": # type: ignore[override] 315 | cloned = super().clone(instance) 316 | if isinstance(cloned, ScopedJWKsVerifier): 317 | return cloned 318 | raise NotImplementedError # pragma: no cover 319 | 320 | def _verify_scope(self, http_auth: HTTPAuthorizationCredentials) -> bool: 321 | try: 322 | claims = jwt.get_unverified_claims(http_auth.credentials) 323 | except JWTError as e: 324 | if self.auto_error: 325 | raise HTTPException( 326 | status_code=status.HTTP_401_UNAUTHORIZED, detail=NOT_AUTHENTICATED 327 | ) from e 328 | else: 329 | return False 330 | 331 | scopes = claims.get(self.scope_key) 332 | if self.scope_name is None: 333 | # scope is not required 334 | return True 335 | 336 | matched = True 337 | if isinstance(scopes, str): 338 | scopes = {scope.strip() for scope in scopes.split()} 339 | else: 340 | try: 341 | scopes = set(scopes) 342 | except TypeError: 343 | matched = False 344 | if matched: 345 | if self.op == Operator._any: 346 | # any 347 | matched = len(self.scope_name & scopes) > 0 348 | else: 349 | # all 350 | matched = self.scope_name.issubset(scopes) 351 | if not matched: 352 | if self.auto_error: 353 | raise HTTPException( 354 | status_code=status.HTTP_403_FORBIDDEN, 355 | detail=SCOPE_NOT_MATCHED, 356 | ) 357 | return False 358 | return True 359 | 360 | async def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: 361 | is_verified = await super().verify_token(http_auth) 362 | if not is_verified: 363 | return False 364 | 365 | if self.scope_name: 366 | is_verified_scope = self._verify_scope(http_auth) 367 | return is_verified_scope 368 | 369 | return True 370 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | # --strict 4 | disallow_any_generics = True 5 | disallow_subclassing_any = True 6 | disallow_untyped_calls = True 7 | disallow_untyped_defs = True 8 | disallow_incomplete_defs = True 9 | check_untyped_defs = True 10 | disallow_untyped_decorators = True 11 | no_implicit_optional = True 12 | warn_redundant_casts = True 13 | warn_unused_ignores = True 14 | warn_return_any = True 15 | implicit_reexport = False 16 | strict_equality = True 17 | # --strict end 18 | 19 | [mypy-tests.*] 20 | ignore_missing_imports = True 21 | check_untyped_defs = False 22 | disallow_untyped_defs = False 23 | 24 | [mypy-jose.*] 25 | ignore_missing_imports = True -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fastapi-cloudauth" 3 | version = "0.4.3" 4 | description = "fastapi-cloudauth supports simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication)." 5 | authors = ["tokusumi "] 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/tokusumi/fastapi-cloudauth" 9 | 10 | include = [ 11 | "LICENSE", 12 | ] 13 | keywords = ["FastAPI", "authentication", "Auth0", "AWS Cognito", "Firebase Authentication"] 14 | classifiers = [ 15 | "Environment :: Web Environment", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python :: 3 :: Only", 19 | "Programming Language :: Python :: 3.6", 20 | "Programming Language :: Python :: 3.7", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Topic :: Security", 24 | "Typing :: Typed", 25 | ] 26 | 27 | [tool.poetry.dependencies] 28 | python = ">=3.6.2,<4.0" 29 | fastapi = ">= 0.60.1, < 1.0" 30 | python-jose = {version = ">=3.3.0,<4.0.0", extras = ["cryptography"]} 31 | requests = ">=2.24.0,<3.0.0" 32 | 33 | [tool.poetry.dev-dependencies] 34 | pytest = ">=6.2.4,<7.0.0" 35 | pytest-cov = ">=2.12.0,<4.0.0" 36 | flake8 = ">=3.8.3,<4.0.0" 37 | mypy = "0.910" 38 | black = "21.9b0" 39 | isort = ">=5.0.6,<6.0.0" 40 | uvicorn = ">=0.12.0,<0.14.0" 41 | botocore = ">=1.17.32" 42 | boto3 = ">=1.14.32" 43 | authlib = ">=0.15.2" 44 | firebase-admin = ">=4.4.0" 45 | auth0-python = ">=3.14.0" 46 | pytest-mock = ">=3.5.1" 47 | pytest-asyncio = ">=0.14.0" 48 | autoflake = ">=1.4.0,<2.0.0" 49 | types-requests = ">=2.26.3,<3.0.0" 50 | 51 | [tool.isort] 52 | profile = "black" 53 | known_third_party = ["fastapi", "pydantic", "starlette"] 54 | 55 | [build-system] 56 | requires = ["poetry>=1.1.12"] 57 | build-backend = "poetry.masonry.api" 58 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | unittest:Run unittest 4 | auth0:Run the integration tests for auth0. Some env is required. 5 | cognito:Run the integration tests for AWS Cognito. Some env is required. 6 | firebase:Run the integration tests for Firebase Authentication. Some env is required. -------------------------------------------------------------------------------- /scripts/dep.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # check dependency problem via install fastapi-cloudauth with wheel 3 | set -e 4 | 5 | poetry build 6 | python3 -m venv env 7 | source env/bin/activate 8 | python3 -m pip install -U pip 9 | python3 -m pip install wheel dist/fastapi_cloudauth-*.whl 10 | 11 | # if not installed cryptography with python-jose, this line fails 12 | python3 -c 'from fastapi_cloudauth.firebase import *; print(JWKS(url="https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"))' 13 | echo "Success" -------------------------------------------------------------------------------- /scripts/develop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source ./scripts/load_env.sh 4 | uvicorn docs.server.auth0:app 5 | uvicorn docs.server.cognito:app 6 | uvicorn docs.server.firebase:app -------------------------------------------------------------------------------- /scripts/format-imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | set -x 3 | 4 | # Sort imports one per line, so autoflake can remove unused imports 5 | isort fastapi_cloudauth tests scripts --force-single-line-imports 6 | sh ./scripts/format.sh -------------------------------------------------------------------------------- /scripts/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -e 2 | set -x 3 | 4 | autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place fastapi_cloudauth tests scripts --exclude=__init__.py 5 | black fastapi_cloudauth tests scripts 6 | isort fastapi_cloudauth tests scripts -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | mypy fastapi_cloudauth 7 | flake8 fastapi_cloudauth tests 8 | black fastapi_cloudauth tests --check 9 | isort fastapi_cloudauth tests scripts --check-only -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | bash ./scripts/lint.sh 7 | pytest --cov=fastapi_cloudauth --cov=tests --cov-report=xml --disable-warnings tests/ -------------------------------------------------------------------------------- /scripts/test_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | bash ./scripts/lint.sh 5 | source ./scripts/load_env.sh 6 | pytest --cov=fastapi_cloudauth --cov=tests --cov-report=term-missing tests ${@} 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tokusumi/fastapi-cloudauth/594153706391258d80590a31e31666f260519a83/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite("tests.helpers") 4 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, Iterable, List, Tuple 5 | 6 | import pytest 7 | from fastapi import HTTPException 8 | from fastapi.security import HTTPAuthorizationCredentials 9 | from fastapi.testclient import TestClient 10 | from pydantic.main import BaseModel 11 | from requests.models import Response 12 | 13 | from fastapi_cloudauth.base import ScopedAuth, UserInfoAuth 14 | from fastapi_cloudauth.verification import JWKsVerifier 15 | 16 | 17 | @dataclass 18 | class Auths: 19 | protect_auth: ScopedAuth 20 | protect_auth_ne: ScopedAuth 21 | ms_auth: UserInfoAuth 22 | ms_auth_ne: UserInfoAuth 23 | invalid_ms_auth: UserInfoAuth 24 | invalid_ms_auth_ne: UserInfoAuth 25 | valid_claim: BaseModel 26 | invalid_claim: BaseModel 27 | 28 | 29 | class BaseTestCloudAuth: 30 | """ 31 | Required 32 | setup: initialize test case 33 | teardown: del items for test 34 | decode: check decoded token and assigned info 35 | """ 36 | 37 | ACCESS_TOKEN = "" 38 | SCOPE_ACCESS_TOKEN = "" 39 | ID_TOKEN = "" 40 | TESTAUTH: Auths 41 | 42 | def setup(self, scope: Iterable[str]) -> None: 43 | ... # pragma: no cover 44 | 45 | def teardown(self) -> None: 46 | ... # pragma: no cover 47 | 48 | def decode(self) -> None: 49 | ... # pragma: no cover 50 | 51 | 52 | def assert_get_response( 53 | client: TestClient, endpoint: str, token: str, status_code: int, detail: str = "" 54 | ) -> Response: 55 | if token: 56 | headers = {"authorization": f"Bearer {token}"} 57 | else: 58 | headers = {} 59 | response = client.get(endpoint, headers=headers) 60 | assert response.status_code == status_code, f"{response.json()}" 61 | if detail: 62 | assert response.json().get("detail", "") == detail 63 | return response 64 | 65 | 66 | def _assert_verifier(token, verifier: JWKsVerifier) -> HTTPException: 67 | http_auth = HTTPAuthorizationCredentials(scheme="a", credentials=token) 68 | with pytest.raises(HTTPException) as e: 69 | verifier._verify_claims(http_auth) 70 | return e.value 71 | 72 | 73 | def _assert_verifier_no_error(token, verifier: JWKsVerifier) -> None: 74 | http_auth = HTTPAuthorizationCredentials(scheme="a", credentials=token) 75 | assert verifier._verify_claims(http_auth) is False 76 | 77 | 78 | def decode_token(token: str) -> Tuple[Dict[str, Any], Dict[str, Any], List[str]]: 79 | header, payload, *rest = token.split(".") 80 | 81 | header += f"{'=' * (len(header) % 4)}" 82 | payload += f"{'=' * (len(payload) % 4)}" 83 | _header = json.loads(base64.b64decode(header).decode()) 84 | _payload = json.loads(base64.b64decode(payload).decode()) 85 | return _header, _payload, rest 86 | -------------------------------------------------------------------------------- /tests/test_auth0.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime, timedelta 3 | from sys import version_info as info 4 | from typing import Iterable, List, Optional 5 | 6 | import pytest 7 | import requests 8 | from auth0.v3.authentication import GetToken 9 | from auth0.v3.management import Auth0 as Auth0sdk 10 | from fastapi.security.http import HTTPAuthorizationCredentials 11 | from jose import jwt 12 | from starlette.status import HTTP_401_UNAUTHORIZED 13 | 14 | from fastapi_cloudauth.auth0 import Auth0, Auth0Claims, Auth0CurrentUser 15 | from fastapi_cloudauth.messages import NOT_VERIFIED 16 | from tests.helpers import ( 17 | Auths, 18 | BaseTestCloudAuth, 19 | _assert_verifier, 20 | _assert_verifier_no_error, 21 | decode_token, 22 | ) 23 | 24 | DOMAIN = os.getenv("AUTH0_DOMAIN") 25 | MGMT_CLIENTID = os.getenv("AUTH0_MGMT_CLIENTID") 26 | MGMT_CLIENT_SECRET = os.getenv("AUTH0_MGMT_CLIENT_SECRET") 27 | CLIENTID = os.getenv("AUTH0_CLIENTID") 28 | CLIENT_SECRET = os.getenv("AUTH0_CLIENT_SECRET") 29 | AUDIENCE = os.getenv("AUTH0_AUDIENCE") 30 | CONNECTION = "Username-Password-Authentication" 31 | 32 | 33 | def assert_env(): 34 | assert DOMAIN, "'AUTH0_DOMAIN' is not defined. Set environment variables" 35 | assert ( 36 | MGMT_CLIENTID 37 | ), "'AUTH0_MGMT_CLIENTID' is not defined. Set environment variables" 38 | assert ( 39 | MGMT_CLIENT_SECRET 40 | ), "'AUTH0_MGMT_CLIENT_SECRET' is not defined. Set environment variables" 41 | assert CLIENTID, "'AUTH0_CLIENTID' is not defined. Set environment variables" 42 | assert ( 43 | CLIENT_SECRET 44 | ), "'AUTH0_CLIENT_SECRET' is not defined. Set environment variables" 45 | assert AUDIENCE, "'AUTH0_AUDIENCE' is not defined. Set environment variables" 46 | 47 | 48 | def init() -> Auth0sdk: 49 | """ 50 | instantiate Auth0 SDK class 51 | Goes to Auth0 dashboard and get followings. 52 | DOMAIN: domain of Auth0 Dashboard Backend Management Client's Applications 53 | MGMT_CLIENTID: client ID of Auth0 Dashboard Backend Management Client's Applications 54 | MGMT_CLIENT_SECRET: client secret of Auth0 Dashboard Backend Management Client's Applications 55 | """ 56 | get_token = GetToken(DOMAIN) 57 | token = get_token.client_credentials( 58 | MGMT_CLIENTID, 59 | MGMT_CLIENT_SECRET, 60 | f"https://{DOMAIN}/api/v2/", 61 | ) 62 | mgmt_api_token = token["access_token"] 63 | 64 | auth0 = Auth0sdk(DOMAIN, mgmt_api_token) 65 | return auth0 66 | 67 | 68 | def add_test_user( 69 | auth0: Auth0sdk, 70 | username=f"test_user{info.major}{info.minor}@example.com", 71 | password="testPass1-", 72 | scopes: Optional[List[str]] = None, 73 | ): 74 | """create test user with Auth0 SDK 75 | Requirements: 76 | CLIENTID: client id of `Default App`. See Applications in Auth0 dashboard 77 | AUDIENCE: create custom API in Auth0 dashboard and add custom permisson (`read:test`). 78 | Then, assing that identifier as AUDIENCE. 79 | """ 80 | resp = requests.post( 81 | f"https://{DOMAIN}/dbconnections/signup", 82 | { 83 | "client_id": CLIENTID, 84 | "email": username, 85 | "password": password, 86 | "connection": CONNECTION, 87 | "username": username, 88 | }, 89 | ) 90 | user_id = f"auth0|{resp.json()['_id']}" 91 | 92 | if scopes: 93 | auth0.users.add_permissions( 94 | user_id, 95 | [ 96 | {"permission_name": scope, "resource_server_identifier": AUDIENCE} 97 | for scope in scopes 98 | ], 99 | ) 100 | 101 | 102 | def delete_user( 103 | auth0: Auth0sdk, 104 | username=f"test_user{info.major}{info.minor}@example.com", 105 | password="testPass1-", 106 | ): 107 | """delete test user with Auth0 SDK""" 108 | access_token = get_access_token(username=username, password=password) 109 | if not access_token: 110 | return 111 | user_id = jwt.get_unverified_claims(access_token)["sub"] 112 | auth0.users.delete(user_id) 113 | 114 | 115 | def get_access_token( 116 | username=f"test_user{info.major}{info.minor}@example.com", 117 | password="testPass1-", 118 | ) -> Optional[str]: 119 | """ 120 | Requirements: 121 | DOMAIN: domain of Auth0 Dashboard Backend Management Client's Applications 122 | CLIENTID: Set client id of `Default App` in environment variable. See Applications in Auth0 dashboard 123 | CLIENT_SECRET: Set client secret of `Default App` in environment variable 124 | AUDIENCE: In Auth0 dashboard, create custom applications and API, 125 | and add permission `read:test` into that API, 126 | and then copy the audience (identifier) in environment variable. 127 | 128 | NOTE: the followings setting in Auth0 dashboard is required 129 | - sidebar > Applications > settings > Advanced settings > grant: click `password` on 130 | - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication 131 | """ 132 | resp = requests.post( 133 | f"https://{DOMAIN}/oauth/token", 134 | headers={"content-type": "application/x-www-form-urlencoded"}, 135 | data={ 136 | "grant_type": "password", 137 | "username": username, 138 | "password": password, 139 | "client_id": CLIENTID, 140 | "client_secret": CLIENT_SECRET, 141 | "audience": AUDIENCE, 142 | }, 143 | ) 144 | access_token = resp.json().get("access_token") 145 | return access_token 146 | 147 | 148 | def get_id_token( 149 | username=f"test_user{info.major}{info.minor}@example.com", 150 | password="testPass1-", 151 | ) -> Optional[str]: 152 | """ 153 | Requirements: 154 | DOMAIN: domain of Auth0 Dashboard Backend Management Client's Applications 155 | CLIENTID: Set client id of `Default App` in environment variable. See Applications in Auth0 dashboard 156 | CLIENT_SECRET: Set client secret of `Default App` in environment variable 157 | AUDIENCE: In Auth0 dashboard, create custom applications and API, 158 | and add permission `read:test` into that API, 159 | and then copy the audience (identifier) in environment variable. 160 | 161 | NOTE: the followings setting in Auth0 dashboard is required 162 | - sidebar > Applications > settings > Advanced settings > grant: click `password` on 163 | - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication 164 | """ 165 | resp = requests.post( 166 | f"https://{DOMAIN}/oauth/token", 167 | headers={"content-type": "application/x-www-form-urlencoded"}, 168 | data={ 169 | "grant_type": "password", 170 | "username": username, 171 | "password": password, 172 | "client_id": CLIENTID, 173 | "client_secret": CLIENT_SECRET, 174 | }, 175 | ) 176 | id_token = resp.json().get("id_token") 177 | return id_token 178 | 179 | 180 | class Auth0Client(BaseTestCloudAuth): 181 | """ 182 | NOTE: RBAC setting must be able 183 | """ 184 | 185 | username = f"test_user{info.major}{info.minor}@example.com" 186 | password = "testPass1-" 187 | 188 | def setup(self, scope: Iterable[str]) -> None: 189 | assert_env() 190 | 191 | auth0sdk = init() 192 | self.scope = scope 193 | self.scope_username = ( 194 | f"{'-'.join(self.scope).replace(':', '-')}{self.username}" 195 | if self.scope 196 | else self.username 197 | ) 198 | 199 | delete_user(auth0sdk, username=self.username, password=self.password) 200 | add_test_user( 201 | auth0sdk, 202 | username=self.username, 203 | password=self.password, 204 | scopes=[self.scope[0]], 205 | ) 206 | self.ACCESS_TOKEN = get_access_token( 207 | username=self.username, password=self.password 208 | ) 209 | self.ID_TOKEN = get_id_token(username=self.username, password=self.password) 210 | 211 | delete_user(auth0sdk, username=self.scope_username) 212 | add_test_user( 213 | auth0sdk, 214 | username=self.scope_username, 215 | password=self.password, 216 | scopes=self.scope, 217 | ) 218 | self.SCOPE_ACCESS_TOKEN = get_access_token( 219 | username=self.scope_username, password=self.password 220 | ) 221 | 222 | self.auth0sdk = auth0sdk 223 | 224 | class Auth0InvalidClaims(Auth0Claims): 225 | fake_field: str 226 | 227 | class Auth0FakeCurrentUser(Auth0CurrentUser): 228 | user_info = Auth0InvalidClaims 229 | 230 | assert DOMAIN and AUDIENCE and CLIENTID 231 | self.TESTAUTH = Auths( 232 | protect_auth=Auth0(domain=DOMAIN, customAPI=AUDIENCE), 233 | protect_auth_ne=Auth0(domain=DOMAIN, customAPI=AUDIENCE, auto_error=False), 234 | ms_auth=Auth0CurrentUser(domain=DOMAIN, client_id=CLIENTID), 235 | ms_auth_ne=Auth0CurrentUser( 236 | domain=DOMAIN, client_id=CLIENTID, auto_error=False 237 | ), 238 | invalid_ms_auth=Auth0FakeCurrentUser(domain=DOMAIN, client_id=CLIENTID), 239 | invalid_ms_auth_ne=Auth0FakeCurrentUser( 240 | domain=DOMAIN, client_id=CLIENTID, auto_error=False 241 | ), 242 | valid_claim=Auth0Claims, 243 | invalid_claim=Auth0InvalidClaims, 244 | ) 245 | 246 | def teardown(self): 247 | delete_user(self.auth0sdk, self.username) 248 | delete_user(self.auth0sdk, self.scope_username) 249 | 250 | def decode(self): 251 | # access token 252 | header, payload, *_ = decode_token(self.ACCESS_TOKEN) 253 | assert header.get("typ") == "JWT" 254 | assert [self.scope[0]] == payload.get("permissions") 255 | 256 | # scope access token 257 | scope_header, scope_payload, *_ = decode_token(self.SCOPE_ACCESS_TOKEN) 258 | assert scope_header.get("typ") == "JWT" 259 | assert set(self.scope) == set(scope_payload.get("permissions")) 260 | 261 | # id token 262 | id_header, id_payload, *_ = decode_token(self.ID_TOKEN) 263 | assert id_header.get("typ") == "JWT" 264 | assert id_payload.get("email") == self.username 265 | 266 | 267 | @pytest.mark.unittest 268 | def test_extra_verify_access_token(): 269 | """ 270 | Testing for access token validation: 271 | - validate standard claims: Token expiration (exp) and Token issuer (iss) 272 | - verify token audience (aud) claims 273 | Ref: https://auth0.com/docs/tokens/access-tokens/validate-access-tokens 274 | """ 275 | domain = DOMAIN 276 | customAPI = "https://dummy-domain" 277 | issuer = "https://dummy" 278 | auth = Auth0(domain=domain, customAPI=customAPI, issuer=issuer) 279 | verifier = auth._verifier 280 | auth_no_error = Auth0( 281 | domain=domain, customAPI=customAPI, issuer=issuer, auto_error=False 282 | ) 283 | verifier_no_error = auth_no_error._verifier 284 | 285 | # correct 286 | token = jwt.encode( 287 | { 288 | "sub": "dummy-ID", 289 | "exp": datetime.utcnow() + timedelta(hours=10), 290 | "iat": datetime.utcnow() - timedelta(hours=10), 291 | "aud": customAPI, 292 | "iss": issuer, 293 | }, 294 | "dummy_secret", 295 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 296 | ) 297 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 298 | verifier_no_error._verify_claims( 299 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 300 | ) 301 | # Testing for validation of JWT standard claims 302 | 303 | # invalid iss 304 | token = jwt.encode( 305 | { 306 | "sub": "dummy-ID", 307 | "exp": datetime.utcnow() + timedelta(hours=10), 308 | "iat": datetime.utcnow() - timedelta(hours=10), 309 | "aud": customAPI, 310 | "iss": "invalid" + issuer, 311 | }, 312 | "dummy_secret", 313 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 314 | ) 315 | e = _assert_verifier(token, verifier) 316 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 317 | _assert_verifier_no_error(token, verifier_no_error) 318 | 319 | # invalid expiration 320 | token = jwt.encode( 321 | { 322 | "sub": "dummy-ID", 323 | "exp": datetime.utcnow() - timedelta(hours=5), 324 | "iat": datetime.utcnow() - timedelta(hours=10), 325 | "aud": customAPI, 326 | "iss": issuer, 327 | }, 328 | "dummy_secret", 329 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 330 | ) 331 | e = _assert_verifier(token, verifier) 332 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 333 | _assert_verifier_no_error(token, verifier_no_error) 334 | 335 | # Testing for access token specific verification 336 | # invalid aud 337 | # aud must be same as custom API 338 | token = jwt.encode( 339 | { 340 | "sub": "dummy-ID", 341 | "exp": datetime.utcnow() + timedelta(hours=10), 342 | "iat": datetime.utcnow() - timedelta(hours=10), 343 | "aud": customAPI + "incorrect", 344 | "iss": issuer, 345 | }, 346 | "dummy_secret", 347 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 348 | ) 349 | e = _assert_verifier(token, verifier) 350 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 351 | _assert_verifier_no_error(token, verifier_no_error) 352 | 353 | 354 | @pytest.mark.unittest 355 | def test_extra_verify_id_token(): 356 | """ 357 | Testing for ID token validation: 358 | - validate standard claims: Token expiration (exp) and Token issuer (iss) 359 | - verify token audience (aud) claims: same as Client ID 360 | - verify Nonce 361 | Ref: https://auth0.com/docs/tokens/id-tokens/validate-id-tokens 362 | """ 363 | domain = DOMAIN 364 | client_id = "dummy-client-ID" 365 | nonce = "dummy-nonce" 366 | issuer = "https://dummy" 367 | auth = Auth0CurrentUser( 368 | domain=domain, client_id=client_id, nonce=nonce, issuer=issuer 369 | ) 370 | verifier = auth._verifier 371 | auth_no_error = Auth0CurrentUser( 372 | domain=domain, client_id=client_id, nonce=nonce, issuer=issuer, auto_error=False 373 | ) 374 | verifier_no_error = auth_no_error._verifier 375 | 376 | # correct 377 | token = jwt.encode( 378 | { 379 | "sub": "dummy-ID", 380 | "exp": datetime.utcnow() + timedelta(hours=10), 381 | "iat": datetime.utcnow() - timedelta(hours=10), 382 | "aud": client_id, 383 | "nonce": nonce, 384 | "iss": issuer, 385 | }, 386 | "dummy_secret", 387 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 388 | ) 389 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 390 | verifier_no_error._verify_claims( 391 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 392 | ) 393 | 394 | # Testing for validation of JWT standard claims 395 | 396 | # invalid iss 397 | token = jwt.encode( 398 | { 399 | "sub": "dummy-ID", 400 | "exp": datetime.utcnow() + timedelta(hours=10), 401 | "iat": datetime.utcnow() - timedelta(hours=10), 402 | "aud": client_id, 403 | "iss": "invalid" + issuer, 404 | }, 405 | "dummy_secret", 406 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 407 | ) 408 | e = _assert_verifier(token, verifier) 409 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 410 | _assert_verifier_no_error(token, verifier_no_error) 411 | 412 | # invalid expiration 413 | token = jwt.encode( 414 | { 415 | "sub": "dummy-ID", 416 | "exp": datetime.utcnow() - timedelta(hours=5), 417 | "iat": datetime.utcnow() - timedelta(hours=10), 418 | "aud": client_id, 419 | "iss": issuer, 420 | }, 421 | "dummy_secret", 422 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 423 | ) 424 | e = _assert_verifier(token, verifier) 425 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 426 | _assert_verifier_no_error(token, verifier_no_error) 427 | 428 | # Testing for ID token specific verification 429 | # invalid aud 430 | # aud must be same as Client ID 431 | token = jwt.encode( 432 | { 433 | "sub": "dummy-ID", 434 | "exp": datetime.utcnow() + timedelta(hours=10), 435 | "iat": datetime.utcnow() - timedelta(hours=10), 436 | "aud": client_id + "incorrect", 437 | "iss": issuer, 438 | }, 439 | "dummy_secret", 440 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 441 | ) 442 | e = _assert_verifier(token, verifier) 443 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 444 | _assert_verifier_no_error(token, verifier_no_error) 445 | 446 | # invalid nonce 447 | token = jwt.encode( 448 | { 449 | "sub": "dummy-ID", 450 | "exp": datetime.utcnow() + timedelta(hours=10), 451 | "iat": datetime.utcnow() - timedelta(hours=10), 452 | "aud": client_id, 453 | "nonce": nonce + "invalid", 454 | "iss": issuer, 455 | }, 456 | "dummy_secret", 457 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 458 | ) 459 | e = _assert_verifier(token, verifier) 460 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 461 | _assert_verifier_no_error(token, verifier_no_error) 462 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import HTTPException 3 | from fastapi.security import HTTPAuthorizationCredentials 4 | from pydantic import BaseModel 5 | 6 | from fastapi_cloudauth.base import ScopedAuth, UserInfoAuth 7 | from fastapi_cloudauth.verification import JWKS 8 | 9 | 10 | @pytest.mark.unittest 11 | def test_raise_error_invalid_set_scope(): 12 | # scope_key is not declaired 13 | token_verifier = ScopedAuth(jwks=JWKS.null()) 14 | with pytest.raises(AttributeError): 15 | # raise AttributeError for invalid instanse attributes wrt scope 16 | token_verifier.scope("read:test") 17 | 18 | 19 | @pytest.mark.unittest 20 | def test_return_instance_with_scope(): 21 | # scope method return new instance to give it for Depends. 22 | verifier = ScopedAuth(jwks=JWKS.null()) 23 | # must set scope_key (Inherit ScopedAuth and override scope_key attribute) 24 | scope_key = "dummy key" 25 | verifier.scope_key = scope_key 26 | 27 | scope_name = "required-scope" 28 | obj = verifier.scope(scope_name) 29 | assert isinstance(obj, ScopedAuth) 30 | assert obj.scope_key == scope_key, "scope_key mustn't be cleared." 31 | assert obj.scope_name == [scope_name], "Must set scope_name in returned instanse." 32 | assert obj.verifier.scope_name == set( 33 | [scope_name] 34 | ), "Must convert scope name into set." 35 | assert obj.verifier._jwks == verifier.verifier._jwks, "return cloned objects" 36 | assert ( 37 | obj.verifier.auto_error == verifier.verifier.auto_error 38 | ), "return cloned objects" 39 | 40 | 41 | @pytest.mark.unittest 42 | @pytest.mark.parametrize( 43 | "scopes", 44 | [ 45 | "user-assigned-scope", 46 | "xxx:xxx user-assigned-scope yyy:yyy", 47 | ["xxx:xxx", "user-assigned-scope", "yyy:yyy"], 48 | ], 49 | ) 50 | def test_validation_scope(mocker, scopes): 51 | mocker.patch( 52 | "fastapi_cloudauth.verification.jwt.get_unverified_claims", 53 | return_value={"dummy key": scopes}, 54 | ) 55 | verifier = ScopedAuth(jwks=JWKS.null()) 56 | scope_key = "dummy key" 57 | verifier.scope_key = scope_key 58 | 59 | scope_name = "user-assigned-scope" 60 | obj = verifier.scope(scope_name) 61 | assert obj.verifier._verify_scope( 62 | HTTPAuthorizationCredentials(scheme="", credentials="") 63 | ) 64 | 65 | scope_name = "user-assigned-scope-invalid" 66 | obj = verifier.scope(scope_name) 67 | with pytest.raises(HTTPException): 68 | obj.verifier._verify_scope( 69 | HTTPAuthorizationCredentials(scheme="", credentials="") 70 | ) 71 | 72 | obj.verifier.auto_error = False 73 | assert not obj.verifier._verify_scope( 74 | HTTPAuthorizationCredentials(scheme="", credentials="") 75 | ) 76 | 77 | 78 | @pytest.mark.unittest 79 | @pytest.mark.asyncio 80 | @pytest.mark.parametrize( 81 | "auth", 82 | [UserInfoAuth, ScopedAuth], 83 | ) 84 | async def test_forget_def_user_info(auth): 85 | dummy_http_auth = HTTPAuthorizationCredentials( 86 | scheme="a", 87 | credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", 88 | ) 89 | """If `.user_info` is None, return raw payload""" 90 | get_current_user = auth(jwks=JWKS.null()) 91 | assert get_current_user.user_info is None 92 | res = await get_current_user.call(dummy_http_auth) 93 | assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} 94 | 95 | 96 | @pytest.mark.unittest 97 | @pytest.mark.asyncio 98 | @pytest.mark.parametrize( 99 | "auth", 100 | [UserInfoAuth, ScopedAuth], 101 | ) 102 | async def test_assign_user_info(auth): 103 | """three way to set user info schema 104 | 1. pass it to arguments when create instance 105 | 2. call `.claim` method and pass it to that arguments 106 | 3. assign with `=` statements 107 | """ 108 | 109 | class SubSchema(BaseModel): 110 | sub: str 111 | 112 | class NameSchema(BaseModel): 113 | name: str 114 | 115 | class IatSchema(BaseModel): 116 | iat: int 117 | 118 | # authorized token 119 | dummy_http_auth = HTTPAuthorizationCredentials( 120 | scheme="a", 121 | credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", 122 | ) 123 | 124 | user = auth(jwks=JWKS.null(), user_info=IatSchema) 125 | assert await user.call(dummy_http_auth) == IatSchema(iat=1516239022) 126 | 127 | assert await user.claim(SubSchema).call(dummy_http_auth) == SubSchema( 128 | sub="1234567890" 129 | ) 130 | 131 | user.user_info = NameSchema 132 | assert await user.call(dummy_http_auth) == NameSchema(name="name") 133 | 134 | 135 | @pytest.mark.unittest 136 | @pytest.mark.asyncio 137 | @pytest.mark.parametrize( 138 | "auth", 139 | [UserInfoAuth, ScopedAuth], 140 | ) 141 | async def test_extract_raw_user_info(auth): 142 | dummy_http_auth = HTTPAuthorizationCredentials( 143 | scheme="a", 144 | credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", 145 | ) 146 | 147 | class NameSchema(BaseModel): 148 | name: str 149 | 150 | get_current_user = auth(jwks=JWKS.null(), user_info=NameSchema) 151 | get_current_user.user_info = None 152 | res = await get_current_user.call(dummy_http_auth) 153 | assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} 154 | 155 | get_current_user = auth(jwks=JWKS.null(), user_info=NameSchema) 156 | res = await get_current_user.claim(None).call(dummy_http_auth) 157 | assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} 158 | -------------------------------------------------------------------------------- /tests/test_cloudauth.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Type 2 | 3 | import pytest 4 | from fastapi import Depends 5 | from fastapi.applications import FastAPI 6 | from fastapi.testclient import TestClient 7 | from pydantic import BaseModel 8 | from requests.models import Response 9 | 10 | from fastapi_cloudauth.messages import ( 11 | NO_PUBLICKEY, 12 | NOT_AUTHENTICATED, 13 | NOT_VALIDATED_CLAIMS, 14 | NOT_VERIFIED, 15 | SCOPE_NOT_MATCHED, 16 | ) 17 | from fastapi_cloudauth.verification import Operator 18 | from tests.helpers import BaseTestCloudAuth as Base 19 | from tests.helpers import assert_get_response 20 | from tests.test_auth0 import Auth0Client 21 | from tests.test_cognito import CognitoClient 22 | from tests.test_firebase import FirebaseClient 23 | 24 | 25 | class BaseTestCase: 26 | scope = ("read:test", "write:test") 27 | verify_access_token = False 28 | verify_id_token = False 29 | ACCESS_TOKEN = "" 30 | SCOPE_ACCESS_TOKEN = "" 31 | ID_TOKEN = "" 32 | client: TestClient 33 | cloud_auth: Type[Base] 34 | _cloud_auth: Base 35 | 36 | @classmethod 37 | def setup_class(cls): 38 | """set credentials and create test user""" 39 | cls._cloud_auth = cls.cloud_auth() 40 | cls._cloud_auth.setup(cls.scope) 41 | 42 | # get access token and id token 43 | cls.ACCESS_TOKEN = cls._cloud_auth.ACCESS_TOKEN 44 | cls.SCOPE_ACCESS_TOKEN = cls._cloud_auth.SCOPE_ACCESS_TOKEN 45 | cls.ID_TOKEN = cls._cloud_auth.ID_TOKEN 46 | 47 | # set application for testing 48 | app = FastAPI() 49 | if cls.verify_access_token: 50 | app = add_endpoint_for_accesstoken(app, cls._cloud_auth, cls.scope) 51 | if cls.verify_id_token: 52 | app = add_endpoint_for_idtoken(app, cls._cloud_auth) 53 | cls.client = TestClient(app) 54 | 55 | @classmethod 56 | def teardown_class(cls): 57 | """delete test user""" 58 | cls._cloud_auth.teardown() 59 | 60 | def test_decode_token(self): 61 | self._cloud_auth.decode() 62 | 63 | 64 | def add_endpoint_for_accesstoken( 65 | app: FastAPI, auth: Base, scope: Iterable[str] 66 | ) -> FastAPI: 67 | t = auth.TESTAUTH 68 | 69 | @app.get("/") 70 | async def secure(payload: bool = Depends(t.protect_auth)) -> bool: 71 | return payload 72 | 73 | @app.get("/no-error/", dependencies=[Depends(t.protect_auth_ne)]) 74 | async def secure_no_error(payload=Depends(t.protect_auth_ne)) -> bool: 75 | return payload 76 | 77 | class AccessClaim(BaseModel): 78 | sub: str = None 79 | 80 | @app.get("/access/user") 81 | async def secure_access_user( 82 | payload: AccessClaim = Depends(t.protect_auth.claim(AccessClaim)), 83 | ): 84 | assert isinstance(payload, AccessClaim) 85 | return payload 86 | 87 | @app.get("/access/user/no-error/") 88 | async def secure_access_user_no_error( 89 | payload: AccessClaim = Depends(t.protect_auth_ne.claim(AccessClaim)), 90 | ) -> Optional[AccessClaim]: 91 | return payload 92 | 93 | class InvalidAccessClaim(BaseModel): 94 | fake_field: str 95 | 96 | @app.get("/access/user/invalid") 97 | async def invalid_access_user( 98 | payload=Depends(t.protect_auth.claim(InvalidAccessClaim)), 99 | ): 100 | return payload # pragma: no cover 101 | 102 | @app.get("/access/user/invalid/no-error/") 103 | async def invalid_access_user_no_error( 104 | payload=Depends(t.protect_auth_ne.claim(InvalidAccessClaim)), 105 | ) -> Optional[InvalidAccessClaim]: 106 | assert payload is None 107 | 108 | @app.get("/scope/") 109 | async def secure_scope(payload=Depends(t.protect_auth.scope(scope))) -> bool: 110 | pass 111 | 112 | @app.get("/scope/no-error/") 113 | async def secure_scope_no_error(payload=Depends(t.protect_auth_ne.scope(scope))): 114 | assert payload is None 115 | 116 | @app.get("/scope-any/") 117 | async def secure_scope_any( 118 | payload=Depends(t.protect_auth.scope(scope, op=Operator._any)) 119 | ) -> bool: 120 | pass 121 | 122 | return app 123 | 124 | 125 | class AccessTokenTestCase(BaseTestCase): 126 | verify_access_token = True 127 | 128 | @classmethod 129 | def success_case(self, path: str, token: str = "") -> Response: 130 | return assert_get_response( 131 | client=self.client, endpoint=path, token=token, status_code=200 132 | ) 133 | 134 | def userinfo_success_case(self, path: str, token: str = "") -> Response: 135 | response = self.success_case(path, token) 136 | for value in response.json().values(): 137 | assert value, f"{response.content} failed to parse" 138 | return response 139 | 140 | def failure_case( 141 | self, path: str, token: str = "", detail: str = "", status=401 142 | ) -> Response: 143 | return assert_get_response( 144 | client=self.client, 145 | endpoint=path, 146 | token=token, 147 | status_code=status, 148 | detail=detail, 149 | ) 150 | 151 | def test_valid_token(self): 152 | self.success_case("/", self.ACCESS_TOKEN) 153 | 154 | def test_no_token(self): 155 | self.failure_case("/") 156 | # not auto_error 157 | self.success_case("no-error") 158 | 159 | def test_malformed_token(self): 160 | # given malformed token 161 | self.failure_case("/", "invaid-format-token", detail=NOT_AUTHENTICATED) 162 | # not auto_error 163 | self.success_case("no-error", "invaid-format-token") 164 | 165 | def test_incompatible_kid_token(self): 166 | # manipulate header 167 | token = self.ACCESS_TOKEN.split(".", 1)[-1] 168 | token = ( 169 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjIzMDQ5ODE1MWMyMTRiNzg4ZGQ5N2YyMmI4NTQxMGE1In0." 170 | + token 171 | ) 172 | self.failure_case("/", token, detail=NO_PUBLICKEY) 173 | # not auto_error 174 | self.success_case("no-error", token) 175 | 176 | def test_no_kid_token(self): 177 | # manipulate header 178 | token = self.ACCESS_TOKEN.split(".", 1)[-1] 179 | token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + token 180 | self.failure_case("/", token, detail=NOT_AUTHENTICATED) 181 | # not auto_error 182 | self.success_case("no-error", token) 183 | 184 | def test_not_verified_token(self): 185 | # manipulate public_key 186 | token = self.ACCESS_TOKEN[:-3] + "aaa" 187 | self.failure_case("/", token, detail=NOT_VERIFIED) 188 | # not auto_error 189 | self.success_case("no-error", token) 190 | 191 | def test_valid_scope(self): 192 | self.success_case("/scope/", self.SCOPE_ACCESS_TOKEN) 193 | 194 | def test_valid_scope_any(self): 195 | # access token must include a part of scopes in SCOPE_ACCESS_TOKEN 196 | self.success_case("/scope-any/", self.ACCESS_TOKEN) 197 | 198 | def test_invalid_scope(self): 199 | self.failure_case( 200 | "/scope/", self.ACCESS_TOKEN, detail=SCOPE_NOT_MATCHED, status=403 201 | ) 202 | self.success_case("/scope/no-error/", self.ACCESS_TOKEN) 203 | 204 | def test_malformed_token_for_scope(self): 205 | # given malformed token 206 | self.failure_case("/scope/", "invaid-format-token", detail=NOT_AUTHENTICATED) 207 | # not auto_error 208 | self.success_case("/scope/no-error", "invaid-format-token") 209 | 210 | def test_valid_token_extraction(self): 211 | self.userinfo_success_case("/access/user", self.ACCESS_TOKEN) 212 | 213 | def test_no_token_extraction(self): 214 | self.failure_case("/access/user") 215 | # not auto_error 216 | self.success_case("/access/user/no-error") 217 | 218 | def test_insufficient_user_info_from_access_token(self): 219 | # verified but token does not contains user info 220 | self.failure_case( 221 | "/access/user/invalid/", self.ACCESS_TOKEN, detail=NOT_VALIDATED_CLAIMS 222 | ) 223 | # not auto_error 224 | self.success_case("/access/user/invalid/no-error", self.ACCESS_TOKEN) 225 | 226 | 227 | def add_endpoint_for_idtoken(app: FastAPI, auth: Base) -> FastAPI: 228 | t = auth.TESTAUTH 229 | 230 | @app.get("/user/", response_model=t.valid_claim) 231 | async def secure_user(current_user: t.valid_claim = Depends(t.ms_auth)): 232 | return current_user 233 | 234 | @app.get("/user/no-error/") 235 | async def secure_user_no_error( 236 | current_user: Optional[t.valid_claim] = Depends(t.ms_auth_ne), 237 | ): 238 | assert current_user is None 239 | 240 | @app.get("/user/invalid/", response_model=t.invalid_claim) 241 | async def invalid_userinfo( 242 | current_user: t.invalid_claim = Depends(t.invalid_ms_auth), 243 | ): 244 | return current_user # pragma: no cover 245 | 246 | @app.get("/user/invalid/no-error/") 247 | async def invalid_userinfo_no_error( 248 | current_user: Optional[t.invalid_claim] = Depends(t.invalid_ms_auth_ne), 249 | ): 250 | assert current_user is None 251 | 252 | return app 253 | 254 | 255 | class IdTokenTestCase(BaseTestCase): 256 | verify_id_token = True 257 | 258 | def success_case(self, path: str, token: str = "") -> Response: 259 | return assert_get_response( 260 | client=self.client, endpoint=path, token=token, status_code=200 261 | ) 262 | 263 | def user_success_case(self, path: str, token: str = "") -> Response: 264 | response = self.success_case(path, token) 265 | for value in response.json().values(): 266 | assert value, f"{response.content} failed to parse" 267 | return response 268 | 269 | def failure_case( 270 | self, path: str, token: str = "", detail: str = "", status=401 271 | ) -> Response: 272 | return assert_get_response( 273 | client=self.client, 274 | endpoint=path, 275 | token=token, 276 | status_code=status, 277 | detail=detail, 278 | ) 279 | 280 | def test_valid_id_token(self): 281 | self.user_success_case("/user/", self.ID_TOKEN) 282 | 283 | def test_no_id_token(self): 284 | # handle in fastapi.security.HTTPBearer 285 | self.failure_case("/user/") 286 | # not auto_error 287 | self.success_case("/user/no-error") 288 | 289 | def test_malformed_token_for_scope(self): 290 | # given malformed token 291 | self.failure_case("/user/", "invaid-format-token", detail=NOT_AUTHENTICATED) 292 | # not auto_error 293 | self.success_case("/user/no-error", "invaid-format-token") 294 | 295 | def test_incompatible_kid_id_token(self): 296 | # manipulate header 297 | token = self.ID_TOKEN.split(".", 1)[-1] 298 | token = ( 299 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjIzMDQ5ODE1MWMyMTRiNzg4ZGQ5N2YyMmI4NTQxMGE1In0." 300 | + token 301 | ) 302 | self.failure_case("/user/", token, detail=NO_PUBLICKEY) 303 | # not auto_error 304 | self.success_case("/user/no-error/", token) 305 | 306 | def test_no_kid_id_token(self): 307 | # manipulate header 308 | token = self.ID_TOKEN.split(".", 1)[-1] 309 | token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + token 310 | self.failure_case("/user/", token, detail=NOT_AUTHENTICATED) 311 | # not auto_error 312 | self.success_case("/user/no-error", token) 313 | 314 | def test_not_verified_id_token(self): 315 | # manipulate public_key 316 | token = f"{self.ID_TOKEN}"[:-3] + "aaa" 317 | self.failure_case("/user/", token, detail=NOT_VERIFIED) 318 | # not auto_error 319 | self.success_case("/user/no-error", token) 320 | 321 | def test_insufficient_current_user_info(self): 322 | # verified but token does not contains user info 323 | self.failure_case("/user/invalid/", self.ID_TOKEN, detail=NOT_VALIDATED_CLAIMS) 324 | # not auto_error 325 | self.success_case("/user/invalid/no-error", self.ID_TOKEN) 326 | 327 | 328 | @pytest.mark.auth0 329 | class TestAuth0(AccessTokenTestCase, IdTokenTestCase): 330 | cloud_auth = Auth0Client 331 | 332 | 333 | @pytest.mark.cognito 334 | class TestCognito(AccessTokenTestCase, IdTokenTestCase): 335 | cloud_auth = CognitoClient 336 | 337 | 338 | @pytest.mark.firebase 339 | class TestFirebase(IdTokenTestCase): 340 | cloud_auth = FirebaseClient 341 | -------------------------------------------------------------------------------- /tests/test_cognito.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime, timedelta 3 | from sys import version_info as info 4 | from typing import Iterable, List, Optional 5 | 6 | import boto3 7 | import pytest 8 | from botocore.exceptions import ClientError 9 | from fastapi.security.http import HTTPAuthorizationCredentials 10 | from jose import jwt 11 | from starlette.status import HTTP_401_UNAUTHORIZED 12 | 13 | from fastapi_cloudauth import Cognito, CognitoCurrentUser 14 | from fastapi_cloudauth.cognito import CognitoClaims 15 | from fastapi_cloudauth.messages import NOT_VERIFIED 16 | from tests.helpers import ( 17 | Auths, 18 | BaseTestCloudAuth, 19 | _assert_verifier, 20 | _assert_verifier_no_error, 21 | decode_token, 22 | ) 23 | 24 | REGION = os.getenv("COGNITO_REGION") 25 | USERPOOLID = os.getenv("COGNITO_USERPOOLID") 26 | CLIENTID = os.getenv("COGNITO_APP_CLIENT_ID") 27 | 28 | 29 | def assert_env(): 30 | assert REGION, "'COGNITO_REGION' is not defined. Set environment variables" 31 | assert USERPOOLID, "'COGNITO_USERPOOLID' is not defined. Set environment variables" 32 | assert CLIENTID, "'COGNITO_APP_CLIENT_ID' is not defined. Set environment variables" 33 | assert CLIENTID, "'COGNITO_APP_CLIENT_ID' is not defined. Set environment variables" 34 | assert os.getenv( 35 | "AWS_ACCESS_KEY_ID" 36 | ), "'AWS_ACCESS_KEY_ID' is not defined. Set environment variables" 37 | assert os.getenv( 38 | "AWS_SECRET_ACCESS_KEY" 39 | ), "'AWS_SECRET_ACCESS_KEY' is not defined. Set environment variables" 40 | 41 | 42 | def initialize(): 43 | client = boto3.client("cognito-idp", region_name=REGION) 44 | return client 45 | 46 | 47 | def add_test_user( 48 | client, 49 | username=f"test_user{info.major}{info.minor}@example.com", 50 | password="testPass1-", 51 | scopes: Optional[List[str]] = None, 52 | ): 53 | client.sign_up( 54 | ClientId=CLIENTID, 55 | Username=username, 56 | Password=password, 57 | UserAttributes=[{"Name": "email", "Value": username}], 58 | ) 59 | client.admin_confirm_sign_up(UserPoolId=USERPOOLID, Username=username) 60 | if scopes: 61 | for scope in scopes: 62 | try: 63 | client.create_group(GroupName=scope, UserPoolId=USERPOOLID) 64 | except ClientError: # pragma: no cover 65 | pass # pragma: no cover 66 | client.admin_add_user_to_group( 67 | UserPoolId=USERPOOLID, 68 | Username=username, 69 | GroupName=scope, 70 | ) 71 | 72 | 73 | def get_cognito_token( 74 | client, 75 | username=f"test_user{info.major}{info.minor}@example.com", 76 | password="testPass1-", 77 | ): 78 | resp = client.admin_initiate_auth( 79 | UserPoolId=USERPOOLID, 80 | ClientId=CLIENTID, 81 | AuthFlow="ADMIN_USER_PASSWORD_AUTH", 82 | AuthParameters={"USERNAME": username, "PASSWORD": password}, 83 | ) 84 | access_token = resp["AuthenticationResult"]["AccessToken"] 85 | id_token = resp["AuthenticationResult"]["IdToken"] 86 | return access_token, id_token 87 | 88 | 89 | def delete_cognito_user( 90 | client, 91 | username=f"test_user{info.major}{info.minor}@example.com", 92 | ): 93 | try: 94 | client.admin_delete_user(UserPoolId=USERPOOLID, Username=username) 95 | except Exception: # pragma: no cover 96 | pass # pragma: no cover 97 | 98 | 99 | class CognitoClient(BaseTestCloudAuth): 100 | scope_user = f"test_scope{info.major}{info.minor}@example.com" 101 | user = f"test_user{info.major}{info.minor}@example.com" 102 | password = "testPass1-" 103 | 104 | def setup(self, scope: Iterable[str]) -> None: 105 | assert_env() 106 | 107 | self.scope = scope 108 | region = REGION 109 | userPoolId = USERPOOLID 110 | 111 | class CognitoInvalidClaims(CognitoClaims): 112 | fake_field: str 113 | 114 | class CognitoFakeCurrentUser(CognitoCurrentUser): 115 | user_info = CognitoInvalidClaims 116 | 117 | self.TESTAUTH = Auths( 118 | protect_auth=Cognito( 119 | region=region, userPoolId=userPoolId, client_id=CLIENTID 120 | ), 121 | protect_auth_ne=Cognito( 122 | region=region, 123 | userPoolId=userPoolId, 124 | client_id=CLIENTID, 125 | auto_error=False, 126 | ), 127 | ms_auth=CognitoCurrentUser( 128 | region=region, userPoolId=userPoolId, client_id=CLIENTID 129 | ), 130 | ms_auth_ne=CognitoCurrentUser( 131 | region=region, 132 | userPoolId=userPoolId, 133 | client_id=CLIENTID, 134 | auto_error=False, 135 | ), 136 | invalid_ms_auth=CognitoFakeCurrentUser( 137 | region=region, userPoolId=userPoolId, client_id=CLIENTID 138 | ), 139 | invalid_ms_auth_ne=CognitoFakeCurrentUser( 140 | region=region, 141 | userPoolId=userPoolId, 142 | client_id=CLIENTID, 143 | auto_error=False, 144 | ), 145 | valid_claim=CognitoClaims, 146 | invalid_claim=CognitoInvalidClaims, 147 | ) 148 | 149 | self.client = initialize() 150 | 151 | delete_cognito_user(self.client, self.user) 152 | add_test_user(self.client, self.user, self.password, scopes=[self.scope[0]]) 153 | self.ACCESS_TOKEN, self.ID_TOKEN = get_cognito_token( 154 | self.client, self.user, self.password 155 | ) 156 | 157 | delete_cognito_user(self.client, self.scope_user) 158 | add_test_user(self.client, self.scope_user, self.password, scopes=self.scope) 159 | self.SCOPE_ACCESS_TOKEN, self.SCOPE_ID_TOKEN = get_cognito_token( 160 | self.client, self.scope_user, self.password 161 | ) 162 | 163 | def teardown(self): 164 | delete_cognito_user(self.client, self.user) 165 | delete_cognito_user(self.client, self.scope_user) 166 | 167 | def decode(self): 168 | # access token 169 | header, payload, *_ = decode_token(self.ACCESS_TOKEN) 170 | assert [self.scope[0]] == payload.get("cognito:groups") 171 | 172 | # scope token 173 | scope_header, scope_payload, *_ = decode_token(self.SCOPE_ACCESS_TOKEN) 174 | assert set(self.scope) == set(scope_payload.get("cognito:groups")) 175 | 176 | # id token 177 | id_header, id_payload, *_ = decode_token(self.ID_TOKEN) 178 | assert id_payload.get("email") == self.user 179 | 180 | 181 | @pytest.mark.unittest 182 | def test_extra_verify_access_token(): 183 | """ 184 | Testing for access token validation: 185 | - validate standard claims: 186 | - exp: Token expiration 187 | - aud: audience should match the app client ID 188 | - iss: Token issuer should match your user pool 189 | - token_use: should match `id` 190 | Ref: https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-step-3 191 | """ 192 | region = REGION 193 | userPoolId = USERPOOLID 194 | client_id = "dummyclientid" 195 | auth = Cognito(region=region, userPoolId=userPoolId, client_id=client_id) 196 | verifier = auth._verifier 197 | auth_no_error = Cognito( 198 | region=region, userPoolId=userPoolId, client_id=client_id, auto_error=False 199 | ) 200 | verifier_no_error = auth_no_error._verifier 201 | 202 | # correct 203 | token = jwt.encode( 204 | { 205 | "sub": "dummy-ID", 206 | "exp": datetime.utcnow() + timedelta(hours=10), 207 | "iat": datetime.utcnow() - timedelta(hours=10), 208 | "aud": client_id, 209 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 210 | "token_use": "access", 211 | }, 212 | "dummy_secret", 213 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 214 | ) 215 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 216 | verifier_no_error._verify_claims( 217 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 218 | ) 219 | 220 | # invalid exp 221 | token = jwt.encode( 222 | { 223 | "sub": "dummy-ID", 224 | "exp": datetime.utcnow() - timedelta(hours=5), 225 | "iat": datetime.utcnow() - timedelta(hours=10), 226 | "aud": client_id, 227 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 228 | "token_use": "access", 229 | }, 230 | "dummy_secret", 231 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 232 | ) 233 | e = _assert_verifier(token, verifier) 234 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 235 | _assert_verifier_no_error(token, verifier_no_error) 236 | 237 | # invalid aud 238 | token = jwt.encode( 239 | { 240 | "sub": "dummy-ID", 241 | "exp": datetime.utcnow() + timedelta(hours=10), 242 | "iat": datetime.utcnow() - timedelta(hours=10), 243 | "aud": client_id + "incorrect", 244 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 245 | "token_use": "access", 246 | }, 247 | "dummy_secret", 248 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 249 | ) 250 | e = _assert_verifier(token, verifier) 251 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 252 | _assert_verifier_no_error(token, verifier_no_error) 253 | 254 | # invalid iss 255 | token = jwt.encode( 256 | { 257 | "sub": "dummy-ID", 258 | "exp": datetime.utcnow() + timedelta(hours=10), 259 | "iat": datetime.utcnow() - timedelta(hours=10), 260 | "aud": client_id, 261 | "iss": "invalid" 262 | + f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}-invalid", 263 | "token_use": "access", 264 | }, 265 | "dummy_secret", 266 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 267 | ) 268 | e = _assert_verifier(token, verifier) 269 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 270 | _assert_verifier_no_error(token, verifier_no_error) 271 | 272 | # invalid token-use 273 | token = jwt.encode( 274 | { 275 | "sub": "dummy-ID", 276 | "exp": datetime.utcnow() + timedelta(hours=10), 277 | "iat": datetime.utcnow() - timedelta(hours=10), 278 | "aud": client_id, 279 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}-invalid", 280 | "token_use": "id", 281 | }, 282 | "dummy_secret", 283 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 284 | ) 285 | e = _assert_verifier(token, verifier) 286 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 287 | _assert_verifier_no_error(token, verifier_no_error) 288 | 289 | 290 | @pytest.mark.unittest 291 | def test_extra_verify_id_token(): 292 | """ 293 | Testing for ID token validation: 294 | - validate standard claims: 295 | - exp: Token expiration 296 | - aud: audience should match the app client ID 297 | - iss: Token issuer should match your user pool 298 | - token_use: should match `id` 299 | Ref: https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html#amazon-cognito-user-pools-using-tokens-step-3 300 | """ 301 | region = REGION 302 | userPoolId = USERPOOLID 303 | client_id = "dummyclientid" 304 | auth = CognitoCurrentUser(region=region, userPoolId=userPoolId, client_id=client_id) 305 | verifier = auth._verifier 306 | auth_no_error = CognitoCurrentUser( 307 | region=region, userPoolId=userPoolId, client_id=client_id, auto_error=False 308 | ) 309 | verifier_no_error = auth_no_error._verifier 310 | 311 | # correct 312 | token = jwt.encode( 313 | { 314 | "at_hash": "some-hash-that-isnt-checked", 315 | "sub": "dummy-ID", 316 | "exp": datetime.utcnow() + timedelta(hours=10), 317 | "iat": datetime.utcnow() - timedelta(hours=10), 318 | "aud": client_id, 319 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 320 | "token_use": "id", 321 | }, 322 | "dummy_secret", 323 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 324 | ) 325 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 326 | verifier_no_error._verify_claims( 327 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 328 | ) 329 | # invalid exp 330 | token = jwt.encode( 331 | { 332 | "sub": "dummy-ID", 333 | "exp": datetime.utcnow() - timedelta(hours=5), 334 | "iat": datetime.utcnow() - timedelta(hours=10), 335 | "aud": client_id, 336 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 337 | "token_use": "id", 338 | }, 339 | "dummy_secret", 340 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 341 | ) 342 | e = _assert_verifier(token, verifier) 343 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 344 | _assert_verifier_no_error(token, verifier_no_error) 345 | 346 | # invalid aud 347 | token = jwt.encode( 348 | { 349 | "sub": "dummy-ID", 350 | "exp": datetime.utcnow() + timedelta(hours=10), 351 | "iat": datetime.utcnow() - timedelta(hours=10), 352 | "aud": client_id + "incorrect", 353 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 354 | "token_use": "id", 355 | }, 356 | "dummy_secret", 357 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 358 | ) 359 | e = _assert_verifier(token, verifier) 360 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 361 | _assert_verifier_no_error(token, verifier_no_error) 362 | 363 | # invalid iss 364 | token = jwt.encode( 365 | { 366 | "sub": "dummy-ID", 367 | "exp": datetime.utcnow() + timedelta(hours=10), 368 | "iat": datetime.utcnow() - timedelta(hours=10), 369 | "aud": client_id, 370 | "iss": "invalid" 371 | + f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}-invalid", 372 | "token_use": "id", 373 | }, 374 | "dummy_secret", 375 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 376 | ) 377 | e = _assert_verifier(token, verifier) 378 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 379 | _assert_verifier_no_error(token, verifier_no_error) 380 | 381 | # invalid token-use 382 | token = jwt.encode( 383 | { 384 | "sub": "dummy-ID", 385 | "exp": datetime.utcnow() + timedelta(hours=10), 386 | "iat": datetime.utcnow() - timedelta(hours=10), 387 | "aud": client_id, 388 | "iss": f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}", 389 | "token_use": "access", 390 | }, 391 | "dummy_secret", 392 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 393 | ) 394 | e = _assert_verifier(token, verifier) 395 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 396 | _assert_verifier_no_error(token, verifier_no_error) 397 | -------------------------------------------------------------------------------- /tests/test_firebase.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import tempfile 5 | from calendar import timegm 6 | from datetime import datetime, timedelta 7 | from sys import version_info as info 8 | from typing import Iterable 9 | 10 | import firebase_admin 11 | import pytest 12 | import requests 13 | from fastapi.security.http import HTTPAuthorizationCredentials 14 | from firebase_admin import auth, credentials 15 | from jose import jwt 16 | from starlette.status import HTTP_401_UNAUTHORIZED 17 | 18 | from fastapi_cloudauth import FirebaseCurrentUser 19 | from fastapi_cloudauth.firebase import FirebaseClaims 20 | from fastapi_cloudauth.messages import NOT_VERIFIED 21 | from tests.helpers import ( 22 | Auths, 23 | BaseTestCloudAuth, 24 | _assert_verifier, 25 | _assert_verifier_no_error, 26 | decode_token, 27 | ) 28 | 29 | PROJECT_ID = os.getenv("FIREBASE_PROJECTID") 30 | API_KEY = os.getenv("FIREBASE_APIKEY") 31 | BASE64_CREDENTIAL = os.getenv("FIREBASE_BASE64_CREDENCIALS") 32 | _verify_password_url = ( 33 | "https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword" 34 | ) 35 | 36 | 37 | def assert_env(): 38 | assert API_KEY, "'FIREBASE_APIKEY' is not defined. Set environment variables" 39 | assert ( 40 | BASE64_CREDENTIAL 41 | ), "'FIREBASE_BASE64_CREDENCIALS' is not defined. Set environment variables" 42 | 43 | 44 | def initialize(): 45 | """set credentials (intermediate credential file is created)""" 46 | credentials_base64 = BASE64_CREDENTIAL 47 | credentials_str = base64.b64decode(credentials_base64) 48 | credentials_json = json.loads(credentials_str) 49 | 50 | tmpdir = tempfile.TemporaryDirectory() 51 | credentials_path = os.path.join(tmpdir.name, "sa.json") 52 | with open( 53 | credentials_path, 54 | "w", 55 | ) as f: 56 | json.dump(credentials_json, f) 57 | 58 | cred = credentials.Certificate(credentials_path) 59 | firebase_admin.initialize_app(cred) 60 | 61 | 62 | def add_test_user(email, password, uid): 63 | auth.create_user(email=email, password=password, uid=uid) 64 | 65 | 66 | def delete_test_user(uid): 67 | try: 68 | auth.delete_user(uid) 69 | except firebase_admin._auth_utils.UserNotFoundError: 70 | pass 71 | 72 | 73 | def get_tokens(email, password, uid): 74 | # get access token 75 | access_token_bytes = auth.create_custom_token(uid) 76 | ACCESS_TOKEN = access_token_bytes.decode("utf-8") 77 | 78 | # get ID token (sign-in with password using FIREBASE AUTH REST API) 79 | body = {"email": email, "password": password, "returnSecureToken": True} 80 | params = {"key": API_KEY} 81 | resp = requests.request("post", _verify_password_url, params=params, json=body) 82 | resp.raise_for_status() 83 | ID_TOKEN = resp.json().get("idToken") 84 | 85 | return ACCESS_TOKEN, ID_TOKEN 86 | 87 | 88 | def get_test_client(): 89 | class FirebaseInvalidClaims(FirebaseClaims): 90 | fake_field: str 91 | 92 | class FirebaseFakeCurrentUser(FirebaseCurrentUser): 93 | user_info = FirebaseInvalidClaims 94 | 95 | return Auths( 96 | protect_auth=None, 97 | protect_auth_ne=None, 98 | ms_auth=FirebaseCurrentUser(project_id=PROJECT_ID), 99 | ms_auth_ne=FirebaseCurrentUser(project_id=PROJECT_ID, auto_error=False), 100 | invalid_ms_auth=FirebaseFakeCurrentUser(project_id=PROJECT_ID), 101 | invalid_ms_auth_ne=FirebaseFakeCurrentUser( 102 | project_id=PROJECT_ID, auto_error=False 103 | ), 104 | valid_claim=FirebaseClaims, 105 | invalid_claim=FirebaseInvalidClaims, 106 | ) 107 | 108 | 109 | class FirebaseClient(BaseTestCloudAuth): 110 | def setup(self, scope: Iterable[str]) -> None: 111 | """set credentials and create test user""" 112 | assert_env() 113 | 114 | self.email = f"fastapi-cloudauth-user-py{info.major}{info.minor}@example.com" 115 | self.password = "secretPassword" 116 | self.uid = f"fastapi-cloudauth-test-uid-py{info.major}{info.minor}" 117 | 118 | initialize() 119 | 120 | delete_test_user(self.uid) 121 | 122 | # create test user 123 | add_test_user(self.email, self.password, self.uid) 124 | 125 | # get access token and id token 126 | self.ACCESS_TOKEN, self.ID_TOKEN = get_tokens( 127 | self.email, self.password, self.uid 128 | ) 129 | 130 | # set application for testing 131 | self.TESTAUTH = get_test_client() 132 | 133 | def teardown(self): 134 | """delete test user""" 135 | delete_test_user(self.uid) 136 | 137 | def decode(self): 138 | # access token 139 | header, payload, *_ = decode_token(self.ACCESS_TOKEN) 140 | assert header.get("typ") == "JWT" 141 | assert payload.get("uid") == self.uid 142 | 143 | # id token 144 | id_header, id_payload, *_ = decode_token(self.ID_TOKEN) 145 | assert id_header.get("typ") == "JWT" 146 | assert id_payload.get("email") == self.email 147 | assert id_payload.get("user_id") == self.uid 148 | 149 | 150 | @pytest.mark.unittest 151 | def test_extra_verify_token(): 152 | """ 153 | Testing for ID token validation: 154 | - validate standard claims: 155 | - exp: Token expiration 156 | - iat: 157 | - aud: audience is same as project ID 158 | - iss: Token issuer 159 | - sub: not null string or user id 160 | - auth_time: authorization time is the past 161 | Ref: https://firebase.google.com/docs/auth/admin/verify-id-tokens#verify_id_tokens_using_a_third-party_jwt_library 162 | """ 163 | pjt_id = "dummy" 164 | auth = FirebaseCurrentUser(pjt_id) 165 | verifier = auth._verifier 166 | auth_no_error = FirebaseCurrentUser(pjt_id, auto_error=False) 167 | verifier_no_error = auth_no_error._verifier 168 | 169 | # correct 170 | token = jwt.encode( 171 | { 172 | "sub": "dummy-ID", 173 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 174 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 175 | "auth_time": timegm( 176 | (datetime.utcnow() - timedelta(hours=10)).utctimetuple() 177 | ), 178 | "aud": pjt_id, 179 | "iss": f"https://securetoken.google.com/{pjt_id}", 180 | }, 181 | "dummy_secret", 182 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 183 | ) 184 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 185 | verifier_no_error._verify_claims( 186 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 187 | ) 188 | 189 | # invalid exp 190 | token = jwt.encode( 191 | { 192 | "sub": "dummy-ID", 193 | "exp": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 194 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 195 | "auth_time": timegm( 196 | (datetime.utcnow() - timedelta(hours=11)).utctimetuple() 197 | ), 198 | "aud": pjt_id, 199 | "iss": f"https://securetoken.google.com/{pjt_id}", 200 | }, 201 | "dummy_secret", 202 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 203 | ) 204 | e = _assert_verifier(token, verifier) 205 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 206 | _assert_verifier_no_error(token, verifier_no_error) 207 | 208 | # invalid iat 209 | token = jwt.encode( 210 | { 211 | "sub": "dummy-ID", 212 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 213 | "iat": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 214 | "auth_time": timegm( 215 | (datetime.utcnow() - timedelta(hours=11)).utctimetuple() 216 | ), 217 | "aud": pjt_id, 218 | "iss": f"https://securetoken.google.com/{pjt_id}", 219 | }, 220 | "dummy_secret", 221 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 222 | ) 223 | e = _assert_verifier(token, verifier) 224 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 225 | _assert_verifier_no_error(token, verifier_no_error) 226 | 227 | # invalid aud 228 | token = jwt.encode( 229 | { 230 | "sub": "dummy-ID", 231 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 232 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 233 | "auth_time": timegm( 234 | (datetime.utcnow() - timedelta(hours=11)).utctimetuple() 235 | ), 236 | "aud": pjt_id + "incorrect", 237 | "iss": f"https://securetoken.google.com/{pjt_id}", 238 | }, 239 | "dummy_secret", 240 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 241 | ) 242 | e = _assert_verifier(token, verifier) 243 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 244 | _assert_verifier_no_error(token, verifier_no_error) 245 | 246 | # invalid iss 247 | token = jwt.encode( 248 | { 249 | "sub": "dummy-ID", 250 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 251 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 252 | "auth_time": timegm( 253 | (datetime.utcnow() - timedelta(hours=11)).utctimetuple() 254 | ), 255 | "aud": pjt_id, 256 | "iss": f"https://securetoken.google.com/{pjt_id}-extra", 257 | }, 258 | "dummy_secret", 259 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 260 | ) 261 | e = _assert_verifier(token, verifier) 262 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 263 | _assert_verifier_no_error(token, verifier_no_error) 264 | 265 | # invalid sub 266 | token = jwt.encode( 267 | { 268 | "sub": "", 269 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 270 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 271 | "auth_time": timegm( 272 | (datetime.utcnow() - timedelta(hours=11)).utctimetuple() 273 | ), 274 | "aud": pjt_id, 275 | "iss": f"https://securetoken.google.com/{pjt_id}-extra", 276 | }, 277 | "dummy_secret", 278 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 279 | ) 280 | e = _assert_verifier(token, verifier) 281 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 282 | _assert_verifier_no_error(token, verifier_no_error) 283 | 284 | # invalid auth_time 285 | token = jwt.encode( 286 | { 287 | "sub": "dummy-ID", 288 | "exp": timegm((datetime.utcnow() + timedelta(hours=10)).utctimetuple()), 289 | "iat": timegm((datetime.utcnow() - timedelta(hours=10)).utctimetuple()), 290 | "auth_time": timegm( 291 | (datetime.utcnow() + timedelta(hours=3)).utctimetuple() 292 | ), 293 | "aud": pjt_id, 294 | "iss": f"https://securetoken.google.com/{pjt_id}", 295 | }, 296 | "dummy_secret", 297 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 298 | ) 299 | e = _assert_verifier(token, verifier) 300 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == NOT_VERIFIED 301 | _assert_verifier_no_error(token, verifier_no_error) 302 | -------------------------------------------------------------------------------- /tests/test_verification.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import datetime, timedelta 3 | from email.utils import format_datetime, parsedate_to_datetime 4 | from typing import Any, Dict, Optional 5 | 6 | import pytest 7 | from fastapi import HTTPException 8 | from fastapi.security import HTTPAuthorizationCredentials 9 | from jose import jwt 10 | from jose.backends.base import Key 11 | from requests.models import Response 12 | from starlette.status import HTTP_401_UNAUTHORIZED 13 | 14 | from fastapi_cloudauth import messages 15 | from fastapi_cloudauth.verification import ( 16 | JWKS, 17 | JWKsVerifier, 18 | Operator, 19 | ScopedJWKsVerifier, 20 | ) 21 | 22 | from .helpers import _assert_verifier, _assert_verifier_no_error 23 | 24 | 25 | @pytest.mark.unittest 26 | @pytest.mark.asyncio 27 | async def test_malformed_token_handling(): 28 | http_auth_with_malformed_token = HTTPAuthorizationCredentials( 29 | scheme="a", 30 | credentials="malformed-token", 31 | ) 32 | 33 | verifier = JWKsVerifier(jwks=JWKS.null()) 34 | with pytest.raises(HTTPException): 35 | await verifier._get_publickey(http_auth_with_malformed_token) 36 | with pytest.raises(HTTPException): 37 | await verifier.verify_token(http_auth_with_malformed_token) 38 | 39 | verifier = JWKsVerifier(jwks=JWKS.null(), auto_error=False) 40 | assert not await verifier._get_publickey(http_auth_with_malformed_token) 41 | assert not await verifier.verify_token(http_auth_with_malformed_token) 42 | 43 | verifier = ScopedJWKsVerifier(jwks=JWKS.null()) 44 | with pytest.raises(HTTPException): 45 | verifier._verify_scope(http_auth_with_malformed_token) 46 | with pytest.raises(HTTPException): 47 | await verifier.verify_token(http_auth_with_malformed_token) 48 | 49 | verifier = ScopedJWKsVerifier(jwks=JWKS.null(), auto_error=False) 50 | assert not verifier._verify_scope(http_auth_with_malformed_token) 51 | assert not await verifier.verify_token(http_auth_with_malformed_token) 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_jwks_test_mode(): 56 | # instantiate null jwks obj (no querying jwks) 57 | _jwks = JWKS.null() 58 | 59 | # instantiate fixed jwks obj (no querying jwks) 60 | dummy = Key(None, None) 61 | _jwks = JWKS(fixed_keys={"test": dummy}) 62 | assert await _jwks.get_publickey("test") == dummy 63 | 64 | 65 | class DummyResp(Response): 66 | def __init__(self, expires: datetime) -> None: 67 | super().__init__() 68 | self.headers["Expires"] = format_datetime(expires) 69 | 70 | @property 71 | def json(self): 72 | return lambda: {} 73 | 74 | 75 | class DummyDecodeJWKS(JWKS): 76 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 77 | return {} 78 | 79 | def _set_expiration(self, resp: Response) -> Optional[datetime]: 80 | expires_header = resp.headers.get("expires") 81 | return parsedate_to_datetime(expires_header) 82 | 83 | 84 | def parse(t: datetime) -> datetime: 85 | return parsedate_to_datetime(format_datetime(t)) 86 | 87 | 88 | @pytest.mark.unittest 89 | @pytest.mark.asyncio 90 | async def test_refresh_jwks(mocker): 91 | too_short_exp = datetime.now() 92 | mocker.patch( 93 | "requests.get", 94 | return_value=DummyResp(too_short_exp), 95 | ) 96 | _jwks = DummyDecodeJWKS(url="http://") 97 | 98 | # expires is stored 99 | assert _jwks.expires == parse(too_short_exp) 100 | 101 | # time goes... 102 | new_exp = too_short_exp + timedelta(days=10) 103 | mocker.patch( 104 | "requests.get", 105 | return_value=DummyResp(new_exp), 106 | ) 107 | await _jwks.get_publickey("") 108 | # expired is refreshed 109 | assert parse(too_short_exp) != parse(new_exp) 110 | assert _jwks.expires == parse(new_exp) 111 | 112 | 113 | class DummyDecodeCntJWKS(JWKS): 114 | def __init__(self, url: str = ""): 115 | self._counter = 0 116 | super().__init__(url=url) 117 | 118 | def _construct(self, jwks: Dict[str, Any]) -> Dict[str, Key]: 119 | asyncio.sleep(0.5) 120 | self._counter += 1 121 | return {"cnt": self._counter} 122 | 123 | def _set_expiration(self, resp: Response) -> Optional[datetime]: 124 | expires_header = resp.headers.get("expires") 125 | return parsedate_to_datetime(expires_header) 126 | 127 | 128 | @pytest.mark.unittest 129 | @pytest.mark.asyncio 130 | async def test_refresh_jwks_multiple(mocker): 131 | too_short_exp = datetime.now() 132 | mocker.patch( 133 | "requests.get", 134 | return_value=DummyResp(too_short_exp), 135 | ) 136 | _jwks = DummyDecodeCntJWKS(url="http://") 137 | 138 | # time goes... 139 | new_exp = too_short_exp + timedelta(days=10) 140 | mocker.patch( 141 | "requests.get", 142 | return_value=DummyResp(new_exp), 143 | ) 144 | # multiple expired access 145 | res = await asyncio.gather( 146 | _jwks.get_publickey("cnt"), 147 | _jwks.get_publickey("cnt"), 148 | _jwks.get_publickey("cnt"), 149 | ) 150 | # jwks was refreshed only at once (counter incremented once). 151 | # all three return publickey from refreshed jwks. 152 | assert list(res) == [2, 2, 2] 153 | 154 | 155 | @pytest.mark.unittest 156 | def test_verify_scope_exeption(mocker): 157 | mocker.patch( 158 | "fastapi_cloudauth.verification.jwt.get_unverified_claims", 159 | return_value={"dummy key": "read:test"}, 160 | ) 161 | scope_key = "dummy key" 162 | http_auth = HTTPAuthorizationCredentials( 163 | scheme="a", 164 | credentials="dummy-token", 165 | ) 166 | 167 | # trivial scope 168 | verifier = ScopedJWKsVerifier( 169 | jwks=JWKS.null(), scope_key=scope_key, scope_name=None 170 | ) 171 | assert verifier._verify_scope(http_auth) 172 | 173 | # invalid incoming scope format 174 | mocker.patch( 175 | "fastapi_cloudauth.verification.jwt.get_unverified_claims", 176 | return_value={"dummy key": 100}, 177 | ) 178 | verifier = ScopedJWKsVerifier( 179 | jwks=JWKS.null(), scope_key=scope_key, scope_name=["read:test"] 180 | ) 181 | with pytest.raises(HTTPException): 182 | verifier._verify_scope(http_auth) 183 | # auto_error is False 184 | verifier = ScopedJWKsVerifier( 185 | jwks=JWKS.null(), 186 | scope_key=scope_key, 187 | scope_name=["read:test"], 188 | auto_error=False, 189 | ) 190 | assert not verifier._verify_scope(http_auth) 191 | 192 | 193 | @pytest.mark.unittest 194 | @pytest.mark.parametrize( 195 | "scopes", 196 | ["xxx:xxx yyy:yyy", ["xxx:xxx", "yyy:yyy"]], 197 | ) 198 | def test_scope_match_all(mocker, scopes): 199 | scope_key = "dummy key" 200 | http_auth = HTTPAuthorizationCredentials( 201 | scheme="a", 202 | credentials="dummy-token", 203 | ) 204 | 205 | # check scope logic 206 | mocker.patch( 207 | "fastapi_cloudauth.verification.jwt.get_unverified_claims", 208 | return_value={"dummy key": scopes}, 209 | ) 210 | jwks = JWKS.null() 211 | 212 | # api scope < user scope 213 | verifier = ScopedJWKsVerifier( 214 | scope_name=["xxx:xxx"], 215 | jwks=jwks, 216 | scope_key=scope_key, 217 | auto_error=False, 218 | ) 219 | assert verifier._verify_scope(http_auth) 220 | 221 | # api scope == user scope (in order) 222 | verifier = ScopedJWKsVerifier( 223 | scope_name=["xxx:xxx", "yyy:yyy"], 224 | jwks=jwks, 225 | scope_key=scope_key, 226 | auto_error=False, 227 | ) 228 | assert verifier._verify_scope(http_auth) 229 | 230 | # api scope == user scope (disorder) 231 | verifier = ScopedJWKsVerifier( 232 | scope_name=["yyy:yyy", "xxx:xxx"], 233 | jwks=jwks, 234 | scope_key=scope_key, 235 | auto_error=False, 236 | ) 237 | assert verifier._verify_scope(http_auth) 238 | 239 | # api scope > user scope 240 | verifier = ScopedJWKsVerifier( 241 | scope_name=["yyy:yyy", "xxx:xxx", "zzz:zzz"], 242 | jwks=jwks, 243 | scope_key=scope_key, 244 | auto_error=False, 245 | ) 246 | assert not verifier._verify_scope(http_auth) 247 | 248 | 249 | @pytest.mark.unittest 250 | @pytest.mark.parametrize( 251 | "scopes", 252 | ["xxx:xxx yyy:yyy", ["xxx:xxx", "yyy:yyy"]], 253 | ) 254 | def test_scope_match_any(mocker, scopes): 255 | scope_key = "dummy key" 256 | http_auth = HTTPAuthorizationCredentials( 257 | scheme="a", 258 | credentials="dummy-token", 259 | ) 260 | 261 | # check scope logic 262 | mocker.patch( 263 | "fastapi_cloudauth.verification.jwt.get_unverified_claims", 264 | return_value={"dummy key": scopes}, 265 | ) 266 | jwks = JWKS.null() 267 | 268 | # api scope < user scope 269 | verifier = ScopedJWKsVerifier( 270 | scope_name=["xxx:xxx"], 271 | jwks=jwks, 272 | scope_key=scope_key, 273 | auto_error=False, 274 | op=Operator._any, 275 | ) 276 | assert verifier._verify_scope(http_auth) 277 | 278 | # api scope == user scope (in order) 279 | verifier = ScopedJWKsVerifier( 280 | scope_name=["xxx:xxx", "yyy:yyy"], 281 | op=Operator._any, 282 | jwks=jwks, 283 | scope_key=scope_key, 284 | auto_error=False, 285 | ) 286 | assert verifier._verify_scope(http_auth) 287 | 288 | # api scope == user scope (disorder) 289 | verifier = ScopedJWKsVerifier( 290 | scope_name=["yyy:yyy", "xxx:xxx"], 291 | op=Operator._any, 292 | jwks=jwks, 293 | scope_key=scope_key, 294 | auto_error=False, 295 | ) 296 | assert verifier._verify_scope(http_auth) 297 | 298 | # api scope > user scope 299 | verifier = ScopedJWKsVerifier( 300 | scope_name=["yyy:yyy", "xxx:xxx", "zzz:zzz"], 301 | op=Operator._any, 302 | jwks=jwks, 303 | scope_key=scope_key, 304 | auto_error=False, 305 | ) 306 | assert verifier._verify_scope(http_auth) 307 | 308 | # api scope ^ user scope 309 | verifier = ScopedJWKsVerifier( 310 | scope_name=["zzz:zzz"], 311 | op=Operator._any, 312 | jwks=jwks, 313 | scope_key=scope_key, 314 | auto_error=False, 315 | ) 316 | assert not verifier._verify_scope(http_auth) 317 | 318 | 319 | @pytest.mark.unittest 320 | def test_verify_token(): 321 | verifier = JWKsVerifier(jwks=JWKS.null()) 322 | verifier_no_error = JWKsVerifier(jwks=JWKS.null(), auto_error=False) 323 | 324 | # correct 325 | token = jwt.encode( 326 | { 327 | "sub": "dummy-ID", 328 | "exp": datetime.utcnow() + timedelta(hours=10), 329 | "iat": datetime.utcnow(), 330 | }, 331 | "dummy_secret", 332 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 333 | ) 334 | verifier._verify_claims(HTTPAuthorizationCredentials(scheme="a", credentials=token)) 335 | verifier_no_error._verify_claims( 336 | HTTPAuthorizationCredentials(scheme="a", credentials=token) 337 | ) 338 | 339 | # token expired 340 | token = jwt.encode( 341 | { 342 | "sub": "dummy-ID", 343 | "exp": datetime.utcnow() - timedelta(hours=10), # 10h before 344 | "iat": datetime.utcnow(), 345 | }, 346 | "dummy_secret", 347 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 348 | ) 349 | e = _assert_verifier(token, verifier) 350 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == messages.NOT_VERIFIED 351 | _assert_verifier_no_error(token, verifier_no_error) 352 | 353 | # token created at future 354 | token = jwt.encode( 355 | { 356 | "sub": "dummy-ID", 357 | "exp": datetime.utcnow() + timedelta(hours=10), 358 | "iat": datetime.utcnow() + timedelta(hours=10), 359 | }, 360 | "dummy_secret", 361 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 362 | ) 363 | e = _assert_verifier(token, verifier) 364 | assert e.status_code == HTTP_401_UNAUTHORIZED and e.detail == messages.NOT_VERIFIED 365 | _assert_verifier_no_error(token, verifier_no_error) 366 | 367 | # invalid format 368 | token = jwt.encode( 369 | { 370 | "sub": "dummy-ID", 371 | "exp": datetime.utcnow() + timedelta(hours=10), 372 | "iat": datetime.utcnow(), 373 | }, 374 | "dummy_secret", 375 | headers={"alg": "HS256", "typ": "JWT", "kid": "dummy-kid"}, 376 | ) 377 | token = token.split(".")[0] 378 | e = _assert_verifier(token, verifier) 379 | assert ( 380 | e.status_code == HTTP_401_UNAUTHORIZED 381 | and e.detail == messages.NOT_AUTHENTICATED 382 | ) 383 | _assert_verifier_no_error(token, verifier_no_error) 384 | --------------------------------------------------------------------------------