├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── .swift-format
├── .travis.yml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── Package.swift
├── README.md
├── Sources
├── .swiftlint.yml
└── VaporOAuth
│ ├── DefaultImplementations
│ ├── EmptyAuthorizationHandler.swift
│ ├── EmptyCodeManager.swift
│ ├── EmptyResourceServerRetriever.swift
│ ├── EmptyUserManager.swift
│ └── StaticClientRetriever.swift
│ ├── Helper
│ ├── OAuthHelper+local.swift
│ ├── OAuthHelper+remote.swift
│ └── OAuthHelper.swift
│ ├── Middleware
│ ├── OAuth2ScopeMiddleware.swift
│ ├── OAuth2TokenIntrospectionMiddleware.swift
│ └── TokenIntrospectionAuthenticationMiddleware.swift
│ ├── Models
│ ├── OAuthClient.swift
│ ├── OAuthCode.swift
│ ├── OAuthResourceServer.swift
│ ├── OAuthUser.swift
│ └── Tokens
│ │ ├── AccessToken.swift
│ │ └── RefreshToken.swift
│ ├── OAuth2.swift
│ ├── Protocols
│ ├── AuthorizeHandler.swift
│ ├── ClientRetriever.swift
│ ├── CodeManager.swift
│ ├── ResourceServerRetriever.swift
│ ├── TokenManager.swift
│ └── UserManager.swift
│ ├── RouteHandlers
│ ├── AuthorizeGetHandler.swift
│ ├── AuthorizePostHandler.swift
│ ├── TokenHandler.swift
│ ├── TokenHandlers
│ │ ├── AuthCodeTokenHandler.swift
│ │ ├── ClientCredentialsTokenHandler.swift
│ │ ├── PasswordTokenHandler.swift
│ │ ├── RefreshTokenHandler.swift
│ │ └── TokenResponseGenerator.swift
│ └── TokenIntrospectionHandler.swift
│ ├── Utilities
│ ├── OAuthFlowType.swift
│ ├── StringDefines.swift
│ └── TokenAuthenticator.swift
│ └── Validators
│ ├── ClientValidator.swift
│ ├── CodeValidator.swift
│ ├── ResourceServerAuthenticator.swift
│ └── ScopeValidator.swift
├── Tests
└── VaporOAuthTests
│ ├── Application+testable.swift
│ ├── AuthorizationTests
│ ├── AuthorizationRequestTests.swift
│ └── AuthorizationResponseTests.swift
│ ├── DefaultImplementationTests
│ └── DefaultImplementationTests.swift
│ ├── Fakes
│ ├── AccessToken.swift
│ ├── CapturingAuthorizeHandler.swift
│ ├── CapturingLogger.swift
│ ├── FakeAuthenticationMiddleware.swift
│ ├── FakeClientGetter.swift
│ ├── FakeCodeManager.swift
│ ├── FakeResourceServerRetriever.swift
│ ├── FakeSessions.swift
│ ├── FakeTokenManager.swift
│ ├── FakeUserManager.swift
│ ├── RefreshToken.swift
│ ├── StubCodeManager.swift
│ ├── StubTokenManager.swift
│ └── StubUserManager.swift
│ ├── GrantTests
│ ├── AuthorizationCodeTokenTests.swift
│ ├── ClientCredentialsTokenTests.swift
│ ├── ImplicitGrantTests.swift
│ ├── PasswordGrantTokenTests.swift
│ └── TokenRefreshTests.swift
│ ├── Helpers
│ ├── HTTPHeaders+location.swift
│ ├── Responses.swift
│ └── TestDataBuilder.swift
│ ├── IntegrationTests
│ └── AuthCodeResourceServerTests.swift
│ └── TokenIntrospectionTests
│ └── TokenIntrospectionTests.swift
├── codecov.yml
└── docker-test.sh
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Vapor OAuth
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | tags:
7 | - '*'
8 | pull_request:
9 | branches: '*'
10 |
11 | jobs:
12 | ubuntu_test:
13 | name: Ubuntu Build & Test
14 | runs-on: ubuntu-22.04
15 | container: swift:6.0-jammy
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Build
19 | run: swift build -v
20 | - name: Run tests
21 | run: swift test
22 | macos_test:
23 | name: macOS Build & Test
24 | runs-on: macos-15
25 | steps:
26 | - name: Select appropriate Xcode version
27 | uses: maxim-lobanov/setup-xcode@v1
28 | with:
29 | xcode-version: latest-stable
30 | - uses: actions/checkout@v4
31 | - name: Build
32 | run: swift build -v
33 | - name: Run tests
34 | run: swift test
35 | format:
36 | name: Lint Formatting
37 | runs-on: ubuntu-22.04
38 | container: swift:6.0-jammy
39 | steps:
40 | - uses: actions/checkout@v4
41 | - name: Lint
42 | run: swift format lint --strict --recursive .
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | Packages
2 | .build
3 | xcuserdata
4 | *.xcodeproj
5 | DerivedData/
6 | .DS_Store
7 | db.sqlite
8 | .swiftpm
9 | Package.resolved
10 |
--------------------------------------------------------------------------------
/.swift-format:
--------------------------------------------------------------------------------
1 | {
2 | "fileScopedDeclarationPrivacy": {
3 | "accessLevel": "private"
4 | },
5 | "indentation": {
6 | "spaces": 4
7 | },
8 | "indentConditionalCompilationBlocks": true,
9 | "indentSwitchCaseLabels": false,
10 | "lineBreakAroundMultilineExpressionChainComponents": false,
11 | "lineBreakBeforeControlFlowKeywords": false,
12 | "lineBreakBeforeEachArgument": false,
13 | "lineBreakBeforeEachGenericRequirement": false,
14 | "lineLength": 140,
15 | "maximumBlankLines": 1,
16 | "multiElementCollectionTrailingCommas": true,
17 | "noAssignmentInExpressions": {
18 | "allowedFunctions": [
19 | "XCTAssertNoThrow"
20 | ]
21 | },
22 | "prioritizeKeepingFunctionOutputTogether": false,
23 | "respectsExistingLineBreaks": true,
24 | "rules": {
25 | "AllPublicDeclarationsHaveDocumentation": false,
26 | "AlwaysUseLiteralForEmptyCollectionInit": false,
27 | "AlwaysUseLowerCamelCase": true,
28 | "AmbiguousTrailingClosureOverload": true,
29 | "BeginDocumentationCommentWithOneLineSummary": false,
30 | "DoNotUseSemicolons": true,
31 | "DontRepeatTypeInStaticProperties": true,
32 | "FileScopedDeclarationPrivacy": true,
33 | "FullyIndirectEnum": true,
34 | "GroupNumericLiterals": true,
35 | "IdentifiersMustBeASCII": true,
36 | "NeverForceUnwrap": false,
37 | "NeverUseForceTry": false,
38 | "NeverUseImplicitlyUnwrappedOptionals": false,
39 | "NoAccessLevelOnExtensionDeclaration": true,
40 | "NoAssignmentInExpressions": true,
41 | "NoBlockComments": true,
42 | "NoCasesWithOnlyFallthrough": true,
43 | "NoEmptyTrailingClosureParentheses": true,
44 | "NoLabelsInCasePatterns": true,
45 | "NoLeadingUnderscores": false,
46 | "NoParensAroundConditions": true,
47 | "NoPlaygroundLiterals": true,
48 | "NoVoidReturnOnFunctionSignature": true,
49 | "OmitExplicitReturns": false,
50 | "OneCasePerLine": true,
51 | "OneVariableDeclarationPerLine": true,
52 | "OnlyOneTrailingClosureArgument": true,
53 | "OrderedImports": true,
54 | "ReplaceForEachWithForLoop": true,
55 | "ReturnVoidInsteadOfEmptyTuple": true,
56 | "TypeNamesShouldBeCapitalized": true,
57 | "UseEarlyExits": false,
58 | "UseExplicitNilCheckInConditions": true,
59 | "UseLetInEveryBoundCaseVariable": true,
60 | "UseShorthandTypeNames": true,
61 | "UseSingleLinePropertyGetter": true,
62 | "UseSynthesizedInitializer": true,
63 | "UseTripleSlashForDocumentationComments": true,
64 | "UseWhereClausesInForLoops": false,
65 | "ValidateDocumentationComments": false
66 | },
67 | "spacesAroundRangeFormationOperators": false,
68 | "tabWidth": 4,
69 | "version": 1
70 | }
71 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | os:
2 | - linux
3 | - osx
4 | language: generic
5 | sudo: required
6 | dist: trusty
7 |
8 | osx_image: xcode9.2
9 | before_install:
10 | - if [ $TRAVIS_OS_NAME == "osx" ]; then
11 | brew update;
12 | brew tap vapor/tap;
13 | brew update;
14 | brew install vapor;
15 | else
16 | eval "$(curl -sL https://apt.vapor.sh)";
17 | sudo apt-get install vapor;
18 | sudo chmod -R a+rx /usr/;
19 | fi
20 |
21 | script:
22 | - swift build
23 | - swift build -c release
24 | - swift test
25 |
26 | after_success:
27 | - eval "$(curl -sL https://raw.githubusercontent.com/vapor-community/swift/swift-4-codecov/codecov-swift4)"
28 |
29 | notifications:
30 | email:
31 | on_success: change
32 | on_failure: change
33 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
6 |
7 | ## Our Standards
8 |
9 | Examples of behavior that contributes to creating a positive environment include:
10 |
11 | * Using welcoming and inclusive language
12 | * Being respectful of differing viewpoints and experiences
13 | * Gracefully accepting constructive criticism
14 | * Focusing on what is best for the community
15 | * Showing empathy towards other community members
16 |
17 | Examples of unacceptable behavior by participants include:
18 |
19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances
20 | * Trolling, insulting/derogatory comments, and personal or political attacks
21 | * Public or private harassment
22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission
23 | * Other conduct which could reasonably be considered inappropriate in a professional setting
24 |
25 | ## Our Responsibilities
26 |
27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
28 |
29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
30 |
31 | ## Scope
32 |
33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
34 |
35 | ## Enforcement
36 |
37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at support@brokenhands.io. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
38 |
39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
40 |
41 | ## Attribution
42 |
43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]
44 |
45 | [homepage]: http://contributor-covenant.org
46 | [version]: http://contributor-covenant.org/version/1/4/
47 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to a Broken Hands project
2 |
3 | :+1::tada: Thank you for wanting to contribute to this project! :tada::+1:
4 |
5 | We ask that you follow a few guidelines when contributing to one of our projects.
6 |
7 | ## Code of Conduct
8 |
9 | This project and everyone participating in it is governed by the [Broken Hands Code of Conduct](https://github.com/brokenhandsio/SteamPress/blob/master/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to [support@brokenhands.io](mailto:support@brokenhandsio).
10 |
11 | # How Can I Contribute?
12 |
13 | ### Reporting Bugs
14 |
15 | This section guides you through submitting a bug report for a Broken Hands project. Following these guidelines helps maintainers and the community understand your report :pencil:, reproduce the behavior :computer: :computer:, and find related reports :mag_right:.
16 |
17 | Before creating bug reports, please check [this list](#before-submitting-a-bug-report) as you might find out that you don't need to create one. When you are creating a bug report, please [include as many details as possible](#how-do-i-submit-a-good-bug-report).
18 |
19 | > **Note:** If you find a **Closed** issue that seems like it is the same thing that you're experiencing, open a new issue and include a link to the original issue in the body of your new one.
20 |
21 | #### Before Submitting A Bug Report
22 |
23 | * **Perform a [cursory search](https://github.com/issues?q=+is%3Aissue+user%3Abrokenhandsio)** to see if the problem has already been reported. If it has **and the issue is still open**, add a comment to the existing issue instead of opening a new one.
24 |
25 | #### How Do I Submit A (Good) Bug Report?
26 |
27 | Bugs are tracked as [GitHub issues](https://guides.github.com/features/issues/). Create an issue on the repository and provide the following information by filling in the issue form.
28 |
29 | Explain the problem and include additional details to help maintainers reproduce the problem:
30 |
31 | * **Use a clear and descriptive title** for the issue to identify the problem.
32 | * **Describe the exact steps which reproduce the problem** in as many details as possible. This usually means including some code, as well as __full__ error messages if applicable.
33 | * **Provide specific examples to demonstrate the steps**. Include links to files or GitHub projects, or copy/pasteable snippets, which you use in those examples. If you're providing snippets in the issue, use [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
34 | * **Describe the behavior you observed after following the steps** and point out what exactly is the problem with that behavior.
35 | * **Explain which behavior you expected to see instead and why.**
36 | * **If the problem wasn't triggered by a specific action**, describe what you were doing before the problem happened and share more information using the guidelines below.
37 |
38 | ### Suggesting Enhancements
39 |
40 | This section guides you through submitting an enhancement suggestion for a Broken Hands project, including completely new features and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion :pencil: and find related suggestions :mag_right:.
41 |
42 | Before creating enhancement suggestions, please check [this list](#before-submitting-an-enhancement-suggestion) as you might find out that you don't need to create one. When you are creating an enhancement suggestion, please [include as many details as possible](#how-do-i-submit-a-good-enhancement-suggestion). Fill in issue form, including the steps that you imagine you would take if the feature you're requesting existed.
43 |
44 | #### Before Submitting An Enhancement Suggestion
45 |
46 | * **Perform a [cursory search](https://github.com/issues?q=+is%3Aissue+user%3Abrokenhandsio)** to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one.
47 |
48 | #### How Do I Submit A (Good) Enhancement Suggestion?
49 |
50 | Enhancement suggestions are tracked as [GitHub issues](https://guides.github.com/features/issues/). Create an issue and provide the following information:
51 |
52 | * **Use a clear and descriptive title** for the issue to identify the suggestion.
53 | * **Provide a step-by-step description of the suggested enhancement** in as many details as possible.
54 | * **Provide specific examples to demonstrate the steps**. Include copy/pasteable snippets which you use in those examples, as [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
55 | * **Describe the current behavior** and **explain which behavior you expected to see instead** and why.
56 | * **Explain why this enhancement would be useful** to other users and isn't something that can or should be implemented as a separate package.
57 |
58 | ### Pull Requests
59 |
60 | * Do not include issue numbers in the PR title
61 | * End all files with a newline
62 | * All new code should be run through `swiftlint`
63 | * All code must run on both Linux and macOS
64 | * All new code must be covered by tests
65 | * All bug fixes must be accompanied by a test which would fail if the bug fix was not implemented
66 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM swift:5.7
2 |
3 | WORKDIR /package
4 |
5 | COPY . ./
6 |
7 | RUN swift package --enable-prefetching fetch
8 | RUN swift package clean
9 | CMD swift test
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Broken Hands
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version:5.6
2 | import PackageDescription
3 |
4 | let package = Package(
5 | name: "vapor-oauth",
6 | platforms: [
7 | .macOS(.v12)
8 | ],
9 | products: [
10 | .library(
11 | name: "OAuth",
12 | targets: ["VaporOAuth"]
13 | )
14 | ],
15 | dependencies: [
16 | .package(url: "https://github.com/vapor/vapor.git", from: "4.111.0")
17 | ],
18 | targets: [
19 | .target(
20 | name: "VaporOAuth",
21 | dependencies: [.product(name: "Vapor", package: "vapor")]
22 | ),
23 | .testTarget(
24 | name: "VaporOAuthTests",
25 | dependencies: [
26 | .target(name: "VaporOAuth"),
27 | .product(name: "XCTVapor", package: "vapor"),
28 | ]),
29 | ]
30 | )
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | Vapor OAuth is an OAuth2 Provider Library written for Vapor. You can integrate the library into your server to provide authorization for applications to connect to your APIs.
20 |
21 | It follows both [RFC 6749](https://tools.ietf.org/html/rfc6749) and [RFC6750](https://tools.ietf.org/html/rfc6749) and there is an extensive test suite to make sure it adheres to the specification.
22 |
23 | It also implements the [RFC 7662](https://tools.ietf.org/html/rfc7662) specification for Token Introspection, which is useful for microservices with a shared, central authorization server.
24 |
25 | Vapor OAuth supports the standard grant types:
26 |
27 | * Authorization Code
28 | * Client Credentials
29 | * Implicit Grant
30 | * Password Credentials
31 |
32 | For an excellent description on how the standard OAuth flows work, and what to expect when using and implementing them, have a look at https://www.oauth.com.
33 |
34 | # Usage
35 |
36 | ## Getting Started
37 |
38 | Vapor OAuth can be added to your Vapor add with a simple provider. To get started, first add the library to your `Package.swift` dependencies:
39 |
40 | ```swift
41 | dependencies: [
42 | ...,
43 | .package(url: "https://github.com/brokenhandsio/vapor-oauth", from: "0.6.0"))
44 | ]
45 | ```
46 |
47 | Next import the library into where you set up your `Droplet`:
48 |
49 | ```swift
50 | import VaporOAuth
51 | ```
52 |
53 | Then add the provider to your `Config`:
54 |
55 | ```swift
56 | try addProvider(VaporOAuth.Provider(codeManager: MyCodeManager(), tokenManager: MyTokenManager(), clientRetriever: MyClientRetriever(), authorizeHandler: MyAuthHandler(), userManager: MyUserManager(), validScopes: ["view_profile", "edit_profile"], resourceServerRetriever: MyResourceServerRetriever()))
57 | ```
58 |
59 | To integrate the library, you need to set up a number of things, which implement the various protocols required:
60 |
61 | * `CodeManager` - this is responsible for generating and managing OAuth Codes. It is only required for the Authorization Code flow, so if you do not want to support this grant, you can leave out this parameter and use the default implementation
62 | * `TokenManager` - this is responsible for generating and managing Access and Refresh Tokens. You can either store these in memory, in Fluent, or with any backend.
63 | * `ClientRetriever` - this is responsible for getting all of the clients you want to support in your app. If you want to be able to dynamically add clients then you will need to make sure you can do that with your implementation. If you only want to support a set group of clients, you can use the `StaticClientRetriever` which is provided for you
64 | * `AuthorizeHandler` - this is responsible for allowing users to allow/deny authorization requests. See below for more details. If you do not want to support this grant type you can exclude this parameter and use the default implementation
65 | * `UserManager` - this is responsible for authenticating and getting users for the Password Credentials flow. If you do not want to support this flow, you can exclude this parameter and use the default implementation.
66 | * `validScopes` - this is an optional array of scopes that you wish to support in your system.
67 | * `ResourceServerRetriever` - this is only required if using the Token Introspection Endpoint and is what is used to authenticate resource servers trying to access the endpoint
68 |
69 | Note that there are a number of default implementations for the different required protocols for Fluent in the [Vapor OAuth Fluent package](https://github.com/brokenhandsio/vapor-oauth-fluent).
70 |
71 | The Provider will then register endpoints for authorization and tokens at `/oauth/authorize` and `/oauth/token`
72 |
73 | ## Protecting Endpoints
74 |
75 | Vapor OAuth has a helper extension on `Request` to allow you to easily protect your API routes. For instance, let's say that you want to ensure that one route is accessed only with tokens with the `profile` scope, you can do:
76 |
77 | ```swift
78 | try request.oauth.assertScopes(["profile"])
79 | ```
80 |
81 | This will throw a 401 error if the token is not valid or does not contain the `profile` scope. This is so common, that there is a dedicated `OAuth2ScopeMiddleware` for this behaviour. You just need to initialise this with an array of scopes that must be required for that `protect` group. If you initialise it with a `nil` array, then it will just make sure that the token is valid.
82 |
83 | You can also get the user with `try request.oauth.user()`.
84 |
85 | ### Protecting Resource Servers With Remote Auth Server
86 |
87 | If you have resource servers that are not the same server as the OAuth server that you wish to protect using the Token Introspection Endpoint, things are slightly different. See the [Token Introspection](#token-introspection) section for more information.
88 |
89 | # Grant Types
90 |
91 | ## Authorization Code Grant
92 |
93 | The Authorization Code flow is the most common flow used with OAuth. It is what most web applications will use for authorization with an OAuth Resource Server. The basic outline of this grant type is:
94 |
95 | 1. A client (another app) redirects a resource owner (a user that holds information with you) to your Vapor app.
96 | 2. Your Vapor app then authenticates the user and asks the user whether they want to allow the client access to the scopes requested (think logging into something with your Facebook account - it's this method).
97 | 3. If the user approves the application then the OAuth server redirects back to the client with an OAuth Code (that is typically valid for 60s or so)
98 | 4. The client can then exchange that code for an access and refresh token
99 | 5. The client can use the access token to make requests to the Resource Server (the OAuth server, or your web app)
100 |
101 | ### Implementation Details
102 |
103 | As well as implementing the Code Manager, Token Manager, and Client Retriever, the most important part to implement is the `AuthorizeHandler`. Your authorize handler is responsible for letting the user decide whether they should let an application have access to their account. It should be [clear and easy](https://www.oauth.com/oauth2-servers/authorization/the-authorization-interface/) to understand what is going on and should be clear what the application is requesting access to.
104 |
105 | It is your responsibility to ensure that the user is logged in and handling the case when they are not. An example implementation for the authorize handler may look something like:
106 |
107 | ```swift
108 | func handleAuthorizationRequest(_ request: Request, authorizationGetRequestObject: AuthorizationGetRequestObject) throws -> ResponseRepresentable {
109 | guard request.auth.isAuthenticated(FluentOAuthUser.self) else {
110 | let redirectCookie = Cookie(name: "OAuthRedirect", value: request.uri.description)
111 | let response = Response(redirect: "/login")
112 | response.cookies.insert(redirectCookie)
113 | return response
114 | }
115 |
116 | var parameters = Node([:], in: nil)
117 | let client = clientRetriever.getClient(clientID: authorizationGetRequestObject.clientID)
118 |
119 | try parameters.set("csrf_token", authorizationGetRequestObject.csrfToken)
120 | try parameters.set("scopes", authorizationGetRequestObject.scopes)
121 | try parameters.set("client_name", client.clientName)
122 | try parameters.set("client_image", client.clientImage)
123 | try parameters.set("user", request.auth.user)
124 |
125 | return try view.make("authorizeApplication", parameters)
126 | }
127 | ```
128 |
129 | You need to add the [`SessionsMiddleware`](https://docs.vapor.codes/2.0/sessions/sessions/) to your application for this flow to complete in order for the CSRF protection to work.
130 |
131 | When submitting the authorize form back to Vapor OAuth, in the form data it must include:
132 |
133 | * `applicationAuthorized` - a boolean value to signify if the user allowed access to the client or not
134 | * `csrfToken` - the CSRF token supplied in the handler to protect against CSRF attacks
135 |
136 | ## Implicit Grant
137 |
138 | The Implicit Grant is almost identical to the Authorize Code flow, except instead of being redirected back with a code which you then exchange for a token, you get redirected back with the token in the fragment. It is up to the client (such as an iOS application) to then parse the token out of the redirect URI fragment.
139 |
140 | This flow was designed for clients where you couldn't guarantee the security of the client secret, client-side apps, but has fallen out of favour recently and it is generally recommended to use the Authorization Code flow without a client secret instead.
141 |
142 | ## Resource Owner Password Credentials Grant
143 |
144 | The Password Credentials flow should only be used for first party applications, and Vapor OAuth mandates this. This flow allows the client to collect the username and password of the user and submit them directly to the OAuth server to get a token.
145 |
146 | Note that if you are using the password flow, as per [the specification](https://tools.ietf.org/html/rfc6749#section-4.3.2), you must secure your endpoint against brute force attacks with rate limiting or generating alerts. The library will output a warning message to the console for any unauthorized attempts, which you can use for this purpose. The message is in the form of `LOGIN WARNING: Invalid login attempt for user `.
147 |
148 | ## Client Credentials Grant
149 |
150 | Client Credentials is a userless flow and is designed for servers accessing other servers without the need for a user. Access is granted based upon the authentication of the client requesting access.
151 |
152 | ## Token Introspection
153 |
154 | If running a microservices architecture it is useful to have a single server that handles authorization, which all the other resource servers query. To do this, you can use the Token Introspection Endpoint extension. In Vapor OAuth, this adds an endpoint you can post tokens tokens at `/oauth/token_info`.
155 |
156 | You can send a POST request to this endpoint with a single parameter, `token`, which contains the OAuth token you want to check. If it is valid and active, then it will return a JSON payload, that looks similar to:
157 |
158 | ```json
159 | {
160 | "active": true,
161 | "client_id": "ABDED0123456",
162 | "scope": "email profile",
163 | "exp": 1503445858,
164 | "user_id": "12345678",
165 | "username": "hansolo",
166 | "email_address": "hansolo@therebelalliance.com"
167 | }
168 | ```
169 |
170 | If the token has expired or does not exist then it will simply return:
171 |
172 | ```json
173 | {
174 | "active": false
175 | }
176 | ```
177 |
178 | This endpoint is protected using HTTP Basic Authentication so you need to send an `Authorization: Basic abc` header with the request. This will check the `ResourceServerRetriever` for the username and password sent.
179 |
180 | **Note:** as per [the spec](https://tools.ietf.org/html/rfc7662#section-4) - the token introspection endpoint MUST be protected by HTTPS - this means the server must be behind a TLS certificate (commonly known as SSL). Vapor OAuth leaves this up to the integrating library to implement.
181 |
182 | ### Protecting Endpoints
183 |
184 | To protect resources on other servers with OAuth using the Token Introspection endpoint, you either need to use the `OAuth2TokenIntrospectionMiddleware` on your routes that you want to protect, or you need to manually set up the `Helper` object (the middleware does this for you). Both the middleware and helper setup require:
185 |
186 | * `tokenIntrospectionEndpoint` - the endpoint where the token can be validated
187 | * `client` - the `Droplet`'s client to send the token validation request with
188 | * `resourceServerUsername` - the username of the resource server
189 | * `resourceServerPassword` - the password of the resource server
190 |
191 | Once either of these has been set up, you can then call `request.oauth.user()` or `request.oauth.assertScopes()` like normal.
192 |
--------------------------------------------------------------------------------
/Sources/.swiftlint.yml:
--------------------------------------------------------------------------------
1 | line_length: 140
2 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/DefaultImplementations/EmptyAuthorizationHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public struct EmptyAuthorizationHandler: AuthorizeHandler {
4 | public init() {}
5 |
6 | public func handleAuthorizationRequest(
7 | _ request: Request,
8 | authorizationRequestObject: AuthorizationRequestObject
9 | ) async throws -> Response {
10 | Response(body: "")
11 | }
12 |
13 | public func handleAuthorizationError(_ errorType: AuthorizationError) async throws -> Response {
14 | Response(body: "")
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift:
--------------------------------------------------------------------------------
1 | public struct EmptyCodeManager: CodeManager {
2 | public init() {}
3 |
4 | public func getCode(_ code: String) -> OAuthCode? {
5 | return nil
6 | }
7 |
8 | public func generateCode(
9 | userID: String,
10 | clientID: String,
11 | redirectURI: String,
12 | scopes: [String]?
13 | ) throws -> String {
14 | return ""
15 | }
16 |
17 | public func codeUsed(_ code: OAuthCode) {}
18 | }
19 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/DefaultImplementations/EmptyResourceServerRetriever.swift:
--------------------------------------------------------------------------------
1 | public struct EmptyResourceServerRetriever: ResourceServerRetriever {
2 |
3 | public init() {}
4 |
5 | public func getServer(_ username: String) async throws -> OAuthResourceServer? {
6 | return nil
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/DefaultImplementations/EmptyUserManager.swift:
--------------------------------------------------------------------------------
1 | public struct EmptyUserManager: UserManager {
2 |
3 | public init() {}
4 |
5 | public func getUser(userID: String) async throws -> OAuthUser? {
6 | return nil
7 | }
8 |
9 | public func authenticateUser(username: String, password: String) async throws -> String? {
10 | return nil
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/DefaultImplementations/StaticClientRetriever.swift:
--------------------------------------------------------------------------------
1 | public struct StaticClientRetriever: ClientRetriever {
2 |
3 | let clients: [String: OAuthClient]
4 |
5 | public init(clients: [OAuthClient]) {
6 | self.clients = clients.reduce([String: OAuthClient]()) { (dict, client) -> [String: OAuthClient] in
7 | var dict = dict
8 | dict[client.clientID] = client
9 | return dict
10 | }
11 | }
12 |
13 | public func getClient(clientID: String) async throws -> OAuthClient? {
14 | return clients[clientID]
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Helper/OAuthHelper+local.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | extension OAuthHelper {
4 | public static func local(
5 | tokenAuthenticator: TokenAuthenticator?,
6 | userManager: UserManager?,
7 | tokenManager: TokenManager?
8 | ) -> Self {
9 | OAuthHelper(
10 | assertScopes: { scopes, request in
11 | guard let tokenAuthenticator = tokenAuthenticator else {
12 | throw Abort(.forbidden)
13 | }
14 |
15 | let accessToken = try await getToken(tokenManager: tokenManager, request: request)
16 |
17 | guard tokenAuthenticator.validateAccessToken(accessToken, requiredScopes: scopes) else {
18 | throw Abort(.unauthorized)
19 | }
20 | },
21 | user: { request in
22 | guard let userManager = userManager else {
23 | throw Abort(.forbidden)
24 | }
25 |
26 | let token = try await getToken(tokenManager: tokenManager, request: request)
27 |
28 | guard let userID = token.userID else {
29 | throw Abort(.unauthorized)
30 | }
31 |
32 | guard let user = try await userManager.getUser(userID: userID) else {
33 | throw Abort(.unauthorized)
34 | }
35 |
36 | return user
37 | }
38 | )
39 | }
40 |
41 | private static func getToken(tokenManager: TokenManager?, request: Request) async throws -> AccessToken {
42 | guard let tokenManager = tokenManager else {
43 | throw Abort(.forbidden)
44 | }
45 |
46 | let token = try request.getOAuthToken()
47 |
48 | guard let accessToken = try await tokenManager.getAccessToken(token) else {
49 | throw Abort(.unauthorized)
50 | }
51 |
52 | guard accessToken.expiryTime >= Date() else {
53 | throw Abort(.unauthorized)
54 | }
55 |
56 | return accessToken
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Helper/OAuthHelper+remote.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | extension OAuthHelper {
4 | public static func remote(
5 | tokenIntrospectionEndpoint: String,
6 | client: Client,
7 | resourceServerUsername: String,
8 | resourceServerPassword: String
9 | ) -> Self {
10 | var remoteTokenResponse: RemoteTokenResponse?
11 | return OAuthHelper(
12 | assertScopes: { scopes, request in
13 | if remoteTokenResponse == nil {
14 | try await setupRemoteTokenResponse(
15 | request: request,
16 | tokenIntrospectionEndpoint: tokenIntrospectionEndpoint,
17 | client: client,
18 | resourceServerUsername: resourceServerUsername,
19 | resourceServerPassword: resourceServerPassword,
20 | remoteTokenResponse: &remoteTokenResponse
21 | )
22 | }
23 |
24 | guard let remoteTokenResponse = remoteTokenResponse else {
25 | throw Abort(.internalServerError)
26 | }
27 |
28 | if let requiredScopes = scopes {
29 | guard let tokenScopes = remoteTokenResponse.scopes else {
30 | throw Abort(.unauthorized)
31 | }
32 |
33 | for scope in requiredScopes {
34 | if !tokenScopes.contains(scope) {
35 | throw Abort(.unauthorized)
36 | }
37 | }
38 | }
39 | },
40 | user: { request in
41 | if remoteTokenResponse == nil {
42 | try await setupRemoteTokenResponse(
43 | request: request,
44 | tokenIntrospectionEndpoint: tokenIntrospectionEndpoint,
45 | client: client,
46 | resourceServerUsername: resourceServerUsername,
47 | resourceServerPassword: resourceServerPassword,
48 | remoteTokenResponse: &remoteTokenResponse
49 | )
50 | }
51 |
52 | guard let remoteTokenResponse = remoteTokenResponse else {
53 | throw Abort(.internalServerError)
54 | }
55 |
56 | guard let user = remoteTokenResponse.user else {
57 | throw Abort(.unauthorized)
58 | }
59 |
60 | return user
61 | }
62 | )
63 | }
64 |
65 | private static func setupRemoteTokenResponse(
66 | request: Request,
67 | tokenIntrospectionEndpoint: String,
68 | client: Client,
69 | resourceServerUsername: String,
70 | resourceServerPassword: String,
71 | remoteTokenResponse: inout RemoteTokenResponse?
72 | ) async throws {
73 | let token = try request.getOAuthToken()
74 |
75 | var headers = HTTPHeaders()
76 | headers.basicAuthorization = .init(
77 | username: resourceServerUsername,
78 | password: resourceServerPassword
79 | )
80 |
81 | struct Token: Content {
82 | let token: String
83 | }
84 | let tokenInfoResponse = try await client.post(
85 | URI(string: tokenIntrospectionEndpoint),
86 | headers: headers,
87 | content: Token(token: token)
88 | ).get()
89 |
90 | let tokenInfoJSON = tokenInfoResponse.content
91 |
92 | guard let tokenActive: Bool = tokenInfoJSON[OAuthResponseParameters.active], tokenActive else {
93 | throw Abort(.unauthorized)
94 | }
95 |
96 | var scopes: [String]?
97 | var oauthUser: OAuthUser?
98 |
99 | if let tokenScopes: String = tokenInfoJSON[OAuthResponseParameters.scope] {
100 | scopes = tokenScopes.components(separatedBy: " ")
101 | }
102 |
103 | if let userID: String = tokenInfoJSON[OAuthResponseParameters.userID] {
104 | guard let username: String = tokenInfoJSON[OAuthResponseParameters.username] else {
105 | throw Abort(.internalServerError)
106 | }
107 | oauthUser = OAuthUser(
108 | userID: userID, username: username,
109 | emailAddress: tokenInfoJSON[String.self, at: OAuthResponseParameters.email],
110 | password: "")
111 | }
112 |
113 | remoteTokenResponse = RemoteTokenResponse(scopes: scopes, user: oauthUser)
114 |
115 | }
116 | }
117 |
118 | struct RemoteTokenResponse {
119 | let scopes: [String]?
120 | let user: OAuthUser?
121 | }
122 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Helper/OAuthHelper.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public struct OAuthHelper {
4 | public var assertScopes: ([String]?, Request) async throws -> Void
5 | public var user: (Request) async throws -> OAuthUser
6 |
7 | public init(
8 | assertScopes: @escaping ([String]?, Request) async throws -> Void,
9 | user: @escaping (Request) async throws -> OAuthUser
10 | ) {
11 | self.assertScopes = assertScopes
12 | self.user = user
13 | }
14 | }
15 |
16 | extension Application {
17 | struct OAuthHelperKey: StorageKey {
18 | typealias Value = OAuthHelper
19 | }
20 |
21 | public var oAuthHelper: OAuthHelper {
22 | get {
23 | guard let oAuthHelper = storage[OAuthHelperKey.self] else {
24 | fatalError("OAuthHelperKey not set up. Use app.oAuthHelper = ...")
25 | }
26 | return oAuthHelper
27 | }
28 | set {
29 | storage[OAuthHelperKey.self] = newValue
30 | }
31 | }
32 | }
33 |
34 | extension Request {
35 | public var oAuthHelper: OAuthHelper { application.oAuthHelper }
36 | }
37 |
38 | extension Request {
39 | func getOAuthToken() throws -> String {
40 | guard let authHeader = headers.first(name: .authorization) else {
41 | throw Abort(.forbidden)
42 | }
43 |
44 | guard authHeader.lowercased().hasPrefix("bearer ") else {
45 | throw Abort(.forbidden)
46 | }
47 |
48 | let token = String(authHeader[authHeader.index(authHeader.startIndex, offsetBy: 7)...])
49 |
50 | guard !token.isEmpty else {
51 | throw Abort(.forbidden)
52 | }
53 |
54 | return token
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Middleware/OAuth2ScopeMiddleware.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public struct OAuth2ScopeMiddleware: AsyncMiddleware {
4 | let requiredScopes: [String]?
5 |
6 | public init(requiredScopes: [String]?) {
7 | self.requiredScopes = requiredScopes
8 | }
9 |
10 | public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
11 | try await request.oAuthHelper.assertScopes(requiredScopes, request)
12 |
13 | return try await next.respond(to: request)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Middleware/OAuth2TokenIntrospectionMiddleware.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public struct OAuth2TokenIntrospectionMiddleware: AsyncMiddleware {
4 | let requiredScopes: [String]?
5 |
6 | public init(requiredScopes: [String]?) {
7 | self.requiredScopes = requiredScopes
8 | }
9 |
10 | public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
11 | try await request.oAuthHelper.assertScopes(requiredScopes, request)
12 |
13 | return try await next.respond(to: request)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Middleware/TokenIntrospectionAuthenticationMiddleware.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct TokenIntrospectionAuthMiddleware: AsyncMiddleware {
4 | let resourceServerAuthenticator: ResourceServerAuthenticator
5 |
6 | func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
7 | guard let basicAuthorization = request.headers.basicAuthorization else {
8 | throw Abort(.unauthorized)
9 | }
10 |
11 | try await resourceServerAuthenticator.authenticate(credentials: basicAuthorization)
12 |
13 | return try await next.respond(to: request)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/OAuthClient.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public final class OAuthClient: Extendable {
4 |
5 | public let clientID: String
6 | public let redirectURIs: [String]?
7 | public let clientSecret: String?
8 | public let validScopes: [String]?
9 | public let confidentialClient: Bool?
10 | public let firstParty: Bool
11 | public let allowedGrantType: OAuthFlowType
12 |
13 | public var extend: Vapor.Extend = .init()
14 |
15 | public init(
16 | clientID: String, redirectURIs: [String]?, clientSecret: String? = nil, validScopes: [String]? = nil,
17 | confidential: Bool? = nil, firstParty: Bool = false, allowedGrantType: OAuthFlowType
18 | ) {
19 | self.clientID = clientID
20 | self.redirectURIs = redirectURIs
21 | self.clientSecret = clientSecret
22 | self.validScopes = validScopes
23 | self.confidentialClient = confidential
24 | self.firstParty = firstParty
25 | self.allowedGrantType = allowedGrantType
26 | }
27 |
28 | func validateRedirectURI(_ redirectURI: String) -> Bool {
29 | guard let redirectURIs = redirectURIs else {
30 | return false
31 | }
32 |
33 | if redirectURIs.contains(redirectURI) {
34 | return true
35 | }
36 |
37 | return false
38 | }
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/OAuthCode.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | public final class OAuthCode {
4 | public let codeID: String
5 | public let clientID: String
6 | public let redirectURI: String
7 | public let userID: String
8 | public let expiryDate: Date
9 | public let scopes: [String]?
10 |
11 | public var extend: [String: Any] = [:]
12 |
13 | public init(
14 | codeID: String,
15 | clientID: String,
16 | redirectURI: String,
17 | userID: String,
18 | expiryDate: Date,
19 | scopes: [String]?
20 | ) {
21 | self.codeID = codeID
22 | self.clientID = clientID
23 | self.redirectURI = redirectURI
24 | self.userID = userID
25 | self.expiryDate = expiryDate
26 | self.scopes = scopes
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/OAuthResourceServer.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public final class OAuthResourceServer: Extendable {
4 | public let username: String
5 | public let password: String
6 | public var extend: Vapor.Extend = .init()
7 |
8 | public init(username: String, password: String) {
9 | self.username = username
10 | self.password = password
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/OAuthUser.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public final class OAuthUser: Authenticatable, Extendable, Encodable {
4 | public let username: String
5 | public let emailAddress: String?
6 | public var password: String
7 | // swiftlint:disable:next identifier_name
8 | public var id: String?
9 |
10 | public var extend: Extend = .init()
11 |
12 | public init(userID: String? = nil, username: String, emailAddress: String?, password: String) {
13 | self.username = username
14 | self.emailAddress = emailAddress
15 | self.password = password
16 | self.id = userID
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/Tokens/AccessToken.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public protocol AccessToken {
4 | var tokenString: String { get }
5 | var clientID: String { get }
6 | var userID: String? { get }
7 | var scopes: [String]? { get }
8 | var expiryTime: Date { get }
9 | }
10 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Models/Tokens/RefreshToken.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public protocol RefreshToken {
4 | var tokenString: String { get set }
5 | var clientID: String { get set }
6 | var userID: String? { get set }
7 | var scopes: [String]? { get set }
8 | }
9 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/OAuth2.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public struct OAuth2: LifecycleHandler {
4 | let codeManager: CodeManager
5 | let tokenManager: TokenManager
6 | let clientRetriever: ClientRetriever
7 | let authorizeHandler: AuthorizeHandler
8 | let userManager: UserManager
9 | let validScopes: [String]?
10 | let resourceServerRetriever: ResourceServerRetriever
11 | let oAuthHelper: OAuthHelper
12 |
13 | public init(
14 | codeManager: CodeManager = EmptyCodeManager(),
15 | tokenManager: TokenManager,
16 | clientRetriever: ClientRetriever,
17 | authorizeHandler: AuthorizeHandler = EmptyAuthorizationHandler(),
18 | userManager: UserManager = EmptyUserManager(),
19 | validScopes: [String]? = nil,
20 | resourceServerRetriever: ResourceServerRetriever = EmptyResourceServerRetriever(),
21 | oAuthHelper: OAuthHelper
22 | ) {
23 | self.codeManager = codeManager
24 | self.clientRetriever = clientRetriever
25 | self.authorizeHandler = authorizeHandler
26 | self.tokenManager = tokenManager
27 | self.userManager = userManager
28 | self.validScopes = validScopes
29 | self.resourceServerRetriever = resourceServerRetriever
30 | self.oAuthHelper = oAuthHelper
31 | }
32 |
33 | public func didBoot(_ application: Application) throws {
34 | addRoutes(to: application)
35 | application.oAuthHelper = oAuthHelper
36 | }
37 |
38 | private func addRoutes(to app: Application) {
39 | let scopeValidator = ScopeValidator(validScopes: validScopes, clientRetriever: clientRetriever)
40 | let clientValidator = ClientValidator(
41 | clientRetriever: clientRetriever,
42 | scopeValidator: scopeValidator,
43 | environment: app.environment
44 | )
45 |
46 | let tokenHandler = TokenHandler(
47 | clientValidator: clientValidator,
48 | tokenManager: tokenManager,
49 | scopeValidator: scopeValidator,
50 | codeManager: codeManager,
51 | userManager: userManager,
52 | logger: app.logger
53 | )
54 |
55 | let tokenIntrospectionHandler = TokenIntrospectionHandler(
56 | clientValidator: clientValidator,
57 | tokenManager: tokenManager,
58 | userManager: userManager
59 | )
60 |
61 | let authorizeGetHandler = AuthorizeGetHandler(
62 | authorizeHandler: authorizeHandler,
63 | clientValidator: clientValidator
64 | )
65 | let authorizePostHandler = AuthorizePostHandler(
66 | tokenManager: tokenManager,
67 | codeManager: codeManager,
68 | clientValidator: clientValidator
69 | )
70 |
71 | let resourceServerAuthenticator = ResourceServerAuthenticator(resourceServerRetriever: resourceServerRetriever)
72 |
73 | // returning something like "Authenticate with GitHub page"
74 | app.get("oauth", "authorize", use: authorizeGetHandler.handleRequest)
75 | // pressing something like "Allow/Deny Access" button on "Authenticate with GitHub page". Returns a code.
76 | app.grouped(OAuthUser.guardMiddleware()).post("oauth", "authorize", use: authorizePostHandler.handleRequest)
77 | // client requesting access/refresh token with code from POST /authorize endpoint
78 | app.post("oauth", "token", use: tokenHandler.handleRequest)
79 |
80 | let tokenIntrospectionAuthMiddleware = TokenIntrospectionAuthMiddleware(resourceServerAuthenticator: resourceServerAuthenticator)
81 | let resourceServerProtected = app.routes.grouped(tokenIntrospectionAuthMiddleware)
82 | resourceServerProtected.post("oauth", "token_info", use: tokenIntrospectionHandler.handleRequest)
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/AuthorizeHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public protocol AuthorizeHandler {
4 | func handleAuthorizationRequest(
5 | _ request: Request,
6 | authorizationRequestObject: AuthorizationRequestObject
7 | ) async throws -> Response
8 | func handleAuthorizationError(_ errorType: AuthorizationError) async throws -> Response
9 | }
10 |
11 | public enum AuthorizationError: Error {
12 | case invalidClientID
13 | case confidentialClientTokenGrant
14 | case invalidRedirectURI
15 | case httpRedirectURI
16 | }
17 |
18 | public struct AuthorizationRequestObject {
19 | public let responseType: String
20 | public let clientID: String
21 | public let redirectURI: URI
22 | public let scope: [String]
23 | public let state: String?
24 | public let csrfToken: String
25 | }
26 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/ClientRetriever.swift:
--------------------------------------------------------------------------------
1 | public protocol ClientRetriever {
2 | func getClient(clientID: String) async throws -> OAuthClient?
3 | }
4 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/CodeManager.swift:
--------------------------------------------------------------------------------
1 | /// Responsible for generating and managing OAuth Codes
2 | public protocol CodeManager {
3 | func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?) async throws -> String
4 | func getCode(_ code: String) async throws -> OAuthCode?
5 |
6 | // This is explicit to ensure that the code is marked as used or deleted (it could be implied that this is done when you call
7 | // `getCode` but it is called explicitly to remind developers to ensure that codes can't be reused)
8 | func codeUsed(_ code: OAuthCode) async throws
9 | }
10 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/ResourceServerRetriever.swift:
--------------------------------------------------------------------------------
1 | public protocol ResourceServerRetriever {
2 | func getServer(_ username: String) async throws -> OAuthResourceServer?
3 | }
4 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/TokenManager.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | public protocol TokenManager {
4 | func generateAccessRefreshTokens(
5 | clientID: String,
6 | userID: String?,
7 | scopes: [String]?,
8 | accessTokenExpiryTime: Int
9 | ) async throws -> (AccessToken, RefreshToken)
10 |
11 | func generateAccessToken(
12 | clientID: String,
13 | userID: String?,
14 | scopes: [String]?,
15 | expiryTime: Int
16 | ) async throws -> AccessToken
17 |
18 | func getRefreshToken(_ refreshToken: String) async throws -> RefreshToken?
19 | func getAccessToken(_ accessToken: String) async throws -> AccessToken?
20 | func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) async throws
21 | }
22 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Protocols/UserManager.swift:
--------------------------------------------------------------------------------
1 | public protocol UserManager {
2 | func authenticateUser(username: String, password: String) async throws -> String?
3 | func getUser(userID: String) async throws -> OAuthUser?
4 | }
5 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/AuthorizeGetHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct AuthorizeGetHandler {
4 | let authorizeHandler: AuthorizeHandler
5 | let clientValidator: ClientValidator
6 |
7 | func handleRequest(request: Request) async throws -> Response {
8 | let (errorResponse, createdAuthRequestObject) = try await validateRequest(request)
9 |
10 | if let errorResponseReturned = errorResponse {
11 | return errorResponseReturned
12 | }
13 |
14 | guard let authRequestObject = createdAuthRequestObject else {
15 | throw Abort(.internalServerError)
16 | }
17 |
18 | do {
19 | try await clientValidator.validateClient(
20 | clientID: authRequestObject.clientID,
21 | responseType: authRequestObject.responseType,
22 | redirectURI: authRequestObject.redirectURIString,
23 | scopes: authRequestObject.scopes
24 | )
25 | } catch AuthorizationError.invalidClientID {
26 | return try await authorizeHandler.handleAuthorizationError(.invalidClientID)
27 | } catch AuthorizationError.invalidRedirectURI {
28 | return try await authorizeHandler.handleAuthorizationError(.invalidRedirectURI)
29 | } catch ScopeError.unknown {
30 | return createErrorResponse(
31 | request: request,
32 | redirectURI: authRequestObject.redirectURIString,
33 | errorType: OAuthResponseParameters.ErrorType.invalidScope,
34 | errorDescription: "scope+is+unknown",
35 | state: authRequestObject.state)
36 | } catch ScopeError.invalid {
37 | return createErrorResponse(
38 | request: request,
39 | redirectURI: authRequestObject.redirectURIString,
40 | errorType: OAuthResponseParameters.ErrorType.invalidScope,
41 | errorDescription: "scope+is+invalid",
42 | state: authRequestObject.state)
43 | } catch AuthorizationError.confidentialClientTokenGrant {
44 | return createErrorResponse(
45 | request: request,
46 | redirectURI: authRequestObject.redirectURIString,
47 | errorType: OAuthResponseParameters.ErrorType.unauthorizedClient,
48 | errorDescription: "token+grant+disabled+for+confidential+clients",
49 | state: authRequestObject.state)
50 | } catch AuthorizationError.httpRedirectURI {
51 | return try await authorizeHandler.handleAuthorizationError(.httpRedirectURI)
52 | }
53 |
54 | let redirectURI = URI(stringLiteral: authRequestObject.redirectURIString)
55 |
56 | let csrfToken = [UInt8].random(count: 32).hex
57 |
58 | request.session.data[SessionData.csrfToken] = csrfToken
59 | let authorizationRequestObject = AuthorizationRequestObject(
60 | responseType: authRequestObject.responseType,
61 | clientID: authRequestObject.clientID, redirectURI: redirectURI,
62 | scope: authRequestObject.scopes, state: authRequestObject.state,
63 | csrfToken: csrfToken)
64 |
65 | return try await authorizeHandler.handleAuthorizationRequest(request, authorizationRequestObject: authorizationRequestObject)
66 | }
67 |
68 | private func validateRequest(_ request: Request) async throws -> (Response?, AuthorizationGetRequestObject?) {
69 | guard let clientID: String = request.query[OAuthRequestParameters.clientID] else {
70 | return (try await authorizeHandler.handleAuthorizationError(.invalidClientID), nil)
71 | }
72 |
73 | guard let redirectURIString: String = request.query[OAuthRequestParameters.redirectURI] else {
74 | return (try await authorizeHandler.handleAuthorizationError(.invalidRedirectURI), nil)
75 | }
76 |
77 | let scopes: [String]
78 |
79 | if let scopeQuery: String = request.query[OAuthRequestParameters.scope] {
80 | scopes = scopeQuery.components(separatedBy: " ")
81 | } else {
82 | scopes = []
83 | }
84 |
85 | let state: String? = request.query[OAuthRequestParameters.state]
86 |
87 | guard let responseType: String = request.query[OAuthRequestParameters.responseType] else {
88 | let errorResponse = createErrorResponse(
89 | request: request,
90 | redirectURI: redirectURIString,
91 | errorType: OAuthResponseParameters.ErrorType.invalidRequest,
92 | errorDescription: "Request+was+missing+the+response_type+parameter",
93 | state: state)
94 | return (errorResponse, nil)
95 | }
96 |
97 | guard responseType == ResponseType.code || responseType == ResponseType.token else {
98 | let errorResponse = createErrorResponse(
99 | request: request,
100 | redirectURI: redirectURIString,
101 | errorType: OAuthResponseParameters.ErrorType.invalidRequest,
102 | errorDescription: "invalid+response+type", state: state)
103 | return (errorResponse, nil)
104 | }
105 |
106 | let authRequestObject = AuthorizationGetRequestObject(
107 | clientID: clientID, redirectURIString: redirectURIString,
108 | scopes: scopes, state: state,
109 | responseType: responseType)
110 |
111 | return (nil, authRequestObject)
112 | }
113 |
114 | private func createErrorResponse(
115 | request: Request,
116 | redirectURI: String,
117 | errorType: String,
118 | errorDescription: String,
119 | state: String?
120 | ) -> Vapor.Response {
121 | var redirectString = "\(redirectURI)?error=\(errorType)&error_description=\(errorDescription)"
122 |
123 | if let state = state {
124 | redirectString += "&state=\(state)"
125 | }
126 |
127 | return request.redirect(to: redirectString)
128 | }
129 | }
130 |
131 | struct AuthorizationGetRequestObject {
132 | let clientID: String
133 | let redirectURIString: String
134 | let scopes: [String]
135 | let state: String?
136 | let responseType: String
137 | }
138 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/AuthorizePostHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct AuthorizePostRequest {
4 | let user: OAuthUser
5 | let userID: String
6 | let redirectURIBaseString: String
7 | let approveApplication: Bool
8 | let clientID: String
9 | let responseType: String
10 | let csrfToken: String
11 | let scopes: [String]?
12 | }
13 |
14 | struct AuthorizePostHandler {
15 |
16 | let tokenManager: TokenManager
17 | let codeManager: CodeManager
18 | let clientValidator: ClientValidator
19 |
20 | func handleRequest(request: Request) async throws -> Response {
21 | let requestObject = try validateAuthPostRequest(request)
22 | var redirectURI = requestObject.redirectURIBaseString
23 |
24 | do {
25 | try await clientValidator.validateClient(
26 | clientID: requestObject.clientID, responseType: requestObject.responseType,
27 | redirectURI: requestObject.redirectURIBaseString, scopes: requestObject.scopes)
28 | } catch is AbortError {
29 | throw Abort(.forbidden)
30 | } catch {
31 | throw Abort(.badRequest)
32 | }
33 |
34 | guard request.session.data[SessionData.csrfToken] == requestObject.csrfToken else {
35 | throw Abort(.badRequest)
36 | }
37 |
38 | if requestObject.approveApplication {
39 | if requestObject.responseType == ResponseType.token {
40 | let accessToken = try await tokenManager.generateAccessToken(
41 | clientID: requestObject.clientID,
42 | userID: requestObject.userID,
43 | scopes: requestObject.scopes,
44 | expiryTime: 3600
45 | )
46 | redirectURI += "#token_type=bearer&access_token=\(accessToken.tokenString)&expires_in=3600"
47 | } else if requestObject.responseType == ResponseType.code {
48 | let generatedCode = try await codeManager.generateCode(
49 | userID: requestObject.userID,
50 | clientID: requestObject.clientID,
51 | redirectURI: requestObject.redirectURIBaseString,
52 | scopes: requestObject.scopes
53 | )
54 | redirectURI += "?code=\(generatedCode)"
55 | } else {
56 | redirectURI += "?error=invalid_request&error_description=unknown+response+type"
57 | }
58 | } else {
59 | redirectURI += "?error=access_denied&error_description=user+denied+the+request"
60 | }
61 |
62 | if let requestedScopes = requestObject.scopes {
63 | if !requestedScopes.isEmpty {
64 | redirectURI += "&scope=\(requestedScopes.joined(separator: "+"))"
65 | }
66 | }
67 |
68 | if let state = try? request.query.get(String.self, at: OAuthRequestParameters.state) {
69 | redirectURI += "&state=\(state)"
70 | }
71 |
72 | return request.redirect(to: redirectURI)
73 | }
74 |
75 | private func validateAuthPostRequest(_ request: Request) throws -> AuthorizePostRequest {
76 | let user = try request.auth.require(OAuthUser.self)
77 |
78 | guard let userID = user.id else {
79 | throw Abort(.unauthorized)
80 | }
81 |
82 | guard let redirectURIBaseString: String = request.query[OAuthRequestParameters.redirectURI] else {
83 | throw Abort(.badRequest)
84 | }
85 |
86 | guard let approveApplicationReceived: Bool? = request.content[OAuthRequestParameters.applicationAuthorized],
87 | let approveApplication = approveApplicationReceived
88 | else {
89 | throw Abort(.badRequest)
90 | }
91 |
92 | guard let clientID: String = request.query[OAuthRequestParameters.clientID] else {
93 | throw Abort(.badRequest)
94 | }
95 |
96 | guard let responseType: String = request.query[OAuthRequestParameters.responseType] else {
97 | throw Abort(.badRequest)
98 | }
99 |
100 | guard let csrfToken: String = request.content[OAuthRequestParameters.csrfToken] else {
101 | throw Abort(.badRequest)
102 | }
103 |
104 | let scopes: [String]?
105 |
106 | if let scopeQuery: String = request.query[OAuthRequestParameters.scope] {
107 | scopes = scopeQuery.components(separatedBy: " ")
108 | } else {
109 | scopes = nil
110 | }
111 |
112 | return AuthorizePostRequest(
113 | user: user, userID: userID, redirectURIBaseString: redirectURIBaseString,
114 | approveApplication: approveApplication, clientID: clientID,
115 | responseType: responseType, csrfToken: csrfToken, scopes: scopes)
116 | }
117 |
118 | }
119 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct TokenHandler {
4 |
5 | let tokenAuthenticator = TokenAuthenticator()
6 | let refreshTokenHandler: RefreshTokenHandler
7 | let clientCredentialsTokenHandler: ClientCredentialsTokenHandler
8 | let tokenResponseGenerator: TokenResponseGenerator
9 | let authCodeTokenHandler: AuthCodeTokenHandler
10 | let passwordTokenHandler: PasswordTokenHandler
11 |
12 | init(
13 | clientValidator: ClientValidator, tokenManager: TokenManager, scopeValidator: ScopeValidator,
14 | codeManager: CodeManager, userManager: UserManager, logger: Logger
15 | ) {
16 | tokenResponseGenerator = TokenResponseGenerator()
17 | refreshTokenHandler = RefreshTokenHandler(
18 | scopeValidator: scopeValidator, tokenManager: tokenManager,
19 | clientValidator: clientValidator, tokenAuthenticator: tokenAuthenticator,
20 | tokenResponseGenerator: tokenResponseGenerator)
21 | clientCredentialsTokenHandler = ClientCredentialsTokenHandler(
22 | clientValidator: clientValidator,
23 | scopeValidator: scopeValidator,
24 | tokenManager: tokenManager,
25 | tokenResponseGenerator: tokenResponseGenerator)
26 | authCodeTokenHandler = AuthCodeTokenHandler(
27 | clientValidator: clientValidator, tokenManager: tokenManager,
28 | codeManager: codeManager,
29 | tokenResponseGenerator: tokenResponseGenerator)
30 | passwordTokenHandler = PasswordTokenHandler(
31 | clientValidator: clientValidator, scopeValidator: scopeValidator,
32 | userManager: userManager, logger: logger, tokenManager: tokenManager,
33 | tokenResponseGenerator: tokenResponseGenerator)
34 | }
35 |
36 | func handleRequest(request: Request) async throws -> Response {
37 | guard let grantType: String = request.content[OAuthRequestParameters.grantType] else {
38 | return try tokenResponseGenerator.createResponse(
39 | error: OAuthResponseParameters.ErrorType.invalidRequest,
40 | description: "Request was missing the 'grant_type' parameter")
41 | }
42 |
43 | switch grantType {
44 | case OAuthFlowType.authorization.rawValue:
45 | return try await authCodeTokenHandler.handleAuthCodeTokenRequest(request)
46 | case OAuthFlowType.password.rawValue:
47 | return try await passwordTokenHandler.handlePasswordTokenRequest(request)
48 | case OAuthFlowType.clientCredentials.rawValue:
49 | return try await clientCredentialsTokenHandler.handleClientCredentialsTokenRequest(request)
50 | case OAuthFlowType.refresh.rawValue:
51 | return try await refreshTokenHandler.handleRefreshTokenRequest(request)
52 | default:
53 | return try tokenResponseGenerator.createResponse(
54 | error: OAuthResponseParameters.ErrorType.unsupportedGrant,
55 | description: "This server does not support the '\(grantType)' grant type")
56 | }
57 |
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandlers/AuthCodeTokenHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct AuthCodeTokenHandler {
4 |
5 | let clientValidator: ClientValidator
6 | let tokenManager: TokenManager
7 | let codeManager: CodeManager
8 | let codeValidator = CodeValidator()
9 | let tokenResponseGenerator: TokenResponseGenerator
10 |
11 | func handleAuthCodeTokenRequest(_ request: Request) async throws -> Response {
12 | guard let codeString: String = request.content[OAuthRequestParameters.code] else {
13 | return try tokenResponseGenerator.createResponse(
14 | error: OAuthResponseParameters.ErrorType.invalidRequest,
15 | description: "Request was missing the 'code' parameter")
16 | }
17 |
18 | guard let redirectURI: String = request.content[OAuthRequestParameters.redirectURI] else {
19 | return try tokenResponseGenerator.createResponse(
20 | error: OAuthResponseParameters.ErrorType.invalidRequest,
21 | description: "Request was missing the 'redirect_uri' parameter")
22 | }
23 |
24 | guard let clientID: String = request.content[OAuthRequestParameters.clientID] else {
25 | return try tokenResponseGenerator.createResponse(
26 | error: OAuthResponseParameters.ErrorType.invalidRequest,
27 | description: "Request was missing the 'client_id' parameter")
28 | }
29 |
30 | do {
31 | try await clientValidator.authenticateClient(
32 | clientID: clientID,
33 | clientSecret: request.content[String.self, at: OAuthRequestParameters.clientSecret],
34 | grantType: .authorization)
35 | } catch {
36 | return try tokenResponseGenerator.createResponse(
37 | error: OAuthResponseParameters.ErrorType.invalidClient,
38 | description: "Request had invalid client credentials", status: .unauthorized)
39 | }
40 |
41 | guard let code = try await codeManager.getCode(codeString),
42 | codeValidator.validateCode(code, clientID: clientID, redirectURI: redirectURI)
43 | else {
44 | let errorDescription = "The code provided was invalid or expired, or the redirect URI did not match"
45 | return try tokenResponseGenerator.createResponse(
46 | error: OAuthResponseParameters.ErrorType.invalidGrant,
47 | description: errorDescription)
48 | }
49 |
50 | try await codeManager.codeUsed(code)
51 |
52 | let scopes = code.scopes
53 | let expiryTime = 3600
54 |
55 | let (access, refresh) = try await tokenManager.generateAccessRefreshTokens(
56 | clientID: clientID, userID: code.userID,
57 | scopes: scopes,
58 | accessTokenExpiryTime: expiryTime
59 | )
60 |
61 | return try tokenResponseGenerator.createResponse(
62 | accessToken: access, refreshToken: refresh, expires: Int(expiryTime),
63 | scope: scopes?.joined(separator: " "))
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandlers/ClientCredentialsTokenHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct ClientCredentialsTokenHandler {
4 |
5 | let clientValidator: ClientValidator
6 | let scopeValidator: ScopeValidator
7 | let tokenManager: TokenManager
8 | let tokenResponseGenerator: TokenResponseGenerator
9 |
10 | func handleClientCredentialsTokenRequest(_ request: Request) async throws -> Response {
11 | guard let clientID: String = request.content[OAuthRequestParameters.clientID] else {
12 | return try tokenResponseGenerator.createResponse(
13 | error: OAuthResponseParameters.ErrorType.invalidRequest,
14 | description: "Request was missing the 'client_id' parameter")
15 | }
16 |
17 | guard let clientSecret: String = request.content[OAuthRequestParameters.clientSecret] else {
18 | return try tokenResponseGenerator.createResponse(
19 | error: OAuthResponseParameters.ErrorType.invalidRequest,
20 | description: "Request was missing the 'client_secret' parameter")
21 | }
22 |
23 | do {
24 | try await clientValidator.authenticateClient(
25 | clientID: clientID, clientSecret: clientSecret,
26 | grantType: .clientCredentials, checkConfidentialClient: true)
27 | } catch ClientError.unauthorized {
28 | return try tokenResponseGenerator.createResponse(
29 | error: OAuthResponseParameters.ErrorType.invalidClient,
30 | description: "Request had invalid client credentials", status: .unauthorized)
31 | } catch ClientError.notConfidential {
32 | return try tokenResponseGenerator.createResponse(
33 | error: OAuthResponseParameters.ErrorType.unauthorizedClient,
34 | description: "You are not authorized to use the Client Credentials grant type")
35 | }
36 |
37 | let scopeString = request.content[String.self, at: OAuthRequestParameters.scope]
38 | if let scopes = scopeString {
39 | do {
40 | try await scopeValidator.validateScope(clientID: clientID, scopes: scopes.components(separatedBy: " "))
41 | } catch ScopeError.invalid {
42 | return try tokenResponseGenerator.createResponse(
43 | error: OAuthResponseParameters.ErrorType.invalidScope,
44 | description: "Request contained an invalid scope")
45 | } catch ScopeError.unknown {
46 | return try tokenResponseGenerator.createResponse(
47 | error: OAuthResponseParameters.ErrorType.invalidScope,
48 | description: "Request contained an unknown scope")
49 | }
50 | }
51 |
52 | let expiryTime = 3600
53 | let scopes = scopeString?.components(separatedBy: " ")
54 | let (access, refresh) = try await tokenManager.generateAccessRefreshTokens(
55 | clientID: clientID, userID: nil,
56 | scopes: scopes,
57 | accessTokenExpiryTime: expiryTime)
58 |
59 | return try tokenResponseGenerator.createResponse(
60 | accessToken: access, refreshToken: refresh,
61 | expires: expiryTime, scope: scopeString)
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandlers/PasswordTokenHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct PasswordTokenHandler {
4 |
5 | let clientValidator: ClientValidator
6 | let scopeValidator: ScopeValidator
7 | let userManager: UserManager
8 | let logger: Logger
9 | let tokenManager: TokenManager
10 | let tokenResponseGenerator: TokenResponseGenerator
11 |
12 | func handlePasswordTokenRequest(_ request: Request) async throws -> Response {
13 | guard let username: String = request.content[OAuthRequestParameters.usernname] else {
14 | return try tokenResponseGenerator.createResponse(
15 | error: OAuthResponseParameters.ErrorType.invalidRequest,
16 | description: "Request was missing the 'username' parameter")
17 | }
18 |
19 | guard let password: String = request.content[OAuthRequestParameters.password] else {
20 | return try tokenResponseGenerator.createResponse(
21 | error: OAuthResponseParameters.ErrorType.invalidRequest,
22 | description: "Request was missing the 'password' parameter")
23 | }
24 |
25 | guard let clientID: String = request.content[OAuthRequestParameters.clientID] else {
26 | return try tokenResponseGenerator.createResponse(
27 | error: OAuthResponseParameters.ErrorType.invalidRequest,
28 | description: "Request was missing the 'client_id' parameter")
29 | }
30 |
31 | do {
32 | try await clientValidator.authenticateClient(
33 | clientID: clientID,
34 | clientSecret: request.content[String.self, at: OAuthRequestParameters.clientSecret],
35 | grantType: .password)
36 | } catch ClientError.unauthorized {
37 | return try tokenResponseGenerator.createResponse(
38 | error: OAuthResponseParameters.ErrorType.invalidClient,
39 | description: "Request had invalid client credentials", status: .unauthorized)
40 | } catch ClientError.notFirstParty {
41 | return try tokenResponseGenerator.createResponse(
42 | error: OAuthResponseParameters.ErrorType.unauthorizedClient,
43 | description: "Password Credentials grant is not allowed")
44 | }
45 |
46 | let scopeString = request.content[String.self, at: OAuthRequestParameters.scope]
47 |
48 | if let scopes = scopeString {
49 | do {
50 | try await scopeValidator.validateScope(clientID: clientID, scopes: scopes.components(separatedBy: " "))
51 | } catch ScopeError.invalid {
52 | return try tokenResponseGenerator.createResponse(
53 | error: OAuthResponseParameters.ErrorType.invalidScope,
54 | description: "Request contained an invalid scope")
55 | } catch ScopeError.unknown {
56 | return try tokenResponseGenerator.createResponse(
57 | error: OAuthResponseParameters.ErrorType.invalidScope,
58 | description: "Request contained an unknown scope")
59 | }
60 | }
61 |
62 | guard let userID = try await userManager.authenticateUser(username: username, password: password) else {
63 | logger.warning("LOGIN WARNING: Invalid login attempt for user \(username)")
64 | return try tokenResponseGenerator.createResponse(
65 | error: OAuthResponseParameters.ErrorType.invalidGrant,
66 | description: "Request had invalid credentials")
67 | }
68 |
69 | let expiryTime = 3600
70 | let scopes = scopeString?.components(separatedBy: " ")
71 |
72 | let (access, refresh) = try await tokenManager.generateAccessRefreshTokens(
73 | clientID: clientID, userID: userID,
74 | scopes: scopes,
75 | accessTokenExpiryTime: expiryTime)
76 |
77 | return try tokenResponseGenerator.createResponse(
78 | accessToken: access, refreshToken: refresh,
79 | expires: expiryTime, scope: scopeString)
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandlers/RefreshTokenHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct RefreshTokenHandler {
4 |
5 | let scopeValidator: ScopeValidator
6 | let tokenManager: TokenManager
7 | let clientValidator: ClientValidator
8 | let tokenAuthenticator: TokenAuthenticator
9 | let tokenResponseGenerator: TokenResponseGenerator
10 |
11 | func handleRefreshTokenRequest(_ request: Request) async throws -> Response {
12 |
13 | let (errorResponseReturned, refreshTokenRequestReturned) = try await validateRefreshTokenRequest(request)
14 |
15 | if let errorResponse = errorResponseReturned {
16 | return errorResponse
17 | }
18 |
19 | guard let refreshTokenRequest = refreshTokenRequestReturned else {
20 | throw Abort(.internalServerError)
21 | }
22 |
23 | let scopesString: String? = request.content[OAuthRequestParameters.scope]
24 | var scopesRequested = scopesString?.components(separatedBy: " ")
25 |
26 | if let scopes = scopesRequested {
27 |
28 | do {
29 | try await scopeValidator.validateScope(clientID: refreshTokenRequest.clientID, scopes: scopes)
30 | } catch ScopeError.invalid {
31 | return try tokenResponseGenerator.createResponse(
32 | error: OAuthResponseParameters.ErrorType.invalidScope,
33 | description: "Request contained an invalid scope")
34 | } catch ScopeError.unknown {
35 | return try tokenResponseGenerator.createResponse(
36 | error: OAuthResponseParameters.ErrorType.invalidScope,
37 | description: "Request contained an unknown scope")
38 | }
39 |
40 | if let tokenScopes = refreshTokenRequest.refreshToken.scopes {
41 | for scope in scopes {
42 | if !tokenScopes.contains(scope) {
43 | return try tokenResponseGenerator.createResponse(
44 | error: OAuthResponseParameters.ErrorType.invalidScope,
45 | description: "Request contained elevated scopes")
46 | }
47 | }
48 | } else {
49 | return try tokenResponseGenerator.createResponse(
50 | error: OAuthResponseParameters.ErrorType.invalidScope,
51 | description: "Request contained elevated scopes"
52 | )
53 | }
54 |
55 | try await tokenManager.updateRefreshToken(refreshTokenRequest.refreshToken, scopes: scopes)
56 | } else {
57 | scopesRequested = refreshTokenRequest.refreshToken.scopes
58 | }
59 |
60 | let expiryTime = 3600
61 | let accessToken = try await tokenManager.generateAccessToken(
62 | clientID: refreshTokenRequest.clientID,
63 | userID: refreshTokenRequest.refreshToken.userID,
64 | scopes: scopesRequested, expiryTime: expiryTime
65 | )
66 |
67 | return try tokenResponseGenerator.createResponse(
68 | accessToken: accessToken, refreshToken: nil,
69 | expires: expiryTime, scope: scopesString)
70 | }
71 |
72 | private func validateRefreshTokenRequest(_ request: Request) async throws -> (Response?, RefreshTokenRequest?) {
73 | guard let clientID: String = request.content[OAuthRequestParameters.clientID] else {
74 | let errorResponse = try tokenResponseGenerator.createResponse(
75 | error: OAuthResponseParameters.ErrorType.invalidRequest,
76 | description: "Request was missing the 'client_id' parameter")
77 | return (errorResponse, nil)
78 | }
79 |
80 | guard let clientSecret: String = request.content[OAuthRequestParameters.clientSecret] else {
81 | let errorResponse = try tokenResponseGenerator.createResponse(
82 | error: OAuthResponseParameters.ErrorType.invalidRequest,
83 | description: "Request was missing the 'client_secret' parameter")
84 | return (errorResponse, nil)
85 | }
86 |
87 | do {
88 | try await clientValidator.authenticateClient(
89 | clientID: clientID, clientSecret: clientSecret,
90 | grantType: nil, checkConfidentialClient: true)
91 | } catch ClientError.unauthorized {
92 | let errorResponse = try tokenResponseGenerator.createResponse(
93 | error: OAuthResponseParameters.ErrorType.invalidClient,
94 | description: "Request had invalid client credentials",
95 | status: .unauthorized)
96 | return (errorResponse, nil)
97 | } catch ClientError.notConfidential {
98 | let errorDescription = "You are not authorized to use the Client Credentials grant type"
99 | let errorResponse = try tokenResponseGenerator.createResponse(
100 | error: OAuthResponseParameters.ErrorType.unauthorizedClient,
101 | description: errorDescription)
102 | return (errorResponse, nil)
103 | }
104 |
105 | guard let refreshTokenString: String = request.content[OAuthRequestParameters.refreshToken] else {
106 | let errorResponse = try tokenResponseGenerator.createResponse(
107 | error: OAuthResponseParameters.ErrorType.invalidRequest,
108 | description: "Request was missing the 'refresh_token' parameter")
109 | return (errorResponse, nil)
110 | }
111 |
112 | guard let refreshToken = try await tokenManager.getRefreshToken(refreshTokenString),
113 | tokenAuthenticator.validateRefreshToken(refreshToken, clientID: clientID)
114 | else {
115 | let errorResponse = try tokenResponseGenerator.createResponse(
116 | error: OAuthResponseParameters.ErrorType.invalidGrant,
117 | description: "The refresh token is invalid")
118 | return (errorResponse, nil)
119 | }
120 |
121 | let refreshTokenRequest = RefreshTokenRequest(clientID: clientID, clientSecret: clientSecret, refreshToken: refreshToken)
122 | return (nil, refreshTokenRequest)
123 | }
124 | }
125 |
126 | struct RefreshTokenRequest {
127 | let clientID: String
128 | let clientSecret: String
129 | let refreshToken: RefreshToken
130 | }
131 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenHandlers/TokenResponseGenerator.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct TokenResponseGenerator {
4 | func createResponse(error: String, description: String, status: HTTPStatus = .badRequest) throws -> Response {
5 | let jsonDictionary = [
6 | OAuthResponseParameters.error: error,
7 | OAuthResponseParameters.errorDescription: description,
8 | ]
9 | let json = try JSONSerialization.data(withJSONObject: jsonDictionary)
10 | return try createResponseForToken(status: status, jsonData: json)
11 | }
12 |
13 | func createResponse(
14 | accessToken: AccessToken, refreshToken: RefreshToken?,
15 | expires: Int, scope: String?
16 | ) throws -> Response {
17 | var jsonDictionary =
18 | [
19 | OAuthResponseParameters.tokenType: "bearer",
20 | OAuthResponseParameters.expires: expires,
21 | OAuthResponseParameters.accessToken: accessToken.tokenString,
22 | ] as [String: Any]
23 |
24 | if let refreshToken = refreshToken {
25 | jsonDictionary[OAuthResponseParameters.refreshToken] = refreshToken.tokenString
26 | }
27 |
28 | if let scope = scope {
29 | jsonDictionary[OAuthResponseParameters.scope] = scope
30 | }
31 |
32 | let json = try JSONSerialization.data(withJSONObject: jsonDictionary)
33 | return try createResponseForToken(status: .ok, jsonData: json)
34 | }
35 |
36 | private func createResponseForToken(status: HTTPStatus, jsonData: Data) throws -> Response {
37 | let response = Response(status: status)
38 |
39 | response.body = .init(data: jsonData)
40 | response.headers.contentType = .json
41 |
42 | response.headers.replaceOrAdd(name: "pragma", value: "no-cache")
43 | response.headers.cacheControl = HTTPHeaders.CacheControl(noStore: true)
44 |
45 | return response
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/RouteHandlers/TokenIntrospectionHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct TokenIntrospectionHandler {
4 | let clientValidator: ClientValidator
5 | let tokenManager: TokenManager
6 | let userManager: UserManager
7 |
8 | func handleRequest(_ req: Request) async throws -> Response {
9 |
10 | struct TokenData: Content {
11 | let token: String
12 | }
13 |
14 | let tokenString: String
15 | do {
16 | tokenString = try req.content.decode(TokenData.self).token
17 | } catch {
18 | return try createErrorResponse(
19 | status: .badRequest,
20 | errorMessage: OAuthResponseParameters.ErrorType.missingToken,
21 | errorDescription: "The token parameter is required")
22 | }
23 |
24 | guard let token = try await tokenManager.getAccessToken(tokenString) else {
25 | return try createTokenResponse(active: false, expiryDate: nil, clientID: nil)
26 | }
27 |
28 | guard token.expiryTime >= Date() else {
29 | return try createTokenResponse(active: false, expiryDate: nil, clientID: nil)
30 | }
31 |
32 | let scopes = token.scopes?.joined(separator: " ")
33 | var user: OAuthUser? = nil
34 |
35 | if let userID = token.userID {
36 | if let tokenUser = try await userManager.getUser(userID: userID) {
37 | user = tokenUser
38 | }
39 | }
40 |
41 | return try createTokenResponse(
42 | active: true, expiryDate: token.expiryTime, clientID: token.clientID,
43 | scopes: scopes, user: user)
44 | }
45 |
46 | func createTokenResponse(
47 | active: Bool, expiryDate: Date?, clientID: String?, scopes: String? = nil,
48 | user: OAuthUser? = nil
49 | ) throws -> Response {
50 | var tokenResponse = TokenResponse(
51 | active: active,
52 | scope: scopes,
53 | clientID: clientID,
54 | username: user?.username
55 | )
56 |
57 | if let expiryDate = expiryDate {
58 | tokenResponse.exp = Int(expiryDate.timeIntervalSince1970)
59 | }
60 |
61 | let response = Response(status: .ok)
62 | try response.content.encode(tokenResponse)
63 | return response
64 | }
65 |
66 | func createErrorResponse(status: HTTPStatus, errorMessage: String, errorDescription: String) throws -> Response {
67 | let response = Response(status: status)
68 | try response.content.encode(ErrorResponse(error: errorMessage, errorDescription: errorDescription))
69 | return response
70 | }
71 | }
72 |
73 | extension TokenIntrospectionHandler {
74 | struct ErrorResponse: Content {
75 | var error: String
76 | var errorDescription: String
77 |
78 | enum CodingKeys: String, CodingKey {
79 | case error
80 | case errorDescription = "error_description"
81 | }
82 | }
83 |
84 | struct TokenResponse: Content {
85 | let active: Bool
86 | var scope: String?
87 | var clientID: String?
88 | var username: String?
89 | var exp: Int?
90 |
91 | enum CodingKeys: String, CodingKey {
92 | case active
93 | case scope
94 | case clientID = "client_id"
95 | case username
96 | case exp
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Utilities/OAuthFlowType.swift:
--------------------------------------------------------------------------------
1 | public enum OAuthFlowType: String {
2 | case authorization = "authorization_code"
3 | case implicit = "implicit"
4 | case password = "password"
5 | case clientCredentials = "client_credentials"
6 | case refresh = "refresh_token"
7 | case tokenIntrospection = "token_introspection"
8 | }
9 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Utilities/StringDefines.swift:
--------------------------------------------------------------------------------
1 | enum OAuthRequestParameters {
2 | static let clientID = "client_id"
3 | static let clientSecret = "client_secret"
4 | static let redirectURI = "redirect_uri"
5 | static let responseType = "response_type"
6 | static let scope = "scope"
7 | static let state = "state"
8 | static let applicationAuthorized = "applicationAuthorized"
9 | static let grantType = "grant_type"
10 | static let refreshToken = "refresh_token"
11 | static let code = "code"
12 | static let password = "password"
13 | static let usernname = "username"
14 | static let csrfToken = "csrfToken"
15 | static let token = "token"
16 | }
17 |
18 | enum OAuthResponseParameters {
19 |
20 | static let error = "error"
21 | static let errorDescription = "error_description"
22 | static let tokenType = "token_type"
23 | static let expires = "expires_in"
24 | static let accessToken = "access_token"
25 | static let refreshToken = "refresh_token"
26 | static let scope = "scope"
27 | static let active = "active"
28 | static let clientID = "client_id"
29 | static let userID = "user_id"
30 | static let username = "username"
31 | static let email = "email_address"
32 | static let expiry = "exp"
33 |
34 | enum ErrorType {
35 | static let invalidRequest = "invalid_request"
36 | static let invalidScope = "invalid_scope"
37 | static let invalidClient = "invalid_client"
38 | static let unauthorizedClient = "unauthorized_client"
39 | static let unsupportedGrant = "unsupported_grant_type"
40 | static let invalidGrant = "invalid_grant"
41 | static let missingToken = "missing_token"
42 | }
43 | }
44 |
45 | enum ResponseType {
46 | static let code = "code"
47 | static let token = "token"
48 | }
49 |
50 | enum SessionData {
51 | static let csrfToken = "CSRFToken"
52 | }
53 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Utilities/TokenAuthenticator.swift:
--------------------------------------------------------------------------------
1 | public struct TokenAuthenticator {
2 |
3 | public init() {}
4 |
5 | func validateRefreshToken(_ refreshToken: RefreshToken, clientID: String) -> Bool {
6 | guard refreshToken.clientID == clientID else {
7 | return false
8 | }
9 |
10 | return true
11 | }
12 |
13 | func validateAccessToken(_ accessToken: AccessToken, requiredScopes: [String]?) -> Bool {
14 | guard let scopes = requiredScopes else {
15 | return true
16 | }
17 |
18 | guard let accessTokenScopes = accessToken.scopes else {
19 | return false
20 | }
21 |
22 | for scope in scopes {
23 | if !accessTokenScopes.contains(scope) {
24 | return false
25 | }
26 | }
27 |
28 | return true
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Validators/ClientValidator.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct ClientValidator {
4 |
5 | let clientRetriever: ClientRetriever
6 | let scopeValidator: ScopeValidator
7 | let environment: Environment
8 |
9 | func validateClient(clientID: String, responseType: String, redirectURI: String, scopes: [String]?) async throws {
10 | guard let client = try await clientRetriever.getClient(clientID: clientID) else {
11 | throw AuthorizationError.invalidClientID
12 | }
13 |
14 | if client.confidentialClient ?? false {
15 | guard responseType == ResponseType.code else {
16 | throw AuthorizationError.confidentialClientTokenGrant
17 | }
18 | }
19 |
20 | guard client.validateRedirectURI(redirectURI) else {
21 | throw AuthorizationError.invalidRedirectURI
22 | }
23 |
24 | if responseType == ResponseType.code {
25 | guard client.allowedGrantType == .authorization else {
26 | throw Abort(.forbidden)
27 | }
28 | } else {
29 | guard client.allowedGrantType == .implicit else {
30 | throw Abort(.forbidden)
31 | }
32 | }
33 |
34 | try await scopeValidator.validateScope(clientID: clientID, scopes: scopes)
35 |
36 | let redirectURI = URI(stringLiteral: redirectURI)
37 |
38 | if environment == .production {
39 | if redirectURI.scheme != "https" {
40 | throw AuthorizationError.httpRedirectURI
41 | }
42 | }
43 | }
44 |
45 | func authenticateClient(
46 | clientID: String, clientSecret: String?, grantType: OAuthFlowType?,
47 | checkConfidentialClient: Bool = false
48 | ) async throws {
49 | guard let client = try await clientRetriever.getClient(clientID: clientID) else {
50 | throw ClientError.unauthorized
51 | }
52 |
53 | guard clientSecret == client.clientSecret else {
54 | throw ClientError.unauthorized
55 | }
56 |
57 | if let grantType = grantType {
58 | guard client.allowedGrantType == grantType else {
59 | throw Abort(.forbidden)
60 | }
61 |
62 | if grantType == .password {
63 | guard client.firstParty else {
64 | throw ClientError.notFirstParty
65 | }
66 | }
67 | }
68 |
69 | if checkConfidentialClient {
70 | guard client.confidentialClient ?? false else {
71 | throw ClientError.notConfidential
72 | }
73 | }
74 | }
75 | }
76 |
77 | public enum ClientError: Error {
78 | case unauthorized
79 | case notFirstParty
80 | case notConfidential
81 | }
82 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Validators/CodeValidator.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | struct CodeValidator {
4 | func validateCode(_ code: OAuthCode, clientID: String, redirectURI: String) -> Bool {
5 | guard code.clientID == clientID else {
6 | return false
7 | }
8 |
9 | guard code.expiryDate >= Date() else {
10 | return false
11 | }
12 |
13 | guard code.redirectURI == redirectURI else {
14 | return false
15 | }
16 |
17 | return true
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Validators/ResourceServerAuthenticator.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct ResourceServerAuthenticator {
4 |
5 | let resourceServerRetriever: ResourceServerRetriever
6 |
7 | func authenticate(credentials: BasicAuthorization) async throws {
8 | guard let resourceServer = try await resourceServerRetriever.getServer(credentials.username) else {
9 | throw Abort(.unauthorized)
10 | }
11 |
12 | guard resourceServer.password == credentials.password else {
13 | throw Abort(.unauthorized)
14 | }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/Sources/VaporOAuth/Validators/ScopeValidator.swift:
--------------------------------------------------------------------------------
1 | struct ScopeValidator {
2 | let validScopes: [String]?
3 | let clientRetriever: ClientRetriever
4 |
5 | func validateScope(clientID: String, scopes: [String]?) async throws {
6 | if let requestedScopes = scopes {
7 | let providerScopes = validScopes ?? []
8 |
9 | if !providerScopes.isEmpty {
10 | for scope in requestedScopes {
11 | guard providerScopes.contains(scope) else {
12 | throw ScopeError.unknown
13 | }
14 | }
15 | }
16 |
17 | let client = try await clientRetriever.getClient(clientID: clientID)
18 | if let clientScopes = client?.validScopes {
19 | for scope in requestedScopes {
20 | guard clientScopes.contains(scope) else {
21 | throw ScopeError.invalid
22 | }
23 | }
24 | }
25 | }
26 | }
27 | }
28 |
29 | public enum ScopeError: Error {
30 | case invalid
31 | case unknown
32 | }
33 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Application+testable.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | extension Application {
6 | static func testableWithTester() throws -> (Application, XCTApplicationTester) {
7 | let app = Application(.testing)
8 | do {
9 | let tester = try app.testable()
10 | return (app, tester)
11 | } catch {
12 | app.shutdown()
13 | throw error
14 | }
15 | }
16 |
17 | static func testable() throws -> Application {
18 | let (app, _) = try self.testableWithTester()
19 | return app
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/AuthorizationTests/AuthorizationRequestTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class AuthorizationRequestTests: XCTestCase {
6 |
7 | // MARK: - Properties
8 |
9 | var app: Application!
10 | var fakeClientRetriever: FakeClientGetter!
11 | var capturingAuthoriseHandler: CapturingAuthoriseHandler!
12 |
13 | let clientID = "1234567890"
14 | let redirectURI = "https://api.brokenhands.io/callback"
15 |
16 | // MARK: - Overrides
17 |
18 | override func setUp() async throws {
19 | fakeClientRetriever = FakeClientGetter()
20 | capturingAuthoriseHandler = CapturingAuthoriseHandler()
21 |
22 | let oauthClient = OAuthClient(
23 | clientID: clientID,
24 | redirectURIs: [redirectURI],
25 | allowedGrantType: .authorization
26 | )
27 | fakeClientRetriever.validClients[clientID] = oauthClient
28 |
29 | app = try TestDataBuilder.getOAuth2Application(
30 | clientRetriever: fakeClientRetriever,
31 | authorizeHandler: capturingAuthoriseHandler
32 | )
33 | }
34 |
35 | override func tearDown() async throws {
36 | app.shutdown()
37 | try await super.tearDown()
38 | }
39 |
40 | // MARK: - Tests
41 |
42 | func testThatAuthorizationCodeRequestCallsAuthoriseHandlerWithQueryParameters() async throws {
43 | let responseType = "code"
44 |
45 | _ = try await respondToOAuthRequest(responseType: responseType, clientID: clientID, redirectURI: redirectURI)
46 |
47 | XCTAssertEqual(capturingAuthoriseHandler.responseType, responseType)
48 | XCTAssertEqual(capturingAuthoriseHandler.clientID, clientID)
49 | XCTAssertEqual(capturingAuthoriseHandler.redirectURI, URI(string: redirectURI))
50 | }
51 |
52 | func testThatAuthorizationTokenRequestRedirectsToAuthoriseApplicationPage() async throws {
53 | let responseType = "token"
54 | let implicitClientID = "implicit"
55 | let implicitClient = OAuthClient(
56 | clientID: implicitClientID,
57 | redirectURIs: [redirectURI],
58 | allowedGrantType: .implicit
59 | )
60 | fakeClientRetriever.validClients[implicitClientID] = implicitClient
61 |
62 | _ = try await respondToOAuthRequest(
63 | responseType: responseType,
64 | clientID: implicitClientID,
65 | redirectURI: redirectURI
66 | )
67 |
68 | XCTAssertEqual(capturingAuthoriseHandler.responseType, responseType)
69 | XCTAssertEqual(capturingAuthoriseHandler.clientID, implicitClientID)
70 | XCTAssertEqual(capturingAuthoriseHandler.redirectURI, URI(string: redirectURI))
71 | }
72 |
73 | func testThatAuthorizeRequestResponseTypeRedirectsBackToClientWithErrorCode() async throws {
74 | let response = try await respondToOAuthRequest(
75 | responseType: nil,
76 | clientID: clientID,
77 | redirectURI: redirectURI
78 | )
79 |
80 | XCTAssertEqual(response.status, .seeOther)
81 | XCTAssertEqual(
82 | response.headers.first(name: "location"),
83 | "\(redirectURI)?error=invalid_request&error_description=Request+was+missing+the+response_type+parameter"
84 | )
85 | }
86 |
87 | func testThatBadRequestRedirectsBackToClientRedirectURI() async throws {
88 | let differentURI = "https://api.test.com/cb"
89 | let clientID = "123ABC"
90 | let client = OAuthClient(clientID: clientID, redirectURIs: [differentURI], allowedGrantType: .authorization)
91 | fakeClientRetriever.validClients[clientID] = client
92 |
93 | let response = try await respondToOAuthRequest(
94 | responseType: nil,
95 | clientID: clientID,
96 | redirectURI: differentURI
97 | )
98 |
99 | XCTAssertEqual(
100 | response.headers.first(name: "location"),
101 | "\(differentURI)?error=invalid_request&error_description=Request+was+missing+the+response_type+parameter"
102 | )
103 | }
104 |
105 | func testThatStateProvidedWhenRedirectingForMissingReponseType() async throws {
106 | let state = "xcoivjuywkdkhvusuye3kch"
107 |
108 | let response = try await respondToOAuthRequest(
109 | responseType: nil,
110 | clientID: clientID,
111 | redirectURI: redirectURI,
112 | state: state
113 | )
114 |
115 | XCTAssertTrue(response.headers.location?.value.contains("state=\(state)") ?? false)
116 | }
117 |
118 | func testThatAuthorizeRequestRedirectsBackToClientWithErrorCodeResponseTypeIsNotCodeOrToken() async throws {
119 | let response = try await respondToOAuthRequest(
120 | responseType: "invalid",
121 | clientID: clientID,
122 | redirectURI: redirectURI
123 | )
124 |
125 | XCTAssertEqual(response.status, .seeOther)
126 | XCTAssertEqual(
127 | response.headers.location?.value,
128 | "\(redirectURI)?error=invalid_request&error_description=invalid+response+type"
129 | )
130 | }
131 |
132 | func testThatStateProvidedWhenRedirectingForInvalidReponseType() async throws {
133 | let state = "xcoivjuywkdkhvusuye3kch"
134 |
135 | let response = try await respondToOAuthRequest(
136 | responseType: "invalid",
137 | clientID: clientID,
138 | redirectURI: redirectURI,
139 | state: state
140 | )
141 |
142 | XCTAssertTrue(response.headers.location?.value.contains("state=\(state)") ?? false)
143 | }
144 |
145 | func testThatAuthorizeRequestFailsWithoutClientIDQuery() async throws {
146 | _ = try await respondToOAuthRequest(clientID: nil, redirectURI: redirectURI)
147 |
148 | XCTAssertEqual(capturingAuthoriseHandler.authorizationError, .invalidClientID)
149 | }
150 |
151 | func testThatAuthorizeRequestFailsWithoutRedirectURI() async throws {
152 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: nil)
153 |
154 | XCTAssertEqual(capturingAuthoriseHandler.authorizationError, .invalidRedirectURI)
155 | }
156 |
157 | func testThatSingleScopePassedThroughToAuthorizationHandler() async throws {
158 | let scope = "profile"
159 |
160 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI, scope: scope)
161 |
162 | XCTAssertEqual(capturingAuthoriseHandler.scope?.count, 1)
163 | XCTAssertTrue(capturingAuthoriseHandler.scope?.contains(scope) ?? false)
164 | }
165 |
166 | func testThatMultipleScopesPassedThroughToAuthorizationHandler() async throws {
167 | let scope1 = "profile"
168 | let scope2 = "create"
169 | let scope = "\(scope1)+\(scope2)"
170 |
171 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI, scope: scope)
172 |
173 | XCTAssertEqual(capturingAuthoriseHandler.scope?.count, 2)
174 | XCTAssertTrue(capturingAuthoriseHandler.scope?.contains(scope1) ?? false)
175 | XCTAssertTrue(capturingAuthoriseHandler.scope?.contains(scope2) ?? false)
176 | }
177 |
178 | func testStatePassedThroughToAuthorizationHandler() async throws {
179 | let state = "xcoivjuywkdkhvusuye3kch"
180 |
181 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI, state: state)
182 |
183 | XCTAssertEqual(capturingAuthoriseHandler.state, state)
184 | }
185 |
186 | func testAllPropertiesPassedThroughToAuthorizationHandler() async throws {
187 | let responseType = "code"
188 | let scope1 = "profile"
189 | let scope2 = "create"
190 | let state = "xcoivjuywkdkhvusuye3kch"
191 | let scope = "\(scope1)+\(scope2)"
192 |
193 | _ = try await respondToOAuthRequest(
194 | responseType: responseType,
195 | clientID: clientID,
196 | redirectURI: redirectURI,
197 | scope: scope,
198 | state: state
199 | )
200 |
201 | XCTAssertEqual(capturingAuthoriseHandler.responseType, responseType)
202 | XCTAssertEqual(capturingAuthoriseHandler.clientID, clientID)
203 | XCTAssertEqual(capturingAuthoriseHandler.redirectURI, URI(string: redirectURI))
204 | XCTAssertEqual(capturingAuthoriseHandler.scope?.count, 2)
205 | XCTAssertTrue(capturingAuthoriseHandler.scope?.contains(scope1) ?? false)
206 | XCTAssertTrue(capturingAuthoriseHandler.scope?.contains(scope2) ?? false)
207 | XCTAssertEqual(capturingAuthoriseHandler.state, state)
208 | }
209 |
210 | func testThatAnInvalidClientIDLoadsErrorPage() async throws {
211 | let clientID = "invalid"
212 |
213 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI)
214 |
215 | XCTAssertEqual(capturingAuthoriseHandler.authorizationError, .invalidClientID)
216 | }
217 |
218 | func testThatInvalidRedirectURICallsErrorHandlerWithCorrectError() async throws {
219 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: "http://this.does.not/match")
220 |
221 | XCTAssertEqual(capturingAuthoriseHandler.authorizationError, .invalidRedirectURI)
222 | }
223 |
224 | func testThatUnknownScopeReturnsInvalidScopeError() async throws {
225 | app.shutdown()
226 | app = try TestDataBuilder.getOAuth2Application(
227 | clientRetriever: fakeClientRetriever,
228 | authorizeHandler: capturingAuthoriseHandler,
229 | validScopes: ["email", "profile", "admin"]
230 | )
231 | let invalidScope = "create"
232 |
233 | let response = try await respondToOAuthRequest(
234 | clientID: clientID,
235 | redirectURI: redirectURI,
236 | scope: invalidScope
237 | )
238 |
239 | XCTAssertEqual(response.status, .seeOther)
240 | XCTAssertEqual(
241 | response.headers.location?.value,
242 | "\(redirectURI)?error=invalid_scope&error_description=scope+is+unknown"
243 | )
244 | }
245 |
246 | func testThatClientAccessingScopeItShouldNotReturnsInvalidScopeError() async throws {
247 | let clientID = "ABCDEFGH"
248 | let scopes = ["email", "profile", "admin"]
249 | let invalidScope = "create"
250 | let scopeClient = OAuthClient(
251 | clientID: clientID,
252 | redirectURIs: [redirectURI],
253 | validScopes: scopes,
254 | allowedGrantType: .authorization
255 | )
256 | fakeClientRetriever.validClients[clientID] = scopeClient
257 |
258 | let response = try await respondToOAuthRequest(
259 | clientID: clientID,
260 | redirectURI: redirectURI,
261 | scope: invalidScope
262 | )
263 |
264 | XCTAssertEqual(response.status, .seeOther)
265 | XCTAssertEqual(
266 | response.headers.location?.value,
267 | "\(redirectURI)?error=invalid_scope&error_description=scope+is+invalid"
268 | )
269 | }
270 |
271 | func testConfidentialClientMakingTokenRequestResultsInUnauthorizedClientError() async throws {
272 | let clientID = "ABCDEFGH"
273 | let responseType = "token"
274 | let confidentialClient = OAuthClient(
275 | clientID: clientID,
276 | redirectURIs: [redirectURI],
277 | confidential: true,
278 | allowedGrantType: .authorization
279 | )
280 | fakeClientRetriever.validClients[clientID] = confidentialClient
281 |
282 | let response = try await respondToOAuthRequest(
283 | responseType: responseType,
284 | clientID: clientID,
285 | redirectURI: redirectURI
286 | )
287 |
288 | XCTAssertEqual(response.status, .seeOther)
289 | XCTAssertEqual(
290 | response.headers.location?.value,
291 | "\(redirectURI)?error=unauthorized_client&error_description=token+grant+disabled+for+confidential+clients"
292 | )
293 | }
294 |
295 | func testNonHTTPSRedirectURICanNotBeUsedWhenInProduction() async throws {
296 | app.shutdown()
297 | app = try TestDataBuilder.getOAuth2Application(
298 | clientRetriever: fakeClientRetriever,
299 | authorizeHandler: capturingAuthoriseHandler,
300 | environment: .production
301 | )
302 |
303 | let nonHTTPSRedirectURI = "http://api.brokenhands.io/callback/"
304 | let httpClient = OAuthClient(
305 | clientID: clientID,
306 | redirectURIs: [nonHTTPSRedirectURI],
307 | allowedGrantType: .authorization
308 | )
309 | fakeClientRetriever.validClients[clientID] = httpClient
310 |
311 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: nonHTTPSRedirectURI)
312 |
313 | XCTAssertEqual(capturingAuthoriseHandler.authorizationError, .httpRedirectURI)
314 | }
315 |
316 | func testCSRFTokenProvidedToAuthorizeHandler() async throws {
317 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI)
318 |
319 | XCTAssertNotNil(capturingAuthoriseHandler.csrfToken)
320 | }
321 |
322 | func testCSRFTokenIsDifferentEachTime() async throws {
323 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI)
324 |
325 | let firstToken = capturingAuthoriseHandler.csrfToken
326 |
327 | _ = try await respondToOAuthRequest(clientID: clientID, redirectURI: redirectURI)
328 |
329 | XCTAssertNotEqual(firstToken, capturingAuthoriseHandler.csrfToken)
330 | }
331 |
332 | func testClientNotConfiguredWithAccessToAuthCodeFlowCantAccessItForGet() async throws {
333 | let unauthorizedID = "not-allowed"
334 | let unauthorizedClient = OAuthClient(
335 | clientID: unauthorizedID,
336 | redirectURIs: [redirectURI],
337 | clientSecret: nil,
338 | validScopes: nil,
339 | allowedGrantType: .implicit
340 | )
341 | fakeClientRetriever.validClients[unauthorizedID] = unauthorizedClient
342 |
343 | let response = try await respondToOAuthRequest(clientID: unauthorizedID, redirectURI: redirectURI)
344 |
345 | XCTAssertEqual(response.status, .forbidden)
346 | }
347 |
348 | func testClientConfiguredWithAccessToAuthCodeFlowCanAccessItForGet() async throws {
349 | let authorizedID = "not-allowed"
350 | let authorizedClient = OAuthClient(
351 | clientID: authorizedID,
352 | redirectURIs: [redirectURI],
353 | clientSecret: nil,
354 | validScopes: nil,
355 | allowedGrantType: .authorization
356 | )
357 | fakeClientRetriever.validClients[authorizedID] = authorizedClient
358 |
359 | let response = try await respondToOAuthRequest(
360 | clientID: authorizedID,
361 | redirectURI: redirectURI
362 | )
363 |
364 | XCTAssertEqual(response.status, .ok)
365 | }
366 |
367 | // // MARK: - Private
368 |
369 | private func respondToOAuthRequest(
370 | responseType: String? = "code",
371 | clientID: String?,
372 | redirectURI: String?,
373 | scope: String? = nil,
374 | state: String? = nil
375 | ) async throws -> XCTHTTPResponse {
376 | try await TestDataBuilder.getAuthRequestResponse(
377 | with: app,
378 | responseType: responseType,
379 | clientID: clientID,
380 | redirectURI: redirectURI,
381 | scope: scope,
382 | state: state
383 | )
384 | }
385 |
386 | }
387 |
388 | extension URI: Equatable {
389 | public static func == (lhs: URI, rhs: URI) -> Bool {
390 | return lhs.description == rhs.description
391 | }
392 | }
393 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/AuthorizationTests/AuthorizationResponseTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class AuthorizationResponseTests: XCTestCase {
6 |
7 | // MARK: - Properties
8 |
9 | var app: Application!
10 | var fakeClientRetriever: FakeClientGetter!
11 | var capturingAuthoriseHandler: CapturingAuthoriseHandler!
12 | var fakeCodeManager: FakeCodeManager!
13 |
14 | static let clientID = "1234567890"
15 | static let redirectURI = "https://api.brokenhands.io/callback"
16 |
17 | // let fakeSessions: FakeSessions!
18 | let scope1 = "email"
19 | let scope2 = "address"
20 | let scope3 = "profile"
21 | let sessionID = "the-session-ID"
22 | let csrfToken = "the-csrf-token"
23 |
24 | // MARK: - Overrides
25 |
26 | override func setUp() async throws {
27 | fakeClientRetriever = FakeClientGetter()
28 | capturingAuthoriseHandler = CapturingAuthoriseHandler()
29 | fakeCodeManager = FakeCodeManager()
30 |
31 | let oauthClient = OAuthClient(
32 | clientID: AuthorizationResponseTests.clientID,
33 | redirectURIs: [AuthorizationResponseTests.redirectURI],
34 | validScopes: [scope1, scope2],
35 | allowedGrantType: .authorization
36 | )
37 | fakeClientRetriever.validClients[AuthorizationResponseTests.clientID] = oauthClient
38 | let fakeSessions = FakeSessions(
39 | sessions: [SessionID(string: sessionID): SessionData(initialData: ["CSRFToken": csrfToken])]
40 | )
41 |
42 | app = try TestDataBuilder.getOAuth2Application(
43 | codeManager: fakeCodeManager,
44 | clientRetriever: fakeClientRetriever,
45 | authorizeHandler: capturingAuthoriseHandler,
46 | sessions: fakeSessions,
47 | registeredUsers: [TestDataBuilder.anyOAuthUser()]
48 | )
49 | }
50 |
51 | override func tearDown() async throws {
52 | app.shutdown()
53 | try await super.tearDown()
54 | }
55 |
56 | // MARK: - Tests
57 |
58 | func testThatCorrectErrorCodeReturnedIfUserDoesNotAuthorizeApplication() async throws {
59 | let authorizationDenyResponse = try await getAuthResponse(approve: false)
60 |
61 | XCTAssertEqual(authorizationDenyResponse.status, .seeOther)
62 | XCTAssertEqual(
63 | authorizationDenyResponse.headers.location?.value,
64 | "\(AuthorizationResponseTests.redirectURI)?error=access_denied&error_description=user+denied+the+request"
65 | )
66 | }
67 |
68 | func testThatTheStateIsReturnedIfUserDoesNotAuthorizeApplication() async throws {
69 | let state = "xcoivjuywkdkhvusuye3kch"
70 | let authorizationDenyResponse = try await getAuthResponse(approve: false, state: state)
71 |
72 | XCTAssertEqual(authorizationDenyResponse.status, .seeOther)
73 | XCTAssertEqual(
74 | authorizationDenyResponse.headers.location?.value,
75 | "\(AuthorizationResponseTests.redirectURI)?error=access_denied&error_description=user+denied+the+request&state=\(state)"
76 | )
77 | }
78 |
79 | func testThatRedirectURICanBeConfiguredIfUserDoesNotAuthorizeApplication() async throws {
80 | let clientID = "ABCDEFG"
81 | let redirectURI = "http://new.brokenhands.io/callback"
82 | let client = OAuthClient(clientID: clientID, redirectURIs: [redirectURI], allowedGrantType: .authorization)
83 | fakeClientRetriever.validClients[clientID] = client
84 |
85 | let authorizationDenyResponse = try await getAuthResponse(
86 | approve: false,
87 | clientID: clientID,
88 | redirectURI: redirectURI
89 | )
90 |
91 | XCTAssertEqual(authorizationDenyResponse.status, .seeOther)
92 | XCTAssertEqual(
93 | authorizationDenyResponse.headers.location?.value,
94 | "\(redirectURI)?error=access_denied&error_description=user+denied+the+request")
95 | }
96 |
97 | func testThatAuthorizationApprovalMustBeSentInPostRequest() async throws {
98 | let authorizeResponse = try await getAuthResponse(approve: nil)
99 |
100 | XCTAssertEqual(authorizeResponse.status, .badRequest)
101 | }
102 |
103 | func testThatClientIDMustBeSentToAuthorizeApproval() async throws {
104 | let response = try await getAuthResponse(clientID: nil)
105 |
106 | XCTAssertEqual(response.status, .badRequest)
107 | }
108 |
109 | func testThatRedirectURIMustBeSentToAuthorizeApproval() async throws {
110 | let response = try await getAuthResponse(redirectURI: nil)
111 |
112 | XCTAssertEqual(response.status, .badRequest)
113 | }
114 |
115 | func testThatResponseTypeMustBeSentToAuthorizeApproval() async throws {
116 | let response = try await getAuthResponse(responseType: nil)
117 |
118 | XCTAssertEqual(response.status, .badRequest)
119 | }
120 |
121 | func testThatInvalidClientIDReturnsBadRequest() async throws {
122 | let response = try await getAuthResponse(clientID: "DONOTEXIST")
123 |
124 | XCTAssertEqual(response.status, .badRequest)
125 | }
126 |
127 | func testThatRedirectURIThatDoesNotMatchClientIDReturnsBadRequest() async throws {
128 | let response = try await getAuthResponse(redirectURI: "https://some.invalid.uri")
129 |
130 | XCTAssertEqual(response.status, .badRequest)
131 | }
132 |
133 | func testThatRedirectURIMustBeHTTPSForProduction() async throws {
134 | app.shutdown()
135 |
136 | app = try TestDataBuilder.getOAuth2Application(
137 | clientRetriever: fakeClientRetriever,
138 | authorizeHandler: capturingAuthoriseHandler,
139 | environment: .production,
140 | registeredUsers: [TestDataBuilder.anyOAuthUser()]
141 | )
142 |
143 | try await Task.sleep(nanoseconds: 1) // Without this the tests are crashing (segmentation fault) on ubuntu
144 |
145 | let clientID = "ABCDE1234"
146 | let redirectURI = "http://api.brokenhands.io/callback"
147 | let newClient = OAuthClient(clientID: clientID, redirectURIs: [redirectURI], allowedGrantType: .authorization)
148 | fakeClientRetriever.validClients[clientID] = newClient
149 |
150 | let response = try await getAuthResponse(clientID: clientID, redirectURI: redirectURI)
151 |
152 | XCTAssertEqual(response.status, .badRequest)
153 | }
154 |
155 | func testThatExpectedTokenReturnedForSuccessfulRequest() async throws {
156 | let code = "ABCDEFGHIJKL"
157 | fakeCodeManager.generatedCode = code
158 |
159 | let response = try await getAuthResponse()
160 |
161 | XCTAssertEqual(response.status, .seeOther)
162 | XCTAssertEqual(response.headers.location?.value, "\(AuthorizationResponseTests.redirectURI)?code=\(code)")
163 | }
164 |
165 | func testThatStateReturnedWithCodeIfProvidedInRequest() async throws {
166 | let code = "ABDDJFEIOW432423"
167 | let state = "grugihreiuhgbf8834dscsc"
168 | fakeCodeManager.generatedCode = code
169 |
170 | let response = try await getAuthResponse(state: state)
171 |
172 | XCTAssertEqual(response.headers.location?.value, "\(AuthorizationResponseTests.redirectURI)?code=\(code)&state=\(state)")
173 | }
174 |
175 | func testUserMustBeLoggedInToGetToken() async throws {
176 | let response = try await getAuthResponse(user: nil)
177 |
178 | XCTAssertEqual(response.status, .unauthorized)
179 | }
180 |
181 | func testThatCodeHasUserIDSetOnIt() async throws {
182 | let codeString = "ABCDEFGHIJKL"
183 | fakeCodeManager.generatedCode = codeString
184 | let user = TestDataBuilder.anyOAuthUser()
185 |
186 | _ = try await getAuthResponse(user: user)
187 |
188 | guard let code = fakeCodeManager.getCode(codeString) else {
189 | XCTFail()
190 | return
191 | }
192 |
193 | XCTAssertEqual(code.userID, user.id)
194 | }
195 |
196 | func testThatClientIDSetOnCode() async throws {
197 | _ = try await getAuthResponse()
198 |
199 | guard let code = fakeCodeManager.getCode(fakeCodeManager.generatedCode) else {
200 | XCTFail()
201 | return
202 | }
203 |
204 | XCTAssertEqual(code.clientID, AuthorizationResponseTests.clientID)
205 | }
206 |
207 | func testThatScopeOnCodeIsNilIfNotSupplied() async throws {
208 | _ = try await getAuthResponse(scope: nil)
209 |
210 | guard let code = fakeCodeManager.getCode(fakeCodeManager.generatedCode) else {
211 | XCTFail()
212 | return
213 | }
214 |
215 | XCTAssertNil(code.scopes)
216 | }
217 |
218 | func testThatCorrectScopesSetOnCodeIfSupplied() async throws {
219 | let scope1 = "email"
220 | let scope2 = "address"
221 | _ = try await getAuthResponse(scope: "\(scope1)+\(scope2)")
222 |
223 | guard let code = fakeCodeManager.getCode(fakeCodeManager.generatedCode) else {
224 | XCTFail()
225 | return
226 | }
227 |
228 | XCTAssertEqual(code.scopes ?? [], [scope1, scope2])
229 |
230 | }
231 |
232 | func testThatRedirectURISetOnCodeCorrectly() async throws {
233 | _ = try await getAuthResponse()
234 |
235 | guard let code = fakeCodeManager.getCode(fakeCodeManager.generatedCode) else {
236 | XCTFail()
237 | return
238 | }
239 |
240 | XCTAssertEqual(code.redirectURI, AuthorizationResponseTests.redirectURI)
241 | }
242 |
243 | func testThatBadRequestReturnedForClientRequestingScopesItDoesNotHaveAccessTo() async throws {
244 | let response = try await getAuthResponse(scope: scope3)
245 |
246 | XCTAssertEqual(response.status, .badRequest)
247 | }
248 |
249 | func testThatBadRequestReturnedForClientRequestingUnknownScopes() async throws {
250 | let response = try await getAuthResponse(scope: "some_unkown_scope")
251 |
252 | XCTAssertEqual(response.status, .badRequest)
253 | }
254 |
255 | func testThatCSRFTokenMustBeSubmitted() async throws {
256 | let response = try await getAuthResponse(csrfToken: nil)
257 |
258 | XCTAssertEqual(response.status, .badRequest)
259 | }
260 |
261 | func testThatRequestWithInvalidCSRFTokenFails() async throws {
262 | let response = try await getAuthResponse(csrfToken: "someRandomToken")
263 |
264 | XCTAssertEqual(response.status, .badRequest)
265 | }
266 |
267 | func testThatSessionCookieMustBeSentInRequest() async throws {
268 | let response = try await getAuthResponse(sessionID: nil)
269 |
270 | XCTAssertEqual(response.status, .badRequest)
271 | }
272 |
273 | func testThatValidSessionCookieMustBeSentInRequest() async throws {
274 | let response = try await getAuthResponse(sessionID: "someRandomSession")
275 |
276 | XCTAssertEqual(response.status, .badRequest)
277 | }
278 |
279 | func testClientNotConfiguredWithAccessToAuthCodeFlowCantAccessItForGet() async throws {
280 | let unauthorizedID = "not-allowed"
281 | let unauthorizedClient = OAuthClient(
282 | clientID: unauthorizedID, redirectURIs: [AuthorizationResponseTests.redirectURI], clientSecret: nil, validScopes: nil,
283 | allowedGrantType: .implicit)
284 | fakeClientRetriever.validClients[unauthorizedID] = unauthorizedClient
285 |
286 | let response = try await getAuthResponse(clientID: unauthorizedID)
287 |
288 | XCTAssertEqual(response.status, .forbidden)
289 | }
290 |
291 | func testClientConfiguredWithAccessToAuthCodeFlowCanAccessItForGet() async throws {
292 | let authorizedID = "not-allowed"
293 | let authorizedClient = OAuthClient(
294 | clientID: authorizedID,
295 | redirectURIs: [AuthorizationResponseTests.redirectURI],
296 | clientSecret: nil,
297 | validScopes: nil,
298 | allowedGrantType: .authorization
299 | )
300 | fakeClientRetriever.validClients[authorizedID] = authorizedClient
301 |
302 | let response = try await getAuthResponse(clientID: authorizedID)
303 |
304 | XCTAssertEqual(response.status, .seeOther)
305 | }
306 |
307 | // MARK: - Private
308 |
309 | private func getAuthResponse(
310 | approve: Bool? = true,
311 | clientID: String? = clientID,
312 | redirectURI: String? = redirectURI,
313 | responseType: String? = "code",
314 | scope: String? = nil,
315 | state: String? = nil,
316 | user: OAuthUser? = TestDataBuilder.anyOAuthUser(),
317 | csrfToken: String? = "the-csrf-token",
318 | sessionID: String? = "the-session-ID"
319 | ) async throws -> XCTHTTPResponse {
320 | try await TestDataBuilder.getAuthResponseResponse(
321 | with: app,
322 | approve: approve,
323 | clientID: clientID,
324 | redirectURI: redirectURI,
325 | responseType: responseType,
326 | scope: scope,
327 | state: state,
328 | csrfToken: csrfToken,
329 | user: user,
330 | sessionID: sessionID
331 | )
332 | }
333 | }
334 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/DefaultImplementationTests/DefaultImplementationTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class DefaultImplementationTests: XCTestCase {
6 | // MARK: - Tests
7 | func testThatEmptyResourceServerRetrieverReturnsNilWhenGettingResourceServer() async throws {
8 | let emptyResourceServerRetriever = EmptyResourceServerRetriever()
9 |
10 | let server = try await emptyResourceServerRetriever.getServer("some username")
11 | XCTAssertNil(server)
12 | }
13 |
14 | func testThatEmptyUserManagerReturnsNilWhenAttemptingToAuthenticate() async throws {
15 | let emptyUserManager = EmptyUserManager()
16 | let token = try await emptyUserManager.authenticateUser(username: "username", password: "password")
17 | XCTAssertNil(token)
18 | }
19 |
20 | func testThatEmptyUserManagerReturnsNilWhenTryingToGetUser() async throws {
21 | let emptyUserManager = EmptyUserManager()
22 | let id = "some-id"
23 | let user = try await emptyUserManager.getUser(userID: id)
24 | XCTAssertNil(user)
25 | }
26 |
27 | func testThatEmptyAuthHandlerReturnsEmptyStringWhenHandlingAuthError() async throws {
28 | let emptyAuthHandler = EmptyAuthorizationHandler()
29 |
30 | let body = try await emptyAuthHandler.handleAuthorizationError(.invalidClientID).body
31 |
32 | XCTAssertEqual(body.string, "")
33 | }
34 |
35 | func testThatEmptyAuthHandlerReturnsEmptyStringWhenHandlingAuthRequest() async throws {
36 | let emptyAuthHandler = EmptyAuthorizationHandler()
37 | let app = try Application.testable()
38 | defer { app.shutdown() }
39 |
40 | let request = Request(application: app, method: .POST, url: "/oauth/auth/", on: app.eventLoopGroup.next())
41 | let uri: URI = "https://api.brokenhands.io/callback"
42 | let authRequestObject = AuthorizationRequestObject(
43 | responseType: "token",
44 | clientID: "client-ID",
45 | redirectURI: uri,
46 | scope: ["email"],
47 | state: "abcdef",
48 | csrfToken: "01234"
49 | )
50 |
51 | let body = try await emptyAuthHandler.handleAuthorizationRequest(
52 | request,
53 | authorizationRequestObject: authRequestObject
54 | ).body
55 |
56 | XCTAssertEqual(body.string, "")
57 | }
58 |
59 | func testThatEmptyCodeManagerReturnsNilWhenGettingCode() {
60 | let emptyCodeManager = EmptyCodeManager()
61 | XCTAssertNil(emptyCodeManager.getCode("code"))
62 | }
63 |
64 | func testThatEmptyCodeManagerGeneratesEmptyStringAsCode() throws {
65 | let emptyCodeManager = EmptyCodeManager()
66 | let id: String = "identifier"
67 | XCTAssertEqual(
68 | try emptyCodeManager.generateCode(
69 | userID: id,
70 | clientID: "client-id",
71 | redirectURI: "https://api.brokenhands.io/callback",
72 | scopes: nil
73 | ),
74 | ""
75 | )
76 | }
77 |
78 | func testThatCodeUsedDoesNothingInEmptyCodeManager() {
79 | let emptyCodeManager = EmptyCodeManager()
80 | let id = "identifier"
81 | let code = OAuthCode(
82 | codeID: "id",
83 | clientID: "client-id",
84 | redirectURI: "https://api.brokenhands.io/callback",
85 | userID: id,
86 | expiryDate: Date(),
87 | scopes: nil
88 | )
89 | emptyCodeManager.codeUsed(code)
90 | }
91 |
92 | }
93 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/AccessToken.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | @testable import VaporOAuth
4 |
5 | struct FakeAccessToken: AccessToken {
6 | let tokenString: String
7 | let clientID: String
8 | let userID: String?
9 | let scopes: [String]?
10 | let expiryTime: Date
11 |
12 | init(tokenString: String, clientID: String, userID: String? = nil, scopes: [String]? = nil, expiryTime: Date) {
13 | self.tokenString = tokenString
14 | self.clientID = clientID
15 | self.userID = userID
16 | self.scopes = scopes
17 | self.expiryTime = expiryTime
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/CapturingAuthorizeHandler.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 | import VaporOAuth
3 |
4 | class CapturingAuthoriseHandler: AuthorizeHandler {
5 | private(set) var request: Request?
6 | private(set) var responseType: String?
7 | private(set) var clientID: String?
8 | private(set) var redirectURI: URI?
9 | private(set) var scope: [String]?
10 | private(set) var state: String?
11 | private(set) var csrfToken: String?
12 |
13 | func handleAuthorizationRequest(
14 | _ request: Request,
15 | authorizationRequestObject: AuthorizationRequestObject
16 | ) async throws -> Response {
17 | self.request = request
18 | self.responseType = authorizationRequestObject.responseType
19 | self.clientID = authorizationRequestObject.clientID
20 | self.redirectURI = authorizationRequestObject.redirectURI
21 | self.scope = authorizationRequestObject.scope
22 | self.state = authorizationRequestObject.state
23 | self.csrfToken = authorizationRequestObject.csrfToken
24 |
25 | return Response(body: .init(string: "Allow/Deny"))
26 | }
27 |
28 | private(set) var authorizationError: AuthorizationError?
29 | func handleAuthorizationError(_ errorType: AuthorizationError) async throws -> Response {
30 | authorizationError = errorType
31 | return Response(body: .init(string: "Error"))
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/CapturingLogger.swift:
--------------------------------------------------------------------------------
1 | import Logging
2 |
3 | class CapturingLogger: LogHandler {
4 | static var shared: CapturingLogger = CapturingLogger()
5 |
6 | subscript(metadataKey key: String) -> Logging.Logger.Metadata.Value? {
7 | get { metadata[key] }
8 | set(newValue) { metadata[key] = newValue }
9 | }
10 |
11 | var metadata: Logging.Logger.Metadata = [:]
12 | var logLevel: Logging.Logger.Level = .trace
13 | private(set) var logMessage: String?
14 |
15 | func log(
16 | level: Logger.Level,
17 | message: Logger.Message,
18 | metadata: Logger.Metadata?,
19 | source: String,
20 | file: String,
21 | function: String,
22 | line: UInt
23 | ) {
24 | logLevel = level
25 | logMessage = "\(message)"
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeAuthenticationMiddleware.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | @testable import VaporOAuth
4 |
5 | struct FakeAuthenticationMiddleware: AsyncMiddleware {
6 | typealias User = OAuthUser
7 |
8 | private let allowedUsers: [OAuthUser]
9 |
10 | init(allowedUsers: [OAuthUser]) {
11 | self.allowedUsers = allowedUsers
12 | }
13 |
14 | func respond(to request: Vapor.Request, chainingTo next: Vapor.AsyncResponder) async throws -> Vapor.Response {
15 | if let basicAuth = request.headers.basicAuthorization,
16 | let user = allowedUsers.first(
17 | where: { $0.username == basicAuth.username && $0.password == basicAuth.password }
18 | )
19 | {
20 | request.auth.login(user)
21 | }
22 | return try await next.respond(to: request)
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeClientGetter.swift:
--------------------------------------------------------------------------------
1 | import VaporOAuth
2 |
3 | class FakeClientGetter: ClientRetriever {
4 |
5 | var validClients: [String: OAuthClient] = [:]
6 |
7 | func getClient(clientID: String) async throws -> OAuthClient? {
8 | return validClients[clientID]
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeCodeManager.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import VaporOAuth
3 |
4 | class FakeCodeManager: CodeManager {
5 |
6 | private(set) var usedCodes: [String] = []
7 | var codes: [String: OAuthCode] = [:]
8 | var generatedCode = UUID().uuidString
9 |
10 | func getCode(_ code: String) -> OAuthCode? {
11 | return codes[code]
12 | }
13 |
14 | func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?) throws -> String {
15 | let code = OAuthCode(
16 | codeID: generatedCode, clientID: clientID, redirectURI: redirectURI, userID: userID, expiryDate: Date().addingTimeInterval(60),
17 | scopes: scopes)
18 | codes[generatedCode] = code
19 | return generatedCode
20 | }
21 |
22 | func codeUsed(_ code: OAuthCode) {
23 | usedCodes.append(code.codeID)
24 | codes.removeValue(forKey: code.codeID)
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeResourceServerRetriever.swift:
--------------------------------------------------------------------------------
1 | import VaporOAuth
2 |
3 | class FakeResourceServerRetriever: ResourceServerRetriever {
4 |
5 | var resourceServers: [String: OAuthResourceServer] = [:]
6 |
7 | func getServer(_ username: String) async throws -> OAuthResourceServer? {
8 | return resourceServers[username]
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeSessions.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | struct FakeSessions: SessionDriver {
4 | var sessions: [SessionID: SessionData] = [:]
5 |
6 | func createSession(_ data: SessionData, for request: Request) -> EventLoopFuture {
7 | return request.eventLoop.makeSucceededFuture(.init(string: ""))
8 | }
9 |
10 | func readSession(_ sessionID: SessionID, for request: Request) -> EventLoopFuture {
11 | return request.eventLoop.makeSucceededFuture(sessions[sessionID])
12 | }
13 |
14 | func updateSession(_ sessionID: SessionID, to data: SessionData, for request: Request) -> EventLoopFuture {
15 | return request.eventLoop.makeSucceededFuture(.init(string: ""))
16 | }
17 |
18 | func deleteSession(_ sessionID: SessionID, for request: Request) -> EventLoopFuture {
19 | return request.eventLoop.makeSucceededFuture(())
20 | }
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeTokenManager.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import VaporOAuth
3 |
4 | class FakeTokenManager: TokenManager {
5 |
6 | var accessTokenToReturn = "ACCESS-TOKEN-STRING"
7 | var refreshTokenToReturn = "REFRESH-TOKEN-STRING"
8 | var refreshTokens: [String: RefreshToken] = [:]
9 | var accessTokens: [String: AccessToken] = [:]
10 | var currentTime = Date()
11 |
12 | func getRefreshToken(_ refreshToken: String) -> RefreshToken? {
13 | return refreshTokens[refreshToken]
14 | }
15 |
16 | func getAccessToken(_ accessToken: String) -> AccessToken? {
17 | return accessTokens[accessToken]
18 | }
19 |
20 | func generateAccessRefreshTokens(clientID: String, userID: String?, scopes: [String]?, accessTokenExpiryTime: Int) throws -> (
21 | AccessToken, RefreshToken
22 | ) {
23 | let accessToken = FakeAccessToken(
24 | tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes,
25 | expiryTime: currentTime.addingTimeInterval(TimeInterval(accessTokenExpiryTime)))
26 | let refreshToken = FakeRefreshToken(tokenString: refreshTokenToReturn, clientID: clientID, userID: userID, scopes: scopes)
27 |
28 | accessTokens[accessTokenToReturn] = accessToken
29 | refreshTokens[refreshTokenToReturn] = refreshToken
30 | return (accessToken, refreshToken)
31 | }
32 |
33 | func generateAccessToken(clientID: String, userID: String?, scopes: [String]?, expiryTime: Int) throws -> AccessToken {
34 | let accessToken = FakeAccessToken(
35 | tokenString: accessTokenToReturn, clientID: clientID, userID: userID, scopes: scopes,
36 | expiryTime: currentTime.addingTimeInterval(TimeInterval(expiryTime)))
37 | accessTokens[accessTokenToReturn] = accessToken
38 | return accessToken
39 | }
40 |
41 | func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) {
42 | var tempRefreshToken = refreshToken
43 | tempRefreshToken.scopes = scopes
44 | refreshTokens[refreshToken.tokenString] = tempRefreshToken
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/FakeUserManager.swift:
--------------------------------------------------------------------------------
1 | import VaporOAuth
2 |
3 | class FakeUserManager: UserManager {
4 | var users: [OAuthUser] = []
5 |
6 | func authenticateUser(username: String, password: String) -> String? {
7 | for user in users {
8 | if user.username == username {
9 | if user.password == password {
10 | return user.id
11 | }
12 | }
13 | }
14 |
15 | return nil
16 | }
17 |
18 | func getUser(userID: String) -> OAuthUser? {
19 | for user in users {
20 | if user.id == userID {
21 | return user
22 | }
23 | }
24 | return nil
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/RefreshToken.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | @testable import VaporOAuth
4 |
5 | public struct FakeRefreshToken: RefreshToken {
6 | public var tokenString: String
7 | public var clientID: String
8 | public var userID: String?
9 | public var scopes: [String]?
10 | }
11 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/StubCodeManager.swift:
--------------------------------------------------------------------------------
1 | import VaporOAuth
2 |
3 | class StubCodeManager: CodeManager {
4 |
5 | var codeToReturn = "ABCDEFHIJKLMNO"
6 |
7 | func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?) throws -> String {
8 | return codeToReturn
9 | }
10 |
11 | func getCode(_ code: String) -> OAuthCode? {
12 | return nil
13 | }
14 |
15 | func codeUsed(_ code: OAuthCode) {
16 |
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/StubTokenManager.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import VaporOAuth
3 |
4 | class StubTokenManager: TokenManager {
5 |
6 | var accessToken = "ABCDEF"
7 | var refreshToken = "GHIJKL"
8 |
9 | func generateAccessRefreshTokens(clientID: String, userID: String?, scopes: [String]?, accessTokenExpiryTime: Int) throws -> (
10 | AccessToken, RefreshToken
11 | ) {
12 | let access = FakeAccessToken(tokenString: accessToken, clientID: clientID, userID: userID, scopes: scopes, expiryTime: Date())
13 | let refresh = FakeRefreshToken(tokenString: refreshToken, clientID: clientID, userID: nil, scopes: scopes)
14 | return (access, refresh)
15 | }
16 |
17 | func generateAccessToken(clientID: String, userID: String?, scopes: [String]?, expiryTime: Int) throws -> AccessToken {
18 | return FakeAccessToken(tokenString: accessToken, clientID: clientID, userID: userID, scopes: scopes, expiryTime: Date())
19 | }
20 |
21 | func getRefreshToken(_ refreshToken: String) -> RefreshToken? {
22 | return nil
23 | }
24 |
25 | func getAccessToken(_ accessToken: String) -> AccessToken? {
26 | return nil
27 | }
28 |
29 | func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) {
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Fakes/StubUserManager.swift:
--------------------------------------------------------------------------------
1 | import VaporOAuth
2 |
3 | struct StubUserManager: UserManager {
4 | func authenticateUser(username: String, password: String) -> String? {
5 | return nil
6 | }
7 |
8 | func getUser(userID: String) -> OAuthUser? {
9 | return nil
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/GrantTests/ClientCredentialsTokenTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class ClientCredentialsTokenTests: XCTestCase {
6 | // MARK: - Properties
7 | var app: Application!
8 | var fakeClientGetter: FakeClientGetter!
9 | var fakeTokenManager: FakeTokenManager!
10 |
11 | let testClientID = "ABCDEF"
12 | let testClientSecret = "01234567890"
13 | let accessToken = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
14 | let refreshToken = "ABCDEFGHIJLMNOP1234567890"
15 | let scope1 = "email"
16 | let scope2 = "create"
17 | let scope3 = "edit"
18 |
19 | // MARK: - Overrides
20 | override func setUp() async throws {
21 | fakeClientGetter = FakeClientGetter()
22 | fakeTokenManager = FakeTokenManager()
23 |
24 | let oauthClient = OAuthClient(
25 | clientID: testClientID,
26 | redirectURIs: nil,
27 | clientSecret: testClientSecret,
28 | validScopes: [scope1, scope2],
29 | confidential: true,
30 | allowedGrantType: .clientCredentials
31 | )
32 |
33 | fakeClientGetter.validClients[testClientID] = oauthClient
34 | fakeTokenManager.accessTokenToReturn = accessToken
35 | fakeTokenManager.refreshTokenToReturn = refreshToken
36 |
37 | app = try TestDataBuilder.getOAuth2Application(
38 | tokenManager: fakeTokenManager,
39 | clientRetriever: fakeClientGetter,
40 | validScopes: [scope1, scope2, scope3]
41 | )
42 | }
43 |
44 | override func tearDown() async throws {
45 | app.shutdown()
46 | try await super.tearDown()
47 | }
48 |
49 | func testCorrectErrorWhenGrantTypeNotSupplied() async throws {
50 | let response = try await getClientCredentialsResponse(grantType: nil)
51 |
52 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
53 |
54 | XCTAssertEqual(response.status, .badRequest)
55 | XCTAssertEqual(responseJSON.error, "invalid_request")
56 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'grant_type' parameter")
57 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
58 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
59 | }
60 |
61 | func testCorrectErrorAndHeadersReceivedWhenIncorrectGrantTypeSet() async throws {
62 | let grantType = "some_unknown_type"
63 | let response = try await getClientCredentialsResponse(grantType: grantType)
64 |
65 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
66 |
67 | XCTAssertEqual(response.status, .badRequest)
68 | XCTAssertEqual(responseJSON.error, "unsupported_grant_type")
69 | XCTAssertEqual(responseJSON.errorDescription, "This server does not support the '\(grantType)' grant type")
70 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
71 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
72 | }
73 |
74 | func testCorrectErrorWhenClientIDNotSupplied() async throws {
75 | let response = try await getClientCredentialsResponse(clientID: nil)
76 |
77 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
78 |
79 | XCTAssertEqual(response.status, .badRequest)
80 | XCTAssertEqual(responseJSON.error, "invalid_request")
81 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'client_id' parameter")
82 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
83 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
84 | }
85 |
86 | func testCorrectErrorWhenClientIDNotValid() async throws {
87 | let response = try await getClientCredentialsResponse(clientID: "UNKNOWN_CLIENT")
88 |
89 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
90 |
91 | XCTAssertEqual(response.status, .unauthorized)
92 | XCTAssertEqual(responseJSON.error, "invalid_client")
93 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
94 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
95 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
96 | }
97 |
98 | func testCorrectErrorWhenClientDoesNotAuthenticate() async throws {
99 | let response = try await getClientCredentialsResponse(clientSecret: "incorrectPassword")
100 |
101 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
102 |
103 | XCTAssertEqual(response.status, .unauthorized)
104 | XCTAssertEqual(responseJSON.error, "invalid_client")
105 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
106 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
107 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
108 | }
109 |
110 | func testCorrectErrorIfClientSecretNotSent() async throws {
111 | let response = try await getClientCredentialsResponse(clientSecret: nil)
112 |
113 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
114 |
115 | XCTAssertEqual(response.status, .badRequest)
116 | XCTAssertEqual(responseJSON.error, "invalid_request")
117 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'client_secret' parameter")
118 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
119 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
120 | }
121 |
122 | func testThatTokenReceivedIfClientAuthenticated() async throws {
123 | let response = try await getClientCredentialsResponse()
124 |
125 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
126 |
127 | XCTAssertEqual(response.status, .ok)
128 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
129 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
130 | XCTAssertEqual(responseJSON.tokenType, "bearer")
131 | XCTAssertEqual(responseJSON.expiresIn, 3600)
132 | XCTAssertEqual(responseJSON.accessToken, accessToken)
133 | XCTAssertEqual(responseJSON.refreshToken, refreshToken)
134 | }
135 |
136 | func testScopeSetOnTokenIfRequested() async throws {
137 | let scope = "email create"
138 |
139 | let response = try await getClientCredentialsResponse(scope: scope)
140 |
141 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
142 |
143 | XCTAssertEqual(response.status, .ok)
144 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
145 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
146 | XCTAssertEqual(responseJSON.tokenType, "bearer")
147 | XCTAssertEqual(responseJSON.expiresIn, 3600)
148 | XCTAssertEqual(responseJSON.accessToken, accessToken)
149 | XCTAssertEqual(responseJSON.refreshToken, refreshToken)
150 | XCTAssertEqual(responseJSON.scope, scope)
151 |
152 | guard let accessToken = fakeTokenManager.getAccessToken(accessToken),
153 | let refreshToken = fakeTokenManager.getRefreshToken(refreshToken)
154 | else {
155 | XCTFail()
156 | return
157 | }
158 |
159 | XCTAssertEqual(accessToken.scopes ?? [], ["email", "create"])
160 | XCTAssertEqual(refreshToken.scopes ?? [], ["email", "create"])
161 | }
162 |
163 | func testCorrectErrorWhenReqeustingScopeApplicationDoesNotHaveAccessTo() async throws {
164 | let scope = "email edit"
165 |
166 | let response = try await getClientCredentialsResponse(scope: scope)
167 |
168 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
169 |
170 | XCTAssertEqual(response.status, .badRequest)
171 | XCTAssertEqual(responseJSON.error, "invalid_scope")
172 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an invalid scope")
173 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
174 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
175 | }
176 |
177 | func testCorrectErrorWhenRequestingUnknownScope() async throws {
178 | let scope = "email unknown"
179 |
180 | let response = try await getClientCredentialsResponse(scope: scope)
181 |
182 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
183 |
184 | XCTAssertEqual(response.status, .badRequest)
185 | XCTAssertEqual(responseJSON.error, "invalid_scope")
186 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an unknown scope")
187 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
188 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
189 | }
190 |
191 | func testCorrectErrorWhenNonConfidentialClientTriesToUseCredentialsGrantType() async throws {
192 | let newClientID = "1234"
193 | let newClientSecret = "1234567899"
194 | let newClient = OAuthClient(
195 | clientID: newClientID, redirectURIs: nil, clientSecret: newClientSecret, confidential: false,
196 | allowedGrantType: .clientCredentials)
197 | fakeClientGetter.validClients[newClientID] = newClient
198 |
199 | let response = try await getClientCredentialsResponse(clientID: newClientID, clientSecret: newClientSecret)
200 |
201 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
202 |
203 | XCTAssertEqual(response.status, .badRequest)
204 | XCTAssertEqual(responseJSON.error, "unauthorized_client")
205 | XCTAssertEqual(responseJSON.errorDescription, "You are not authorized to use the Client Credentials grant type")
206 | XCTAssertTrue(response.headers.cacheControl?.noStore ?? false)
207 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
208 | }
209 |
210 | func testAccessTokenHasCorrectExpiryTime() async throws {
211 | let currentTime = Date()
212 | fakeTokenManager.currentTime = currentTime
213 |
214 | let response = try await getClientCredentialsResponse()
215 |
216 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
217 |
218 | guard let accessTokenString = responseJSON.accessToken else {
219 | XCTFail()
220 | return
221 | }
222 |
223 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
224 | XCTFail()
225 | return
226 | }
227 |
228 | XCTAssertEqual(accessToken.expiryTime, currentTime.addingTimeInterval(3600))
229 | }
230 |
231 | func testClientIDSetOnAccessTokenCorrectly() async throws {
232 | let newClientString = "a-new-client"
233 | let newClient = OAuthClient(
234 | clientID: newClientString, redirectURIs: nil, clientSecret: testClientSecret, validScopes: [scope1, scope2], confidential: true,
235 | allowedGrantType: .clientCredentials)
236 | fakeClientGetter.validClients[newClientString] = newClient
237 |
238 | let response = try await getClientCredentialsResponse(clientID: newClientString)
239 |
240 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
241 |
242 | guard let accessTokenString = responseJSON.accessToken else {
243 | XCTFail()
244 | return
245 | }
246 |
247 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
248 | XCTFail()
249 | return
250 | }
251 |
252 | XCTAssertEqual(accessToken.clientID, newClientString)
253 | }
254 |
255 | func testThatRefreshTokenHasCorrectClientIDSet() async throws {
256 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
257 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
258 |
259 | _ = try await getClientCredentialsResponse()
260 |
261 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
262 | XCTFail()
263 | return
264 | }
265 |
266 | XCTAssertEqual(refreshToken.clientID, testClientID)
267 | }
268 |
269 | func testThatRefreshTokenHasNoScopesIfNoneRequested() async throws {
270 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
271 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
272 |
273 | _ = try await getClientCredentialsResponse(scope: nil)
274 |
275 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
276 | XCTFail()
277 | return
278 | }
279 |
280 | XCTAssertNil(refreshToken.scopes)
281 | }
282 |
283 | func testThatRefreshTokenHasCorrectScopesIfSet() async throws {
284 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
285 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
286 |
287 | _ = try await getClientCredentialsResponse(scope: "email create")
288 |
289 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
290 | XCTFail()
291 | return
292 | }
293 |
294 | XCTAssertEqual(refreshToken.scopes ?? [], ["email", "create"])
295 | }
296 |
297 | func testNoUserIDSetOnRefreshToken() async throws {
298 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
299 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
300 |
301 | _ = try await getClientCredentialsResponse()
302 |
303 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
304 | XCTFail()
305 | return
306 | }
307 |
308 | XCTAssertNil(refreshToken.userID)
309 | }
310 |
311 | func testClientNotConfiguredWithAccessToClientCredentialsFlowCantAccessIt() async throws {
312 | let unauthorizedID = "not-allowed"
313 | let unauthorizedSecret = "client-secret"
314 | let unauthorizedClient = OAuthClient(
315 | clientID: unauthorizedID, redirectURIs: nil, clientSecret: unauthorizedSecret, validScopes: nil, confidential: true,
316 | firstParty: true, allowedGrantType: .refresh)
317 | fakeClientGetter.validClients[unauthorizedID] = unauthorizedClient
318 |
319 | let response = try await getClientCredentialsResponse(clientID: unauthorizedID, clientSecret: unauthorizedSecret)
320 |
321 | XCTAssertEqual(response.status, .forbidden)
322 | }
323 |
324 | func testClientConfiguredWithAccessToClientCredentialsFlowCanAccessIt() async throws {
325 | let authorizedID = "not-allowed"
326 | let authorizedSecret = "client-secret"
327 | let authorizedClient = OAuthClient(
328 | clientID: authorizedID, redirectURIs: nil, clientSecret: authorizedSecret, validScopes: nil, confidential: true,
329 | firstParty: true, allowedGrantType: .clientCredentials)
330 | fakeClientGetter.validClients[authorizedID] = authorizedClient
331 |
332 | let response = try await getClientCredentialsResponse(clientID: authorizedID, clientSecret: authorizedSecret)
333 |
334 | XCTAssertEqual(response.status, .ok)
335 | }
336 |
337 | // MARK: - Private
338 |
339 | func getClientCredentialsResponse(
340 | grantType: String? = "client_credentials",
341 | clientID: String? = "ABCDEF",
342 | clientSecret: String? = "01234567890",
343 | scope: String? = nil
344 | ) async throws -> XCTHTTPResponse {
345 | return try await TestDataBuilder.getTokenRequestResponse(
346 | with: app,
347 | grantType: grantType,
348 | clientID: clientID,
349 | clientSecret: clientSecret,
350 | scope: scope
351 | )
352 | }
353 |
354 | }
355 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/GrantTests/PasswordGrantTokenTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class PasswordGrantTokenTests: XCTestCase {
6 | // MARK: - Properties
7 | var app: Application!
8 | var fakeClientGetter: FakeClientGetter!
9 | var fakeUserManager: FakeUserManager!
10 | var fakeTokenManager: FakeTokenManager!
11 | var capturingLogger: CapturingLogger!
12 | let testClientID = "ABCDEF"
13 | let testClientSecret = "01234567890"
14 | let testUsername = "testUser"
15 | let testPassword = "testPassword"
16 | let testUserID = "ABCD-FJUH-31232"
17 | let accessToken = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
18 | let refreshToken = "ABCDEFGHIJLMNOP1234567890"
19 | let scope1 = "email"
20 | let scope2 = "create"
21 | let scope3 = "edit"
22 |
23 | // MARK: - Overrides
24 | override class func setUp() {
25 | super.setUp()
26 | LoggingSystem.bootstrap { _ in
27 | CapturingLogger.shared
28 | }
29 | }
30 |
31 | override func setUp() {
32 | fakeClientGetter = FakeClientGetter()
33 | fakeUserManager = FakeUserManager()
34 | fakeTokenManager = FakeTokenManager()
35 | capturingLogger = .shared
36 |
37 | app = try! TestDataBuilder.getOAuth2Application(
38 | tokenManager: fakeTokenManager,
39 | clientRetriever: fakeClientGetter,
40 | userManager: fakeUserManager,
41 | validScopes: [scope1, scope2, scope3],
42 | logger: capturingLogger
43 | )
44 |
45 | let testClient = OAuthClient(
46 | clientID: testClientID,
47 | redirectURIs: nil,
48 | clientSecret: testClientSecret,
49 | validScopes: [scope1, scope2],
50 | firstParty: true,
51 | allowedGrantType: .password
52 | )
53 |
54 | fakeClientGetter.validClients[testClientID] = testClient
55 | let testUser = OAuthUser(userID: testUserID, username: testUsername, emailAddress: nil, password: testPassword)
56 | fakeUserManager.users.append(testUser)
57 | fakeTokenManager.accessTokenToReturn = accessToken
58 | fakeTokenManager.refreshTokenToReturn = refreshToken
59 | }
60 |
61 | override func tearDown() async throws {
62 | app.shutdown()
63 | try await super.tearDown()
64 | }
65 |
66 | // MARK: - Tests
67 |
68 | func testCorrectErrorWhenGrantTypeNotSupplied() async throws {
69 | let response = try await getPasswordResponse(grantType: nil)
70 |
71 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
72 |
73 | XCTAssertEqual(response.status, .badRequest)
74 | XCTAssertEqual(responseJSON.error, "invalid_request")
75 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'grant_type' parameter")
76 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
77 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
78 | }
79 |
80 | func testCorrectErrorAndHeadersReceivedWhenIncorrectGrantTypeSet() async throws {
81 | let grantType = "some_unknown_type"
82 | let response = try await getPasswordResponse(grantType: grantType)
83 |
84 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
85 |
86 | XCTAssertEqual(response.status, .badRequest)
87 | XCTAssertEqual(responseJSON.error, "unsupported_grant_type")
88 | XCTAssertEqual(responseJSON.errorDescription, "This server does not support the '\(grantType)' grant type")
89 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
90 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
91 | }
92 |
93 | func testCorrectErrorWhenUsernameNotSupplied() async throws {
94 | let response = try await getPasswordResponse(username: nil)
95 |
96 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
97 |
98 | XCTAssertEqual(response.status, .badRequest)
99 | XCTAssertEqual(responseJSON.error, "invalid_request")
100 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'username' parameter")
101 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
102 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
103 | }
104 |
105 | func testCorrectErrorWhenPasswordNotSupplied() async throws {
106 | let response = try await getPasswordResponse(password: nil)
107 |
108 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
109 |
110 | XCTAssertEqual(response.status, .badRequest)
111 | XCTAssertEqual(responseJSON.error, "invalid_request")
112 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'password' parameter")
113 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
114 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
115 | }
116 |
117 | func testCorrectErrorWhenClientIDNotSupplied() async throws {
118 | let response = try await getPasswordResponse(clientID: nil)
119 |
120 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
121 |
122 | XCTAssertEqual(response.status, .badRequest)
123 | XCTAssertEqual(responseJSON.error, "invalid_request")
124 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'client_id' parameter")
125 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
126 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
127 | }
128 |
129 | func testCorrectErrorWhenClientIDNotValid() async throws {
130 | let response = try await getPasswordResponse(clientID: "UNKNOWN_CLIENT")
131 |
132 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
133 |
134 | XCTAssertEqual(response.status, .unauthorized)
135 | XCTAssertEqual(responseJSON.error, "invalid_client")
136 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
137 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
138 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
139 | }
140 |
141 | func testCorrectErrorWhenClientDoesNotAuthenticate() async throws {
142 | let clientID = "ABCDEF"
143 | let clientWithSecret = OAuthClient(
144 | clientID: clientID, redirectURIs: ["https://api.brokenhands.io/callback"], clientSecret: "1234567890ABCD",
145 | allowedGrantType: .password)
146 | fakeClientGetter.validClients[clientID] = clientWithSecret
147 |
148 | let response = try await getPasswordResponse(clientID: clientID, clientSecret: "incorrectPassword")
149 |
150 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
151 |
152 | XCTAssertEqual(response.status, .unauthorized)
153 | XCTAssertEqual(responseJSON.error, "invalid_client")
154 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
155 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
156 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
157 | }
158 |
159 | func testCorrectErrorIfClientSecretNotSentAndIsExpected() async throws {
160 | let clientID = "ABCDEF"
161 | let clientWithSecret = OAuthClient(
162 | clientID: clientID, redirectURIs: ["https://api.brokenhands.io/callback"], clientSecret: "1234567890ABCD",
163 | allowedGrantType: .password)
164 | fakeClientGetter.validClients[clientID] = clientWithSecret
165 |
166 | let response = try await getPasswordResponse(clientID: clientID, clientSecret: nil)
167 |
168 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
169 |
170 | XCTAssertEqual(response.status, .unauthorized)
171 | XCTAssertEqual(responseJSON.error, "invalid_client")
172 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
173 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
174 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
175 | }
176 |
177 | func testCorrectErrorWhenUserDoesNotExist() async throws {
178 | let response = try await getPasswordResponse(username: "UNKNOWN_USER")
179 |
180 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
181 |
182 | XCTAssertEqual(response.status, .badRequest)
183 | XCTAssertEqual(responseJSON.error, "invalid_grant")
184 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid credentials")
185 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
186 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
187 | }
188 |
189 | func testCorrectErrorWhenPasswordIsIncorrect() async throws {
190 | let response = try await getPasswordResponse(password: "INCORRECT_PASSWORD")
191 |
192 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
193 |
194 | XCTAssertEqual(response.status, .badRequest)
195 | XCTAssertEqual(responseJSON.error, "invalid_grant")
196 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid credentials")
197 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
198 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
199 | }
200 |
201 | func testThatTokenReceivedIfUserAuthenticated() async throws {
202 | let response = try await getPasswordResponse()
203 |
204 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
205 |
206 | XCTAssertEqual(response.status, .ok)
207 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
208 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
209 | XCTAssertEqual(responseJSON.tokenType, "bearer")
210 | XCTAssertEqual(responseJSON.expiresIn, 3600)
211 | XCTAssertEqual(responseJSON.accessToken, accessToken)
212 | XCTAssertEqual(responseJSON.refreshToken, refreshToken)
213 | }
214 |
215 | func testScopeSetOnTokenIfRequested() async throws {
216 | let scope = "email create"
217 |
218 | let response = try await getPasswordResponse(scope: scope)
219 |
220 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
221 |
222 | XCTAssertEqual(response.status, .ok)
223 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
224 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
225 | XCTAssertEqual(responseJSON.tokenType, "bearer")
226 | XCTAssertEqual(responseJSON.expiresIn, 3600)
227 | XCTAssertEqual(responseJSON.accessToken, accessToken)
228 | XCTAssertEqual(responseJSON.refreshToken, refreshToken)
229 | XCTAssertEqual(responseJSON.scope, scope)
230 |
231 | guard let accessToken = fakeTokenManager.getAccessToken(accessToken),
232 | let refreshToken = fakeTokenManager.getRefreshToken(refreshToken)
233 | else {
234 | XCTFail()
235 | return
236 | }
237 |
238 | XCTAssertEqual(accessToken.scopes ?? [], ["email", "create"])
239 | XCTAssertEqual(refreshToken.scopes ?? [], ["email", "create"])
240 | }
241 |
242 | func testCorrectErrorWhenReqeustingScopeApplicationDoesNotHaveAccessTo() async throws {
243 | let scope = "email edit"
244 |
245 | let response = try await getPasswordResponse(scope: scope)
246 |
247 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
248 |
249 | XCTAssertEqual(response.status, .badRequest)
250 | XCTAssertEqual(responseJSON.error, "invalid_scope")
251 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an invalid scope")
252 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
253 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
254 | }
255 |
256 | func testCorrectErrorWhenRequestingUnknownScope() async throws {
257 | let scope = "email unknown"
258 |
259 | let response = try await getPasswordResponse(scope: scope)
260 |
261 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
262 |
263 | XCTAssertEqual(response.status, .badRequest)
264 | XCTAssertEqual(responseJSON.error, "invalid_scope")
265 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an unknown scope")
266 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
267 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
268 | }
269 |
270 | func testCorrectErrorWhen3rdParyClientTriesToUsePassword() async throws {
271 | let newClientID = "AB1234"
272 | let newClient = OAuthClient(clientID: newClientID, redirectURIs: nil, firstParty: false, allowedGrantType: .password)
273 | fakeClientGetter.validClients[newClientID] = newClient
274 |
275 | let response = try await getPasswordResponse(clientID: newClientID, clientSecret: nil)
276 |
277 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
278 |
279 | XCTAssertEqual(response.status, .badRequest)
280 | XCTAssertEqual(responseJSON.error, "unauthorized_client")
281 | XCTAssertEqual(responseJSON.errorDescription, "Password Credentials grant is not allowed")
282 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
283 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
284 | }
285 |
286 | func testMessageLoggedForIncorrectLogin() async throws {
287 | _ = try await getPasswordResponse(password: "INCORRECT_PASSWORD")
288 |
289 | XCTAssertEqual(capturingLogger.logLevel, .warning)
290 | XCTAssertEqual(capturingLogger.logMessage, "LOGIN WARNING: Invalid login attempt for user \(testUsername)")
291 | }
292 |
293 | func testUserIsAssociatedWithTokenID() async throws {
294 | let response = try await getPasswordResponse()
295 |
296 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
297 |
298 | guard let token = fakeTokenManager.getAccessToken(responseJSON.accessToken ?? "") else {
299 | XCTFail()
300 | return
301 | }
302 |
303 | XCTAssertEqual(token.userID, testUserID)
304 | }
305 |
306 | func testExpiryTimeIsSetOnAccessToken() async throws {
307 | let currentTime = Date()
308 | fakeTokenManager.currentTime = currentTime
309 |
310 | let response = try await getPasswordResponse()
311 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
312 |
313 | guard let accessTokenString = responseJSON.accessToken else {
314 | XCTFail()
315 | return
316 | }
317 |
318 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
319 | XCTFail()
320 | return
321 | }
322 |
323 | XCTAssertEqual(accessToken.expiryTime, currentTime.addingTimeInterval(3600))
324 | }
325 |
326 | func testThatRefreshTokenHasCorrectClientIDSet() async throws {
327 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
328 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
329 |
330 | _ = try await getPasswordResponse()
331 |
332 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
333 | XCTFail()
334 | return
335 | }
336 |
337 | XCTAssertEqual(refreshToken.clientID, testClientID)
338 | }
339 |
340 | func testThatRefreshTokenHasNoScopesIfNoneRequested() async throws {
341 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
342 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
343 |
344 | _ = try await getPasswordResponse(scope: nil)
345 |
346 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
347 | XCTFail()
348 | return
349 | }
350 |
351 | XCTAssertNil(refreshToken.scopes)
352 | }
353 |
354 | func testThatRefreshTokenHasCorrectScopesIfSet() async throws {
355 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
356 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
357 |
358 | _ = try await getPasswordResponse(scope: "email create")
359 |
360 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
361 | XCTFail()
362 | return
363 | }
364 |
365 | XCTAssertEqual(refreshToken.scopes ?? [], ["email", "create"])
366 | }
367 |
368 | func testUserIDSetOnRefreshToken() async throws {
369 | let refreshTokenString = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
370 | fakeTokenManager.refreshTokenToReturn = refreshTokenString
371 |
372 | _ = try await getPasswordResponse()
373 |
374 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
375 | XCTFail()
376 | return
377 | }
378 |
379 | XCTAssertEqual(refreshToken.userID, testUserID)
380 | }
381 |
382 | func testClientNotConfiguredWithAccessToPasswordFlowCantAccessIt() async throws {
383 | let unauthorizedID = "not-allowed"
384 | let unauthorizedSecret = "client-secret"
385 | let unauthorizedClient = OAuthClient(
386 | clientID: unauthorizedID, redirectURIs: nil, clientSecret: unauthorizedSecret, validScopes: nil, confidential: true,
387 | firstParty: true, allowedGrantType: .clientCredentials)
388 | fakeClientGetter.validClients[unauthorizedID] = unauthorizedClient
389 |
390 | let response = try await getPasswordResponse(clientID: unauthorizedID, clientSecret: unauthorizedSecret)
391 |
392 | XCTAssertEqual(response.status, .forbidden)
393 | }
394 |
395 | func testClientConfiguredWithAccessToPasswordFlowCanAccessIt() async throws {
396 | let authorizedID = "not-allowed"
397 | let authorizedSecret = "client-secret"
398 | let authorizedClient = OAuthClient(
399 | clientID: authorizedID,
400 | redirectURIs: nil,
401 | clientSecret: authorizedSecret,
402 | validScopes: nil,
403 | confidential: true,
404 | firstParty: true,
405 | allowedGrantType: .password
406 | )
407 | fakeClientGetter.validClients[authorizedID] = authorizedClient
408 |
409 | let response = try await getPasswordResponse(clientID: authorizedID, clientSecret: authorizedSecret)
410 |
411 | XCTAssertEqual(response.status, .ok)
412 | }
413 |
414 | // MARK: - Private
415 |
416 | func getPasswordResponse(
417 | grantType: String? = "password",
418 | username: String? = "testUser",
419 | password: String? = "testPassword",
420 | clientID: String? = "ABCDEF",
421 | clientSecret: String? = "01234567890",
422 | scope: String? = nil
423 | ) async throws -> XCTHTTPResponse {
424 | return try await TestDataBuilder.getTokenRequestResponse(
425 | with: app,
426 | grantType: grantType,
427 | clientID: clientID,
428 | clientSecret: clientSecret,
429 | scope: scope,
430 | username: username,
431 | password: password
432 | )
433 | }
434 |
435 | }
436 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/GrantTests/TokenRefreshTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class TokenRefreshTests: XCTestCase {
6 |
7 | // MARK: - Properties
8 |
9 | var app: Application!
10 | var fakeClientGetter: FakeClientGetter!
11 | var fakeTokenManager: FakeTokenManager!
12 | let testClientID = "ABCDEF"
13 | let testClientSecret = "01234567890"
14 | let refreshTokenString = "ABCDEFGJ-REFRESH-TOKEN"
15 | let scope1 = "email"
16 | let scope2 = "create"
17 | let scope3 = "edit"
18 | let scope4 = "profile"
19 | var validRefreshToken: RefreshToken!
20 |
21 | // MARK: - Overrides
22 |
23 | override func setUp() {
24 | fakeClientGetter = FakeClientGetter()
25 | fakeTokenManager = FakeTokenManager()
26 |
27 | app = try! TestDataBuilder.getOAuth2Application(
28 | tokenManager: fakeTokenManager,
29 | clientRetriever: fakeClientGetter,
30 | validScopes: [scope1, scope2, scope3, scope4]
31 | )
32 |
33 | let testClient = OAuthClient(
34 | clientID: testClientID,
35 | redirectURIs: nil,
36 | clientSecret: testClientSecret,
37 | validScopes: [scope1, scope2, scope4],
38 | confidential: true,
39 | allowedGrantType: .authorization
40 | )
41 | fakeClientGetter.validClients[testClientID] = testClient
42 | validRefreshToken = FakeRefreshToken(
43 | tokenString: refreshTokenString,
44 | clientID: testClientID,
45 | userID: nil,
46 | scopes: [scope1, scope2]
47 | )
48 | fakeTokenManager.refreshTokens[refreshTokenString] = validRefreshToken
49 | }
50 |
51 | override func tearDown() async throws {
52 | app.shutdown()
53 | try await super.tearDown()
54 | }
55 |
56 | // MARK: - Tests
57 | func testCorrectErrorWhenGrantTypeNotSupplied() async throws {
58 | let response = try await getTokenResponse(grantType: nil)
59 |
60 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
61 |
62 | XCTAssertEqual(response.status, .badRequest)
63 | XCTAssertEqual(responseJSON.error, "invalid_request")
64 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'grant_type' parameter")
65 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
66 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
67 | }
68 |
69 | func testCorrectErrorAndHeadersReceivedWhenIncorrectGrantTypeSet() async throws {
70 | let grantType = "some_unknown_type"
71 | let response = try await getTokenResponse(grantType: grantType)
72 |
73 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
74 |
75 | XCTAssertEqual(response.status, .badRequest)
76 | XCTAssertEqual(responseJSON.error, "unsupported_grant_type")
77 | XCTAssertEqual(responseJSON.errorDescription, "This server does not support the '\(grantType)' grant type")
78 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
79 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
80 | }
81 |
82 | func testCorrectErrorWhenClientIDNotSupplied() async throws {
83 | let response = try await getTokenResponse(clientID: nil)
84 |
85 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
86 |
87 | XCTAssertEqual(response.status, .badRequest)
88 | XCTAssertEqual(responseJSON.error, "invalid_request")
89 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'client_id' parameter")
90 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
91 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
92 | }
93 |
94 | func testCorrectErrorWhenClientIDNotValid() async throws {
95 | let response = try await getTokenResponse(clientID: "UNKNOWN_CLIENT")
96 |
97 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
98 |
99 | XCTAssertEqual(response.status, .unauthorized)
100 | XCTAssertEqual(responseJSON.error, "invalid_client")
101 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
102 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
103 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
104 | }
105 |
106 | func testCorrectErrorWhenClientDoesNotAuthenticate() async throws {
107 | let response = try await getTokenResponse(clientSecret: "incorrectPassword")
108 |
109 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
110 |
111 | XCTAssertEqual(response.status, .unauthorized)
112 | XCTAssertEqual(responseJSON.error, "invalid_client")
113 | XCTAssertEqual(responseJSON.errorDescription, "Request had invalid client credentials")
114 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
115 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
116 | }
117 |
118 | func testCorrectErrorIfClientSecretNotSent() async throws {
119 | let response = try await getTokenResponse(clientSecret: nil)
120 |
121 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
122 |
123 | XCTAssertEqual(response.status, .badRequest)
124 | XCTAssertEqual(responseJSON.error, "invalid_request")
125 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'client_secret' parameter")
126 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
127 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
128 | }
129 |
130 | func testCorrectErrrIfRefreshTokenNotSent() async throws {
131 | let response = try await getTokenResponse(refreshToken: nil)
132 |
133 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
134 |
135 | XCTAssertEqual(response.status, .badRequest)
136 | XCTAssertEqual(responseJSON.error, "invalid_request")
137 | XCTAssertEqual(responseJSON.errorDescription, "Request was missing the 'refresh_token' parameter")
138 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
139 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
140 | }
141 |
142 | func testThatNonConfidentialClientsGetErrorWhenRequestingToken() async throws {
143 | let nonConfidentialClientID = "NONCONF"
144 | let nonConfidentialClientSecret = "SECRET"
145 | let nonConfidentialClient = OAuthClient(
146 | clientID: nonConfidentialClientID, redirectURIs: nil, clientSecret: nonConfidentialClientSecret, confidential: false,
147 | allowedGrantType: .authorization)
148 | fakeClientGetter.validClients[nonConfidentialClientID] = nonConfidentialClient
149 |
150 | let response = try await getTokenResponse(clientID: nonConfidentialClientID, clientSecret: nonConfidentialClientSecret)
151 |
152 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
153 |
154 | XCTAssertEqual(response.status, .badRequest)
155 | XCTAssertEqual(responseJSON.error, "unauthorized_client")
156 | XCTAssertEqual(responseJSON.errorDescription, "You are not authorized to use the Client Credentials grant type")
157 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
158 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
159 | }
160 |
161 | func testThatAttemptingRefreshWithNonExistentTokenReturnsError() async throws {
162 | let expiredRefreshToken = "NONEXISTENTTOKEN"
163 |
164 | let response = try await getTokenResponse(refreshToken: expiredRefreshToken)
165 |
166 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
167 |
168 | XCTAssertEqual(response.status, .badRequest)
169 | XCTAssertEqual(responseJSON.error, "invalid_grant")
170 | XCTAssertEqual(responseJSON.errorDescription, "The refresh token is invalid")
171 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
172 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
173 | }
174 |
175 | func testThatAttemptingRefreshWithRefreshTokenFromDifferentClientReturnsError() async throws {
176 | let otherClientID = "ABCDEFGHIJKLMON"
177 | let otherClientSecret = "1234"
178 | let otherClient = OAuthClient(
179 | clientID: otherClientID, redirectURIs: nil, clientSecret: otherClientSecret, confidential: true,
180 | allowedGrantType: .authorization)
181 | fakeClientGetter.validClients[otherClientID] = otherClient
182 |
183 | let response = try await getTokenResponse(clientID: otherClientID, clientSecret: otherClientSecret)
184 |
185 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
186 |
187 | XCTAssertEqual(response.status, .badRequest)
188 | XCTAssertEqual(responseJSON.error, "invalid_grant")
189 | XCTAssertEqual(responseJSON.errorDescription, "The refresh token is invalid")
190 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
191 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
192 | }
193 |
194 | func testThatProvidingValidRefreshTokenProvidesAccessTokenInResponse() async throws {
195 | let accessToken = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
196 | fakeTokenManager.accessTokenToReturn = accessToken
197 | let response = try await getTokenResponse()
198 |
199 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
200 |
201 | XCTAssertEqual(response.status, .ok)
202 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
203 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
204 | XCTAssertEqual(responseJSON.tokenType, "bearer")
205 | XCTAssertEqual(responseJSON.expiresIn, 3600)
206 | XCTAssertEqual(responseJSON.accessToken, accessToken)
207 | XCTAssertNil(responseJSON.refreshToken)
208 | }
209 |
210 | func testCorrectErrorWhenReqeustingScopeApplicationDoesNotHaveAccessTo() async throws {
211 | let scope = "email edit"
212 |
213 | let response = try await getTokenResponse(scope: scope)
214 |
215 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
216 |
217 | XCTAssertEqual(response.status, .badRequest)
218 | XCTAssertEqual(responseJSON.error, "invalid_scope")
219 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an invalid scope")
220 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
221 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
222 | }
223 |
224 | func testCorrectErrorWhenRequestingUnknownScope() async throws {
225 | let scope = "email unknown"
226 |
227 | let response = try await getTokenResponse(scope: scope)
228 |
229 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
230 |
231 | XCTAssertEqual(response.status, .badRequest)
232 | XCTAssertEqual(responseJSON.error, "invalid_scope")
233 | XCTAssertEqual(responseJSON.errorDescription, "Request contained an unknown scope")
234 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
235 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
236 | }
237 |
238 | func testErrorIfRequestingScopeGreaterThanOriginallyRequestedEvenIfApplicatioHasAccess() async throws {
239 | let response = try await getTokenResponse(scope: "\(scope1) \(scope4)")
240 |
241 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
242 |
243 | XCTAssertEqual(response.status, .badRequest)
244 | XCTAssertEqual(responseJSON.error, "invalid_scope")
245 | XCTAssertEqual(responseJSON.errorDescription, "Request contained elevated scopes")
246 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
247 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
248 | }
249 |
250 | func testLoweringScopeOnRefreshSetsScopeCorrectlyOnAccessAndRefreshTokens() async throws {
251 | let response = try await getTokenResponse(scope: scope1)
252 |
253 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
254 |
255 | guard let accessTokenString = responseJSON.accessToken else {
256 | XCTFail()
257 | return
258 | }
259 |
260 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
261 | XCTFail()
262 | return
263 | }
264 |
265 | XCTAssertEqual(accessToken.scopes ?? [], [scope1])
266 |
267 | XCTAssertEqual(response.status, .ok)
268 | XCTAssertEqual(responseJSON.scope, scope1)
269 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
270 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
271 |
272 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
273 | XCTFail()
274 | return
275 | }
276 |
277 | XCTAssertEqual(refreshToken.scopes ?? [], [scope1])
278 | }
279 |
280 | func testNotRequestingScopeOnRefreshDoesNotAlterOriginalScope() async throws {
281 | let originalScopes = validRefreshToken.scopes
282 |
283 | let response = try await getTokenResponse()
284 |
285 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
286 |
287 | guard let accessTokenString = responseJSON.accessToken,
288 | let accessToken = fakeTokenManager.getAccessToken(accessTokenString)
289 | else {
290 | XCTFail()
291 | return
292 | }
293 |
294 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
295 | XCTFail()
296 | return
297 | }
298 |
299 | XCTAssertEqual(accessToken.scopes!, originalScopes ?? [])
300 | XCTAssertEqual(refreshToken.scopes!, originalScopes!)
301 |
302 | }
303 |
304 | func testRequestingTheSameScopeWhenRefreshingWorksCorrectlyAndReturnsResult() async throws {
305 | let scopesToRequest = validRefreshToken.scopes
306 | let response = try await getTokenResponse(scope: scopesToRequest?.joined(separator: " "))
307 |
308 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
309 |
310 | guard let accessTokenString = responseJSON.accessToken,
311 | let accessToken = fakeTokenManager.getAccessToken(accessTokenString)
312 | else {
313 | XCTFail()
314 | return
315 | }
316 |
317 | guard let refreshToken = fakeTokenManager.getRefreshToken(refreshTokenString) else {
318 | XCTFail()
319 | return
320 | }
321 |
322 | XCTAssertEqual(accessToken.scopes!, scopesToRequest ?? [])
323 | XCTAssertEqual(refreshToken.scopes!, scopesToRequest!)
324 | }
325 |
326 | func testErrorWhenRequestingScopeWithNoScopesOriginallyRequestedOnRefreshToken() async throws {
327 | let newRefreshToken = "NEW_REFRESH_TOKEN"
328 | let refreshTokenWithoutScope = FakeRefreshToken(tokenString: newRefreshToken, clientID: testClientID, userID: nil, scopes: nil)
329 | fakeTokenManager.refreshTokens[newRefreshToken] = refreshTokenWithoutScope
330 |
331 | let response = try await getTokenResponse(refreshToken: newRefreshToken, scope: scope1)
332 |
333 | let responseJSON = try JSONDecoder().decode(ErrorResponse.self, from: response.body)
334 |
335 | XCTAssertEqual(response.status, .badRequest)
336 | XCTAssertEqual(responseJSON.error, "invalid_scope")
337 | XCTAssertEqual(responseJSON.errorDescription, "Request contained elevated scopes")
338 | XCTAssertEqual(response.headers.cacheControl?.noStore, true)
339 | XCTAssertEqual(response.headers[HTTPHeaders.Name.pragma], ["no-cache"])
340 | }
341 |
342 | func testUserIDIsSetOnAccessTokenIfRefreshTokenHasOne() async throws {
343 | let userID = "abcdefg-123456"
344 | let accessToken = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
345 | let userIDRefreshTokenString = "ASHFUIEWHFIHEWIUF"
346 | let userIDRefreshToken = FakeRefreshToken(
347 | tokenString: userIDRefreshTokenString, clientID: testClientID, userID: userID, scopes: [scope1, scope2])
348 | fakeTokenManager.refreshTokens[userIDRefreshTokenString] = userIDRefreshToken
349 | fakeTokenManager.accessTokenToReturn = accessToken
350 | _ = try await getTokenResponse(refreshToken: userIDRefreshTokenString)
351 |
352 | guard let token = fakeTokenManager.getAccessToken(accessToken) else {
353 | XCTFail()
354 | return
355 | }
356 |
357 | XCTAssertEqual(token.userID, userID)
358 | }
359 |
360 | func testClientIDSetOnAccessTokenFromRefreshToken() async throws {
361 | let refreshTokenString = "some-new-refreshToken"
362 | let clientID = "the-client-id-to-set"
363 | let refreshToken = FakeRefreshToken(tokenString: refreshTokenString, clientID: clientID, userID: "some-user")
364 | fakeTokenManager.refreshTokens[refreshTokenString] = refreshToken
365 | fakeClientGetter.validClients[clientID] = OAuthClient(
366 | clientID: clientID, redirectURIs: nil, clientSecret: testClientSecret, confidential: true, allowedGrantType: .authorization)
367 |
368 | let response = try await getTokenResponse(clientID: clientID, refreshToken: refreshTokenString)
369 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
370 |
371 | guard let accessTokenString = responseJSON.accessToken else {
372 | XCTFail()
373 | return
374 | }
375 |
376 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
377 | XCTFail()
378 | return
379 | }
380 |
381 | XCTAssertEqual(accessToken.clientID, clientID)
382 |
383 | }
384 |
385 | func testExpiryTimeSetOnNewAccessToken() async throws {
386 | let currentTime = Date()
387 | fakeTokenManager.currentTime = currentTime
388 |
389 | let response = try await getTokenResponse()
390 | let responseJSON = try JSONDecoder().decode(SuccessResponse.self, from: response.body)
391 |
392 | guard let accessTokenString = responseJSON.accessToken else {
393 | XCTFail()
394 | return
395 | }
396 |
397 | guard let accessToken = fakeTokenManager.getAccessToken(accessTokenString) else {
398 | XCTFail()
399 | return
400 | }
401 |
402 | XCTAssertEqual(accessToken.expiryTime, currentTime.addingTimeInterval(3600))
403 | }
404 |
405 | // MARK: - Private
406 |
407 | func getTokenResponse(
408 | grantType: String? = "refresh_token",
409 | clientID: String? = "ABCDEF",
410 | clientSecret: String? = "01234567890",
411 | refreshToken: String? = "ABCDEFGJ-REFRESH-TOKEN",
412 | scope: String? = nil
413 | ) async throws -> XCTHTTPResponse {
414 | return try await TestDataBuilder.getTokenRequestResponse(
415 | with: app,
416 | grantType: grantType,
417 | clientID: clientID,
418 | clientSecret: clientSecret,
419 | scope: scope,
420 | refreshToken: refreshToken
421 | )
422 | }
423 |
424 | }
425 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Helpers/HTTPHeaders+location.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 |
3 | extension HTTPHeaders {
4 | public struct Location: ExpressibleByStringLiteral, Equatable {
5 | public let value: String
6 |
7 | public init(value: String) {
8 | self.value = value
9 | }
10 |
11 | public init(stringLiteral value: String) {
12 | self.init(value: value)
13 | }
14 | }
15 |
16 | public var location: Location? {
17 | get {
18 | self.first(name: .location).flatMap(Location.init(value:))
19 | }
20 | set {
21 | if let value = newValue {
22 | self.replaceOrAdd(name: .location, value: value.value)
23 | } else {
24 | self.remove(name: .location)
25 | }
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Helpers/Responses.swift:
--------------------------------------------------------------------------------
1 | struct ErrorResponse: Decodable {
2 | var error: String
3 | var errorDescription: String
4 |
5 | enum CodingKeys: String, CodingKey {
6 | case error
7 | case errorDescription = "error_description"
8 | }
9 | }
10 |
11 | struct SuccessResponse: Decodable {
12 | var tokenType: String?
13 | var expiresIn: Int?
14 | var accessToken: String?
15 | var refreshToken: String?
16 | var scope: String?
17 |
18 | enum CodingKeys: String, CodingKey {
19 | case tokenType = "token_type"
20 | case expiresIn = "expires_in"
21 | case accessToken = "access_token"
22 | case refreshToken = "refresh_token"
23 | case scope
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/Helpers/TestDataBuilder.swift:
--------------------------------------------------------------------------------
1 | import Vapor
2 | import XCTVapor
3 |
4 | @testable import VaporOAuth
5 |
6 | class TestDataBuilder {
7 | static func getOAuth2Application(
8 | codeManager: CodeManager = EmptyCodeManager(),
9 | tokenManager: TokenManager = StubTokenManager(),
10 | clientRetriever: ClientRetriever = FakeClientGetter(),
11 | userManager: UserManager = EmptyUserManager(),
12 | authorizeHandler: AuthorizeHandler = EmptyAuthorizationHandler(),
13 | validScopes: [String]? = nil,
14 | resourceServerRetriever: ResourceServerRetriever = EmptyResourceServerRetriever(),
15 | environment: Environment = .testing,
16 | logger: CapturingLogger? = nil,
17 | sessions: FakeSessions? = nil,
18 | registeredUsers: [OAuthUser] = []
19 | ) throws -> Application {
20 | let app = Application(environment)
21 |
22 | if let sessions = sessions {
23 | app.sessions.use { _ in sessions }
24 | }
25 |
26 | app.middleware.use(FakeAuthenticationMiddleware(allowedUsers: registeredUsers))
27 | app.middleware.use(app.sessions.middleware)
28 |
29 | app.lifecycle.use(
30 | OAuth2(
31 | codeManager: codeManager,
32 | tokenManager: tokenManager,
33 | clientRetriever: clientRetriever,
34 | authorizeHandler: authorizeHandler,
35 | userManager: userManager,
36 | validScopes: validScopes,
37 | resourceServerRetriever: resourceServerRetriever,
38 | oAuthHelper: .local(
39 | tokenAuthenticator: nil,
40 | userManager: nil,
41 | tokenManager: nil
42 | )
43 | )
44 | )
45 |
46 | do {
47 | _ = try app.testable()
48 | } catch {
49 | app.shutdown()
50 | throw error
51 | }
52 |
53 | return app
54 | }
55 |
56 | static func getTokenRequestResponse(
57 | with app: Application,
58 | grantType: String?,
59 | clientID: String?,
60 | clientSecret: String?,
61 | redirectURI: String? = nil,
62 | code: String? = nil,
63 | scope: String? = nil,
64 | username: String? = nil,
65 | password: String? = nil,
66 | refreshToken: String? = nil
67 | ) async throws -> XCTHTTPResponse {
68 | struct RequestData: Content {
69 | var grantType: String?
70 | var clientID: String?
71 | var clientSecret: String?
72 | var redirectURI: String?
73 | var code: String?
74 | var scope: String?
75 | var username: String?
76 | var password: String?
77 | var refreshToken: String?
78 |
79 | enum CodingKeys: String, CodingKey {
80 | case username, password, scope, code
81 | case grantType = "grant_type"
82 | case clientID = "client_id"
83 | case clientSecret = "client_secret"
84 | case redirectURI = "redirect_uri"
85 | case refreshToken = "refresh_token"
86 | }
87 | }
88 |
89 | let requestData = RequestData(
90 | grantType: grantType,
91 | clientID: clientID,
92 | clientSecret: clientSecret,
93 | redirectURI: redirectURI,
94 | code: code,
95 | scope: scope,
96 | username: username,
97 | password: password,
98 | refreshToken: refreshToken
99 | )
100 |
101 | return try await withCheckedThrowingContinuation { continuation in
102 | do {
103 | try app.test(
104 | .POST,
105 | "/oauth/token/",
106 | beforeRequest: { request in
107 | try request.content.encode(requestData, as: .urlEncodedForm)
108 | },
109 | afterResponse: { response in
110 | continuation.resume(returning: response)
111 | }
112 | )
113 | } catch {
114 | continuation.resume(throwing: error)
115 | }
116 | }
117 | }
118 |
119 | static func getAuthRequestResponse(
120 | with app: Application,
121 | responseType: String?,
122 | clientID: String?,
123 | redirectURI: String?,
124 | scope: String?,
125 | state: String?
126 | ) async throws -> XCTHTTPResponse {
127 |
128 | var queries: [String] = []
129 |
130 | if let responseType = responseType {
131 | queries.append("response_type=\(responseType)")
132 | }
133 |
134 | if let clientID = clientID {
135 | queries.append("client_id=\(clientID)")
136 | }
137 |
138 | if let redirectURI = redirectURI {
139 | queries.append("redirect_uri=\(redirectURI)")
140 | }
141 |
142 | if let scope = scope {
143 | queries.append("scope=\(scope)")
144 | }
145 |
146 | if let state = state {
147 | queries.append("state=\(state)")
148 | }
149 |
150 | let requestQuery = queries.joined(separator: "&")
151 |
152 | return try await withCheckedThrowingContinuation { continuation in
153 | do {
154 | try app.test(
155 | .GET, "/oauth/authorize?\(requestQuery)",
156 | afterResponse: { response in
157 | continuation.resume(returning: response)
158 | })
159 | } catch {
160 | continuation.resume(throwing: error)
161 | }
162 | }
163 | }
164 |
165 | static func getAuthResponseResponse(
166 | with app: Application,
167 | approve: Bool?,
168 | clientID: String?,
169 | redirectURI: String?,
170 | responseType: String?,
171 | scope: String?,
172 | state: String?,
173 | csrfToken: String?,
174 | user: OAuthUser?,
175 | sessionCookie: HTTPCookies? = nil,
176 | sessionID: String? = nil
177 | ) async throws -> XCTHTTPResponse {
178 | var queries: [String] = []
179 |
180 | if let clientID = clientID {
181 | queries.append("client_id=\(clientID)")
182 | }
183 |
184 | if let redirectURI = redirectURI {
185 | queries.append("redirect_uri=\(redirectURI)")
186 | }
187 |
188 | if let state = state {
189 | queries.append("state=\(state)")
190 | }
191 |
192 | if let scope = scope {
193 | queries.append("scope=\(scope)")
194 | }
195 |
196 | if let responseType = responseType {
197 | queries.append("response_type=\(responseType)")
198 | }
199 |
200 | let requestQuery = queries.joined(separator: "&")
201 |
202 | struct RequestBody: Encodable {
203 | var applicationAuthorized: Bool?
204 | var csrfToken: String?
205 | }
206 |
207 | var requestBody = RequestBody()
208 | requestBody.applicationAuthorized = approve
209 | requestBody.csrfToken = csrfToken
210 |
211 | return try await withCheckedThrowingContinuation { continuation in
212 | do {
213 | try app.test(
214 | .POST,
215 | "/oauth/authorize?\(requestQuery)",
216 | beforeRequest: { request in
217 | if let sessionID = sessionID {
218 | request.headers.cookie = ["vapor-session": .init(string: sessionID)]
219 | }
220 | if let sessionCookie = sessionCookie {
221 | request.headers.cookie = sessionCookie
222 | }
223 | try request.content.encode(requestBody, as: .urlEncodedForm)
224 |
225 | if let user = user {
226 | request.headers.basicAuthorization = .init(
227 | username: user.username,
228 | password: user.password
229 | )
230 | }
231 | },
232 | afterResponse: { response in
233 | continuation.resume(returning: response)
234 | }
235 | )
236 | } catch {
237 | continuation.resume(throwing: error)
238 | }
239 | }
240 | }
241 |
242 | static let anyUserID: String = "12345-asbdsadi"
243 | static func anyOAuthUser() -> OAuthUser {
244 | return OAuthUser(
245 | userID: TestDataBuilder.anyUserID,
246 | username: "hansolo",
247 | emailAddress: "han.solo@therebelalliance.com",
248 | password: "leia"
249 | )
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/Tests/VaporOAuthTests/TokenIntrospectionTests/TokenIntrospectionTests.swift:
--------------------------------------------------------------------------------
1 | import XCTVapor
2 |
3 | @testable import VaporOAuth
4 |
5 | class TokenIntrospectionTests: XCTestCase {
6 | // MARK: - Properties
7 | var app: Application!
8 | var fakeTokenManager: FakeTokenManager!
9 | var fakeUserManager: FakeUserManager!
10 | var fakeResourceServerRetriever: FakeResourceServerRetriever!
11 | let testClientID = "ABCDEF"
12 | let testClientSecret = "01234567890"
13 | let accessToken = "ABDEFGHIJKLMNO01234567890"
14 | let scope1 = "email"
15 | let scope2 = "create"
16 | let resourceServerName = "brokenhands-users"
17 | let resourceServerPassword = "users"
18 | let clientID = "some-client"
19 |
20 | // MARK: - Overrides
21 |
22 | override func setUp() {
23 | fakeTokenManager = FakeTokenManager()
24 | fakeUserManager = FakeUserManager()
25 | fakeResourceServerRetriever = FakeResourceServerRetriever()
26 |
27 | app = try! TestDataBuilder.getOAuth2Application(
28 | tokenManager: fakeTokenManager,
29 | userManager: fakeUserManager,
30 | validScopes: [scope1, scope2],
31 | resourceServerRetriever: fakeResourceServerRetriever
32 | )
33 |
34 | let resourceServer = OAuthResourceServer(username: resourceServerName, password: resourceServerPassword)
35 | fakeResourceServerRetriever.resourceServers[resourceServerName] = resourceServer
36 |
37 | let validToken = FakeAccessToken(
38 | tokenString: accessToken,
39 | clientID: clientID,
40 | userID: nil,
41 | expiryTime: Date().addingTimeInterval(60)
42 | )
43 | fakeTokenManager.accessTokens[accessToken] = validToken
44 | }
45 |
46 | override func tearDown() async throws {
47 | app.shutdown()
48 | try await super.tearDown()
49 | }
50 |
51 | // MARK: - Tests
52 | func testCorrectErrorWhenTokenParameterNotSuppliedInRequest() async throws {
53 | let response = try await getInfoResponse(token: nil)
54 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.ErrorResponse.self)
55 |
56 | XCTAssertEqual(response.status, .badRequest)
57 | XCTAssertEqual(responseJSON.error, "missing_token")
58 | XCTAssertEqual(responseJSON.errorDescription, "The token parameter is required")
59 | }
60 |
61 | func testCorrectErrorWhenNoAuthorisationSuppliied() async throws {
62 | let response = try await getInfoResponse(authHeader: nil)
63 |
64 | XCTAssertEqual(response.status, .unauthorized)
65 | }
66 |
67 | func testCorrectErrorWhenInvalidAuthorisationSupplied() async throws {
68 | let response = try await getInfoResponse(authHeader: "INVALID")
69 |
70 | XCTAssertEqual(response.status, .unauthorized)
71 | }
72 |
73 | func testCorrectErrorWhenInvalidUsernnameSuppliedForAuthorisation() async throws {
74 | let header = "UNKOWNUSER:\(resourceServerPassword)".base64String()
75 | let response = try await getInfoResponse(authHeader: header)
76 |
77 | XCTAssertEqual(response.status, .unauthorized)
78 | }
79 |
80 | func testCorrectErrorWhenInvalidPasswordSuppliedForAuthorisation() async throws {
81 | let header = "\(resourceServerName):SOMEPASSWORD".base64String()
82 | let response = try await getInfoResponse(authHeader: header)
83 |
84 | XCTAssertEqual(response.status, .unauthorized)
85 | }
86 |
87 | func testThatInvalidTokenReturnsInactive() async throws {
88 | let response = try await getInfoResponse(token: "UNKNOWN_TOKEN")
89 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
90 |
91 | XCTAssertEqual(response.status, .ok)
92 | XCTAssertFalse(responseJSON.active)
93 | }
94 |
95 | func testThatExpiredTokenReturnsInactive() async throws {
96 | let tokenString = "EXPIRED_TOKEN"
97 | let expiredToken = FakeAccessToken(
98 | tokenString: tokenString, clientID: testClientID, userID: nil, expiryTime: Date().addingTimeInterval(-60))
99 | fakeTokenManager.accessTokens[tokenString] = expiredToken
100 | let response = try await getInfoResponse(token: tokenString)
101 |
102 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
103 |
104 | XCTAssertEqual(response.status, .ok)
105 | XCTAssertFalse(responseJSON.active)
106 | }
107 |
108 | func testThatValidTokenReturnsActive() async throws {
109 | let response = try await getInfoResponse()
110 |
111 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
112 |
113 | XCTAssertEqual(response.status, .ok)
114 | XCTAssertTrue(responseJSON.active)
115 | }
116 |
117 | func testThatScopeReturnedInReponseIfTokenHasScope() async throws {
118 | let tokenString = "VALID_TOKEN"
119 | let validToken = FakeAccessToken(
120 | tokenString: tokenString, clientID: clientID, userID: nil, scopes: ["email", "profile"],
121 | expiryTime: Date().addingTimeInterval(60))
122 | fakeTokenManager.accessTokens[tokenString] = validToken
123 |
124 | let response = try await getInfoResponse(token: tokenString)
125 |
126 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
127 |
128 | XCTAssertEqual(response.status, .ok)
129 | XCTAssertEqual(responseJSON.active, true)
130 | XCTAssertEqual(responseJSON.scope, "email profile")
131 | }
132 |
133 | func testCliendIDReturnedInTokenResponse() async throws {
134 | let response = try await getInfoResponse()
135 |
136 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
137 |
138 | XCTAssertEqual(response.status, .ok)
139 | XCTAssertEqual(responseJSON.active, true)
140 | XCTAssertEqual(responseJSON.clientID, clientID)
141 | }
142 |
143 | func testUsernameReturnedInTokenResponseIfTokenHasAUser() async throws {
144 | let userID = "123"
145 | let username = "hansolo"
146 | let tokenString = "VALID_TOKEN"
147 | let validToken = FakeAccessToken(
148 | tokenString: tokenString, clientID: clientID, userID: userID, expiryTime: Date().addingTimeInterval(60))
149 | fakeTokenManager.accessTokens[tokenString] = validToken
150 | let newUser = OAuthUser(userID: userID, username: username, emailAddress: "han@therebelalliance.com", password: "leia")
151 | fakeUserManager.users.append(newUser)
152 |
153 | let response = try await getInfoResponse(token: tokenString)
154 |
155 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
156 |
157 | XCTAssertEqual(response.status, .ok)
158 | XCTAssertEqual(responseJSON.active, true)
159 | XCTAssertEqual(responseJSON.username, username)
160 | }
161 |
162 | func testTokenExpiryReturnedInResponse() async throws {
163 | let tokenString = "VALID_TOKEN"
164 | let expiryDate = Date().addingTimeInterval(60)
165 | let validToken = FakeAccessToken(tokenString: tokenString, clientID: clientID, userID: nil, expiryTime: expiryDate)
166 | fakeTokenManager.accessTokens[tokenString] = validToken
167 |
168 | let response = try await getInfoResponse(token: tokenString)
169 |
170 | let responseJSON = try response.content.decode(TokenIntrospectionHandler.TokenResponse.self)
171 |
172 | XCTAssertEqual(response.status, .ok)
173 | XCTAssertEqual(responseJSON.active, true)
174 | XCTAssertEqual(responseJSON.exp, Int(expiryDate.timeIntervalSince1970))
175 | }
176 |
177 | // MARK: - Helper method
178 |
179 | // Auth Header is brokenhands-users:users Base64 encoded
180 | func getInfoResponse(
181 | token: String? = "ABDEFGHIJKLMNO01234567890",
182 | authHeader: String? = "YnJva2VuaGFuZHMtdXNlcnM6dXNlcnM="
183 | ) async throws -> XCTHTTPResponse {
184 | // TODO - try Form URL encoded
185 | struct TokenData: Content {
186 | var token: String?
187 | }
188 |
189 | return try await withCheckedThrowingContinuation { continuation in
190 | do {
191 | try app.test(
192 | .POST,
193 | "/oauth/token_info",
194 | beforeRequest: { request in
195 | if let authHeader = authHeader {
196 | request.headers.add(name: "authorization", value: "Basic \(authHeader)")
197 | }
198 | if let token = token {
199 | let tokenData = TokenData(token: token)
200 | try request.content.encode(tokenData)
201 | }
202 | },
203 | afterResponse: { response in
204 | continuation.resume(returning: response)
205 | }
206 | )
207 | } catch {
208 | continuation.resume(throwing: error)
209 | }
210 | }
211 | }
212 |
213 | }
214 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | range: "0...100"
3 | ignore:
4 | - "Tests/"
5 |
--------------------------------------------------------------------------------
/docker-test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | docker build --tag vapor-oauth .
3 | docker run --rm vapor-oauth
4 |
--------------------------------------------------------------------------------