├── .dockerignore ├── .editorconfig ├── .github ├── CODEOWNERS ├── FUNDING.yml ├── dependabot.yml └── workflows │ ├── api-breakage.yml │ ├── ci.yml │ ├── nightly.yml │ ├── validate.yml │ └── verify-documentation.yml ├── .gitignore ├── .swift-format ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Package.swift ├── README.md ├── Sources ├── WSClient │ ├── Client │ │ ├── ClientChannel.swift │ │ ├── ClientConnection.swift │ │ ├── Parser.swift │ │ ├── TLSClientChannel.swift │ │ ├── TSTLSOptions.swift │ │ └── URI.swift │ ├── Exports.swift │ ├── WebSocketClient.swift │ ├── WebSocketClientChannel.swift │ ├── WebSocketClientConfiguration.swift │ └── WebSocketClientError.swift ├── WSCompression │ └── PerMessageDeflateExtension.swift └── WSCore │ ├── Extensions │ ├── WebSocketExtension.swift │ ├── WebSocketExtensionBuilder.swift │ └── WebSocketExtensionHTTPParameters.swift │ ├── String+validatingString.swift │ ├── UnsafeTransfer.swift │ ├── WebSocketContext.swift │ ├── WebSocketDataFrame.swift │ ├── WebSocketDataHandler.swift │ ├── WebSocketFrameSequence.swift │ ├── WebSocketHandler.swift │ ├── WebSocketInboundMessageStream.swift │ ├── WebSocketInboundStream.swift │ ├── WebSocketMessage.swift │ ├── WebSocketOutboundWriter.swift │ └── WebSocketStateMachine.swift ├── Tests └── WebSocketTests │ ├── AutobahnTests.swift │ ├── ClientTests.swift │ ├── WebSocketExtensionNegotiationTests.swift │ └── WebSocketStateMachineTests.swift └── scripts ├── autobahn-config └── fuzzingserver.json ├── autobahn-server.sh └── validate.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | .build 2 | .git -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = space 5 | indent_size = 4 6 | end_of_line = lf 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @adam-fowler @Joannis 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: adam-fowler 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | ignore: 8 | - dependency-name: "codecov/codecov-action" 9 | update-types: ["version-update:semver-major"] 10 | groups: 11 | dependencies: 12 | patterns: 13 | - "*" 14 | - package-ecosystem: "swift" 15 | directory: "/" 16 | schedule: 17 | interval: "daily" 18 | open-pull-requests-limit: 5 19 | allow: 20 | - dependency-type: all 21 | groups: 22 | all-dependencies: 23 | patterns: 24 | - "*" 25 | -------------------------------------------------------------------------------- /.github/workflows/api-breakage.yml: -------------------------------------------------------------------------------- 1 | name: API breaking changes 2 | 3 | on: 4 | pull_request: 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }}-apibreakage 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | linux: 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 15 13 | container: 14 | image: swift:latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | # https://github.com/actions/checkout/issues/766 21 | - name: Mark the workspace as safe 22 | run: git config --global --add safe.directory ${GITHUB_WORKSPACE} 23 | - name: API breaking changes 24 | run: | 25 | swift package diagnose-api-breaking-changes origin/${GITHUB_BASE_REF} 26 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - '**.swift' 9 | - '**.yml' 10 | pull_request: 11 | workflow_dispatch: 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }}-ci 14 | cancel-in-progress: true 15 | 16 | env: 17 | FUZZING_SERVER: autobahn 18 | jobs: 19 | linux: 20 | runs-on: ubuntu-latest 21 | timeout-minutes: 15 22 | strategy: 23 | matrix: 24 | image: ["swift:5.10", "swift:6.0", "swift:6.1"] 25 | container: 26 | image: ${{ matrix.image }} 27 | services: 28 | autobahn: 29 | image: crossbario/autobahn-testsuite 30 | options: --name fuzzingserver 31 | ports: 32 | - 9001:9001 33 | volumes: 34 | - ${{ github.workspace }}/scripts/autobahn-config:/config 35 | 36 | steps: 37 | - name: Checkout 38 | uses: actions/checkout@v4 39 | - name: Restart Autobahn 40 | # The autobahn service container is started *before* swift-websocket is checked 41 | # out. Restarting the container after the checkout step is needed for the 42 | # container to see volumes populated from the checked out workspace. 43 | uses: docker://docker 44 | with: 45 | args: docker restart fuzzingserver 46 | 47 | - name: Test 48 | run: | 49 | swift test --enable-code-coverage 50 | - name: Convert coverage files 51 | run: | 52 | llvm-cov export -format="lcov" \ 53 | .build/debug/swift-websocketPackageTests.xctest \ 54 | -ignore-filename-regex="\/Tests\/" \ 55 | -ignore-filename-regex="\/Benchmarks\/" \ 56 | -instr-profile .build/debug/codecov/default.profdata > info.lcov 57 | - name: Upload to codecov.io 58 | uses: codecov/codecov-action@v4 59 | with: 60 | files: info.lcov 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | -------------------------------------------------------------------------------- /.github/workflows/nightly.yml: -------------------------------------------------------------------------------- 1 | name: Swift nightly build 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | env: 7 | FUZZING_SERVER: autobahn 8 | jobs: 9 | linux: 10 | runs-on: ubuntu-latest 11 | timeout-minutes: 15 12 | strategy: 13 | matrix: 14 | image: ['nightly-focal', 'nightly-jammy', 'nightly-amazonlinux2'] 15 | container: 16 | image: swiftlang/swift:${{ matrix.image }} 17 | services: 18 | autobahn: 19 | image: crossbario/autobahn-testsuite 20 | options: --name fuzzingserver 21 | ports: 22 | - 9001:9001 23 | volumes: 24 | - ${{ github.workspace }}/scripts/autobahn-config:/config 25 | 26 | steps: 27 | - name: Checkout 28 | uses: actions/checkout@v4 29 | - name: Restart Autobahn 30 | # The autobahn service container is started *before* swift-websocket is checked 31 | # out. Restarting the container after the checkout step is needed for the 32 | # container to see volumes populated from the checked out workspace. 33 | uses: docker://docker 34 | with: 35 | args: docker restart fuzzingserver 36 | 37 | - name: Test 38 | run: | 39 | swift test 40 | -------------------------------------------------------------------------------- /.github/workflows/validate.yml: -------------------------------------------------------------------------------- 1 | name: Validity Check 2 | 3 | on: 4 | pull_request: 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }}-validate 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | validate: 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 15 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 1 18 | - name: run script 19 | run: ./scripts/validate.sh 20 | -------------------------------------------------------------------------------- /.github/workflows/verify-documentation.yml: -------------------------------------------------------------------------------- 1 | name: Verify Documentation 2 | 3 | on: 4 | pull_request: 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }}-verifydocs 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | linux: 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 15 13 | container: 14 | image: swift:latest 15 | steps: 16 | - name: Install rsync 📚 17 | run: | 18 | apt-get update && apt-get install -y rsync bc 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | path: "package" 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | with: 27 | repository: "hummingbird-project/hummingbird-docs" 28 | fetch-depth: 0 29 | path: "documentation" 30 | - name: Verify 31 | run: | 32 | cd documentation 33 | swift package edit ${GITHUB_REPOSITORY#*/} --path ../package 34 | ./scripts/build-docc.sh -e 35 | 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .build/ 3 | .swiftpm/ 4 | .vscode/ 5 | .index-build/ 6 | .devcontainer/ 7 | /Packages 8 | /*.xcodeproj 9 | xcuserdata/ 10 | Package.resolved 11 | /public 12 | /docs 13 | .benchmarkBaselines -------------------------------------------------------------------------------- /.swift-format: -------------------------------------------------------------------------------- 1 | { 2 | "version" : 1, 3 | "indentation" : { 4 | "spaces" : 4 5 | }, 6 | "tabWidth" : 4, 7 | "fileScopedDeclarationPrivacy" : { 8 | "accessLevel" : "private" 9 | }, 10 | "spacesAroundRangeFormationOperators" : false, 11 | "indentConditionalCompilationBlocks" : false, 12 | "indentSwitchCaseLabels" : false, 13 | "lineBreakAroundMultilineExpressionChainComponents" : false, 14 | "lineBreakBeforeControlFlowKeywords" : false, 15 | "lineBreakBeforeEachArgument" : true, 16 | "lineBreakBeforeEachGenericRequirement" : true, 17 | "lineLength" : 150, 18 | "maximumBlankLines" : 1, 19 | "respectsExistingLineBreaks" : true, 20 | "prioritizeKeepingFunctionOutputTogether" : true, 21 | "multiElementCollectionTrailingCommas" : true, 22 | "rules" : { 23 | "AllPublicDeclarationsHaveDocumentation" : false, 24 | "AlwaysUseLiteralForEmptyCollectionInit" : false, 25 | "AlwaysUseLowerCamelCase" : false, 26 | "AmbiguousTrailingClosureOverload" : true, 27 | "BeginDocumentationCommentWithOneLineSummary" : false, 28 | "DoNotUseSemicolons" : true, 29 | "DontRepeatTypeInStaticProperties" : true, 30 | "FileScopedDeclarationPrivacy" : true, 31 | "FullyIndirectEnum" : true, 32 | "GroupNumericLiterals" : true, 33 | "IdentifiersMustBeASCII" : true, 34 | "NeverForceUnwrap" : false, 35 | "NeverUseForceTry" : false, 36 | "NeverUseImplicitlyUnwrappedOptionals" : false, 37 | "NoAccessLevelOnExtensionDeclaration" : true, 38 | "NoAssignmentInExpressions" : true, 39 | "NoBlockComments" : true, 40 | "NoCasesWithOnlyFallthrough" : true, 41 | "NoEmptyTrailingClosureParentheses" : true, 42 | "NoLabelsInCasePatterns" : true, 43 | "NoLeadingUnderscores" : false, 44 | "NoParensAroundConditions" : true, 45 | "NoVoidReturnOnFunctionSignature" : true, 46 | "OmitExplicitReturns" : true, 47 | "OneCasePerLine" : true, 48 | "OneVariableDeclarationPerLine" : true, 49 | "OnlyOneTrailingClosureArgument" : true, 50 | "OrderedImports" : true, 51 | "ReplaceForEachWithForLoop" : true, 52 | "ReturnVoidInsteadOfEmptyTuple" : true, 53 | "UseEarlyExits" : false, 54 | "UseExplicitNilCheckInConditions" : false, 55 | "UseLetInEveryBoundCaseVariable" : false, 56 | "UseShorthandTypeNames" : true, 57 | "UseSingleLinePropertyGetter" : false, 58 | "UseSynthesizedInitializer" : false, 59 | "UseTripleSlashForDocumentationComments" : true, 60 | "UseWhereClausesInForLoops" : false, 61 | "ValidateDocumentationComments" : false 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | All developers should feel welcome and encouraged to contribute to Hummingbird. Because of this we have adopted the code of conduct defined by [contributor-covenant.org](https://www.contributor-covenant.org). This document is used across many open source 4 | communities, and we think it articulates our values well. The full text is copied below: 5 | 6 | ## Contributor Code of Conduct v1.3 7 | 8 | As contributors and maintainers of this project, and in the interest of 9 | fostering an open and welcoming community, we pledge to respect all people who 10 | contribute through reporting issues, posting feature requests, updating 11 | documentation, submitting pull requests or patches, and other activities. 12 | 13 | We are committed to making participation in this project a harassment-free 14 | experience for everyone, regardless of level of experience, gender, gender 15 | identity and expression, sexual orientation, disability, personal appearance, 16 | body size, race, ethnicity, age, religion, or nationality. 17 | 18 | Examples of unacceptable behavior by participants include: 19 | 20 | * The use of sexualized language or imagery 21 | * Personal attacks 22 | * Trolling or insulting/derogatory comments 23 | * Public or private harassment 24 | * Publishing other's private information, such as physical or electronic 25 | addresses, without explicit permission 26 | * Other unethical or unprofessional conduct 27 | 28 | Project maintainers have the right and responsibility to remove, edit, or 29 | reject comments, commits, code, wiki edits, issues, and other contributions 30 | that are not aligned to this Code of Conduct, or to ban temporarily or 31 | permanently any contributor for other behaviors that they deem inappropriate, 32 | threatening, offensive, or harmful. 33 | 34 | By adopting this Code of Conduct, project maintainers commit themselves to 35 | fairly and consistently applying these principles to every aspect of managing 36 | this project. Project maintainers who do not follow or enforce the Code of 37 | Conduct may be permanently removed from the project team. 38 | 39 | This Code of Conduct applies both within project spaces and in public spaces 40 | when an individual is representing the project or its community. 41 | 42 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 43 | reported by contacting a project maintainer at [INSERT EMAIL ADDRESS]. All 44 | complaints will be reviewed and investigated and will result in a response that 45 | is deemed necessary and appropriate to the circumstances. Maintainers are 46 | obligated to maintain confidentiality with regard to the reporter of an 47 | incident. 48 | 49 | 50 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 51 | version 1.3.0, available at https://www.contributor-covenant.org/version/1/3/0/code-of-conduct.html 52 | 53 | [homepage]: https://www.contributor-covenant.org 54 | 55 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Legal 4 | By submitting a pull request, you represent that you have the right to license your contribution to the community, and agree by submitting the patch 5 | that your contributions are licensed under the Apache 2.0 license (see [LICENSE](LICENSE)). 6 | 7 | ## Contributor Conduct 8 | All contributors are expected to adhere to the project's [Code of Conduct](CODE_OF_CONDUCT.md). 9 | 10 | ## Submitting a bug or issue 11 | Please ensure to include the following in your bug report 12 | - A consise description of the issue, what happened and what you expected. 13 | - Simple reproduction steps 14 | - Version of the library you are using 15 | - Contextual information (Swift version, OS etc) 16 | 17 | ## Submitting a Pull Request 18 | 19 | Please ensure to include the following in your Pull Request 20 | - A description of what you are trying to do. What the PR provides to the library, additional functionality, fixing a bug etc 21 | - A description of the code changes 22 | - Documentation on how these changes are being tested 23 | - Additional tests to show your code working and to ensure future changes don't break your code. 24 | 25 | Please keep your PRs to a minimal number of changes. If a PR is large try to split it up into smaller PRs. Don't move code around unnecessarily it makes comparing old with new very hard. 26 | 27 | The main development branch of the repository is `main`. 28 | 29 | ### Formatting 30 | 31 | We use Apple's swift-format for formatting code. PRs will not be accepted if they haven't be formatted. -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # ================================ 2 | # Build image 3 | # ================================ 4 | FROM swift:6.0 as build 5 | 6 | WORKDIR /build 7 | 8 | # First just resolve dependencies. 9 | # This creates a cached layer that can be reused 10 | # as long as your Package.swift/Package.resolved 11 | # files do not change. 12 | COPY ./Package.* ./ 13 | RUN swift package resolve 14 | 15 | # Copy entire repo into container 16 | COPY . . 17 | 18 | RUN swift test --sanitize=thread 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Adam Fowler 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.10 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let swiftSettings: [SwiftSetting] = [.enableExperimentalFeature("StrictConcurrency=complete")] 7 | 8 | let package = Package( 9 | name: "swift-websocket", 10 | platforms: [.macOS(.v13), .iOS(.v16), .tvOS(.v16)], 11 | products: [ 12 | .library(name: "WSClient", targets: ["WSClient"]), 13 | .library(name: "WSCompression", targets: ["WSCompression"]), 14 | .library(name: "WSCore", targets: ["WSCore"]), 15 | ], 16 | dependencies: [ 17 | .package(url: "https://github.com/apple/swift-http-types.git", from: "1.0.0"), 18 | .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), 19 | .package(url: "https://github.com/apple/swift-nio.git", from: "2.77.0"), 20 | .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"), 21 | .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"), 22 | .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.20.0"), 23 | .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.3.0"), 24 | .package(url: "https://github.com/swift-server/swift-service-lifecycle", from: "2.0.0"), 25 | ], 26 | targets: [ 27 | .target( 28 | name: "WSClient", 29 | dependencies: [ 30 | .byName(name: "WSCore"), 31 | .product(name: "HTTPTypes", package: "swift-http-types"), 32 | .product(name: "Logging", package: "swift-log"), 33 | .product(name: "NIOCore", package: "swift-nio"), 34 | .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), 35 | .product(name: "NIOPosix", package: "swift-nio"), 36 | .product(name: "NIOSSL", package: "swift-nio-ssl"), 37 | .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), 38 | .product(name: "NIOWebSocket", package: "swift-nio"), 39 | ], 40 | swiftSettings: swiftSettings 41 | ), 42 | .target( 43 | name: "WSCore", 44 | dependencies: [ 45 | .product(name: "HTTPTypes", package: "swift-http-types"), 46 | .product(name: "NIOCore", package: "swift-nio"), 47 | .product(name: "NIOWebSocket", package: "swift-nio"), 48 | .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), 49 | ], 50 | swiftSettings: swiftSettings 51 | ), 52 | .target( 53 | name: "WSCompression", 54 | dependencies: [ 55 | .byName(name: "WSCore"), 56 | .product(name: "CompressNIO", package: "compress-nio"), 57 | ], 58 | swiftSettings: swiftSettings 59 | ), 60 | 61 | .testTarget( 62 | name: "WebSocketTests", 63 | dependencies: [ 64 | .byName(name: "WSClient"), 65 | .byName(name: "WSCompression"), 66 | ] 67 | ), 68 | ], 69 | swiftLanguageVersions: [.v5, .version("6")] 70 | ) 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## swift-websocket 2 | 3 | Support for WebSockets 4 | 5 | ### Overview 6 | 7 | Package containing support for WebSockets. It contains three libraries 8 | - WSCore: Core WebSocket handler (can be used by both server and client) 9 | - WSClient: WebSocket client 10 | - WSCompression: WebSocket compression support 11 | 12 | ### Client 13 | 14 | The WebSocketClient is built on top of structured concurrency. When you connect it calls the closure you provide with an inbound stream of frames, a writer to write outbound frames and a context structure. When you exit the closure the client will automatically perform the close handshake for you. 15 | 16 | ```swift 17 | import WSClient 18 | 19 | let ws = WebSocketClient.connect(url: "ws://mywebsocket.com/ws") { inbound, outbound, context in 20 | try await outbound.write(.text("Hello")) 21 | // you can convert the inbound stream of frames into a stream of full messages using `messages(maxSize:)` 22 | for try await frame in inbound.messages(maxSize: 1 << 14) { 23 | context.logger.info(frame) 24 | } 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/ClientChannel.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | 18 | /// ClientConnection child channel setup protocol 19 | @_documentation(visibility: internal) 20 | public protocol ClientConnectionChannel: Sendable { 21 | associatedtype Value: Sendable 22 | associatedtype Result 23 | 24 | /// Setup child channel 25 | /// - Parameters: 26 | /// - channel: Child channel 27 | /// - logger: Logger used during setup 28 | /// - Returns: Object to process input/output on child channel 29 | func setup(channel: Channel, logger: Logger) -> EventLoopFuture 30 | 31 | /// handle messages being passed down the channel pipeline 32 | /// - Parameters: 33 | /// - value: Object to process input/output on child channel 34 | /// - logger: Logger to use while processing messages 35 | func handle(value: Value, logger: Logger) async throws -> Result 36 | } 37 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/ClientConnection.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | import NIOPosix 18 | import NIOWebSocket 19 | 20 | #if canImport(Network) 21 | import Network 22 | import NIOTransportServices 23 | #endif 24 | 25 | /// A generic client connection to a server. 26 | /// 27 | /// Actual client protocol is implemented in `ClientChannel` generic parameter 28 | @_documentation(visibility: internal) 29 | public struct ClientConnection: Sendable { 30 | /// Address to connect to 31 | public struct Address: Sendable, Equatable { 32 | enum _Internal: Equatable { 33 | case hostname(_ host: String, port: Int) 34 | case unixDomainSocket(path: String) 35 | } 36 | 37 | let value: _Internal 38 | init(_ value: _Internal) { 39 | self.value = value 40 | } 41 | 42 | // Address define by host and port 43 | public static func hostname(_ host: String, port: Int) -> Self { .init(.hostname(host, port: port)) } 44 | // Address defined by unxi domain socket 45 | public static func unixDomainSocket(path: String) -> Self { .init(.unixDomainSocket(path: path)) } 46 | } 47 | 48 | typealias ChannelResult = ClientChannel.Value 49 | /// Logger used by Server 50 | let logger: Logger 51 | let eventLoopGroup: EventLoopGroup 52 | let clientChannel: ClientChannel 53 | let address: Address 54 | #if canImport(Network) 55 | let tlsOptions: NWProtocolTLS.Options? 56 | #endif 57 | 58 | /// Initialize Client 59 | public init( 60 | _ clientChannel: ClientChannel, 61 | address: Address, 62 | eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, 63 | logger: Logger 64 | ) { 65 | self.clientChannel = clientChannel 66 | self.address = address 67 | self.eventLoopGroup = eventLoopGroup 68 | self.logger = logger 69 | #if canImport(Network) 70 | self.tlsOptions = nil 71 | #endif 72 | } 73 | 74 | #if canImport(Network) 75 | /// Initialize Client with TLS options 76 | public init( 77 | _ clientChannel: ClientChannel, 78 | address: Address, 79 | transportServicesTLSOptions: TSTLSOptions, 80 | eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, 81 | logger: Logger 82 | ) throws { 83 | self.clientChannel = clientChannel 84 | self.address = address 85 | self.eventLoopGroup = eventLoopGroup 86 | self.logger = logger 87 | self.tlsOptions = transportServicesTLSOptions.options 88 | } 89 | #endif 90 | 91 | public func run() async throws -> ClientChannel.Result { 92 | let channelResult = try await self.makeClient( 93 | clientChannel: self.clientChannel, 94 | address: self.address 95 | ) 96 | return try await self.clientChannel.handle(value: channelResult, logger: self.logger) 97 | } 98 | 99 | /// Connect to server 100 | func makeClient(clientChannel: ClientChannel, address: Address) async throws -> ChannelResult { 101 | // get bootstrap 102 | let bootstrap: ClientBootstrapProtocol 103 | #if canImport(Network) 104 | if let tsBootstrap = self.createTSBootstrap() { 105 | bootstrap = tsBootstrap 106 | } else { 107 | #if os(iOS) || os(tvOS) 108 | self.logger.warning( 109 | "Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework" 110 | ) 111 | #endif 112 | bootstrap = self.createSocketsBootstrap() 113 | } 114 | #else 115 | bootstrap = self.createSocketsBootstrap() 116 | #endif 117 | 118 | // connect 119 | let result: ChannelResult 120 | do { 121 | switch address.value { 122 | case .hostname(let host, let port): 123 | result = 124 | try await bootstrap 125 | .connect(host: host, port: port) { channel in 126 | clientChannel.setup(channel: channel, logger: self.logger) 127 | } 128 | self.logger.debug("Client connnected to \(host):\(port)") 129 | case .unixDomainSocket(let path): 130 | result = 131 | try await bootstrap 132 | .connect(unixDomainSocketPath: path) { channel in 133 | clientChannel.setup(channel: channel, logger: self.logger) 134 | } 135 | self.logger.debug("Client connnected to socket path \(path)") 136 | } 137 | return result 138 | } catch { 139 | throw error 140 | } 141 | } 142 | 143 | /// create a BSD sockets based bootstrap 144 | private func createSocketsBootstrap() -> ClientBootstrap { 145 | ClientBootstrap(group: self.eventLoopGroup) 146 | .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) 147 | } 148 | 149 | #if canImport(Network) 150 | /// create a NIOTransportServices bootstrap using Network.framework 151 | private func createTSBootstrap() -> NIOTSConnectionBootstrap? { 152 | guard 153 | let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)? 154 | .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) 155 | else { 156 | return nil 157 | } 158 | if let tlsOptions { 159 | return bootstrap.tlsOptions(tlsOptions) 160 | } 161 | return bootstrap 162 | } 163 | #endif 164 | } 165 | 166 | protocol ClientBootstrapProtocol { 167 | func connect( 168 | host: String, 169 | port: Int, 170 | channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture 171 | ) async throws -> Output 172 | 173 | func connect( 174 | unixDomainSocketPath: String, 175 | channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture 176 | ) async throws -> Output 177 | } 178 | 179 | extension ClientBootstrap: ClientBootstrapProtocol {} 180 | #if canImport(Network) 181 | extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {} 182 | #endif 183 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/Parser.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2021-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | // Half inspired by Reader class from John Sundell's Ink project 16 | // https://github.com/JohnSundell/Ink/blob/master/Sources/Ink/Internal/Reader.swift 17 | // with optimisation working ie removing String and doing my own UTF8 processing inspired by Fabian Fett's work in 18 | // https://github.com/fabianfett/pure-swift-json/blob/master/Sources/PureSwiftJSONParsing/DocumentReader.swift 19 | 20 | /// Reader object for parsing String buffers 21 | struct Parser: Sendable { 22 | enum Error: Swift.Error { 23 | case overflow 24 | case unexpected 25 | case emptyString 26 | case invalidUTF8 27 | } 28 | 29 | /// Create a Parser object 30 | /// - Parameter string: UTF8 data to parse 31 | init?(_ utf8Data: some Collection, validateUTF8: Bool = true) { 32 | if let buffer = utf8Data as? [UInt8] { 33 | self.buffer = buffer 34 | } else { 35 | self.buffer = Array(utf8Data) 36 | } 37 | self.index = 0 38 | self.range = 0.. 65 | } 66 | 67 | // MARK: sub-parsers 68 | 69 | extension Parser { 70 | /// initialise a parser that parses a section of the buffer attached to another parser 71 | private init(_ parser: Parser, range: Range) { 72 | self.buffer = parser.buffer 73 | self.index = range.startIndex 74 | self.range = range 75 | 76 | precondition(range.startIndex >= 0 && range.endIndex <= self.buffer.endIndex) 77 | precondition(range.startIndex == self.buffer.endIndex || self.buffer[range.startIndex] & 0xC0 != 0x80) // check we arent in the middle of a UTF8 character 78 | } 79 | 80 | /// initialise a parser that parses a section of the buffer attached to this parser 81 | func subParser(_ range: Range) -> Parser { 82 | Parser(self, range: range) 83 | } 84 | } 85 | 86 | extension Parser { 87 | /// Return current character 88 | /// - Throws: .overflow 89 | /// - Returns: Current character 90 | mutating func character() throws -> Unicode.Scalar { 91 | guard !self.reachedEnd() else { throw Error.overflow } 92 | return unsafeCurrentAndAdvance() 93 | } 94 | 95 | /// Read the current character and return if it is as intended. If character test returns true then move forward 1 96 | /// - Parameter char: character to compare against 97 | /// - Throws: .overflow 98 | /// - Returns: If current character was the one we expected 99 | mutating func read(_ char: Unicode.Scalar) throws -> Bool { 100 | let initialIndex = self.index 101 | let c = try character() 102 | guard c == char else { 103 | self.index = initialIndex 104 | return false 105 | } 106 | return true 107 | } 108 | 109 | /// Read the current character and check if it is in a set of characters If character test returns true then move forward 1 110 | /// - Parameter characterSet: Set of characters to compare against 111 | /// - Throws: .overflow 112 | /// - Returns: If current character is in character set 113 | mutating func read(_ characterSet: Set) throws -> Bool { 114 | let initialIndex = self.index 115 | let c = try character() 116 | guard characterSet.contains(c) else { 117 | self.index = initialIndex 118 | return false 119 | } 120 | return true 121 | } 122 | 123 | /// Compare characters at current position against provided string. If the characters are the same as string provided advance past string 124 | /// - Parameter string: String to compare against 125 | /// - Throws: .overflow, .emptyString 126 | /// - Returns: If characters at current position equal string 127 | mutating func read(_ string: String) throws -> Bool { 128 | let initialIndex = self.index 129 | guard string.count > 0 else { throw Error.emptyString } 130 | let subString = try read(count: string.count) 131 | guard subString.string == string else { 132 | self.index = initialIndex 133 | return false 134 | } 135 | return true 136 | } 137 | 138 | /// Read next so many characters from buffer 139 | /// - Parameter count: Number of characters to read 140 | /// - Throws: .overflow 141 | /// - Returns: The string read from the buffer 142 | mutating func read(count: Int) throws -> Parser { 143 | var count = count 144 | var readEndIndex = self.index 145 | while count > 0 { 146 | guard readEndIndex != self.range.endIndex else { throw Error.overflow } 147 | readEndIndex = skipUTF8Character(at: readEndIndex) 148 | count -= 1 149 | } 150 | let result = self.subParser(self.index.. Parser { 160 | let startIndex = self.index 161 | while !self.reachedEnd() { 162 | if unsafeCurrent() == until { 163 | return self.subParser(startIndex.., throwOnOverflow: Bool = true) throws -> Parser { 179 | let startIndex = self.index 180 | while !self.reachedEnd() { 181 | if characterSet.contains(unsafeCurrent()) { 182 | return self.subParser(startIndex.. Bool, throwOnOverflow: Bool = true) throws -> Parser { 198 | let startIndex = self.index 199 | while !self.reachedEnd() { 200 | if until(unsafeCurrent()) { 201 | return self.subParser(startIndex.., throwOnOverflow: Bool = true) throws -> Parser { 217 | let startIndex = self.index 218 | while !self.reachedEnd() { 219 | if unsafeCurrent()[keyPath: keyPath] { 220 | return self.subParser(startIndex.. Parser { 238 | var untilString = untilString 239 | return try untilString.withUTF8 { utf8 in 240 | guard utf8.count > 0 else { throw Error.emptyString } 241 | let startIndex = self.index 242 | var foundIndex = self.index 243 | var untilIndex = 0 244 | while !self.reachedEnd() { 245 | if self.buffer[self.index] == utf8[untilIndex] { 246 | if untilIndex == 0 { 247 | foundIndex = self.index 248 | } 249 | untilIndex += 1 250 | if untilIndex == utf8.endIndex { 251 | unsafeAdvance() 252 | if skipToEnd == false { 253 | self.index = foundIndex 254 | } 255 | let result = self.subParser(startIndex.. Parser { 274 | let startIndex = self.index 275 | self.index = self.range.endIndex 276 | return self.subParser(startIndex.. Int { 283 | var count = 0 284 | while !self.reachedEnd(), 285 | unsafeCurrent() == `while` 286 | { 287 | unsafeAdvance() 288 | count += 1 289 | } 290 | return count 291 | } 292 | 293 | /// Read while character at current position is in supplied set 294 | /// - Parameter while: character set to check 295 | /// - Returns: String read from buffer 296 | @discardableResult mutating func read(while characterSet: Set) -> Parser { 297 | let startIndex = self.index 298 | while !self.reachedEnd(), 299 | characterSet.contains(unsafeCurrent()) 300 | { 301 | unsafeAdvance() 302 | } 303 | return self.subParser(startIndex.. Bool) -> Parser { 310 | let startIndex = self.index 311 | while !self.reachedEnd(), 312 | `while`(unsafeCurrent()) 313 | { 314 | unsafeAdvance() 315 | } 316 | return self.subParser(startIndex..) -> Parser { 323 | let startIndex = self.index 324 | while !self.reachedEnd(), 325 | unsafeCurrent()[keyPath: keyPath] 326 | { 327 | unsafeAdvance() 328 | } 329 | return self.subParser(startIndex.. [Parser] { 336 | var subParsers: [Parser] = [] 337 | while !self.reachedEnd() { 338 | do { 339 | let section = try read(until: separator) 340 | subParsers.append(section) 341 | unsafeAdvance() 342 | } catch { 343 | if !self.reachedEnd() { 344 | subParsers.append(self.readUntilTheEnd()) 345 | } 346 | } 347 | } 348 | return subParsers 349 | } 350 | 351 | /// Return whether we have reached the end of the buffer 352 | /// - Returns: Have we reached the end 353 | func reachedEnd() -> Bool { 354 | self.index == self.range.endIndex 355 | } 356 | } 357 | 358 | /// Public versions of internal functions which include tests for overflow 359 | extension Parser { 360 | /// Return the character at the current position 361 | /// - Throws: .overflow 362 | /// - Returns: Unicode.Scalar 363 | func current() -> Unicode.Scalar { 364 | guard !self.reachedEnd() else { return Unicode.Scalar(0) } 365 | return unsafeCurrent() 366 | } 367 | 368 | /// Move forward one character 369 | /// - Throws: .overflow 370 | mutating func advance() throws { 371 | guard !self.reachedEnd() else { throw Error.overflow } 372 | return self.unsafeAdvance() 373 | } 374 | 375 | /// Move forward so many character 376 | /// - Parameter amount: number of characters to move forward 377 | /// - Throws: .overflow 378 | mutating func advance(by amount: Int) throws { 379 | var amount = amount 380 | while amount > 0 { 381 | guard !self.reachedEnd() else { throw Error.overflow } 382 | self.index = skipUTF8Character(at: self.index) 383 | amount -= 1 384 | } 385 | } 386 | 387 | /// Move backwards one character 388 | /// - Throws: .overflow 389 | mutating func retreat() throws { 390 | guard self.index > self.range.startIndex else { throw Error.overflow } 391 | self.index = backOneUTF8Character(at: self.index) 392 | } 393 | 394 | /// Move back so many characters 395 | /// - Parameter amount: number of characters to move back 396 | /// - Throws: .overflow 397 | mutating func retreat(by amount: Int) throws { 398 | var amount = amount 399 | while amount > 0 { 400 | guard self.index > self.range.startIndex else { throw Error.overflow } 401 | self.index = backOneUTF8Character(at: self.index) 402 | amount -= 1 403 | } 404 | } 405 | 406 | /// Move parser to beginning of string 407 | mutating func moveToStart() { 408 | self.index = self.range.startIndex 409 | } 410 | 411 | /// Move parser to end of string 412 | mutating func moveToEnd() { 413 | self.index = self.range.endIndex 414 | } 415 | 416 | mutating func unsafeAdvance() { 417 | self.index = skipUTF8Character(at: self.index) 418 | } 419 | 420 | mutating func unsafeAdvance(by amount: Int) { 421 | var amount = amount 422 | while amount > 0 { 423 | self.index = skipUTF8Character(at: self.index) 424 | amount -= 1 425 | } 426 | } 427 | } 428 | 429 | /// extend Parser to conform to Sequence 430 | extension Parser: Sequence { 431 | public typealias Element = Unicode.Scalar 432 | 433 | public func makeIterator() -> Iterator { 434 | Iterator(self) 435 | } 436 | 437 | public struct Iterator: IteratorProtocol { 438 | public typealias Element = Unicode.Scalar 439 | 440 | var parser: Parser 441 | 442 | init(_ parser: Parser) { 443 | self.parser = parser 444 | } 445 | 446 | public mutating func next() -> Unicode.Scalar? { 447 | guard !self.parser.reachedEnd() else { return nil } 448 | return self.parser.unsafeCurrentAndAdvance() 449 | } 450 | } 451 | } 452 | 453 | // internal versions without checks 454 | extension Parser { 455 | fileprivate func unsafeCurrent() -> Unicode.Scalar { 456 | decodeUTF8Character(at: self.index).0 457 | } 458 | 459 | fileprivate mutating func unsafeCurrentAndAdvance() -> Unicode.Scalar { 460 | let (unicodeScalar, index) = decodeUTF8Character(at: self.index) 461 | self.index = index 462 | return unicodeScalar 463 | } 464 | 465 | fileprivate mutating func _setPosition(_ index: Int) { 466 | self.index = index 467 | } 468 | 469 | fileprivate func makeString(_ bytes: Bytes) -> String where Bytes.Element == UInt8, Bytes.Index == Int { 470 | if let string = bytes.withContiguousStorageIfAvailable({ String(decoding: $0, as: Unicode.UTF8.self) }) { 471 | return string 472 | } else { 473 | return String(decoding: bytes, as: Unicode.UTF8.self) 474 | } 475 | } 476 | } 477 | 478 | // UTF8 parsing 479 | extension Parser { 480 | func decodeUTF8Character(at index: Int) -> (Unicode.Scalar, Int) { 481 | var index = index 482 | let byte1 = UInt32(buffer[index]) 483 | var value: UInt32 484 | if byte1 & 0xC0 == 0xC0 { 485 | index += 1 486 | let byte2 = UInt32(buffer[index] & 0x3F) 487 | if byte1 & 0xE0 == 0xE0 { 488 | index += 1 489 | let byte3 = UInt32(buffer[index] & 0x3F) 490 | if byte1 & 0xF0 == 0xF0 { 491 | index += 1 492 | let byte4 = UInt32(buffer[index] & 0x3F) 493 | value = (byte1 & 0x7) << 18 + byte2 << 12 + byte3 << 6 + byte4 494 | } else { 495 | value = (byte1 & 0xF) << 12 + byte2 << 6 + byte3 496 | } 497 | } else { 498 | value = (byte1 & 0x1F) << 6 + byte2 499 | } 500 | } else { 501 | value = byte1 & 0x7F 502 | } 503 | let unicodeScalar = Unicode.Scalar(value)! 504 | return (unicodeScalar, index + 1) 505 | } 506 | 507 | func skipUTF8Character(at index: Int) -> Int { 508 | if self.buffer[index] & 0x80 != 0x80 { return index + 1 } 509 | if self.buffer[index + 1] & 0xC0 == 0x80 { return index + 2 } 510 | if self.buffer[index + 2] & 0xC0 == 0x80 { return index + 3 } 511 | return index + 4 512 | } 513 | 514 | func backOneUTF8Character(at index: Int) -> Int { 515 | if self.buffer[index - 1] & 0xC0 != 0x80 { return index - 1 } 516 | if self.buffer[index - 2] & 0xC0 != 0x80 { return index - 2 } 517 | if self.buffer[index - 3] & 0xC0 != 0x80 { return index - 3 } 518 | return index - 4 519 | } 520 | 521 | /// same as `decodeUTF8Character` but adds extra validation, so we can make assumptions later on in decode and skip 522 | func validateUTF8Character(at index: Int) -> (Unicode.Scalar?, Int) { 523 | var index = index 524 | let byte1 = UInt32(buffer[index]) 525 | var value: UInt32 526 | if byte1 & 0xC0 == 0xC0 { 527 | index += 1 528 | let byte = UInt32(buffer[index]) 529 | guard byte & 0xC0 == 0x80 else { return (nil, index) } 530 | let byte2 = UInt32(byte & 0x3F) 531 | if byte1 & 0xE0 == 0xE0 { 532 | index += 1 533 | let byte = UInt32(buffer[index]) 534 | guard byte & 0xC0 == 0x80 else { return (nil, index) } 535 | let byte3 = UInt32(byte & 0x3F) 536 | if byte1 & 0xF0 == 0xF0 { 537 | index += 1 538 | let byte = UInt32(buffer[index]) 539 | guard byte & 0xC0 == 0x80 else { return (nil, index) } 540 | let byte4 = UInt32(byte & 0x3F) 541 | value = (byte1 & 0x7) << 18 + byte2 << 12 + byte3 << 6 + byte4 542 | } else { 543 | value = (byte1 & 0xF) << 12 + byte2 << 6 + byte3 544 | } 545 | } else { 546 | value = (byte1 & 0x1F) << 6 + byte2 547 | } 548 | } else { 549 | value = byte1 & 0x7F 550 | } 551 | let unicodeScalar = Unicode.Scalar(value) 552 | return (unicodeScalar, index + 1) 553 | } 554 | 555 | /// return if the buffer is valid UTF8 556 | func validateUTF8() -> Bool { 557 | var index = self.range.startIndex 558 | while index < self.range.endIndex { 559 | let (scalar, newIndex) = self.validateUTF8Character(at: index) 560 | guard scalar != nil else { return false } 561 | index = newIndex 562 | } 563 | return true 564 | } 565 | 566 | private static let asciiHexValues: [UInt8] = [ 567 | /* 00 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 568 | /* 08 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 569 | /* 10 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 570 | /* 18 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 571 | /* 20 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 572 | /* 28 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 573 | /* 30 */ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 574 | /* 38 */ 0x08, 0x09, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 575 | /* 40 */ 0x80, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x80, 576 | /* 48 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 577 | /* 50 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 578 | /* 58 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 579 | /* 60 */ 0x80, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x80, 580 | /* 68 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 581 | /* 70 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 582 | /* 78 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 583 | 584 | /* 80 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 585 | /* 88 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 586 | /* 90 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 587 | /* 98 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 588 | /* A0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 589 | /* A8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 590 | /* B0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 591 | /* B8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 592 | /* C0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 593 | /* C8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 594 | /* D0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 595 | /* D8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 596 | /* E0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 597 | /* E8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 598 | /* F0 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 599 | /* F8 */ 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 600 | ] 601 | 602 | /// percent decode UTF8 603 | public func percentDecode() -> String? { 604 | struct DecodeError: Swift.Error {} 605 | func _percentDecode(_ original: ArraySlice, _ bytes: UnsafeMutableBufferPointer) throws -> Int { 606 | var newIndex = 0 607 | var index = original.startIndex 608 | while index < (original.endIndex - 2) { 609 | // if we have found a percent sign 610 | if original[index] == 0x25 { 611 | let high = Self.asciiHexValues[Int(original[index + 1])] 612 | let low = Self.asciiHexValues[Int(original[index + 2])] 613 | index += 3 614 | if ((high | low) & 0x80) != 0 { 615 | throw DecodeError() 616 | } 617 | bytes[newIndex] = (high << 4) | low 618 | newIndex += 1 619 | } else { 620 | bytes[newIndex] = original[index] 621 | newIndex += 1 622 | index += 1 623 | } 624 | } 625 | while index < original.endIndex { 626 | bytes[newIndex] = original[index] 627 | newIndex += 1 628 | index += 1 629 | } 630 | return newIndex 631 | } 632 | guard self.index != self.range.endIndex else { return "" } 633 | do { 634 | if #available(macOS 11, macCatalyst 14.0, iOS 14.0, tvOS 14.0, *) { 635 | return try String(unsafeUninitializedCapacity: range.endIndex - index) { bytes -> Int in 636 | try _percentDecode(self.buffer[self.index.. { 667 | init(_ string: String) { 668 | self = Set(string.unicodeScalars) 669 | } 670 | } 671 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/TLSClientChannel.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | import NIOSSL 18 | import NIOWebSocket 19 | 20 | /// Sets up client channel to use TLS before accessing base channel setup 21 | @_documentation(visibility: internal) 22 | public struct TLSClientChannel: ClientConnectionChannel { 23 | public typealias Value = BaseChannel.Value 24 | public typealias Result = BaseChannel.Result 25 | 26 | /// Initialize TLSChannel 27 | /// - Parameters: 28 | /// - baseChannel: Base child channel wrap 29 | /// - tlsConfiguration: TLS configuration 30 | public init(_ baseChannel: BaseChannel, tlsConfiguration: TLSConfiguration, serverHostname: String? = nil) throws { 31 | self.baseChannel = baseChannel 32 | self.sslContext = try NIOSSLContext(configuration: tlsConfiguration) 33 | self.serverHostname = serverHostname 34 | } 35 | 36 | /// Setup child channel with TLS and the base channel setup 37 | /// - Parameters: 38 | /// - channel: Child channel 39 | /// - logger: Logger used during setup 40 | /// - Returns: Object to process input/output on child channel 41 | @inlinable 42 | public func setup(channel: Channel, logger: Logger) -> EventLoopFuture { 43 | channel.eventLoop.makeCompletedFuture { 44 | let sslHandler = try NIOSSLClientHandler(context: self.sslContext, serverHostname: self.serverHostname) 45 | try channel.pipeline.syncOperations.addHandler(sslHandler) 46 | }.flatMap { 47 | self.baseChannel.setup(channel: channel, logger: logger) 48 | } 49 | } 50 | 51 | @inlinable 52 | /// handle messages being passed down the channel pipeline 53 | /// - Parameters: 54 | /// - value: Object to process input/output on child channel 55 | /// - logger: Logger to use while processing messages 56 | public func handle(value: BaseChannel.Value, logger: Logging.Logger) async throws -> Result { 57 | try await self.baseChannel.handle(value: value, logger: logger) 58 | } 59 | 60 | @usableFromInline 61 | let sslContext: NIOSSLContext 62 | @usableFromInline 63 | let serverHostname: String? 64 | @usableFromInline 65 | var baseChannel: BaseChannel 66 | } 67 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/TSTLSOptions.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2021-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | #if canImport(Network) 16 | import Foundation 17 | import Network 18 | import Security 19 | 20 | /// Wrapper for NIO transport services TLS options 21 | public struct TSTLSOptions: Sendable { 22 | public struct Error: Swift.Error, Equatable { 23 | enum _Internal: Equatable { 24 | case invalidFormat 25 | case interactionNotAllowed 26 | case verificationFailed 27 | } 28 | 29 | private let value: _Internal 30 | init(_ value: _Internal) { 31 | self.value = value 32 | } 33 | 34 | // invalid format 35 | public static var invalidFormat: Self { .init(.invalidFormat) } 36 | // unable to import p12 as no interaction is allowed 37 | public static var interactionNotAllowed: Self { .init(.interactionNotAllowed) } 38 | // MAC verification failed during PKCS12 import (wrong password?) 39 | public static var verificationFailed: Self { .init(.verificationFailed) } 40 | } 41 | 42 | public struct Identity { 43 | let secIdentity: SecIdentity 44 | 45 | public static func secIdentity(_ secIdentity: SecIdentity) -> Self { 46 | .init(secIdentity: secIdentity) 47 | } 48 | 49 | public static func p12(filename: String, password: String) throws -> Self { 50 | guard let secIdentity = try Self.loadP12(filename: filename, password: password) else { throw Error.invalidFormat } 51 | return .init(secIdentity: secIdentity) 52 | } 53 | 54 | private static func loadP12(filename: String, password: String) throws -> SecIdentity? { 55 | let data = try Data(contentsOf: URL(fileURLWithPath: filename)) 56 | let options: [String: String] = [kSecImportExportPassphrase as String: password] 57 | var rawItems: CFArray? 58 | let result = SecPKCS12Import(data as CFData, options as CFDictionary, &rawItems) 59 | switch result { 60 | case errSecSuccess: 61 | break 62 | case errSecInteractionNotAllowed: 63 | throw Error.interactionNotAllowed 64 | case errSecPkcs12VerifyFailure: 65 | throw Error.verificationFailed 66 | default: 67 | throw Error.invalidFormat 68 | } 69 | let items = rawItems! as! [[String: Any]] 70 | let firstItem = items[0] 71 | return firstItem[kSecImportItemIdentity as String] as! SecIdentity? 72 | } 73 | } 74 | 75 | /// Struct defining an array of certificates 76 | public struct Certificates { 77 | let certificates: [SecCertificate] 78 | 79 | /// Create certificate array from already loaded SecCertificate array 80 | public static var none: Self { .init(certificates: []) } 81 | 82 | /// Create certificate array from already loaded SecCertificate array 83 | public static func certificates(_ secCertificates: [SecCertificate]) -> Self { .init(certificates: secCertificates) } 84 | 85 | /// Create certificate array from DER file 86 | public static func der(filename: String) throws -> Self { 87 | let certificateData = try Data(contentsOf: URL(fileURLWithPath: filename)) 88 | guard let secCertificate = SecCertificateCreateWithData(nil, certificateData as CFData) else { throw Error.invalidFormat } 89 | return .init(certificates: [secCertificate]) 90 | } 91 | } 92 | 93 | /// Initialize TSTLSOptions 94 | public init(_ options: NWProtocolTLS.Options?) { 95 | if let options { 96 | self.value = .some(options) 97 | } else { 98 | self.value = .none 99 | } 100 | } 101 | 102 | /// TSTLSOptions holding options 103 | public static func options(_ options: NWProtocolTLS.Options) -> Self { 104 | .init(value: .some(options)) 105 | } 106 | 107 | public static func options( 108 | serverIdentity: Identity 109 | ) -> Self? { 110 | let options = NWProtocolTLS.Options() 111 | 112 | // server identity 113 | guard let secIdentity = sec_identity_create(serverIdentity.secIdentity) else { return nil } 114 | sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) 115 | 116 | return .init(value: .some(options)) 117 | } 118 | 119 | public static func options( 120 | clientIdentity: Identity, 121 | trustRoots: Certificates = .none, 122 | serverName: String? = nil 123 | ) -> Self? { 124 | let options = NWProtocolTLS.Options() 125 | 126 | // server identity 127 | guard let secIdentity = sec_identity_create(clientIdentity.secIdentity) else { return nil } 128 | sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) 129 | if let serverName { 130 | sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverName) 131 | } 132 | // sec_protocol_options_set 133 | sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) 134 | 135 | // add verify block to control certificate verification 136 | if trustRoots.certificates.count > 0 { 137 | sec_protocol_options_set_verify_block( 138 | options.securityProtocolOptions, 139 | { _, sec_trust, sec_protocol_verify_complete in 140 | let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue() 141 | SecTrustSetAnchorCertificates(trust, trustRoots.certificates as CFArray) 142 | SecTrustEvaluateAsyncWithError(trust, Self.tlsDispatchQueue) { _, result, error in 143 | if let error { 144 | print("Trust failed: \(error.localizedDescription)") 145 | } 146 | sec_protocol_verify_complete(result) 147 | } 148 | }, 149 | Self.tlsDispatchQueue 150 | ) 151 | } 152 | return .init(value: .some(options)) 153 | } 154 | 155 | /// Empty TSTLSOptions 156 | public static var none: Self { 157 | .init(value: .none) 158 | } 159 | 160 | var options: NWProtocolTLS.Options? { 161 | if case .some(let options) = self.value { return options } 162 | return nil 163 | } 164 | 165 | /// Internal storage for TSTLSOptions. @unchecked Sendable while NWProtocolTLS.Options 166 | /// is not Sendable 167 | private enum Internal: @unchecked Sendable { 168 | case some(NWProtocolTLS.Options) 169 | case none 170 | } 171 | 172 | private let value: Internal 173 | private init(value: Internal) { self.value = value } 174 | 175 | /// Dispatch queue used by Network framework TLS to control certificate verification 176 | static let tlsDispatchQueue = DispatchQueue(label: "WSTSTLSConfiguration") 177 | } 178 | #endif 179 | -------------------------------------------------------------------------------- /Sources/WSClient/Client/URI.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2021-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | /// Simple URL parser 16 | struct URI: Sendable, CustomStringConvertible, ExpressibleByStringLiteral { 17 | struct Scheme: RawRepresentable, Equatable { 18 | let rawValue: String 19 | 20 | init(rawValue: String) { 21 | self.rawValue = rawValue 22 | } 23 | 24 | static var http: Self { .init(rawValue: "http") } 25 | static var https: Self { .init(rawValue: "https") } 26 | static var unix: Self { .init(rawValue: "unix") } 27 | static var http_unix: Self { .init(rawValue: "http_unix") } 28 | static var https_unix: Self { .init(rawValue: "https_unix") } 29 | static var ws: Self { .init(rawValue: "ws") } 30 | static var wss: Self { .init(rawValue: "wss") } 31 | } 32 | 33 | let string: String 34 | 35 | /// URL scheme 36 | var scheme: Scheme? { self._scheme.map { .init(rawValue: $0.string) } } 37 | /// URL host 38 | var host: String? { self._host.map(\.string) } 39 | /// URL port 40 | var port: Int? { self._port.map { Int($0.string) } ?? nil } 41 | /// URL path 42 | var path: String { self._path.map(\.string) ?? "/" } 43 | /// URL query 44 | var query: String? { self._query.map { String($0.string) } } 45 | 46 | private let _scheme: Parser? 47 | private let _host: Parser? 48 | private let _port: Parser? 49 | private let _path: Parser? 50 | private let _query: Parser? 51 | 52 | var description: String { self.string } 53 | 54 | /// Initialize `URI` from `String` 55 | /// - Parameter string: input string 56 | init(_ string: String) { 57 | enum ParsingState { 58 | case readingScheme 59 | case readingHost 60 | case readingPort 61 | case readingPath 62 | case readingQuery 63 | case finished 64 | } 65 | var scheme: Parser? 66 | var host: Parser? 67 | var port: Parser? 68 | var path: Parser? 69 | var query: Parser? 70 | var state: ParsingState = .readingScheme 71 | if string.first == "/" { 72 | state = .readingPath 73 | } 74 | 75 | var parser = Parser(string) 76 | while state != .finished { 77 | if parser.reachedEnd() { break } 78 | switch state { 79 | case .readingScheme: 80 | // search for "://" to find scheme and host 81 | scheme = try? parser.read(untilString: "://", skipToEnd: true) 82 | if scheme != nil { 83 | state = .readingHost 84 | } else { 85 | state = .readingPath 86 | } 87 | 88 | case .readingHost: 89 | let h = try! parser.read(until: Self.hostEndSet, throwOnOverflow: false) 90 | if h.count != 0 { 91 | host = h 92 | } 93 | if parser.current() == ":" { 94 | state = .readingPort 95 | } else if parser.current() == "?" { 96 | state = .readingQuery 97 | } else { 98 | state = .readingPath 99 | } 100 | 101 | case .readingPort: 102 | parser.unsafeAdvance() 103 | port = try! parser.read(until: Self.portEndSet, throwOnOverflow: false) 104 | state = .readingPath 105 | 106 | case .readingPath: 107 | path = try! parser.read(until: "?", throwOnOverflow: false) 108 | state = .readingQuery 109 | 110 | case .readingQuery: 111 | parser.unsafeAdvance() 112 | query = try! parser.read(until: "#", throwOnOverflow: false) 113 | state = .finished 114 | 115 | case .finished: 116 | break 117 | } 118 | } 119 | 120 | self.string = string 121 | self._scheme = scheme 122 | self._host = host 123 | self._port = port 124 | self._path = path 125 | self._query = query 126 | } 127 | 128 | init(stringLiteral value: String) { 129 | self.init(value) 130 | } 131 | 132 | private static let hostEndSet: Set = Set(":/?") 133 | private static let portEndSet: Set = Set("/?") 134 | } 135 | -------------------------------------------------------------------------------- /Sources/WSClient/Exports.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | @_exported @_documentation(visibility: internal) import WSCore 16 | -------------------------------------------------------------------------------- /Sources/WSClient/WebSocketClient.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024-2025 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import HTTPTypes 16 | import Logging 17 | import NIOCore 18 | import NIOPosix 19 | import NIOSSL 20 | import NIOTransportServices 21 | import NIOWebSocket 22 | import WSCore 23 | 24 | /// WebSocket client 25 | /// 26 | /// Connect to HTTP server with WebSocket upgrade available. 27 | /// 28 | /// Supports TLS via both NIOSSL and Network framework. 29 | /// 30 | /// Initialize the WebSocketClient with your handler and then call ``WebSocketClient/run()`` 31 | /// to connect. The handler is provider with an `inbound` stream of WebSocket packets coming 32 | /// from the server and an `outbound` writer that can be used to write packets to the server. 33 | /// ```swift 34 | /// let webSocket = WebSocketClient(url: "ws://test.org/ws", logger: logger) { inbound, outbound, context in 35 | /// for try await packet in inbound { 36 | /// if case .text(let string) = packet { 37 | /// try await outbound.write(.text(string)) 38 | /// } 39 | /// } 40 | /// } 41 | /// ``` 42 | public struct WebSocketClient { 43 | /// Client implementation of ``/WSCore/WebSocketContext``. 44 | public struct Context: WebSocketContext { 45 | public let logger: Logger 46 | 47 | package init(logger: Logger) { 48 | self.logger = logger 49 | } 50 | } 51 | 52 | enum MultiPlatformTLSConfiguration: Sendable { 53 | case niossl(TLSConfiguration) 54 | #if canImport(Network) 55 | case ts(TSTLSOptions) 56 | #endif 57 | } 58 | 59 | /// WebSocket URL 60 | let url: URI 61 | /// WebSocket data handler 62 | let handler: WebSocketDataHandler 63 | /// configuration 64 | let configuration: WebSocketClientConfiguration 65 | /// EventLoopGroup to use 66 | let eventLoopGroup: EventLoopGroup 67 | /// Logger 68 | let logger: Logger 69 | /// TLS configuration 70 | let tlsConfiguration: MultiPlatformTLSConfiguration? 71 | 72 | /// Initialize websocket client 73 | /// 74 | /// - Parametes: 75 | /// - url: URL of websocket 76 | /// - tlsConfiguration: TLS configuration 77 | /// - handler: WebSocket data handler 78 | /// - maxFrameSize: Max frame size for a single packet 79 | /// - eventLoopGroup: EventLoopGroup to run WebSocket client on 80 | /// - logger: Logger 81 | public init( 82 | url: String, 83 | configuration: WebSocketClientConfiguration = .init(), 84 | tlsConfiguration: TLSConfiguration? = nil, 85 | eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, 86 | logger: Logger, 87 | handler: @escaping WebSocketDataHandler 88 | ) { 89 | self.url = .init(url) 90 | self.handler = handler 91 | self.configuration = configuration 92 | self.eventLoopGroup = eventLoopGroup 93 | self.logger = logger 94 | self.tlsConfiguration = tlsConfiguration.map { .niossl($0) } 95 | } 96 | 97 | #if canImport(Network) 98 | /// Initialize websocket client 99 | /// 100 | /// - Parametes: 101 | /// - url: URL of websocket 102 | /// - transportServicesTLSOptions: TLS options for NIOTransportServices 103 | /// - handler: WebSocket data handler 104 | /// - maxFrameSize: Max frame size for a single packet 105 | /// - eventLoopGroup: EventLoopGroup to run WebSocket client on 106 | /// - logger: Logger 107 | public init( 108 | url: String, 109 | configuration: WebSocketClientConfiguration = .init(), 110 | transportServicesTLSOptions: TSTLSOptions, 111 | eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, 112 | logger: Logger, 113 | handler: @escaping WebSocketDataHandler 114 | ) { 115 | self.url = .init(url) 116 | self.handler = handler 117 | self.configuration = configuration 118 | self.eventLoopGroup = eventLoopGroup 119 | self.logger = logger 120 | self.tlsConfiguration = .ts(transportServicesTLSOptions) 121 | } 122 | #endif 123 | 124 | /// Connect and run handler 125 | /// - Returns: WebSocket close frame details if server returned any 126 | @discardableResult public func run() async throws -> WebSocketCloseFrame? { 127 | guard let host = url.host else { throw WebSocketClientError.invalidURL } 128 | let requiresTLS = self.url.scheme == .wss || self.url.scheme == .https 129 | let port = self.url.port ?? (requiresTLS ? 443 : 80) 130 | if requiresTLS { 131 | switch self.tlsConfiguration { 132 | case .niossl(let tlsConfiguration): 133 | let client = try ClientConnection( 134 | TLSClientChannel( 135 | WebSocketClientChannel(handler: handler, url: url, configuration: self.configuration), 136 | tlsConfiguration: tlsConfiguration, 137 | serverHostname: self.configuration.sniHostname ?? host 138 | ), 139 | address: .hostname(host, port: port), 140 | eventLoopGroup: self.eventLoopGroup, 141 | logger: self.logger 142 | ) 143 | return try await client.run() 144 | 145 | #if canImport(Network) 146 | case .ts(let tlsOptions): 147 | let client = try ClientConnection( 148 | WebSocketClientChannel(handler: handler, url: url, configuration: self.configuration), 149 | address: .hostname(host, port: port), 150 | transportServicesTLSOptions: tlsOptions, 151 | eventLoopGroup: self.eventLoopGroup, 152 | logger: self.logger 153 | ) 154 | return try await client.run() 155 | 156 | #endif 157 | case .none: 158 | let client = try ClientConnection( 159 | TLSClientChannel( 160 | WebSocketClientChannel( 161 | handler: handler, 162 | url: url, 163 | configuration: self.configuration 164 | ), 165 | tlsConfiguration: TLSConfiguration.makeClientConfiguration(), 166 | serverHostname: host 167 | ), 168 | address: .hostname(host, port: port), 169 | eventLoopGroup: self.eventLoopGroup, 170 | logger: self.logger 171 | ) 172 | return try await client.run() 173 | } 174 | } else { 175 | let client = try ClientConnection( 176 | WebSocketClientChannel( 177 | handler: handler, 178 | url: url, 179 | configuration: self.configuration 180 | ), 181 | address: .hostname(host, port: port), 182 | eventLoopGroup: self.eventLoopGroup, 183 | logger: self.logger 184 | ) 185 | return try await client.run() 186 | } 187 | } 188 | } 189 | 190 | extension WebSocketClient { 191 | /// Create websocket client, connect and handle connection 192 | /// 193 | /// - Parametes: 194 | /// - url: URL of websocket 195 | /// - tlsConfiguration: TLS configuration 196 | /// - maxFrameSize: Max frame size for a single packet 197 | /// - eventLoopGroup: EventLoopGroup to run WebSocket client on 198 | /// - logger: Logger 199 | /// - process: Closure handling webSocket 200 | /// - Returns: WebSocket close frame details if server returned any 201 | @discardableResult public static func connect( 202 | url: String, 203 | configuration: WebSocketClientConfiguration = .init(), 204 | tlsConfiguration: TLSConfiguration? = nil, 205 | eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, 206 | logger: Logger, 207 | handler: @escaping WebSocketDataHandler 208 | ) async throws -> WebSocketCloseFrame? { 209 | let ws = self.init( 210 | url: url, 211 | configuration: configuration, 212 | tlsConfiguration: tlsConfiguration, 213 | eventLoopGroup: eventLoopGroup, 214 | logger: logger, 215 | handler: handler 216 | ) 217 | return try await ws.run() 218 | } 219 | 220 | #if canImport(Network) 221 | /// Create websocket client, connect and handle connection 222 | /// 223 | /// - Parametes: 224 | /// - url: URL of websocket 225 | /// - transportServicesTLSOptions: TLS options for NIOTransportServices 226 | /// - maxFrameSize: Max frame size for a single packet 227 | /// - eventLoopGroup: EventLoopGroup to run WebSocket client on 228 | /// - logger: Logger 229 | /// - process: WebSocket data handler 230 | /// - Returns: WebSocket close frame details if server returned any 231 | public static func connect( 232 | url: String, 233 | configuration: WebSocketClientConfiguration = .init(), 234 | transportServicesTLSOptions: TSTLSOptions, 235 | eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, 236 | logger: Logger, 237 | handler: @escaping WebSocketDataHandler 238 | ) async throws -> WebSocketCloseFrame? { 239 | let ws = self.init( 240 | url: url, 241 | configuration: configuration, 242 | transportServicesTLSOptions: transportServicesTLSOptions, 243 | eventLoopGroup: eventLoopGroup, 244 | logger: logger, 245 | handler: handler 246 | ) 247 | return try await ws.run() 248 | } 249 | #endif 250 | } 251 | -------------------------------------------------------------------------------- /Sources/WSClient/WebSocketClientChannel.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import HTTPTypes 16 | import Logging 17 | import NIOCore 18 | import NIOHTTP1 19 | import NIOHTTPTypesHTTP1 20 | import NIOWebSocket 21 | @_spi(WSInternal) import WSCore 22 | 23 | struct WebSocketClientChannel: ClientConnectionChannel { 24 | enum UpgradeResult { 25 | case websocket(NIOAsyncChannel, [any WebSocketExtension]) 26 | case notUpgraded 27 | } 28 | 29 | typealias Value = EventLoopFuture 30 | 31 | let urlPath: String 32 | let hostHeader: String 33 | let handler: WebSocketDataHandler 34 | let configuration: WebSocketClientConfiguration 35 | 36 | init(handler: @escaping WebSocketDataHandler, url: URI, configuration: WebSocketClientConfiguration) throws { 37 | guard let hostHeader = Self.urlHostHeader(for: url) else { throw WebSocketClientError.invalidURL } 38 | self.hostHeader = hostHeader 39 | self.urlPath = Self.urlPath(for: url) 40 | self.handler = handler 41 | self.configuration = configuration 42 | } 43 | 44 | func setup(channel: any Channel, logger: Logger) -> NIOCore.EventLoopFuture { 45 | channel.eventLoop.makeCompletedFuture { 46 | let upgrader = NIOTypedWebSocketClientUpgrader( 47 | maxFrameSize: self.configuration.maxFrameSize, 48 | upgradePipelineHandler: { channel, head in 49 | channel.eventLoop.makeCompletedFuture { 50 | let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) 51 | // work out what extensions we should add based off the server response 52 | let headerFields = HTTPFields(head.headers, splitCookie: false) 53 | let extensions = try configuration.extensions.buildClientExtensions(from: headerFields) 54 | if extensions.count > 0 { 55 | logger.debug( 56 | "Enabled extensions", 57 | metadata: ["hb.ws.extensions": .string(extensions.map(\.name).joined(separator: ","))] 58 | ) 59 | } 60 | return UpgradeResult.websocket(asyncChannel, extensions) 61 | } 62 | } 63 | ) 64 | 65 | var headers = HTTPHeaders() 66 | headers.add(name: "Content-Length", value: "0") 67 | headers.add(name: "Host", value: self.hostHeader) 68 | let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) 69 | headers.add(contentsOf: additionalHeaders) 70 | // add websocket extensions to headers 71 | headers.add( 72 | contentsOf: self.configuration.extensions.compactMap { 73 | let requestHeaders = $0.clientRequestHeader() 74 | return requestHeaders != "" ? ("Sec-WebSocket-Extensions", requestHeaders) : nil 75 | } 76 | ) 77 | 78 | let requestHead = HTTPRequestHead( 79 | version: .http1_1, 80 | method: .GET, 81 | uri: self.urlPath, 82 | headers: headers 83 | ) 84 | 85 | let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( 86 | upgradeRequestHead: requestHead, 87 | upgraders: [upgrader], 88 | notUpgradingCompletionHandler: { channel in 89 | channel.eventLoop.makeCompletedFuture { 90 | return UpgradeResult.notUpgraded 91 | } 92 | } 93 | ) 94 | 95 | var pipelineConfiguration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: clientUpgradeConfiguration) 96 | pipelineConfiguration.leftOverBytesStrategy = .forwardBytes 97 | let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( 98 | configuration: pipelineConfiguration 99 | ) 100 | 101 | return negotiationResultFuture 102 | } 103 | } 104 | 105 | func handle(value: Value, logger: Logger) async throws -> WebSocketCloseFrame? { 106 | switch try await value.get() { 107 | case .websocket(let webSocketChannel, let extensions): 108 | return try await WebSocketHandler.handle( 109 | type: .client, 110 | configuration: .init( 111 | extensions: extensions, 112 | autoPing: self.configuration.autoPing, 113 | closeTimeout: self.configuration.closeTimeout, 114 | validateUTF8: self.configuration.validateUTF8 115 | ), 116 | asyncChannel: webSocketChannel, 117 | context: WebSocketClient.Context(logger: logger), 118 | handler: self.handler 119 | ) 120 | case .notUpgraded: 121 | // The upgrade to websocket did not succeed. 122 | logger.debug("Upgrade declined") 123 | throw WebSocketClientError.webSocketUpgradeFailed 124 | } 125 | } 126 | 127 | static func urlPath(for url: URI) -> String { 128 | url.path + (url.query.map { "?\($0)" } ?? "") 129 | } 130 | 131 | static func urlHostHeader(for url: URI) -> String? { 132 | guard let host = url.host else { return nil } 133 | if let port = url.port { 134 | return "\(host):\(port)" 135 | } else { 136 | return host 137 | } 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /Sources/WSClient/WebSocketClientConfiguration.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2025 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import HTTPTypes 16 | import WSCore 17 | 18 | /// Configuration for a client connecting to a WebSocket 19 | public struct WebSocketClientConfiguration: Sendable { 20 | /// Max websocket frame size that can be sent/received 21 | public var maxFrameSize: Int 22 | /// Additional headers to be sent with the initial HTTP request 23 | public var additionalHeaders: HTTPFields 24 | /// WebSocket extensions 25 | public var extensions: [any WebSocketExtensionBuilder] 26 | /// Close timeout 27 | public var closeTimeout: Duration 28 | /// Automatic ping setup 29 | public var autoPing: AutoPingSetup 30 | /// Should text be validated to be UTF8 31 | public var validateUTF8: Bool 32 | /// Hostname used during TLS handshake 33 | public var sniHostname: String? 34 | 35 | /// Initialize WebSocketClient configuration 36 | /// - Paramters 37 | /// - maxFrameSize: Max websocket frame size that can be sent/received 38 | /// - additionalHeaders: Additional headers to be sent with the initial HTTP request 39 | /// - extensions: WebSocket extensions 40 | /// - autoPing: Automatic Ping configuration 41 | /// - validateUTF8: Should text be checked to see if it is valid UTF8 42 | /// - sniHostname: Hostname used during TLS handshake 43 | public init( 44 | maxFrameSize: Int = (1 << 14), 45 | additionalHeaders: HTTPFields = .init(), 46 | extensions: [WebSocketExtensionFactory] = [], 47 | closeTimeout: Duration = .seconds(15), 48 | autoPing: AutoPingSetup = .disabled, 49 | validateUTF8: Bool = false, 50 | sniHostname: String? = nil 51 | ) { 52 | self.maxFrameSize = maxFrameSize 53 | self.additionalHeaders = additionalHeaders 54 | self.extensions = extensions.map { $0.build() } 55 | self.closeTimeout = closeTimeout 56 | self.autoPing = autoPing 57 | self.validateUTF8 = validateUTF8 58 | self.sniHostname = sniHostname 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /Sources/WSClient/WebSocketClientError.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | /// Errors returned by ``WebSocketClient`` 16 | public struct WebSocketClientError: Swift.Error, Equatable { 17 | private enum _Internal: Equatable { 18 | case invalidURL 19 | case webSocketUpgradeFailed 20 | } 21 | 22 | private let value: _Internal 23 | private init(_ value: _Internal) { 24 | self.value = value 25 | } 26 | 27 | /// Provided URL is invalid 28 | public static var invalidURL: Self { .init(.invalidURL) } 29 | /// WebSocket upgrade failed. 30 | public static var webSocketUpgradeFailed: Self { .init(.webSocketUpgradeFailed) } 31 | } 32 | 33 | extension WebSocketClientError: CustomStringConvertible { 34 | public var description: String { 35 | switch self.value { 36 | case .invalidURL: "Invalid URL" 37 | case .webSocketUpgradeFailed: "WebSocket upgrade failed" 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /Sources/WSCompression/PerMessageDeflateExtension.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import CompressNIO 16 | import NIOCore 17 | import NIOWebSocket 18 | import WSCore 19 | 20 | /// PerMessageDeflate Websocket extension builder 21 | struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { 22 | static let name = "permessage-deflate" 23 | 24 | let clientMaxWindow: Int? 25 | let clientNoContextTakeover: Bool 26 | let serverMaxWindow: Int? 27 | let serverNoContextTakeover: Bool 28 | let compressionLevel: Int? 29 | let memoryLevel: Int? 30 | let maxDecompressedFrameSize: Int 31 | let minFrameSizeToCompress: Int 32 | 33 | init( 34 | clientMaxWindow: Int? = nil, 35 | clientNoContextTakeover: Bool = false, 36 | serverMaxWindow: Int? = nil, 37 | serverNoContextTakeover: Bool = false, 38 | compressionLevel: Int? = nil, 39 | memoryLevel: Int? = nil, 40 | maxDecompressedFrameSize: Int = (1 << 14), 41 | minFrameSizeToCompress: Int = 256 42 | ) { 43 | self.clientMaxWindow = clientMaxWindow 44 | self.clientNoContextTakeover = clientNoContextTakeover 45 | self.serverMaxWindow = serverMaxWindow 46 | self.serverNoContextTakeover = serverNoContextTakeover 47 | self.compressionLevel = compressionLevel 48 | self.memoryLevel = memoryLevel 49 | self.maxDecompressedFrameSize = maxDecompressedFrameSize 50 | self.minFrameSizeToCompress = minFrameSizeToCompress 51 | } 52 | 53 | /// Return client request header 54 | func clientRequestHeader() -> String { 55 | var header = "permessage-deflate" 56 | if let maxWindow = self.clientMaxWindow { 57 | header += ";client_max_window_bits=\(maxWindow)" 58 | } 59 | if self.clientNoContextTakeover { 60 | header += ";client_no_context_takeover" 61 | } 62 | if let maxWindow = self.serverMaxWindow { 63 | header += ";server_max_window_bits=\(maxWindow)" 64 | } 65 | if self.serverNoContextTakeover { 66 | header += ";server_no_context_takeover" 67 | } 68 | return header 69 | } 70 | 71 | /// Return server response header, given a client request 72 | /// - Parameter request: Client request header parameters 73 | /// - Returns: Server response parameters 74 | func serverReponseHeader(to request: WebSocketExtensionHTTPParameters) -> String? { 75 | let configuration = self.responseConfiguration(to: request) 76 | var header = "permessage-deflate" 77 | if let maxWindow = configuration.receiveMaxWindow { 78 | header += ";client_max_window_bits=\(maxWindow)" 79 | } 80 | if configuration.receiveNoContextTakeover { 81 | header += ";client_no_context_takeover" 82 | } 83 | if let maxWindow = configuration.sendMaxWindow { 84 | header += ";server_max_window_bits=\(maxWindow)" 85 | } 86 | if configuration.sendNoContextTakeover { 87 | header += ";server_no_context_takeover" 88 | } 89 | return header 90 | } 91 | 92 | /// Create server PerMessageDeflateExtension based off request headers 93 | /// - Parameters: 94 | /// - request: Client request 95 | func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { 96 | let configuration = self.responseConfiguration(to: request) 97 | return try PerMessageDeflateExtension(configuration: configuration) 98 | } 99 | 100 | /// Create client PerMessageDeflateExtension based off response headers 101 | /// - Parameters: 102 | /// - response: Server response 103 | func clientExtension(from response: WebSocketExtensionHTTPParameters) throws -> WebSocketExtension? { 104 | let clientMaxWindowParam = response.parameters["client_max_window_bits"]?.integer 105 | let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil 106 | let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer 107 | let serverNoContextTakeoverParam = response.parameters["server_no_context_takeover"] != nil 108 | return try PerMessageDeflateExtension( 109 | configuration: .init( 110 | receiveMaxWindow: serverMaxWindowParam, 111 | receiveNoContextTakeover: serverNoContextTakeoverParam, 112 | sendMaxWindow: clientMaxWindowParam, 113 | sendNoContextTakeover: clientNoContextTakeoverParam, 114 | compressionLevel: self.compressionLevel, 115 | memoryLevel: self.memoryLevel, 116 | maxDecompressedFrameSize: self.maxDecompressedFrameSize, 117 | minFrameSizeToCompress: self.minFrameSizeToCompress 118 | ) 119 | ) 120 | } 121 | 122 | private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration { 123 | let requestServerMaxWindow = request.parameters["server_max_window_bits"] 124 | let requestServerNoContextTakeover = request.parameters["server_no_context_takeover"] != nil 125 | let requestClientMaxWindow = request.parameters["client_max_window_bits"] 126 | let requestClientNoContextTakeover = request.parameters["client_no_context_takeover"] != nil 127 | 128 | // calculate client max window. If parameter doesn't exist then server cannot set it, if it does 129 | // exist then the value should be set to minimum of both values, or the value of the other if 130 | // one is nil 131 | let receiveMaxWindow: Int? = 132 | if let requestClientMaxWindow { 133 | optionalMin(requestClientMaxWindow.integer, self.clientMaxWindow) 134 | } else { 135 | nil 136 | } 137 | 138 | return PerMessageDeflateExtension.Configuration( 139 | receiveMaxWindow: receiveMaxWindow, 140 | receiveNoContextTakeover: requestClientNoContextTakeover || self.clientNoContextTakeover, 141 | sendMaxWindow: optionalMin(requestServerMaxWindow?.integer, self.serverMaxWindow), 142 | sendNoContextTakeover: requestServerNoContextTakeover || self.serverNoContextTakeover, 143 | compressionLevel: self.compressionLevel, 144 | memoryLevel: self.memoryLevel, 145 | maxDecompressedFrameSize: self.maxDecompressedFrameSize, 146 | minFrameSizeToCompress: self.minFrameSizeToCompress 147 | ) 148 | } 149 | } 150 | 151 | /// PerMessageDeflate websocket extension 152 | /// 153 | /// Uses deflate to compress messages sent across a WebSocket 154 | /// See RFC 7692 for more details https://www.rfc-editor.org/rfc/rfc7692 155 | struct PerMessageDeflateExtension: WebSocketExtension { 156 | struct Configuration: Sendable { 157 | let receiveMaxWindow: Int? 158 | let receiveNoContextTakeover: Bool 159 | let sendMaxWindow: Int? 160 | let sendNoContextTakeover: Bool 161 | let compressionLevel: Int? 162 | let memoryLevel: Int? 163 | let maxDecompressedFrameSize: Int 164 | let minFrameSizeToCompress: Int 165 | } 166 | 167 | actor Decompressor { 168 | enum ReceiveState: Sendable { 169 | case idle 170 | case receivingMessage 171 | case decompressingMessage 172 | } 173 | 174 | fileprivate var decompressor: ZlibDecompressor 175 | var state: ReceiveState 176 | 177 | init(algorithm: ZlibAlgorithm, windowSize: Int32) throws { 178 | self.state = .idle 179 | self.decompressor = try ZlibDecompressor(algorithm: algorithm, windowSize: windowSize) 180 | } 181 | 182 | func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame { 183 | if self.state == .idle { 184 | if frame.rsv1 { 185 | self.state = .decompressingMessage 186 | } else { 187 | self.state = .receivingMessage 188 | } 189 | } 190 | if self.state == .decompressingMessage { 191 | var frame = frame 192 | var unmaskedData = frame.unmaskedData 193 | if frame.fin { 194 | // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame 195 | // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). 196 | unmaskedData.writeBytes([0, 0, 255, 255]) 197 | self.state = .idle 198 | } 199 | frame.data = try unmaskedData.decompressStream( 200 | with: self.decompressor, 201 | maxSize: maxSize, 202 | allocator: ByteBufferAllocator() 203 | ) 204 | frame.maskKey = nil 205 | if resetStream, frame.fin { 206 | try self.decompressor.reset() 207 | } 208 | return frame 209 | } 210 | if frame.fin { 211 | self.state = .idle 212 | } 213 | return frame 214 | } 215 | } 216 | 217 | actor Compressor { 218 | enum SendState: Sendable { 219 | case idle 220 | case sendingMessage 221 | } 222 | 223 | fileprivate var compressor: ZlibCompressor 224 | var sendState: SendState 225 | let minFrameSizeToCompress: Int 226 | 227 | init(algorithm: ZlibAlgorithm, configuration: ZlibConfiguration, minFrameSizeToCompress: Int) throws { 228 | self.compressor = try ZlibCompressor(algorithm: algorithm, configuration: configuration) 229 | self.minFrameSizeToCompress = minFrameSizeToCompress 230 | self.sendState = .idle 231 | } 232 | 233 | func compress(_ frame: WebSocketFrame, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame { 234 | // if the frame is larger than `minFrameSizeToCompress` bytes, we haven't received a final frame 235 | // or we are in the process of sending a message compress the data 236 | let shouldWeCompress = frame.data.readableBytes >= self.minFrameSizeToCompress || !frame.fin || self.sendState != .idle 237 | if shouldWeCompress { 238 | var newFrame = frame 239 | if self.sendState == .idle { 240 | newFrame.rsv1 = true 241 | self.sendState = .sendingMessage 242 | } 243 | newFrame.data = try newFrame.data.compressStream(with: self.compressor, flush: .sync, allocator: ByteBufferAllocator()) 244 | // if final frame then remove last four bytes 0x00 0x00 0xff 0xff 245 | // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) 246 | if newFrame.fin { 247 | newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data 248 | self.sendState = .idle 249 | if resetStream { 250 | try self.compressor.reset() 251 | } 252 | } 253 | return newFrame 254 | } 255 | return frame 256 | } 257 | } 258 | 259 | let name = "permessage-deflate" 260 | let configuration: Configuration 261 | let decompressor: Decompressor 262 | let compressor: Compressor 263 | 264 | init(configuration: Configuration) throws { 265 | self.configuration = configuration 266 | self.decompressor = try .init( 267 | algorithm: .deflate, 268 | windowSize: numericCast(configuration.receiveMaxWindow ?? 15) 269 | ) 270 | self.compressor = try .init( 271 | algorithm: .deflate, 272 | configuration: .init( 273 | windowSize: numericCast(configuration.sendMaxWindow ?? 15), 274 | compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, 275 | memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 276 | ), 277 | minFrameSizeToCompress: self.configuration.minFrameSizeToCompress 278 | ) 279 | } 280 | 281 | func shutdown() async {} 282 | 283 | func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 284 | try await self.decompressor.decompress( 285 | frame, 286 | maxSize: self.configuration.maxDecompressedFrameSize, 287 | resetStream: self.configuration.receiveNoContextTakeover, 288 | context: context 289 | ) 290 | } 291 | 292 | func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 293 | let isCorrectType = frame.opcode == .text || frame.opcode == .binary || frame.opcode == .continuation 294 | if isCorrectType { 295 | return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context) 296 | } 297 | return frame 298 | } 299 | 300 | /// Reserved bits extension uses 301 | var reservedBits: WebSocketFrame.ReservedBits { .rsv1 } 302 | } 303 | 304 | extension WebSocketExtensionFactory { 305 | /// permessage-deflate websocket extension 306 | /// - Parameters: 307 | /// - maxWindow: Max window to be used for decompression and compression 308 | /// - noContextTakeover: Should we reset window on every message 309 | /// - maxDecompressedFrameSize: Maximum size for a decompressed frame 310 | /// - minFrameSizeToCompress: Minimum size of a frame before compression is applied 311 | public static func perMessageDeflate( 312 | maxWindow: Int? = nil, 313 | noContextTakeover: Bool = false, 314 | maxDecompressedFrameSize: Int = 1 << 14, 315 | minFrameSizeToCompress: Int = 256 316 | ) -> WebSocketExtensionFactory { 317 | .init { 318 | PerMessageDeflateExtensionBuilder( 319 | clientMaxWindow: maxWindow, 320 | clientNoContextTakeover: noContextTakeover, 321 | serverMaxWindow: maxWindow, 322 | serverNoContextTakeover: noContextTakeover, 323 | compressionLevel: nil, 324 | memoryLevel: nil, 325 | maxDecompressedFrameSize: maxDecompressedFrameSize, 326 | minFrameSizeToCompress: minFrameSizeToCompress 327 | ) 328 | } 329 | } 330 | 331 | /// permessage-deflate websocket extension 332 | /// - Parameters: 333 | /// - clientMaxWindow: Max window to be used for client compression 334 | /// - clientNoContextTakeover: Should client reset window on every message 335 | /// - serverMaxWindow: Max window to be used for server compression 336 | /// - serverNoContextTakeover: Should server reset window on every message 337 | /// - compressionLevel: Zlib compression level. Value between 0 and 9 where 1 gives best speed, 9 gives 338 | /// give best compression and 0 gives no compression. 339 | /// - memoryLevel: Defines how much memory should be given to compression. Value between 1 and 9 where 1 340 | /// uses least memory and 9 give best compression and optimal speed. 341 | /// - maxDecompressedFrameSize: Maximum size for a decompressed frame 342 | /// - minFrameSizeToCompress: Minimum size of a frame before compression is applied 343 | public static func perMessageDeflate( 344 | clientMaxWindow: Int? = nil, 345 | clientNoContextTakeover: Bool = false, 346 | serverMaxWindow: Int? = nil, 347 | serverNoContextTakeover: Bool = false, 348 | compressionLevel: Int? = nil, 349 | memoryLevel: Int? = nil, 350 | maxDecompressedFrameSize: Int = 1 << 14, 351 | minFrameSizeToCompress: Int = 256 352 | ) -> WebSocketExtensionFactory { 353 | .init { 354 | PerMessageDeflateExtensionBuilder( 355 | clientMaxWindow: clientMaxWindow, 356 | clientNoContextTakeover: clientNoContextTakeover, 357 | serverMaxWindow: serverMaxWindow, 358 | serverNoContextTakeover: serverNoContextTakeover, 359 | compressionLevel: compressionLevel, 360 | memoryLevel: memoryLevel, 361 | maxDecompressedFrameSize: maxDecompressedFrameSize, 362 | minFrameSizeToCompress: minFrameSizeToCompress 363 | ) 364 | } 365 | } 366 | } 367 | 368 | /// Minimum of two optional integers. 369 | /// 370 | /// Returns the other is one of them is nil 371 | private func optionalMin(_ a: Int?, _ b: Int?) -> Int? { 372 | switch (a, b) { 373 | case (.some(let a), .some(let b)): 374 | return min(a, b) 375 | case (.some(a), .none): 376 | return a 377 | case (.none, .some(b)): 378 | return b 379 | default: 380 | return nil 381 | } 382 | } 383 | -------------------------------------------------------------------------------- /Sources/WSCore/Extensions/WebSocketExtension.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Foundation 16 | import HTTPTypes 17 | import Logging 18 | import NIOCore 19 | import NIOWebSocket 20 | 21 | /// Basic context implementation of ``WebSocketContext``. 22 | public struct WebSocketExtensionContext: Sendable { 23 | public let logger: Logger 24 | 25 | init(logger: Logger) { 26 | self.logger = logger 27 | } 28 | } 29 | 30 | /// Protocol for WebSocket extension 31 | public protocol WebSocketExtension: Sendable { 32 | /// Extension name 33 | var name: String { get } 34 | /// Process frame received from websocket 35 | func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame 36 | /// Process frame about to be sent to websocket 37 | func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame 38 | /// Reserved bits extension uses 39 | var reservedBits: WebSocketFrame.ReservedBits { get } 40 | /// shutdown extension 41 | func shutdown() async 42 | } 43 | 44 | extension WebSocketExtension { 45 | /// Reserved bits extension uses (default is none) 46 | public var reservedBits: WebSocketFrame.ReservedBits { .init() } 47 | } 48 | -------------------------------------------------------------------------------- /Sources/WSCore/Extensions/WebSocketExtensionBuilder.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | import HTTPTypes 15 | 16 | /// Protocol for WebSocket extension builder 17 | public protocol WebSocketExtensionBuilder: Sendable { 18 | /// name of WebSocket extension name 19 | static var name: String { get } 20 | /// construct client request header 21 | func clientRequestHeader() -> String 22 | /// construct server response header based of client request 23 | func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? 24 | /// construct server version of extension based of client request 25 | func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? 26 | /// construct client version of extension based of server response 27 | func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? 28 | } 29 | 30 | extension WebSocketExtensionBuilder { 31 | /// construct server response header based of all client requests 32 | public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { 33 | for request in requests { 34 | guard request.name == Self.name else { continue } 35 | if let response = serverReponseHeader(to: request) { 36 | return response 37 | } 38 | } 39 | return nil 40 | } 41 | 42 | /// construct all server extensions based of all client requests 43 | public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { 44 | for request in requests { 45 | guard request.name == Self.name else { continue } 46 | if let ext = try serverExtension(from: request) { 47 | return ext 48 | } 49 | } 50 | if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { 51 | return nonNegotiableExtensionBuilder.build() 52 | } 53 | return nil 54 | } 55 | 56 | /// construct all client extensions based of all server responses 57 | public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { 58 | for request in requests { 59 | guard request.name == Self.name else { continue } 60 | if let ext = try clientExtension(from: request) { 61 | return ext 62 | } 63 | } 64 | if let nonNegotiableExtensionBuilder = self as? any _WebSocketNonNegotiableExtensionBuilderProtocol { 65 | return nonNegotiableExtensionBuilder.build() 66 | } 67 | return nil 68 | } 69 | } 70 | 71 | /// Protocol for w WebSocket extension that is applied without any negotiation with the other side 72 | protocol _WebSocketNonNegotiableExtensionBuilderProtocol: WebSocketExtensionBuilder { 73 | associatedtype Extension: WebSocketExtension 74 | func build() -> Extension 75 | } 76 | 77 | /// A WebSocket extension that is applied without any negotiation with the other side 78 | public struct WebSocketNonNegotiableExtensionBuilder: _WebSocketNonNegotiableExtensionBuilderProtocol { 79 | public static var name: String { String(describing: type(of: Extension.self)) } 80 | 81 | let _build: @Sendable () -> Extension 82 | 83 | init(_ build: @escaping @Sendable () -> Extension) { 84 | self._build = build 85 | } 86 | 87 | public func build() -> Extension { 88 | self._build() 89 | } 90 | } 91 | 92 | extension WebSocketNonNegotiableExtensionBuilder { 93 | /// construct client request header 94 | public func clientRequestHeader() -> String { "" } 95 | /// construct server response header based of client request 96 | public func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? { nil } 97 | /// construct server version of extension based of client request 98 | public func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } 99 | /// construct client version of extension based of server response 100 | public func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? { self.build() } 101 | } 102 | 103 | extension [any WebSocketExtensionBuilder] { 104 | /// Build client extensions from response from WebSocket server 105 | /// - Parameter responseHeaders: Server response headers 106 | /// - Returns: Array of client extensions to enable 107 | public func buildClientExtensions(from responseHeaders: HTTPFields) throws -> [any WebSocketExtension] { 108 | let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(responseHeaders) 109 | return try self.compactMap { 110 | try $0.clientExtension(from: serverExtensions) 111 | } 112 | } 113 | 114 | /// Do the client/server WebSocket negotiation based off headers received from the client. 115 | /// - Parameter requestHeaders: Client request headers 116 | /// - Returns: Headers to pass back to client and array of server extensions to enable 117 | public func serverExtensionNegotiation(requestHeaders: HTTPFields) throws -> (HTTPFields, [any WebSocketExtension]) { 118 | var responseHeaders: HTTPFields = .init() 119 | let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders) 120 | let extensionResponseHeaders = self.compactMap { $0.serverResponseHeader(to: clientHeaders) } 121 | responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) }) 122 | let extensions = try self.compactMap { 123 | try $0.serverExtension(from: clientHeaders) 124 | } 125 | return (responseHeaders, extensions) 126 | } 127 | } 128 | 129 | /// Build WebSocket extension builder 130 | public struct WebSocketExtensionFactory: Sendable { 131 | public let build: @Sendable () -> any WebSocketExtensionBuilder 132 | 133 | public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { 134 | self.build = build 135 | } 136 | 137 | /// Extension to be applied without negotiation with the other side. 138 | /// 139 | /// Most extensions involve some form of negotiation between the client and the server 140 | /// to decide on whether they should be applied and with what parameters. This extension 141 | /// builder is for the situation where no negotiation is needed or that negotiation has 142 | /// already occurred. 143 | /// 144 | /// - Parameter build: closure creating extension 145 | /// - Returns: WebSocketExtensionFactory 146 | public static func nonNegotiatedExtension(_ build: @escaping @Sendable () -> some WebSocketExtension) -> Self { 147 | .init { 148 | WebSocketNonNegotiableExtensionBuilder(build) 149 | } 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /Sources/WSCore/Extensions/WebSocketExtensionHTTPParameters.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import HTTPTypes 16 | 17 | /// Parsed parameters from `Sec-WebSocket-Extensions` header 18 | public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { 19 | /// A single parameter 20 | public enum Parameter: Sendable, Equatable { 21 | // Parameter with a value 22 | case value(String) 23 | // Parameter with no value 24 | case null 25 | 26 | // Convert to optional 27 | public var optional: String? { 28 | switch self { 29 | case .value(let string): 30 | return .some(string) 31 | case .null: 32 | return .none 33 | } 34 | } 35 | 36 | // Convert to integer 37 | public var integer: Int? { 38 | switch self { 39 | case .value(let string): 40 | return Int(string) 41 | case .null: 42 | return .none 43 | } 44 | } 45 | } 46 | 47 | public let parameters: [String: Parameter] 48 | public let name: String 49 | 50 | /// initialise WebSocket extension parameters from string 51 | init?(from header: some StringProtocol) { 52 | let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] 53 | if let name = split.first { 54 | self.name = name 55 | } else { 56 | return nil 57 | } 58 | var index = split.index(after: split.startIndex) 59 | var parameters: [String: Parameter] = [:] 60 | while index != split.endIndex { 61 | let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } 62 | if let key = keyValue.first { 63 | if keyValue.count > 1 { 64 | parameters[key] = .value(keyValue[1]) 65 | } else { 66 | parameters[key] = .null 67 | } 68 | } 69 | index = split.index(after: index) 70 | } 71 | self.parameters = parameters 72 | } 73 | 74 | /// Parse all `Sec-WebSocket-Extensions` header values 75 | /// - Parameters: 76 | /// - headers: headers coming from other 77 | /// - Returns: Array of extensions 78 | public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { 79 | let extHeaders = headers[values: .secWebSocketExtensions] 80 | return extHeaders.compactMap { .init(from: $0) } 81 | } 82 | } 83 | 84 | extension WebSocketExtensionHTTPParameters { 85 | /// Initialiser used by tests 86 | package init(_ name: String, parameters: [String: Parameter]) { 87 | self.name = name 88 | self.parameters = parameters 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /Sources/WSCore/String+validatingString.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | 17 | extension String { 18 | init?(buffer: ByteBuffer, validateUTF8: Bool) { 19 | #if compiler(>=6) 20 | if #available(macOS 15, iOS 18, tvOS 18, watchOS 11, visionOS 2, *), validateUTF8 { 21 | do { 22 | var buffer = buffer 23 | self = try buffer.readUTF8ValidatedString(length: buffer.readableBytes)! 24 | } catch { 25 | return nil 26 | } 27 | } else { 28 | self = .init(buffer: buffer) 29 | } 30 | #else 31 | self = .init(buffer: buffer) 32 | #endif // compiler(>=6) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Sources/WSCore/UnsafeTransfer.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | //===----------------------------------------------------------------------===// 16 | // 17 | // This source file is part of the SwiftNIO open source project 18 | // 19 | // Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors 20 | // Licensed under Apache License v2.0 21 | // 22 | // See LICENSE.txt for license information 23 | // See CONTRIBUTORS.txt for the list of SwiftNIO project authors 24 | // 25 | // SPDX-License-Identifier: Apache-2.0 26 | // 27 | //===----------------------------------------------------------------------===// 28 | 29 | /// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`. 30 | /// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler. 31 | /// It can be used similar to `@unsafe Sendable` but for values instead of types. 32 | @usableFromInline 33 | struct UnsafeTransfer { 34 | @usableFromInline 35 | var wrappedValue: Wrapped 36 | 37 | @inlinable 38 | init(_ wrappedValue: Wrapped) { 39 | self.wrappedValue = wrappedValue 40 | } 41 | } 42 | 43 | extension UnsafeTransfer: @unchecked Sendable {} 44 | 45 | extension UnsafeTransfer: Equatable where Wrapped: Equatable {} 46 | extension UnsafeTransfer: Hashable where Wrapped: Hashable {} 47 | 48 | /// ``UnsafeMutableTransferBox`` can be used to make non-`Sendable` values `Sendable` and mutable. 49 | /// It can be used to capture local mutable values in a `@Sendable` closure and mutate them from within the closure. 50 | /// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler and does not add any synchronisation. 51 | @usableFromInline 52 | final class UnsafeMutableTransferBox { 53 | @usableFromInline 54 | var wrappedValue: Wrapped 55 | 56 | @inlinable 57 | init(_ wrappedValue: Wrapped) { 58 | self.wrappedValue = wrappedValue 59 | } 60 | } 61 | 62 | extension UnsafeMutableTransferBox: @unchecked Sendable {} 63 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketContext.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | 18 | /// Protocol for WebSocket Data handling functions context parameter 19 | public protocol WebSocketContext: Sendable { 20 | var logger: Logger { get } 21 | } 22 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketDataFrame.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | /// WebSocket data frame. 19 | public struct WebSocketDataFrame: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible { 20 | public enum Opcode: String, Sendable { 21 | case text 22 | case binary 23 | case continuation 24 | } 25 | 26 | public var opcode: Opcode 27 | public var data: ByteBuffer 28 | public var fin: Bool 29 | 30 | init?(from frame: WebSocketFrame) { 31 | switch frame.opcode { 32 | case .binary: self.opcode = .binary 33 | case .text: self.opcode = .text 34 | case .continuation: self.opcode = .continuation 35 | default: return nil 36 | } 37 | self.data = frame.unmaskedData 38 | self.fin = frame.fin 39 | } 40 | 41 | public var description: String { 42 | "\(self.opcode): \(self.data.description), finished: \(self.fin)" 43 | } 44 | 45 | public var debugDescription: String { 46 | "\(self.opcode): \(self.data.debugDescription), finished: \(self.fin)" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketDataHandler.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | /// Function that handles websocket data and text blocks 16 | public typealias WebSocketDataHandler = 17 | @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void 18 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketFrameSequence.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2021-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | /// Sequence of fragmented WebSocket frames. 19 | struct WebSocketFrameSequence { 20 | var frames: [WebSocketDataFrame] 21 | var size: Int 22 | var first: WebSocketDataFrame { self.frames[0] } 23 | 24 | init(frame: WebSocketDataFrame) { 25 | assert(frame.opcode != .continuation, "Cannot create a WebSocketFrameSequence starting with a continuation") 26 | self.frames = [frame] 27 | self.size = 0 28 | } 29 | 30 | mutating func append(_ frame: WebSocketDataFrame) { 31 | assert(frame.opcode == .continuation) 32 | self.frames.append(frame) 33 | self.size += frame.data.readableBytes 34 | } 35 | 36 | var bytes: ByteBuffer { 37 | if self.frames.count == 1 { 38 | return self.frames[0].data 39 | } else { 40 | var result = ByteBufferAllocator().buffer(capacity: self.size) 41 | for frame in self.frames { 42 | var data = frame.data 43 | result.writeBuffer(&data) 44 | } 45 | return result 46 | } 47 | } 48 | 49 | func getMessage(validateUTF8: Bool) -> WebSocketMessage? { 50 | .init(frame: self.collated, validate: validateUTF8) 51 | } 52 | 53 | var collated: WebSocketDataFrame { 54 | var frame = self.first 55 | frame.data = self.bytes 56 | frame.fin = true 57 | return frame 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketHandler.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | import NIOWebSocket 18 | import ServiceLifecycle 19 | 20 | /// WebSocket type 21 | public enum WebSocketType: Sendable { 22 | case client 23 | case server 24 | } 25 | 26 | /// Automatic ping setup 27 | public struct AutoPingSetup: Sendable { 28 | enum Internal { 29 | case disabled 30 | case enabled(timePeriod: Duration) 31 | } 32 | 33 | internal let value: Internal 34 | internal init(_ value: Internal) { 35 | self.value = value 36 | } 37 | 38 | /// disable auto ping 39 | public static var disabled: Self { .init(.disabled) } 40 | /// send ping with fixed period 41 | public static func enabled(timePeriod: Duration) -> Self { .init(.enabled(timePeriod: timePeriod)) } 42 | } 43 | 44 | /// Close frame that caused WebSocket close 45 | public struct WebSocketCloseFrame: Sendable { 46 | // status code indicating a reason for closure 47 | public let closeCode: WebSocketErrorCode 48 | // close reason 49 | public let reason: String? 50 | } 51 | 52 | /// Handler processing raw WebSocket packets. Used by WebSocket transports 53 | /// 54 | /// Manages ping, pong and close messages. Collates data and text messages into final frame 55 | /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. 56 | /// 57 | /// SPI WSInternal is used to make the WebSocket Handler available to both client and server 58 | /// implementations 59 | @_spi(WSInternal) public actor WebSocketHandler { 60 | enum InternalError: Error { 61 | case close(WebSocketErrorCode) 62 | } 63 | 64 | enum CloseState { 65 | case open 66 | case closing 67 | case closed(WebSocketCloseFrame?) 68 | } 69 | 70 | @_spi(WSInternal) public struct Configuration: Sendable { 71 | let extensions: [any WebSocketExtension] 72 | let autoPing: AutoPingSetup 73 | let validateUTF8: Bool 74 | let reservedBits: WebSocketFrame.ReservedBits 75 | let closeTimeout: Duration 76 | 77 | @_spi(WSInternal) public init( 78 | extensions: [any WebSocketExtension], 79 | autoPing: AutoPingSetup, 80 | closeTimeout: Duration = .seconds(15), 81 | validateUTF8: Bool 82 | ) { 83 | self.extensions = extensions 84 | self.autoPing = autoPing 85 | self.closeTimeout = closeTimeout 86 | self.validateUTF8 = validateUTF8 87 | // store reserved bits used by this handler 88 | self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in 89 | partialResult.union(`extension`.reservedBits) 90 | } 91 | } 92 | } 93 | 94 | let channel: Channel 95 | var outbound: NIOAsyncChannelOutboundWriter 96 | let type: WebSocketType 97 | let configuration: Configuration 98 | let logger: Logger 99 | var stateMachine: WebSocketStateMachine 100 | 101 | private init( 102 | channel: Channel, 103 | outbound: NIOAsyncChannelOutboundWriter, 104 | type: WebSocketType, 105 | configuration: Configuration, 106 | context: some WebSocketContext 107 | ) { 108 | self.channel = channel 109 | self.outbound = outbound 110 | self.type = type 111 | self.configuration = configuration 112 | self.logger = context.logger 113 | self.stateMachine = .init(autoPingSetup: configuration.autoPing) 114 | } 115 | 116 | @_spi(WSInternal) public static func handle( 117 | type: WebSocketType, 118 | configuration: Configuration, 119 | asyncChannel: NIOAsyncChannel, 120 | context: Context, 121 | handler: WebSocketDataHandler 122 | ) async throws -> WebSocketCloseFrame? { 123 | defer { 124 | context.logger.debug("Closed WebSocket") 125 | } 126 | do { 127 | let rt = try await asyncChannel.executeThenClose { inbound, outbound in 128 | defer { 129 | context.logger.trace("Closing WebSocket") 130 | } 131 | return try await withTaskCancellationHandler { 132 | let webSocketHandler = Self( 133 | channel: asyncChannel.channel, 134 | outbound: outbound, 135 | type: type, 136 | configuration: configuration, 137 | context: context 138 | ) 139 | return try await webSocketHandler.handle( 140 | type: type, 141 | inbound: inbound, 142 | outbound: outbound, 143 | handler: handler, 144 | context: context 145 | ) 146 | } onCancel: { 147 | Task { 148 | try await asyncChannel.channel.close(mode: .input) 149 | } 150 | } 151 | } 152 | return rt 153 | } catch let error as NIOAsyncWriterError { 154 | // ignore already finished errors 155 | if error == NIOAsyncWriterError.alreadyFinished() { return nil } 156 | throw error 157 | } 158 | } 159 | 160 | func handle( 161 | type: WebSocketType, 162 | inbound: NIOAsyncChannelInboundStream, 163 | outbound: NIOAsyncChannelOutboundWriter, 164 | handler: WebSocketDataHandler, 165 | context: Context 166 | ) async throws -> WebSocketCloseFrame? { 167 | try await withGracefulShutdownHandler { 168 | try await withThrowingTaskGroup(of: Void.self) { group in 169 | if case .enabled = configuration.autoPing.value { 170 | /// Add task sending ping frames every so often and verifying a pong frame was sent back 171 | group.addTask { 172 | try await self.runAutoPingLoop() 173 | } 174 | } 175 | let webSocketOutbound = WebSocketOutboundWriter(handler: self) 176 | var inboundIterator = inbound.makeAsyncIterator() 177 | let webSocketInbound = WebSocketInboundStream( 178 | iterator: inboundIterator, 179 | handler: self 180 | ) 181 | let closeCode: WebSocketErrorCode 182 | var clientError: Error? 183 | do { 184 | // handle websocket data and text 185 | try await handler(webSocketInbound, webSocketOutbound, context) 186 | closeCode = .normalClosure 187 | } catch InternalError.close(let code) { 188 | closeCode = code 189 | } catch { 190 | clientError = error 191 | closeCode = .unexpectedServerError 192 | } 193 | do { 194 | try await self.close(code: closeCode) 195 | if case .closing = self.stateMachine.state { 196 | group.addTask { 197 | try await Task.sleep(for: self.configuration.closeTimeout) 198 | try await self.channel.close(mode: .input) 199 | } 200 | // Close handshake. Wait for responding close or until inbound ends 201 | while let frame = try await inboundIterator.next() { 202 | if case .connectionClose = frame.opcode { 203 | try await self.receivedClose(frame) 204 | // only the server can close the connection, so clients 205 | // should continue reading from inbound until it is closed 206 | if type == .server { 207 | break 208 | } 209 | } 210 | } 211 | } 212 | // don't propagate error if channel is already closed 213 | } catch ChannelError.ioOnClosedChannel {} 214 | if type == .client, let clientError { 215 | throw clientError 216 | } 217 | 218 | group.cancelAll() 219 | } 220 | } onGracefulShutdown: { 221 | Task { 222 | try? await self.close(code: .normalClosure) 223 | } 224 | } 225 | return switch self.stateMachine.state { 226 | case .closed(let code): code 227 | default: nil 228 | } 229 | } 230 | 231 | func runAutoPingLoop() async throws { 232 | let period = self.stateMachine.pingTimePeriod 233 | try await Task.sleep(for: period) 234 | while true { 235 | switch self.stateMachine.sendPing() { 236 | case .sendPing(let buffer): 237 | try await self.write(frame: .init(fin: true, opcode: .ping, data: buffer)) 238 | 239 | case .wait(let time): 240 | try await Task.sleep(for: time) 241 | 242 | case .closeConnection(let errorCode): 243 | try await self.sendClose(code: errorCode, reason: "Ping timeout") 244 | try await self.channel.close(mode: .input) 245 | return 246 | 247 | case .stop: 248 | return 249 | } 250 | } 251 | } 252 | 253 | /// Send WebSocket frame 254 | func write(frame: WebSocketFrame) async throws { 255 | var frame = frame 256 | do { 257 | for ext in self.configuration.extensions { 258 | frame = try await ext.processFrameToSend( 259 | frame, 260 | context: WebSocketExtensionContext(logger: self.logger) 261 | ) 262 | } 263 | } catch { 264 | self.logger.debug("Closing as we failed to generate valid frame data") 265 | throw WebSocketHandler.InternalError.close(.unexpectedServerError) 266 | } 267 | // Set mask key if client 268 | if self.type == .client { 269 | frame.maskKey = self.makeMaskKey() 270 | } 271 | try await self.outbound.write(frame) 272 | 273 | self.logger.trace("Sent \(frame.traceDescription)") 274 | } 275 | 276 | func finish() { 277 | self.outbound.finish() 278 | } 279 | 280 | /// Respond to ping 281 | func onPing(_ frame: WebSocketFrame) async throws { 282 | // a ping frame without the FIN flag is illegal 283 | guard frame.fin else { 284 | self.channel.close(promise: nil) 285 | return 286 | } 287 | switch self.stateMachine.receivedPing(frameData: frame.unmaskedData) { 288 | case .pong(let frameData): 289 | try await self.write(frame: .init(fin: true, opcode: .pong, data: frameData)) 290 | 291 | case .protocolError: 292 | try await self.close(code: .protocolError) 293 | 294 | case .doNothing: 295 | break 296 | } 297 | } 298 | 299 | /// Respond to pong 300 | func onPong(_ frame: WebSocketFrame) async throws { 301 | // a pong frame without the FIN flag is illegal 302 | guard frame.fin else { 303 | self.channel.close(promise: nil) 304 | return 305 | } 306 | self.stateMachine.receivedPong(frameData: frame.unmaskedData) 307 | } 308 | 309 | /// Send close 310 | func close(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws { 311 | switch self.stateMachine.close() { 312 | case .sendClose: 313 | try await self.sendClose(code: code, reason: reason) 314 | // Only server should initiate a connection close. Clients should wait for the 315 | // server to close the connection when it receives the WebSocket close packet 316 | // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 317 | if self.type == .server { 318 | self.outbound.finish() 319 | } 320 | case .doNothing: 321 | break 322 | } 323 | } 324 | 325 | func receivedClose(_ frame: WebSocketFrame) async throws { 326 | guard frame.reservedBits.isEmpty else { 327 | try await self.sendClose(code: .protocolError, reason: nil) 328 | // Only server should initiate a connection close. Clients should wait for the 329 | // server to close the connection when it receives the WebSocket close packet 330 | // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 331 | if self.type == .server { 332 | self.outbound.finish() 333 | } 334 | return 335 | } 336 | switch self.stateMachine.receivedClose(frameData: frame.unmaskedData, validateUTF8: self.configuration.validateUTF8) { 337 | case .sendClose(let errorCode): 338 | try await self.sendClose(code: errorCode, reason: nil) 339 | // Only server should initiate a connection close. Clients should wait for the 340 | // server to close the connection when it receives the WebSocket close packet 341 | // See https://www.rfc-editor.org/rfc/rfc6455#section-7.1.1 342 | if self.type == .server { 343 | self.outbound.finish() 344 | } 345 | case .doNothing: 346 | break 347 | } 348 | } 349 | 350 | private func sendClose(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) async throws { 351 | var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) 352 | buffer.write(webSocketErrorCode: code) 353 | if let reason { 354 | buffer.writeString(reason) 355 | } 356 | 357 | try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer)) 358 | } 359 | 360 | /// Make mask key to be used in WebSocket frame 361 | private func makeMaskKey() -> WebSocketMaskingKey? { 362 | guard self.type == .client else { return nil } 363 | let bytes: [UInt8] = (0...3).map { _ in UInt8.random(in: .min ... .max) } 364 | return WebSocketMaskingKey(bytes) 365 | } 366 | } 367 | 368 | extension WebSocketErrorCode { 369 | init(_ error: any Error) { 370 | switch error { 371 | case NIOWebSocketError.invalidFrameLength: 372 | self = .messageTooLarge 373 | case NIOWebSocketError.fragmentedControlFrame, 374 | NIOWebSocketError.multiByteControlFrameLength: 375 | self = .protocolError 376 | case WebSocketHandler.InternalError.close(let error): 377 | self = error 378 | default: 379 | self = .unexpectedServerError 380 | } 381 | } 382 | } 383 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketInboundMessageStream.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | /// Inbound WebSocket messages AsyncSequence. 19 | public struct WebSocketInboundMessageStream: AsyncSequence, Sendable { 20 | public typealias Element = WebSocketMessage 21 | 22 | let inboundStream: WebSocketInboundStream 23 | let maxSize: Int 24 | 25 | public struct AsyncIterator: AsyncIteratorProtocol { 26 | var frameIterator: WebSocketInboundStream.AsyncIterator 27 | let maxSize: Int 28 | 29 | public mutating func next() async throws -> Element? { 30 | try await self.frameIterator.nextMessage(maxSize: self.maxSize) 31 | } 32 | } 33 | 34 | public func makeAsyncIterator() -> AsyncIterator { 35 | .init(frameIterator: self.inboundStream.makeAsyncIterator(), maxSize: self.maxSize) 36 | } 37 | } 38 | 39 | extension WebSocketInboundStream { 40 | /// Convert to AsyncSequence of WebSocket messages 41 | /// 42 | /// A WebSocket message can be fragmented across multiple WebSocket frames. This 43 | /// converts the inbound stream of WebSocket data frames into a sequence of WebSocket 44 | /// messages. 45 | /// 46 | /// - Parameter maxSize: The maximum size of message we are allowed to read 47 | public func messages(maxSize: Int) -> WebSocketInboundMessageStream { 48 | .init(inboundStream: self, maxSize: maxSize) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketInboundStream.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOConcurrencyHelpers 16 | import NIOCore 17 | import NIOWebSocket 18 | 19 | /// Inbound WebSocket data frame AsyncSequence 20 | /// 21 | /// This AsyncSequence only returns binary, text and continuation frames. All other frames 22 | /// are dealt with internally 23 | public final class WebSocketInboundStream: AsyncSequence, Sendable { 24 | public typealias Element = WebSocketDataFrame 25 | 26 | typealias UnderlyingIterator = NIOAsyncChannelInboundStream.AsyncIterator 27 | /// Underlying NIOAsyncChannelInboundStream 28 | let underlyingIterator: UnsafeTransfer 29 | /// Handler for websockets 30 | let handler: WebSocketHandler 31 | internal let alreadyIterated: NIOLockedValueBox 32 | 33 | init( 34 | iterator: UnderlyingIterator, 35 | handler: WebSocketHandler 36 | ) { 37 | self.underlyingIterator = .init(iterator) 38 | self.handler = handler 39 | self.alreadyIterated = .init(false) 40 | } 41 | 42 | /// Inbound websocket data AsyncSequence iterator 43 | public struct AsyncIterator: AsyncIteratorProtocol { 44 | let handler: WebSocketHandler 45 | var iterator: UnderlyingIterator 46 | var closed: Bool 47 | 48 | init(sequence: WebSocketInboundStream, closed: Bool) { 49 | self.handler = sequence.handler 50 | self.iterator = sequence.underlyingIterator.wrappedValue 51 | self.closed = closed 52 | } 53 | 54 | /// Return next WebSocket frame, while dealing with any other frames 55 | public mutating func next() async throws -> WebSocketDataFrame? { 56 | guard !self.closed else { return nil } 57 | // parse messages coming from inbound 58 | while let frame = try await self.iterator.next() { 59 | do { 60 | self.handler.logger.trace("Received \(frame.traceDescription)") 61 | switch frame.opcode { 62 | case .connectionClose: 63 | try await self.handler.receivedClose(frame) 64 | return nil 65 | case .ping: 66 | try await self.handler.onPing(frame) 67 | case .pong: 68 | try await self.handler.onPong(frame) 69 | case .text, .binary, .continuation: 70 | guard self.handler.configuration.reservedBits.contains(frame.reservedBits) else { 71 | throw WebSocketHandler.InternalError.close(.protocolError) 72 | } 73 | // apply extensions 74 | var frame = frame 75 | for ext in self.handler.configuration.extensions.reversed() { 76 | frame = try await ext.processReceivedFrame( 77 | frame, 78 | context: WebSocketExtensionContext(logger: self.handler.logger) 79 | ) 80 | } 81 | return .init(from: frame) 82 | default: 83 | // if we receive a reserved opcode we should fail the connection 84 | self.handler.logger.trace("Received reserved opcode", metadata: ["hb.ws.opcode": .stringConvertible(frame.opcode)]) 85 | throw WebSocketHandler.InternalError.close(.protocolError) 86 | } 87 | } catch { 88 | self.handler.logger.trace("Error: \(error)") 89 | // catch errors while processing websocket frames so responding close message 90 | // can be dealt with 91 | let errorCode = WebSocketErrorCode(error) 92 | do { 93 | try await self.handler.close(code: errorCode) 94 | // don't propagate error if channel is already closed 95 | } catch ChannelError.ioOnClosedChannel {} 96 | } 97 | } 98 | 99 | return nil 100 | } 101 | 102 | /// Return next WebSocket messsage, while dealing with any other frames 103 | /// 104 | /// A WebSocket message can be fragmented across multiple WebSocket frames. This 105 | /// function collates fragmented frames until it has a full message 106 | public mutating func nextMessage(maxSize: Int) async throws -> WebSocketMessage? { 107 | var frameSequence: WebSocketFrameSequence 108 | // parse first frame 109 | guard let frame = try await self.next() else { return nil } 110 | switch frame.opcode { 111 | case .text, .binary: 112 | frameSequence = .init(frame: frame) 113 | if frame.fin { 114 | guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else { 115 | throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage) 116 | } 117 | return message 118 | } 119 | default: 120 | throw WebSocketHandler.InternalError.close(.protocolError) 121 | } 122 | // parse continuation frames until we get a frame with a FIN flag 123 | while let frame = try await self.next() { 124 | guard frame.opcode == .continuation else { 125 | throw WebSocketHandler.InternalError.close(.protocolError) 126 | } 127 | guard frameSequence.size + frame.data.readableBytes <= maxSize else { 128 | throw WebSocketHandler.InternalError.close(.messageTooLarge) 129 | } 130 | frameSequence.append(frame) 131 | if frame.fin { 132 | guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else { 133 | throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage) 134 | } 135 | return message 136 | } 137 | } 138 | return nil 139 | } 140 | } 141 | 142 | /// Creates the Asynchronous Iterator 143 | public func makeAsyncIterator() -> AsyncIterator { 144 | // verify if an iterator has already been created. If it has then create an 145 | // iterator that returns nothing. This could be a precondition failure (currently 146 | // an assert) as you should not be allowed to do this. 147 | let done = self.alreadyIterated.withLockedValue { 148 | assert($0 == false, "Can only create iterator from WebSocketInboundStream once") 149 | let done = $0 150 | $0 = true 151 | return done 152 | } 153 | return .init(sequence: self, closed: done) 154 | } 155 | } 156 | 157 | /// Extend WebSocketFrame to provide debug description for trace logs 158 | extension WebSocketFrame { 159 | var traceDescription: String { 160 | var flags: [String] = [] 161 | if self.fin { 162 | flags.append("FIN") 163 | } 164 | if self.rsv1 { 165 | flags.append("RSV1") 166 | } 167 | if self.rsv2 { 168 | flags.append("RSV2") 169 | } 170 | if self.rsv3 { 171 | flags.append("RSV3") 172 | } 173 | let unmaskedData = self.unmaskedData 174 | var desc = "[" 175 | let slice = unmaskedData.getSlice(at: unmaskedData.readerIndex, length: min(24, unmaskedData.readableBytes)) 176 | for byte in slice!.readableBytesView { 177 | let hexByte = String(byte, radix: 16) 178 | desc += " \(hexByte.count == 1 ? "0" : "")\(hexByte)" 179 | } 180 | if unmaskedData.readableBytes > 24 { 181 | desc += " ..." 182 | } 183 | desc += " ]" 184 | 185 | return "WebSocketFrame(\(self.opcode), flags: \(flags.joined(separator: ",")), data: {length: \(unmaskedData.readableBytes), bytes: \(desc)})" 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketMessage.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | /// Enumeration holding WebSocket message 19 | public enum WebSocketMessage: Equatable, Sendable, CustomStringConvertible, CustomDebugStringConvertible { 20 | case text(String) 21 | case binary(ByteBuffer) 22 | 23 | init?(frame: WebSocketDataFrame, validate: Bool) { 24 | switch frame.opcode { 25 | case .text: 26 | guard let string = String(buffer: frame.data, validateUTF8: validate) else { 27 | return nil 28 | } 29 | self = .text(string) 30 | case .binary: 31 | self = .binary(frame.data) 32 | default: 33 | return nil 34 | } 35 | } 36 | 37 | public var description: String { 38 | switch self { 39 | case .text(let string): 40 | return "string(\"\(string)\")" 41 | case .binary(let buffer): 42 | return "binary(\(buffer.description))" 43 | } 44 | } 45 | 46 | public var debugDescription: String { 47 | switch self { 48 | case .text(let string): 49 | return "string(\"\(string)\")" 50 | case .binary(let buffer): 51 | return "binary(\(buffer.debugDescription))" 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketOutboundWriter.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2023-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | /// Outbound websocket writer 19 | public struct WebSocketOutboundWriter: Sendable { 20 | /// WebSocket frame that can be written 21 | public enum OutboundFrame: Sendable { 22 | /// Text frame 23 | case text(String) 24 | /// Binary data frame 25 | case binary(ByteBuffer) 26 | /// Unsolicited pong frame 27 | case pong 28 | /// A custom frame not supported by the above 29 | case custom(WebSocketFrame) 30 | } 31 | 32 | let handler: WebSocketHandler 33 | 34 | /// Write WebSocket frame 35 | public func write(_ frame: OutboundFrame) async throws { 36 | try Task.checkCancellation() 37 | switch frame { 38 | case .binary(let buffer): 39 | // send binary data 40 | try await self.handler.write(frame: .init(fin: true, opcode: .binary, data: buffer)) 41 | case .text(let string): 42 | // send text based data 43 | let buffer = ByteBuffer(string: string) 44 | try await self.handler.write(frame: .init(fin: true, opcode: .text, data: buffer)) 45 | case .pong: 46 | // send unexplained pong as a heartbeat 47 | try await self.handler.write(frame: .init(fin: true, opcode: .pong, data: .init())) 48 | case .custom(let frame): 49 | // send custom WebSocketFrame 50 | try await self.handler.write(frame: frame) 51 | } 52 | } 53 | 54 | /// Send close control frame. 55 | /// 56 | /// In most cases calling this is unnecessary as the WebSocket handling code will do 57 | /// this for you automatically, but if you want to send a custom close code or reason 58 | /// use this function. 59 | /// 60 | /// After calling this function you should not send anymore data 61 | /// - Parameters: 62 | /// - closeCode: Close code 63 | /// - reason: Close reason string 64 | public func close(_ closeCode: WebSocketErrorCode, reason: String?) async throws { 65 | try await self.handler.close(code: closeCode, reason: reason) 66 | } 67 | 68 | /// Write WebSocket message as a series as frames 69 | public struct MessageWriter { 70 | let opcode: WebSocketOpcode 71 | let handler: WebSocketHandler 72 | var prevFrame: WebSocketFrame? 73 | 74 | /// Write string to WebSocket frame 75 | public mutating func callAsFunction(_ text: String) async throws { 76 | let buffer = ByteBuffer(string: text) 77 | try await self.write(buffer, opcode: self.opcode) 78 | } 79 | 80 | /// Write buffer to WebSocket frame 81 | public mutating func callAsFunction(_ buffer: ByteBuffer) async throws { 82 | try await self.write(buffer, opcode: self.opcode) 83 | } 84 | 85 | mutating func write(_ data: ByteBuffer, opcode: WebSocketOpcode) async throws { 86 | if let prevFrame { 87 | try await self.handler.write(frame: prevFrame) 88 | self.prevFrame = .init(fin: false, opcode: .continuation, data: data) 89 | } else { 90 | self.prevFrame = .init(fin: false, opcode: opcode, data: data) 91 | } 92 | } 93 | 94 | func finish() async throws { 95 | if var prevFrame { 96 | prevFrame.fin = true 97 | try await self.handler.write(frame: prevFrame) 98 | } 99 | } 100 | } 101 | 102 | /// Write a single WebSocket text message as a series of fragmented frames 103 | /// - Parameter write: Function writing frames 104 | public func withTextMessageWriter(_ write: (inout MessageWriter) async throws -> Value) async throws -> Value { 105 | var writer = MessageWriter(opcode: .text, handler: self.handler) 106 | let value: Value 107 | do { 108 | value = try await write(&writer) 109 | } catch { 110 | try await writer.finish() 111 | throw error 112 | } 113 | try await writer.finish() 114 | return value 115 | } 116 | 117 | /// Write a single WebSocket binary message as a series of fragmented frames 118 | /// - Parameter write: Function writing frames 119 | public func withBinaryMessageWriter(_ write: (inout MessageWriter) async throws -> Value) async throws -> Value { 120 | var writer = MessageWriter(opcode: .binary, handler: self.handler) 121 | let value: Value 122 | do { 123 | value = try await write(&writer) 124 | } catch { 125 | try await writer.finish() 126 | throw error 127 | } 128 | try await writer.finish() 129 | return value 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /Sources/WSCore/WebSocketStateMachine.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | 18 | struct WebSocketStateMachine { 19 | static let pingDataSize = 16 20 | let pingTimePeriod: Duration 21 | var state: State 22 | 23 | init(autoPingSetup: AutoPingSetup) { 24 | switch autoPingSetup.value { 25 | case .enabled(let timePeriod): 26 | self.pingTimePeriod = timePeriod 27 | case .disabled: 28 | self.pingTimePeriod = .nanoseconds(0) 29 | } 30 | self.state = .open(.init()) 31 | } 32 | 33 | enum CloseResult { 34 | case sendClose 35 | case doNothing 36 | } 37 | 38 | mutating func close() -> CloseResult { 39 | switch self.state { 40 | case .open: 41 | self.state = .closing 42 | return .sendClose 43 | case .closing: 44 | return .doNothing 45 | case .closed: 46 | return .doNothing 47 | } 48 | } 49 | 50 | enum ReceivedCloseResult { 51 | case sendClose(WebSocketErrorCode) 52 | case doNothing 53 | } 54 | 55 | // we received a connection close. 56 | // send a close back if it hasn't already been send and exit 57 | mutating func receivedClose(frameData: ByteBuffer, validateUTF8: Bool) -> ReceivedCloseResult { 58 | var frameData = frameData 59 | let dataSize = frameData.readableBytes 60 | // read close code and close reason 61 | let closeCode = frameData.readWebSocketErrorCode() 62 | let hasReason = frameData.readableBytes > 0 63 | let reason: String? = 64 | if hasReason { 65 | String(buffer: frameData, validateUTF8: validateUTF8) 66 | } else { 67 | nil 68 | } 69 | 70 | switch self.state { 71 | case .open: 72 | if hasReason, reason == nil { 73 | self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) 74 | return .sendClose(.protocolError) 75 | } 76 | self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) 77 | let code: WebSocketErrorCode = 78 | if dataSize == 0 || closeCode != nil { 79 | // codes 3000 - 3999 are reserved for use by libraries, frameworks 80 | // codes 4000 - 4999 are reserved for private use 81 | // both of these are considered valid. 82 | if case .unknown(let code) = closeCode, code < 3000 || code > 4999 { 83 | .protocolError 84 | } else { 85 | .normalClosure 86 | } 87 | } else { 88 | .protocolError 89 | } 90 | return .sendClose(code) 91 | case .closing: 92 | self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) }) 93 | return .doNothing 94 | case .closed: 95 | return .doNothing 96 | } 97 | } 98 | 99 | enum SendPingResult { 100 | case sendPing(ByteBuffer) 101 | case wait(Duration) 102 | case closeConnection(WebSocketErrorCode) 103 | case stop 104 | } 105 | 106 | mutating func sendPing() -> SendPingResult { 107 | switch self.state { 108 | case .open(var state): 109 | if let lastPingTime = state.lastPingTime { 110 | let timeSinceLastPing = .now - lastPingTime 111 | // if time is less than timeout value, set wait time to when it would timeout 112 | // and re-run loop 113 | if timeSinceLastPing < self.pingTimePeriod { 114 | return .wait(self.pingTimePeriod - timeSinceLastPing) 115 | } else { 116 | return .closeConnection(.goingAway) 117 | } 118 | } 119 | // creating random payload 120 | let random = (0.. ReceivedPingResult { 142 | switch self.state { 143 | case .open: 144 | guard frameData.readableBytes < 126 else { return .protocolError } 145 | return .pong(frameData) 146 | 147 | case .closing: 148 | return .pong(frameData) 149 | 150 | case .closed: 151 | return .doNothing 152 | } 153 | } 154 | 155 | mutating func receivedPong(frameData: ByteBuffer) { 156 | switch self.state { 157 | case .open(var state): 158 | let frameData = frameData 159 | // ignore pong frames with frame data not the same as the last ping 160 | guard frameData == state.pingData else { return } 161 | // clear ping data 162 | state.lastPingTime = nil 163 | self.state = .open(state) 164 | 165 | case .closing: 166 | break 167 | 168 | case .closed: 169 | break 170 | } 171 | } 172 | } 173 | 174 | extension WebSocketStateMachine { 175 | struct OpenState { 176 | var pingData: ByteBuffer 177 | var lastPingTime: ContinuousClock.Instant? 178 | 179 | init() { 180 | self.pingData = ByteBufferAllocator().buffer(capacity: WebSocketStateMachine.pingDataSize) 181 | self.lastPingTime = nil 182 | } 183 | } 184 | 185 | enum State { 186 | case open(OpenState) 187 | case closing 188 | case closed(WebSocketCloseFrame?) 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/AutobahnTests.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Foundation 16 | import Logging 17 | import NIOConcurrencyHelpers 18 | import NIOPosix 19 | import WSClient 20 | import WSCompression 21 | import XCTest 22 | 23 | /// The Autobahn|Testsuite provides a fully automated test suite to verify client and server 24 | /// implementations of The WebSocket Protocol for specification conformance and implementation robustness. 25 | /// You can find out more at https://github.com/crossbario/autobahn-testsuite 26 | /// 27 | /// Before running these tests run `./scripts/autobahn-server.sh` to running the test server. 28 | final class AutobahnTests: XCTestCase { 29 | /// To run all the autobahn compression tests takes a long time. By default we only run a selection. 30 | /// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them. 31 | var runAllTests: Bool { ProcessInfo.processInfo.environment["AUTOBAHN_ALL_TESTS"] == "true" } 32 | var autobahnServer: String { ProcessInfo.processInfo.environment["FUZZING_SERVER"] ?? "localhost" } 33 | 34 | func getValue(_ path: String, as: T.Type) async throws -> T { 35 | let result: NIOLockedValueBox = .init(nil) 36 | try await WebSocketClient.connect( 37 | url: "ws://\(self.autobahnServer):9001/\(path)", 38 | configuration: .init(validateUTF8: true), 39 | logger: Logger(label: "Autobahn") 40 | ) { inbound, _, _ in 41 | var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator() 42 | switch try await inboundIterator.next() { 43 | case .text(let text): 44 | let data = Data(text.utf8) 45 | let report = try JSONDecoder().decode(T.self, from: data) 46 | result.withLockedValue { $0 = report } 47 | 48 | case .binary: 49 | preconditionFailure("Received unexpected data") 50 | 51 | case .none: 52 | return 53 | } 54 | } 55 | return try result.withLockedValue { try XCTUnwrap($0) } 56 | } 57 | 58 | /// Run a number of autobahn tests 59 | func autobahnTests( 60 | cases: Set, 61 | extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)] 62 | ) async throws { 63 | // These are broken in CI currently 64 | try XCTSkipIf(ProcessInfo.processInfo.environment["CI"] != nil) 65 | 66 | struct CaseInfo: Decodable { 67 | let id: String 68 | let description: String 69 | } 70 | struct CaseStatus: Decodable { 71 | let behavior: String 72 | } 73 | 74 | let logger = Logger(label: "Autobahn") 75 | 76 | // Run tests 77 | do { 78 | for index in cases.sorted() { 79 | // get case info 80 | let info = try await getValue("getCaseInfo?case=\(index)&agent=swift-websocket", as: CaseInfo.self) 81 | logger.info("\(info.id): \(info.description)") 82 | 83 | // run case 84 | try await WebSocketClient.connect( 85 | url: "ws://\(self.autobahnServer):9001/runCase?case=\(index)&agent=swift-websocket", 86 | configuration: .init( 87 | maxFrameSize: 16_777_216, 88 | extensions: extensions, 89 | validateUTF8: true 90 | ), 91 | logger: logger 92 | ) { inbound, outbound, _ in 93 | for try await msg in inbound.messages(maxSize: .max) { 94 | switch msg { 95 | case .binary(let buffer): 96 | try await outbound.write(.binary(buffer)) 97 | case .text(let string): 98 | try await outbound.write(.text(string)) 99 | } 100 | } 101 | } 102 | 103 | // get case status 104 | let status = try await getValue("getCaseStatus?case=\(index)&agent=swift-websocket", as: CaseStatus.self) 105 | XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL" || status.behavior == "NON-STRICT") 106 | } 107 | 108 | try await WebSocketClient.connect( 109 | url: .init("ws://\(self.autobahnServer):9001/updateReports?agent=HB"), 110 | logger: logger 111 | ) { inbound, _, _ in 112 | for try await _ in inbound {} 113 | } 114 | } catch let error as NIOConnectionError { 115 | logger.error("Autobahn tests require a running Autobahn fuzzing server. Run ./scripts/autobahn-server.sh") 116 | throw error 117 | } 118 | } 119 | 120 | func test_1_Framing() async throws { 121 | try await self.autobahnTests(cases: .init(1..<17)) 122 | } 123 | 124 | func test_2_PingPongs() async throws { 125 | try await self.autobahnTests(cases: .init(17..<28)) 126 | } 127 | 128 | func test_3_ReservedBits() async throws { 129 | try await self.autobahnTests(cases: .init(28..<35)) 130 | } 131 | 132 | func test_4_Opcodes() async throws { 133 | try await self.autobahnTests(cases: .init(35..<45)) 134 | } 135 | 136 | func test_5_Fragmentation() async throws { 137 | try await self.autobahnTests(cases: .init(45..<65)) 138 | } 139 | 140 | func test_6_UTF8Handling() async throws { 141 | // UTF8 validation is available on swift 5.10 or earlier 142 | #if compiler(<6) 143 | try XCTSkipIf(true) 144 | #endif 145 | try await self.autobahnTests(cases: .init(65..<210)) 146 | } 147 | 148 | func test_7_CloseHandling() async throws { 149 | // UTF8 validation is available on swift 5.10 or earlier 150 | #if compiler(<6) 151 | try await self.autobahnTests(cases: .init(210..<222)) 152 | try await self.autobahnTests(cases: .init(223..<247)) 153 | #else 154 | try await self.autobahnTests(cases: .init(210..<247)) 155 | #endif 156 | } 157 | 158 | func test_9_Performance() async throws { 159 | if !self.runAllTests { 160 | try await self.autobahnTests(cases: .init([247, 260, 270, 281, 291, 296])) 161 | } else { 162 | try await self.autobahnTests(cases: .init(247..<301)) 163 | } 164 | } 165 | 166 | func test_10_AutoFragmentation() async throws { 167 | try await self.autobahnTests(cases: .init([301])) 168 | } 169 | 170 | func test_12_CompressionDifferentPayloads() async throws { 171 | if !self.runAllTests { 172 | try await self.autobahnTests(cases: .init([302, 330, 349, 360, 388])) 173 | } else { 174 | try await self.autobahnTests(cases: .init(302..<391)) 175 | } 176 | } 177 | 178 | func test_13_CompressionDifferentParameters() async throws { 179 | if !self.runAllTests { 180 | try await self.autobahnTests( 181 | cases: .init([392]), 182 | extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 183 | ) 184 | try await self.autobahnTests( 185 | cases: .init([427]), 186 | extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 187 | ) 188 | try await self.autobahnTests( 189 | cases: .init([440]), 190 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 191 | ) 192 | try await self.autobahnTests( 193 | cases: .init([451]), 194 | extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 195 | ) 196 | try await self.autobahnTests( 197 | cases: .init([473]), 198 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 199 | ) 200 | try await self.autobahnTests( 201 | cases: .init([498]), 202 | extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 203 | ) 204 | // case 13.7.x are repeated with different setups 205 | try await self.autobahnTests( 206 | cases: .init([509]), 207 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 208 | ) 209 | try await self.autobahnTests( 210 | cases: .init([517]), 211 | extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 212 | ) 213 | try await self.autobahnTests( 214 | cases: .init([504]), 215 | extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 216 | ) 217 | } else { 218 | try await self.autobahnTests( 219 | cases: .init(392..<410), 220 | extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 221 | ) 222 | try await self.autobahnTests( 223 | cases: .init(410..<428), 224 | extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 225 | ) 226 | try await self.autobahnTests( 227 | cases: .init(428..<446), 228 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 229 | ) 230 | try await self.autobahnTests( 231 | cases: .init(446..<464), 232 | extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 233 | ) 234 | try await self.autobahnTests( 235 | cases: .init(464..<482), 236 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 237 | ) 238 | try await self.autobahnTests( 239 | cases: .init(482..<500), 240 | extensions: [.perMessageDeflate(maxWindow: 15, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 241 | ) 242 | // case 13.7.x are repeated with different setups 243 | try await self.autobahnTests( 244 | cases: .init(500..<518), 245 | extensions: [.perMessageDeflate(maxWindow: 9, noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 246 | ) 247 | try await self.autobahnTests( 248 | cases: .init(500..<518), 249 | extensions: [.perMessageDeflate(noContextTakeover: true, maxDecompressedFrameSize: 131_072)] 250 | ) 251 | try await self.autobahnTests( 252 | cases: .init(500..<518), 253 | extensions: [.perMessageDeflate(noContextTakeover: false, maxDecompressedFrameSize: 131_072)] 254 | ) 255 | } 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/ClientTests.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024-2025 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import Logging 16 | import NIOCore 17 | import NIOSSL 18 | import NIOWebSocket 19 | import WSClient 20 | import XCTest 21 | 22 | final class WebSocketClientTests: XCTestCase { 23 | 24 | func testEchoServer() async throws { 25 | let clientLogger = { 26 | var logger = Logger(label: "client") 27 | logger.logLevel = .trace 28 | return logger 29 | }() 30 | try await WebSocketClient.connect( 31 | url: "wss://echo.websocket.org/", 32 | tlsConfiguration: TLSConfiguration.makeClientConfiguration(), 33 | logger: clientLogger 34 | ) { inbound, outbound, _ in 35 | var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator() 36 | try await outbound.write(.text("hello")) 37 | if let msg = try await inboundIterator.next() { 38 | print(msg) 39 | } 40 | } 41 | } 42 | 43 | func testEchoServerWithSNIHostname() async throws { 44 | let clientLogger = { 45 | var logger = Logger(label: "client") 46 | logger.logLevel = .trace 47 | return logger 48 | }() 49 | try await WebSocketClient.connect( 50 | url: "wss://echo.websocket.org/", 51 | configuration: .init(sniHostname: "echo.websocket.org"), 52 | tlsConfiguration: TLSConfiguration.makeClientConfiguration(), 53 | logger: clientLogger 54 | ) { inbound, outbound, _ in 55 | var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator() 56 | try await outbound.write(.text("hello")) 57 | if let msg = try await inboundIterator.next() { 58 | print(msg) 59 | } 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/WebSocketExtensionNegotiationTests.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2021-2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import HTTPTypes 16 | import NIOWebSocket 17 | import XCTest 18 | 19 | @testable import WSCompression 20 | @testable import WSCore 21 | 22 | final class WebSocketExtensionNegotiationTests: XCTestCase { 23 | func testExtensionHeaderParsing() { 24 | let headers: HTTPFields = .init([ 25 | .init(name: .secWebSocketExtensions, value: "permessage-deflate; client_max_window_bits; server_max_window_bits=10"), 26 | .init(name: .secWebSocketExtensions, value: "permessage-deflate;client_max_window_bits"), 27 | ]) 28 | let extensions = WebSocketExtensionHTTPParameters.parseHeaders(headers) 29 | XCTAssertEqual( 30 | extensions, 31 | [ 32 | .init("permessage-deflate", parameters: ["client_max_window_bits": .null, "server_max_window_bits": .value("10")]), 33 | .init("permessage-deflate", parameters: ["client_max_window_bits": .null]), 34 | ] 35 | ) 36 | } 37 | 38 | func testDeflateServerResponse() { 39 | let requestHeaders: [WebSocketExtensionHTTPParameters] = [ 40 | .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]) 41 | ] 42 | let ext = PerMessageDeflateExtensionBuilder(clientNoContextTakeover: true, serverNoContextTakeover: true) 43 | let serverResponse = ext.serverResponseHeader(to: requestHeaders) 44 | XCTAssertEqual( 45 | serverResponse, 46 | "permessage-deflate;client_max_window_bits=10;client_no_context_takeover;server_no_context_takeover" 47 | ) 48 | } 49 | 50 | func testDeflateServerResponseClientMaxWindowBits() { 51 | let requestHeaders: [WebSocketExtensionHTTPParameters] = [ 52 | .init("permessage-deflate", parameters: ["client_max_window_bits": .null]) 53 | ] 54 | let ext1 = PerMessageDeflateExtensionBuilder(serverNoContextTakeover: true) 55 | let serverResponse1 = ext1.serverResponseHeader(to: requestHeaders) 56 | XCTAssertEqual( 57 | serverResponse1, 58 | "permessage-deflate;server_no_context_takeover" 59 | ) 60 | let ext2 = PerMessageDeflateExtensionBuilder(clientNoContextTakeover: true, serverMaxWindow: 12) 61 | let serverResponse2 = ext2.serverResponseHeader(to: requestHeaders) 62 | XCTAssertEqual( 63 | serverResponse2, 64 | "permessage-deflate;client_no_context_takeover;server_max_window_bits=12" 65 | ) 66 | } 67 | 68 | func testUnregonisedExtensionServerResponse() throws { 69 | let serverExtensions: [WebSocketExtensionBuilder] = [PerMessageDeflateExtensionBuilder()] 70 | let (headers, extensions) = try serverExtensions.serverExtensionNegotiation( 71 | requestHeaders: [ 72 | .secWebSocketExtensions: "permessage-foo;bar=baz", 73 | .secWebSocketExtensions: "permessage-deflate;client_max_window_bits=10", 74 | ] 75 | ) 76 | XCTAssertEqual( 77 | headers[.secWebSocketExtensions], 78 | "permessage-deflate;client_max_window_bits=10" 79 | ) 80 | XCTAssertEqual(extensions.count, 1) 81 | let firstExtension = try XCTUnwrap(extensions.first) 82 | XCTAssert(firstExtension is PerMessageDeflateExtension) 83 | 84 | let requestExtensions = try serverExtensions.buildClientExtensions(from: headers) 85 | XCTAssertEqual(requestExtensions.count, 1) 86 | XCTAssert(requestExtensions[0] is PerMessageDeflateExtension) 87 | } 88 | 89 | func testNonNegotiableClientExtension() throws { 90 | struct MyExtension: WebSocketExtension { 91 | var name = "my-extension" 92 | 93 | func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 94 | frame 95 | } 96 | 97 | func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 98 | frame 99 | } 100 | 101 | func shutdown() async {} 102 | } 103 | let clientExtensionBuilders: [WebSocketExtensionBuilder] = [ 104 | WebSocketExtensionFactory.nonNegotiatedExtension { 105 | MyExtension() 106 | }.build() 107 | ] 108 | let clientExtensions = try clientExtensionBuilders.buildClientExtensions(from: [:]) 109 | XCTAssertEqual(clientExtensions.count, 1) 110 | let myExtension = try XCTUnwrap(clientExtensions.first) 111 | XCTAssert(myExtension is MyExtension) 112 | } 113 | 114 | func testNonNegotiableServerExtension() throws { 115 | struct MyExtension: WebSocketExtension { 116 | var name = "my-extension" 117 | 118 | func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 119 | frame 120 | } 121 | 122 | func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame { 123 | frame 124 | } 125 | 126 | func shutdown() async {} 127 | } 128 | let serverExtensionBuilders: [WebSocketExtensionBuilder] = [WebSocketNonNegotiableExtensionBuilder { MyExtension() }] 129 | let (headers, serverExtensions) = try serverExtensionBuilders.serverExtensionNegotiation( 130 | requestHeaders: [:] 131 | ) 132 | XCTAssertEqual(headers.count, 0) 133 | XCTAssertEqual(serverExtensions.count, 1) 134 | let myExtension = try XCTUnwrap(serverExtensions.first) 135 | XCTAssert(myExtension is MyExtension) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/WebSocketStateMachineTests.swift: -------------------------------------------------------------------------------- 1 | //===----------------------------------------------------------------------===// 2 | // 3 | // This source file is part of the Hummingbird server framework project 4 | // 5 | // Copyright (c) 2024 the Hummingbird authors 6 | // Licensed under Apache License v2.0 7 | // 8 | // See LICENSE.txt for license information 9 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 10 | // 11 | // SPDX-License-Identifier: Apache-2.0 12 | // 13 | //===----------------------------------------------------------------------===// 14 | 15 | import NIOCore 16 | import NIOWebSocket 17 | import XCTest 18 | 19 | @testable import WSCore 20 | 21 | final class WebSocketStateMachineTests: XCTestCase { 22 | private func closeFrameData(code: WebSocketErrorCode = .normalClosure, reason: String? = nil) -> ByteBuffer { 23 | var buffer = ByteBufferAllocator().buffer(capacity: 2 + (reason?.utf8.count ?? 0)) 24 | buffer.write(webSocketErrorCode: code) 25 | if let reason { 26 | buffer.writeString(reason) 27 | } 28 | return buffer 29 | } 30 | 31 | func testClose() { 32 | var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) 33 | guard case .sendClose = stateMachine.close() else { 34 | XCTFail() 35 | return 36 | } 37 | guard case .doNothing = stateMachine.close() else { 38 | XCTFail() 39 | return 40 | } 41 | guard case .doNothing = stateMachine.receivedClose(frameData: self.closeFrameData(), validateUTF8: false) else { 42 | XCTFail() 43 | return 44 | } 45 | guard case .closed(let frame) = stateMachine.state else { 46 | XCTFail() 47 | return 48 | } 49 | XCTAssertEqual(frame?.closeCode, .normalClosure) 50 | } 51 | 52 | func testReceivedClose() { 53 | var stateMachine = WebSocketStateMachine(autoPingSetup: .disabled) 54 | guard case .sendClose(let error) = stateMachine.receivedClose(frameData: closeFrameData(code: .goingAway), validateUTF8: false) else { 55 | XCTFail() 56 | return 57 | } 58 | XCTAssertEqual(error, .normalClosure) 59 | guard case .closed(let frame) = stateMachine.state else { 60 | XCTFail() 61 | return 62 | } 63 | XCTAssertEqual(frame?.closeCode, .goingAway) 64 | } 65 | 66 | func testPingLoopNoPong() { 67 | var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) 68 | guard case .sendPing = stateMachine.sendPing() else { 69 | XCTFail() 70 | return 71 | } 72 | guard case .wait = stateMachine.sendPing() else { 73 | XCTFail() 74 | return 75 | } 76 | } 77 | 78 | func testPingLoop() { 79 | var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .seconds(15))) 80 | guard case .sendPing(let buffer) = stateMachine.sendPing() else { 81 | XCTFail() 82 | return 83 | } 84 | guard case .wait = stateMachine.sendPing() else { 85 | XCTFail() 86 | return 87 | } 88 | stateMachine.receivedPong(frameData: buffer) 89 | guard case .open(let openState) = stateMachine.state else { 90 | XCTFail() 91 | return 92 | } 93 | XCTAssertEqual(openState.lastPingTime, nil) 94 | guard case .sendPing = stateMachine.sendPing() else { 95 | XCTFail() 96 | return 97 | } 98 | } 99 | 100 | // Verify ping buffer size doesnt grow 101 | func testPingBufferSize() async throws { 102 | var stateMachine = WebSocketStateMachine(autoPingSetup: .enabled(timePeriod: .milliseconds(1))) 103 | var currentBuffer = ByteBuffer() 104 | var count = 0 105 | while true { 106 | switch stateMachine.sendPing() { 107 | case .sendPing(let buffer): 108 | XCTAssertEqual(buffer.readableBytes, 16) 109 | currentBuffer = buffer 110 | count += 1 111 | if count > 4 { 112 | return 113 | } 114 | 115 | case .wait(let time): 116 | try await Task.sleep(for: time) 117 | stateMachine.receivedPong(frameData: currentBuffer) 118 | 119 | case .closeConnection: 120 | XCTFail("Should not timeout") 121 | return 122 | 123 | case .stop: 124 | XCTFail("Should not stop") 125 | return 126 | } 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /scripts/autobahn-config/fuzzingserver.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "ws://127.0.0.1:9001", 3 | "outdir": "./reports/", 4 | "cases": ["*"], 5 | "exclude-agent-cases": {} 6 | } 7 | -------------------------------------------------------------------------------- /scripts/autobahn-server.sh: -------------------------------------------------------------------------------- 1 | docker run -it --rm \ 2 | -v "${PWD}/scripts/autobahn-config:/config" \ 3 | -v "${PWD}/.build/reports:/reports" \ 4 | -p 9001:9001 \ 5 | --name fuzzingserver \ 6 | crossbario/autobahn-testsuite 7 | -------------------------------------------------------------------------------- /scripts/validate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ##===----------------------------------------------------------------------===## 3 | ## 4 | ## This source file is part of the Hummingbird server framework project 5 | ## 6 | ## Copyright (c) 2021-2024 the Hummingbird authors 7 | ## Licensed under Apache License v2.0 8 | ## 9 | ## See LICENSE.txt for license information 10 | ## See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 11 | ## 12 | ## SPDX-License-Identifier: Apache-2.0 13 | ## 14 | ##===----------------------------------------------------------------------===## 15 | ##===----------------------------------------------------------------------===## 16 | ## 17 | ## This source file is part of the SwiftNIO open source project 18 | ## 19 | ## Copyright (c) 2017-2019 Apple Inc. and the SwiftNIO project authors 20 | ## Licensed under Apache License v2.0 21 | ## 22 | ## See LICENSE.txt for license information 23 | ## See CONTRIBUTORS.txt for the list of SwiftNIO project authors 24 | ## 25 | ## SPDX-License-Identifier: Apache-2.0 26 | ## 27 | ##===----------------------------------------------------------------------===## 28 | 29 | SWIFT_FORMAT_VERSION=0.53.10 30 | 31 | set -eu 32 | here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 33 | 34 | function replace_acceptable_years() { 35 | # this needs to replace all acceptable forms with 'YEARS' 36 | sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' -e '/^#!/ d' 37 | } 38 | 39 | printf "=> Checking format... " 40 | FIRST_OUT="$(git status --porcelain)" 41 | git ls-files -z '*.swift' | xargs -0 swift format format --parallel --in-place 42 | git diff --exit-code '*.swift' 43 | 44 | SECOND_OUT="$(git status --porcelain)" 45 | if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then 46 | printf "\033[0;31mformatting issues!\033[0m\n" 47 | git --no-pager diff 48 | exit 1 49 | else 50 | printf "\033[0;32mokay.\033[0m\n" 51 | fi 52 | printf "=> Checking license headers... " 53 | tmp=$(mktemp /tmp/.soto-core-sanity_XXXXXX) 54 | 55 | exit 0 56 | 57 | for language in swift-or-c; do 58 | declare -a matching_files 59 | declare -a exceptions 60 | expections=( ) 61 | matching_files=( -name '*' ) 62 | case "$language" in 63 | swift-or-c) 64 | exceptions=( -path '*/Benchmarks/.build/*' -o -name Package.swift) 65 | matching_files=( -name '*.swift' -o -name '*.c' -o -name '*.h' ) 66 | cat > "$tmp" <<"EOF" 67 | //===----------------------------------------------------------------------===// 68 | // 69 | // This source file is part of the Hummingbird server framework project 70 | // 71 | // Copyright (c) YEARS the Hummingbird authors 72 | // Licensed under Apache License v2.0 73 | // 74 | // See LICENSE.txt for license information 75 | // See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 76 | // 77 | // SPDX-License-Identifier: Apache-2.0 78 | // 79 | //===----------------------------------------------------------------------===// 80 | EOF 81 | ;; 82 | bash) 83 | matching_files=( -name '*.sh' ) 84 | cat > "$tmp" <<"EOF" 85 | ##===----------------------------------------------------------------------===## 86 | ## 87 | ## This source file is part of the Hummingbird server framework project 88 | ## 89 | ## Copyright (c) YEARS the Hummingbird authors 90 | ## Licensed under Apache License v2.0 91 | ## 92 | ## See LICENSE.txt for license information 93 | ## See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors 94 | ## 95 | ## SPDX-License-Identifier: Apache-2.0 96 | ## 97 | ##===----------------------------------------------------------------------===## 98 | EOF 99 | ;; 100 | *) 101 | echo >&2 "ERROR: unknown language '$language'" 102 | ;; 103 | esac 104 | 105 | lines_to_compare=$(cat "$tmp" | wc -l | tr -d " ") 106 | # need to read one more line as we remove the '#!' line 107 | lines_to_read=$(expr "$lines_to_compare" + 1) 108 | expected_sha=$(cat "$tmp" | shasum) 109 | 110 | ( 111 | cd "$here/.." 112 | find . \ 113 | \( \! -path './.build/*' -a \ 114 | \( "${matching_files[@]}" \) -a \ 115 | \( \! \( "${exceptions[@]}" \) \) \) | while read line; do 116 | if [[ "$(cat "$line" | head -n $lines_to_read | replace_acceptable_years | head -n $lines_to_compare | shasum)" != "$expected_sha" ]]; then 117 | printf "\033[0;31mmissing headers in file '$line'!\033[0m\n" 118 | diff -u <(cat "$line" | head -n $lines_to_read | replace_acceptable_years | head -n $lines_to_compare) "$tmp" 119 | exit 1 120 | fi 121 | done 122 | printf "\033[0;32mokay.\033[0m\n" 123 | ) 124 | done 125 | 126 | rm "$tmp" 127 | --------------------------------------------------------------------------------