├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .swift-version ├── .swiftformat ├── LICENSE.md ├── Package.resolved ├── Package.swift ├── README.md ├── Sources └── WebSocket │ ├── OSLog+WebSocket.swift │ ├── SystemURLSession.swift │ ├── SystemWebSocket.swift │ ├── WebSocket.swift │ ├── WebSocketClose.swift │ ├── WebSocketCloseCode.swift │ ├── WebSocketError.swift │ ├── WebSocketMessage.swift │ └── WebSocketOptions.swift ├── Tests ├── LinuxMain.swift └── WebSocketTests │ ├── Server │ ├── ServerBootstrap+WebSocketTests.swift │ └── WebSocketServer.swift │ ├── SystemWebSocketTests.swift │ ├── URLSessionWebSocketTaskCloseCodeTests.swift │ └── XCTestManifests.swift └── bin └── format.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: push 4 | 5 | jobs: 6 | test: 7 | runs-on: macos-14 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Select Xcode 15 12 | run: sudo xcode-select -s /Applications/Xcode_15.4.app 13 | - name: Test 14 | run: swift test 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # from https://github.com/github/gitignore/blob/master/Swift.gitignore 2 | 3 | .env 4 | 5 | .DS_Store 6 | 7 | .swiftpm 8 | 9 | /.vscode 10 | 11 | # Xcode 12 | # 13 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 14 | 15 | ## Build generated 16 | build/ 17 | DerivedData 18 | 19 | ## Various settings 20 | *.pbxuser 21 | !default.pbxuser 22 | *.mode1v3 23 | !default.mode1v3 24 | *.mode2v3 25 | !default.mode2v3 26 | *.perspectivev3 27 | !default.perspectivev3 28 | xcuserdata 29 | 30 | ## Other 31 | *.xccheckout 32 | *.moved-aside 33 | *.xcuserstate 34 | *.xcscmblueprint 35 | 36 | ## Obj-C/Swift specific 37 | *.hmap 38 | *.ipa 39 | 40 | # Swift Package Manager 41 | # 42 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 43 | # Packages/ 44 | .build/ 45 | 46 | # CocoaPods 47 | # 48 | # We recommend against adding the Pods directory to your .gitignore. However 49 | # you should judge for yourself, the pros and cons are mentioned at: 50 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 51 | # 52 | # Pods/ 53 | 54 | # Carthage 55 | Carthage/Checkouts 56 | Carthage/Build 57 | 58 | # fastlane 59 | # 60 | # It is recommended to not store the screenshots in the git repo. Instead, use fastlane to re-generate the 61 | # screenshots whenever they are needed. 62 | # For more information about the recommended setup visit: 63 | # https://github.com/fastlane/fastlane/blob/master/docs/Gitignore.md 64 | 65 | fastlane/report.xml 66 | fastlane/screenshots 67 | -------------------------------------------------------------------------------- /.swift-version: -------------------------------------------------------------------------------- 1 | 5.9.0 2 | 3 | -------------------------------------------------------------------------------- /.swiftformat: -------------------------------------------------------------------------------- 1 | --disable \ 2 | hoistAwait, \ 3 | hoistTry 4 | 5 | --decimalgrouping 3,5 6 | --funcattributes prev-line 7 | --minversion 0.47.2 8 | --maxwidth 96 9 | --typeattributes prev-line 10 | --wraparguments before-first 11 | --wrapparameters before-first 12 | --wrapcollections before-first 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 Shareup Software Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "async-extensions", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/shareup/async-extensions.git", 7 | "state" : { 8 | "revision" : "7e727e3b9009a5de429393691f9f499aedb7a109", 9 | "version" : "4.3.0" 10 | } 11 | }, 12 | { 13 | "identity" : "dispatch-timer", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/shareup/dispatch-timer.git", 16 | "state" : { 17 | "revision" : "2d8c304aa6f382a7a362cd5a814884f3930c5662", 18 | "version" : "3.0.1" 19 | } 20 | }, 21 | { 22 | "identity" : "swift-atomics", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/apple/swift-atomics.git", 25 | "state" : { 26 | "revision" : "cd142fd2f64be2100422d658e7411e39489da985", 27 | "version" : "1.2.0" 28 | } 29 | }, 30 | { 31 | "identity" : "swift-collections", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/apple/swift-collections.git", 34 | "state" : { 35 | "revision" : "3d2dc41a01f9e49d84f0a3925fb858bed64f702d", 36 | "version" : "1.1.2" 37 | } 38 | }, 39 | { 40 | "identity" : "swift-http-types", 41 | "kind" : "remoteSourceControl", 42 | "location" : "https://github.com/apple/swift-http-types", 43 | "state" : { 44 | "revision" : "ae67c8178eb46944fd85e4dc6dd970e1f3ed6ccd", 45 | "version" : "1.3.0" 46 | } 47 | }, 48 | { 49 | "identity" : "swift-nio", 50 | "kind" : "remoteSourceControl", 51 | "location" : "https://github.com/apple/swift-nio.git", 52 | "state" : { 53 | "revision" : "4c4453b489cf76e6b3b0f300aba663eb78182fad", 54 | "version" : "2.70.0" 55 | } 56 | }, 57 | { 58 | "identity" : "swift-nio-extras", 59 | "kind" : "remoteSourceControl", 60 | "location" : "https://github.com/apple/swift-nio-extras.git", 61 | "state" : { 62 | "revision" : "d1ead62745cc3269e482f1c51f27608057174379", 63 | "version" : "1.24.0" 64 | } 65 | }, 66 | { 67 | "identity" : "swift-nio-http2", 68 | "kind" : "remoteSourceControl", 69 | "location" : "https://github.com/apple/swift-nio-http2.git", 70 | "state" : { 71 | "revision" : "b5f7062b60e4add1e8c343ba4eb8da2e324b3a94", 72 | "version" : "1.34.0" 73 | } 74 | }, 75 | { 76 | "identity" : "swift-nio-ssl", 77 | "kind" : "remoteSourceControl", 78 | "location" : "https://github.com/apple/swift-nio-ssl.git", 79 | "state" : { 80 | "revision" : "a9fa5efd86e7ce2e5c1b6de113262e58035ca251", 81 | "version" : "2.27.1" 82 | } 83 | }, 84 | { 85 | "identity" : "swift-nio-transport-services", 86 | "kind" : "remoteSourceControl", 87 | "location" : "https://github.com/apple/swift-nio-transport-services.git", 88 | "state" : { 89 | "revision" : "38ac8221dd20674682148d6451367f89c2652980", 90 | "version" : "1.21.0" 91 | } 92 | }, 93 | { 94 | "identity" : "swift-system", 95 | "kind" : "remoteSourceControl", 96 | "location" : "https://github.com/apple/swift-system.git", 97 | "state" : { 98 | "revision" : "d2ba781702a1d8285419c15ee62fd734a9437ff5", 99 | "version" : "1.3.2" 100 | } 101 | }, 102 | { 103 | "identity" : "synchronized", 104 | "kind" : "remoteSourceControl", 105 | "location" : "https://github.com/shareup/synchronized.git", 106 | "state" : { 107 | "revision" : "85653e23270ec88ae19f8d494157769487e34aed", 108 | "version" : "4.0.1" 109 | } 110 | }, 111 | { 112 | "identity" : "websocket-kit", 113 | "kind" : "remoteSourceControl", 114 | "location" : "https://github.com/vapor/websocket-kit.git", 115 | "state" : { 116 | "revision" : "4232d34efa49f633ba61afde365d3896fc7f8740", 117 | "version" : "2.15.0" 118 | } 119 | } 120 | ], 121 | "version" : 2 122 | } 123 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:5.9 2 | 3 | import PackageDescription 4 | 5 | let package = Package( 6 | name: "WebSocket", 7 | platforms: [ 8 | .macOS(.v12), .iOS(.v15), .tvOS(.v15), .watchOS(.v8), 9 | ], 10 | products: [ 11 | .library( 12 | name: "WebSocket", 13 | targets: ["WebSocket"] 14 | ), 15 | ], 16 | dependencies: [ 17 | .package( 18 | url: "https://github.com/shareup/async-extensions.git", 19 | from: "4.3.0" 20 | ), 21 | .package( 22 | url: "https://github.com/shareup/dispatch-timer.git", 23 | from: "3.0.0" 24 | ), 25 | .package( 26 | url: "https://github.com/vapor/websocket-kit.git", 27 | from: "2.15.0" 28 | ), 29 | .package( 30 | url: "https://github.com/apple/swift-nio.git", 31 | from: "2.62.0" 32 | ), 33 | ], 34 | targets: [ 35 | .target( 36 | name: "WebSocket", 37 | dependencies: [ 38 | .product(name: "AsyncExtensions", package: "async-extensions"), 39 | .product(name: "DispatchTimer", package: "dispatch-timer"), 40 | ] 41 | ), 42 | .testTarget( 43 | name: "WebSocketTests", 44 | dependencies: [ 45 | .product(name: "NIO", package: "swift-nio"), 46 | .product(name: "NIOHTTP1", package: "swift-nio"), 47 | .product(name: "NIOWebSocket", package: "swift-nio"), 48 | "WebSocket", 49 | .product(name: "WebSocketKit", package: "websocket-kit"), 50 | ] 51 | ), 52 | ] 53 | ) 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WebSocket wrapper around `URLSessionWebSocketTask` 2 | 3 | ## _(macOS, iOS, iPadOS, tvOS, and watchOS)_ 4 | 5 | A concrete implementation of a WebSocket client implemented by wrapping Apple's [`URLSessionWebSocketTask`](https://developer.apple.com/documentation/foundation/urlsessionwebsockettask). 6 | 7 | The public interface of `WebSocket` is a simple struct whose public methods are exposed as closures. The reason for this design is to make it easy to inject fake WebSockets into your code for testing purposes. 8 | 9 | The actual implementation is `SystemWebSocket`, but this type is not publicly accessible. Instead, you can access it via `WebSocket.system(url:)` or `WebSocket.system(request:)`. `SystemWebSocket` tries its best to mirror the documented behavior of web browsers' [`WebSocket`](http://developer.mozilla.org/en-US/docs/Web/API/WebSocket). Please report any deviations as bugs. 10 | 11 | `WebSocket` exposes a simple API and makes heavy use of [Swift Concurrency](https://developer.apple.com/documentation/swift/swift_standard_library/concurrency). 12 | 13 | ## Installation 14 | 15 | To use WebSocket, add a dependency to your Package.swift file: 16 | 17 | ```swift 18 | let package = Package( 19 | dependencies: [ 20 | .package( 21 | url: "https://github.com/shareup/websocket-apple.git", 22 | from: "4.0.0" 23 | ) 24 | ] 25 | ) 26 | ``` 27 | 28 | ## Usage 29 | 30 | ```swift 31 | // `WebSocket` starts connecting to the specified `URL` immediately. 32 | let socket = WebSocket.system(url: url(49999)) 33 | 34 | // Wait for `WebSocket` to be ready to send and receive messages. 35 | try await socket.open() 36 | 37 | // Send a message to the server 38 | try await socket.send(.text("hello")) 39 | 40 | // Receive messages from the server 41 | for await message in socket.messages { 42 | print(message) 43 | } 44 | 45 | try await socket.close() 46 | ``` 47 | 48 | ## Tests 49 | 50 | 1. In your Terminal, navigate to the `websocket-apple` directory 51 | 2. Run the tests using `swift test` 52 | 53 | ## Notices 54 | 55 | This library includes code from [WebSocketKit](https://github.com/vapor/websocket-kit) and [SwiftNIO](https://github.com/apple/swift-nio), the use of which depends on their licenses. 56 | -------------------------------------------------------------------------------- /Sources/WebSocket/OSLog+WebSocket.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import os.log 3 | 4 | extension OSLog { 5 | static let webSocket = OSLog(subsystem: subsystem, category: "websocket") 6 | } 7 | 8 | private let subsystem = 9 | Bundle.main.bundleIdentifier ?? "app.shareup.websocket-apple" 10 | -------------------------------------------------------------------------------- /Sources/WebSocket/SystemURLSession.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Synchronized 3 | 4 | func webSocketTask( 5 | for request: URLRequest, 6 | options: WebSocketOptions, 7 | onOpen: @escaping @Sendable () async -> Void, 8 | onClose: @escaping @Sendable (WebSocketCloseCode, Data?) async -> Void 9 | ) -> URLSessionWebSocketTask { 10 | let session = session(for: options) 11 | 12 | let task = session.webSocketTask(with: request) 13 | task.maximumMessageSize = options.maximumMessageSize 14 | 15 | let delegate = session.delegate as! Delegate 16 | delegate.set(onOpen: onOpen, onClose: onClose, for: ObjectIdentifier(task)) 17 | 18 | return task 19 | } 20 | 21 | func cancelAndInvalidateAllTasks() { 22 | sessions.access { sessions in 23 | sessions.forEach { $0.value.invalidateAndCancel() } 24 | sessions.removeAll() 25 | } 26 | } 27 | 28 | private let sessions = Locked<[WebSocketOptions: URLSession]>([:]) 29 | 30 | private func session(for options: WebSocketOptions) -> URLSession { 31 | sessions.access { sessions in 32 | if let session = sessions[options] { 33 | return session 34 | } else { 35 | let session = URLSession( 36 | configuration: configuration(with: options), 37 | delegate: Delegate(), 38 | delegateQueue: nil 39 | ) 40 | 41 | sessions[options] = session 42 | 43 | return session 44 | } 45 | } 46 | } 47 | 48 | private func configuration(with options: WebSocketOptions) -> URLSessionConfiguration { 49 | let config = URLSessionConfiguration.default 50 | config.waitsForConnectivity = false 51 | config.timeoutIntervalForRequest = options.timeoutIntervalForRequest 52 | config.timeoutIntervalForResource = options.timeoutIntervalForResource 53 | return config 54 | } 55 | 56 | private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { 57 | private struct Callbacks: Sendable { 58 | let onOpen: @Sendable () async -> Void 59 | let onClose: @Sendable (WebSocketCloseCode, Data?) async -> Void 60 | } 61 | 62 | // `Dictionary` 63 | private let state: Locked<[ObjectIdentifier: Callbacks]> = .init([:]) 64 | 65 | func set( 66 | onOpen: @escaping @Sendable () async -> Void, 67 | onClose: @escaping @Sendable (WebSocketCloseCode, Data?) async -> Void, 68 | for taskID: ObjectIdentifier 69 | ) { 70 | state.access { $0[taskID] = .init(onOpen: onOpen, onClose: onClose) } 71 | } 72 | 73 | func urlSession( 74 | _: URLSession, 75 | webSocketTask: URLSessionWebSocketTask, 76 | didOpenWithProtocol _: String? 77 | ) { 78 | let taskID = ObjectIdentifier(webSocketTask) 79 | 80 | if let onOpen = state.access({ $0[taskID]?.onOpen }) { 81 | Task { await onOpen() } 82 | } 83 | } 84 | 85 | func urlSession( 86 | _: URLSession, 87 | webSocketTask: URLSessionWebSocketTask, 88 | didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, 89 | reason: Data? 90 | ) { 91 | let taskID = ObjectIdentifier(webSocketTask) 92 | 93 | if let onClose = state.access({ $0[taskID]?.onClose }) { 94 | Task { await onClose(WebSocketCloseCode(closeCode), reason) } 95 | } 96 | } 97 | 98 | func urlSession( 99 | _: URLSession, 100 | task: URLSessionTask, 101 | didCompleteWithError error: Error? 102 | ) { 103 | let taskID = ObjectIdentifier(task) 104 | 105 | if let onClose = state.access({ $0[taskID]?.onClose }) { 106 | Task { [weak self] in 107 | if let error { 108 | await onClose( 109 | .abnormalClosure, 110 | Data(error.localizedDescription.utf8) 111 | ) 112 | } else { 113 | await onClose(.normalClosure, nil) 114 | } 115 | 116 | self?.state.access { _ = $0.removeValue(forKey: taskID) } 117 | } 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /Sources/WebSocket/SystemWebSocket.swift: -------------------------------------------------------------------------------- 1 | import AsyncExtensions 2 | @preconcurrency import Combine 3 | import Foundation 4 | import os.log 5 | import Synchronized 6 | 7 | private typealias OpenFuture = AsyncThrowingFuture 8 | 9 | private typealias CloseFuture = AsyncThrowingFuture< 10 | (code: WebSocketCloseCode, reason: Data?) 11 | > 12 | 13 | final actor SystemWebSocket: Publisher { 14 | typealias Output = WebSocketMessage 15 | typealias Failure = Never 16 | 17 | var isOpen: Bool { get async { 18 | guard case .open = state else { return false } 19 | return true 20 | } } 21 | 22 | var isClosed: Bool { get async { 23 | guard case .closed = state else { return false } 24 | return true 25 | } } 26 | 27 | nonisolated var url: URL { request.url! } 28 | nonisolated let request: URLRequest 29 | nonisolated let options: WebSocketOptions 30 | nonisolated let onOpen: WebSocketOnOpen 31 | nonisolated let onClose: WebSocketOnClose 32 | 33 | private var state: State = .unopened 34 | 35 | private var didOpen: OpenFuture 36 | private var didClose: CloseFuture? 37 | 38 | private var messageIndex = 0 // Used to identify sent messages 39 | 40 | private nonisolated let subject = PassthroughSubject() 41 | 42 | // Deliver messages to the subscribers on a separate queue because it's a bad idea 43 | // to let the subscribers, who could potentially be doing long-running tasks with the 44 | // data we send them, block our network queue. 45 | private let subscriberQueue = DispatchQueue( 46 | label: "app.shareup.websocket.subjectqueue", 47 | attributes: [], 48 | autoreleaseFrequency: .workItem, 49 | target: DispatchQueue.global(qos: .default) 50 | ) 51 | 52 | init( 53 | request: URLRequest, 54 | options: WebSocketOptions = .init(), 55 | onOpen: @escaping WebSocketOnOpen = {}, 56 | onClose: @escaping WebSocketOnClose = { _ in } 57 | ) async throws { 58 | self.request = request 59 | self.options = options 60 | self.onOpen = onOpen 61 | self.onClose = onClose 62 | 63 | didOpen = .init(timeout: options.timeoutIntervalForRequest) 64 | 65 | try connect() 66 | } 67 | 68 | deinit { 69 | didOpen.fail(CancellationError()) 70 | didClose?.fail(CancellationError()) 71 | state.ws?.cancel() 72 | subject.send(completion: .finished) 73 | } 74 | 75 | nonisolated func receive( 76 | subscriber: S 77 | ) where S.Input == WebSocketMessage, S.Failure == Never { 78 | subject 79 | .receive(on: subscriberQueue) 80 | .receive(subscriber: subscriber) 81 | } 82 | 83 | func open() async throws { 84 | switch state { 85 | case .unopened, .connecting: 86 | do { 87 | try await didOpen.value 88 | } catch is CancellationError { 89 | doClose(closeCode: .cancelled, reason: Data("cancelled".utf8)) 90 | } catch is TimeoutError { 91 | doClose(closeCode: .timeout, reason: Data("timeout".utf8)) 92 | throw TimeoutError() 93 | } catch let error as WebSocketError { 94 | doClose( 95 | closeCode: error.closeCode ?? .unknown, 96 | reason: error.reason 97 | ) 98 | 99 | throw error 100 | } catch { 101 | preconditionFailure("Invalid error: \(String(reflecting: error))") 102 | } 103 | 104 | case .open: 105 | return 106 | 107 | case .closed: 108 | throw WebSocketError(.alreadyClosed, nil) 109 | } 110 | } 111 | 112 | func send(_ message: WebSocketMessage) async throws { 113 | // Mirrors the document behavior of JavaScript's `WebSocket` 114 | // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send 115 | switch state { 116 | case let .open(ws): 117 | messageIndex += 1 118 | 119 | os_log( 120 | "send: index=%d message=%s", 121 | log: .webSocket, 122 | type: .debug, 123 | messageIndex, 124 | message.description 125 | ) 126 | 127 | try await ws.send(message.wsMessage) 128 | 129 | case .unopened, .connecting: 130 | os_log( 131 | "send message while connecting: %s", 132 | log: .webSocket, 133 | type: .error, 134 | message.description 135 | ) 136 | throw WebSocketError.sendMessageWhileConnecting 137 | 138 | case .closed: 139 | os_log( 140 | "send message while closed: %s", 141 | log: .webSocket, 142 | type: .debug, 143 | message.description 144 | ) 145 | } 146 | } 147 | 148 | func close( 149 | code: WebSocketCloseCode = .normalClosure, 150 | reason: Data? = nil, 151 | timeout: TimeInterval? = nil 152 | ) async throws { 153 | switch state { 154 | case .unopened: 155 | doClose(closeCode: code, reason: reason) 156 | 157 | case .connecting, .open: 158 | if let didClose { 159 | _ = try await didClose.value 160 | } else { 161 | let didClose = CloseFuture( 162 | timeout: timeout ?? options.timeoutIntervalForRequest 163 | ) 164 | self.didClose = didClose 165 | doClose(closeCode: code, reason: reason) 166 | _ = try await didClose.value 167 | } 168 | 169 | case .closed: 170 | doClose(closeCode: code, reason: reason) 171 | } 172 | } 173 | } 174 | 175 | private extension SystemWebSocket { 176 | var isUnopened: Bool { 177 | guard case .unopened = state else { return false } 178 | return true 179 | } 180 | 181 | func connect() throws { 182 | precondition(isUnopened) 183 | let task = webSocketTask( 184 | for: request, 185 | options: options, 186 | onOpen: { [weak self] in await self?.doOpen() }, 187 | onClose: { [weak self] closeCode, reason async in 188 | await self?.doClose(closeCode: closeCode, reason: reason) 189 | } 190 | ) 191 | state = .connecting(task) 192 | task.resume() 193 | } 194 | 195 | func doOpen() { 196 | switch state { 197 | case let .connecting(ws): 198 | os_log("open", log: .webSocket, type: .debug) 199 | state = .open(ws) 200 | onOpen() 201 | didOpen.resolve() 202 | doReceiveMessage(ws) 203 | 204 | case .unopened: 205 | os_log("received open before connecting", log: .webSocket, type: .error) 206 | preconditionFailure("Cannot receive open before trying to connect") 207 | 208 | case .open: 209 | // Ignore this because there might be multiple consumers 210 | // waiting on `.open(timeout:)` to return. 211 | break 212 | 213 | case .closed: 214 | os_log( 215 | "trying to open already-closed connection", 216 | log: .webSocket, 217 | type: .error 218 | ) 219 | doClose(closeCode: .alreadyClosed, reason: nil) 220 | } 221 | } 222 | 223 | func doReceiveMessage(_ ws: URLSessionWebSocketTask) { 224 | guard ws.closeCode == .invalid, !Task.isCancelled else { return } 225 | 226 | ws.receive { [weak self] (result: Result) in 227 | guard let self, ws.closeCode == .invalid, !Task.isCancelled else { return } 228 | 229 | switch result { 230 | case let .success(msg): 231 | let message = WebSocketMessage(msg) 232 | os_log( 233 | "receive: message=%s", 234 | log: .webSocket, 235 | type: .debug, 236 | message.description 237 | ) 238 | subject.send(message) 239 | Task { [weak self] in await self?.doReceiveMessage(ws) } 240 | 241 | case let .failure(error): 242 | Task { [weak self] in 243 | await self?.doClose( 244 | closeCode: .abnormalClosure, 245 | reason: Data(error.localizedDescription.utf8) 246 | ) 247 | } 248 | } 249 | } 250 | } 251 | 252 | func doClose(closeCode: WebSocketCloseCode, reason: Data?) { 253 | switch state { 254 | case .unopened: 255 | state = .closed(.init(closeCode, reason)) 256 | 257 | case let .connecting(ws), let .open(ws): 258 | os_log( 259 | "close: code=%{public}s", 260 | log: .webSocket, 261 | type: .debug, 262 | closeCode.description 263 | ) 264 | 265 | // When the task is not yet closed, this value is `.invalid`. 266 | if ws.closeCode == .invalid { 267 | if let code = closeCode.wsCloseCode { 268 | ws.cancel(with: code, reason: reason) 269 | } else { 270 | ws.cancel() 271 | } 272 | } 273 | 274 | let close = WebSocketClose(closeCode, nil) 275 | state = .closed(close) 276 | onClose(close) 277 | didClose?.resolve((code: closeCode, reason: reason)) 278 | subject.send(completion: .finished) 279 | 280 | case .closed: 281 | break 282 | } 283 | } 284 | } 285 | 286 | private extension SystemWebSocket { 287 | enum State: CustomStringConvertible, CustomDebugStringConvertible { 288 | case unopened 289 | case connecting(URLSessionWebSocketTask) 290 | case open(URLSessionWebSocketTask) 291 | case closed(WebSocketClose) 292 | 293 | var ws: URLSessionWebSocketTask? { 294 | switch self { 295 | case let .connecting(ws), let .open(ws): 296 | ws 297 | 298 | case .unopened, .closed: 299 | nil 300 | } 301 | } 302 | 303 | var description: String { 304 | switch self { 305 | case .unopened: "unopened" 306 | case .connecting: "connecting" 307 | case .open: "open" 308 | case .closed: "closed" 309 | } 310 | } 311 | 312 | var debugDescription: String { 313 | switch self { 314 | case .unopened: "unopened" 315 | case let .connecting(ws): "connecting(\(String(reflecting: ws)))" 316 | case let .open(ws): "open(\(String(reflecting: ws)))" 317 | case let .closed(error): "closed(\(error.description))" 318 | } 319 | } 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocket.swift: -------------------------------------------------------------------------------- 1 | import Combine 2 | import Foundation 3 | import Synchronized 4 | 5 | public typealias WebSocketOnOpen = @Sendable () -> Void 6 | public typealias WebSocketOnClose = @Sendable (WebSocketClose) 7 | -> Void 8 | 9 | public struct WebSocket: Identifiable, Sendable { 10 | public var id: Int 11 | 12 | /// Sets a closure to be called when the WebSocket connects successfully. 13 | public var onOpen: WebSocketOnOpen 14 | 15 | /// Sets a closure to be called when the WebSocket closes. 16 | public var onClose: WebSocketOnClose 17 | 18 | /// Opens the WebSocket connection. After this function returns, 19 | /// the WebSocket connection is open ready to be used. If the 20 | /// connection fails or times out, an error is thrown. 21 | public var open: @Sendable () async throws -> Void 22 | 23 | /// Sends a close frame to the server with the given close code. 24 | public var close: @Sendable (WebSocketCloseCode, TimeInterval?) async throws -> Void 25 | 26 | /// Invalidates **all** WebSocket connections. It should only be used 27 | /// when all WebSocket connections in the current process need to be 28 | /// cancelled. 29 | public var invalidateAll: @Sendable () -> Void 30 | 31 | /// Sends a text or binary message. 32 | public var send: @Sendable (WebSocketMessage) async throws -> Void 33 | 34 | /// Publishes messages received from WebSocket. Finishes when the 35 | /// WebSocket connection closes. 36 | public var messagesPublisher: @Sendable () 37 | -> AnyPublisher 38 | 39 | public init( 40 | id: Int, 41 | onOpen: @escaping WebSocketOnOpen = {}, 42 | onClose: @escaping WebSocketOnClose = { _ in }, 43 | open: @escaping @Sendable () async throws -> Void = {}, 44 | close: @escaping @Sendable (WebSocketCloseCode, TimeInterval?) async throws 45 | -> Void = { _, _ in }, 46 | invalidateAll: @escaping @Sendable () -> Void = {}, 47 | send: @escaping @Sendable (WebSocketMessage) async throws -> Void = { _ in }, 48 | messagesPublisher: @escaping @Sendable () -> AnyPublisher = { 49 | Empty(completeImmediately: false).eraseToAnyPublisher() 50 | } 51 | ) { 52 | self.id = id 53 | self.onOpen = onOpen 54 | self.onClose = onClose 55 | self.open = open 56 | self.close = close 57 | self.invalidateAll = invalidateAll 58 | self.send = send 59 | self.messagesPublisher = messagesPublisher 60 | } 61 | } 62 | 63 | public extension WebSocket { 64 | /// Calls `WebSocket.close(.normalClosure, nil)`. 65 | func close() async throws { 66 | try await close(.normalClosure, nil) 67 | } 68 | 69 | /// Calls `WebSocket.close(.normalClosure, timeout)`. 70 | func close(timeout: TimeInterval) async throws { 71 | try await close(.normalClosure, timeout) 72 | } 73 | 74 | /// The WebSocket's received messages as an asynchronous stream. 75 | var messages: AsyncStream { 76 | let cancellable = Locked(nil) 77 | 78 | return AsyncStream { cont in 79 | func finish() { 80 | cancellable.access { cancellable in 81 | if cancellable != nil { 82 | cont.finish() 83 | cancellable = nil 84 | } 85 | } 86 | } 87 | 88 | let _cancellable = self.messagesPublisher() 89 | .handleEvents(receiveCancel: { finish() }) 90 | .sink( 91 | receiveCompletion: { _ in finish() }, 92 | receiveValue: { cont.yield($0) } 93 | ) 94 | 95 | cancellable.access { $0 = _cancellable } 96 | } 97 | } 98 | } 99 | 100 | public extension WebSocket { 101 | /// System WebSocket implementation powered by `URLSessionWebSocketTask`. 102 | static func system( 103 | url: URL, 104 | options: WebSocketOptions = .init(), 105 | onOpen: @escaping WebSocketOnOpen = {}, 106 | onClose: @escaping WebSocketOnClose = { _ in } 107 | ) async throws -> Self { 108 | try await system( 109 | request: URLRequest(url: url), 110 | options: options, 111 | onOpen: onOpen, 112 | onClose: onClose 113 | ) 114 | } 115 | 116 | /// System WebSocket implementation powered by `URLSessionWebSocketTask`. 117 | static func system( 118 | request: URLRequest, 119 | options: WebSocketOptions = .init(), 120 | onOpen: @escaping WebSocketOnOpen = {}, 121 | onClose: @escaping WebSocketOnClose = { _ in } 122 | ) async throws -> Self { 123 | let ws = try await SystemWebSocket( 124 | request: request, 125 | options: options, 126 | onOpen: onOpen, 127 | onClose: onClose 128 | ) 129 | return try await .system(ws) 130 | } 131 | 132 | // This is only intended for use in tests. 133 | internal static func system(_ ws: SystemWebSocket) async throws -> Self { 134 | Self( 135 | id: Int(bitPattern: ObjectIdentifier(ws)), 136 | onOpen: ws.onOpen, 137 | onClose: ws.onClose, 138 | open: { try await ws.open() }, 139 | close: { code, timeout in try await ws.close(code: code, timeout: timeout) }, 140 | invalidateAll: { cancelAndInvalidateAllTasks() }, 141 | send: { message in try await ws.send(message) }, 142 | messagesPublisher: { ws.eraseToAnyPublisher() } 143 | ) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocketClose.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | public struct WebSocketClose: Hashable, CustomStringConvertible, Sendable { 4 | public let code: WebSocketCloseCode 5 | public let reason: Data? 6 | 7 | public init(_ code: WebSocketCloseCode, _ reason: Data?) { 8 | self.code = code 9 | self.reason = reason 10 | } 11 | 12 | public var description: String { "\(code.description)" } 13 | } 14 | 15 | public extension WebSocketClose { 16 | var isNormal: Bool { 17 | switch code { 18 | case .normalClosure: true 19 | default: false 20 | } 21 | } 22 | 23 | var isCancelled: Bool { 24 | switch code { 25 | case .cancelled: true 26 | default: false 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocketCloseCode.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Network 3 | 4 | /// A code indicating why a WebSocket connection closed. 5 | /// 6 | /// Mirrors [URLSessionWebSocketTask](https://developer.apple.com/documentation/foundation/urlsessionwebsockettask/closecode). 7 | public enum WebSocketCloseCode: Int, CaseIterable, Sendable { 8 | /// A code that indicates the connection is still open. 9 | case invalid = 0 10 | 11 | /// A code that indicates normal connection closure. 12 | case normalClosure = 1000 13 | 14 | /// A code that indicates an endpoint is going away. 15 | case goingAway = 1001 16 | 17 | /// A code that indicates an endpoint terminated the connection due to a 18 | /// protocol error. 19 | case protocolError = 1002 20 | 21 | /// A code that indicates an endpoint terminated the connection after 22 | /// receiving a type of data it can’t accept. 23 | case unsupportedData = 1003 24 | 25 | /// A reserved code that indicates an endpoint expected a status code and 26 | /// didn’t receive one. 27 | case noStatusReceived = 1005 28 | 29 | /// A reserved code that indicates the connection closed without a close 30 | /// control frame. 31 | case abnormalClosure = 1006 32 | 33 | /// A code that indicates the server terminated the connection because it 34 | /// received data inconsistent with the message’s type. 35 | case invalidFramePayloadData = 1007 36 | 37 | /// A code that indicates an endpoint terminated the connection because it 38 | /// received a message that violates its policy. 39 | case policyViolation = 1008 40 | 41 | /// A code that indicates an endpoint is terminating the connection because 42 | /// it received a message too big for it to process. 43 | case messageTooBig = 1009 44 | 45 | /// A code that indicates the client terminated the connection because the 46 | /// server didn’t negotiate a required extension. 47 | case mandatoryExtensionMissing = 1010 48 | 49 | /// A code that indicates the server terminated the connection because it 50 | /// encountered an unexpected condition. 51 | case internalServerError = 1011 52 | 53 | /// A reserved code that indicates the connection closed due to the failure 54 | /// to perform a TLS handshake. 55 | case tlsHandshakeFailure = 1015 56 | 57 | // NOTE: Status codes in the range 4000-4999 are reserved for private use 58 | // and thus can't be registered. Such codes can be used by prior 59 | // agreements between WebSocket applications. The interpretation of 60 | // these codes is undefined by this protocol. 61 | // 62 | // https://www.rfc-editor.org/rfc/rfc6455#section-7.4.1 63 | 64 | /// A code that indicates the connection closed because it was cancelled by 65 | /// the client. 66 | case cancelled = 4000 67 | 68 | /// A code that indicates the connection failed to open because it had 69 | /// already been closed. 70 | case alreadyClosed = 4001 71 | 72 | /// A code that indicates the connection timed out while opening. 73 | case timeout = 4002 74 | 75 | /// A code that indicates the connection closed because of an unknown reason. 76 | case unknown = 4999 77 | } 78 | 79 | extension WebSocketCloseCode: CustomStringConvertible { 80 | public var description: String { 81 | switch self { 82 | case .invalid: "invalid" 83 | case .normalClosure: "normalClosure" 84 | case .goingAway: "goingAway" 85 | case .protocolError: "protocolError" 86 | case .unsupportedData: "unsupportedData" 87 | case .noStatusReceived: "noStatusReceived" 88 | case .abnormalClosure: "abnormalClosure" 89 | case .invalidFramePayloadData: "invalidFramePayloadData" 90 | case .policyViolation: "policyViolation" 91 | case .messageTooBig: "messageTooBig" 92 | case .mandatoryExtensionMissing: "mandatoryExtensionMissing" 93 | case .internalServerError: "internalServerError" 94 | case .tlsHandshakeFailure: "tlsHandshakeFailure" 95 | case .cancelled: "cancelled" 96 | case .alreadyClosed: "alreadyClosed" 97 | case .timeout: "timeout" 98 | case .unknown: "unknown" 99 | } 100 | } 101 | } 102 | 103 | extension WebSocketCloseCode { 104 | init(_ code: URLSessionWebSocketTask.CloseCode) { 105 | switch code { 106 | case .invalid: 107 | self = .invalid 108 | case .normalClosure: 109 | self = .normalClosure 110 | case .goingAway: 111 | self = .goingAway 112 | case .protocolError: 113 | self = .protocolError 114 | case .unsupportedData: 115 | self = .unsupportedData 116 | case .noStatusReceived: 117 | self = .noStatusReceived 118 | case .abnormalClosure: 119 | self = .abnormalClosure 120 | case .invalidFramePayloadData: 121 | self = .invalidFramePayloadData 122 | case .policyViolation: 123 | self = .policyViolation 124 | case .messageTooBig: 125 | self = .messageTooBig 126 | case .mandatoryExtensionMissing: 127 | self = .mandatoryExtensionMissing 128 | case .internalServerError: 129 | self = .internalServerError 130 | case .tlsHandshakeFailure: 131 | self = .tlsHandshakeFailure 132 | @unknown default: 133 | self = .unknown 134 | } 135 | } 136 | 137 | var wsCloseCode: URLSessionWebSocketTask.CloseCode? { 138 | switch self { 139 | case .invalid: .invalid 140 | case .normalClosure: .normalClosure 141 | case .goingAway: .goingAway 142 | case .protocolError: .protocolError 143 | case .unsupportedData: .unsupportedData 144 | case .noStatusReceived: .noStatusReceived 145 | case .abnormalClosure: .abnormalClosure 146 | case .invalidFramePayloadData: .invalidFramePayloadData 147 | case .policyViolation: .policyViolation 148 | case .messageTooBig: .messageTooBig 149 | case .mandatoryExtensionMissing: .mandatoryExtensionMissing 150 | case .internalServerError: .internalServerError 151 | case .tlsHandshakeFailure: .tlsHandshakeFailure 152 | case .cancelled: nil 153 | case .alreadyClosed: nil 154 | case .timeout: nil 155 | case .unknown: nil 156 | } 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocketError.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Network 3 | 4 | public enum WebSocketError: Error, Equatable { 5 | case closeCodeAndReason(WebSocketCloseCode, Data?) 6 | case invalidURL(URL) 7 | case sendMessageWhileConnecting 8 | 9 | init(_ closeCode: WebSocketCloseCode?, _ reason: Data?) { 10 | self = .closeCodeAndReason( 11 | closeCode ?? .unknown, 12 | reason 13 | ) 14 | } 15 | 16 | var closeCode: WebSocketCloseCode? { 17 | guard case let .closeCodeAndReason(code, _) = self 18 | else { return nil } 19 | return code 20 | } 21 | 22 | var reason: Data? { 23 | guard case let .closeCodeAndReason(_, reason) = self 24 | else { return nil } 25 | return reason 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocketMessage.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Network 3 | 4 | /// An enumeration of the types of messages that can be sent or received. 5 | public enum WebSocketMessage: CustomStringConvertible, Hashable, Sendable { 6 | /// A WebSocket message that contains a block of data. 7 | case data(Data) 8 | 9 | /// A WebSocket message that contains a UTF-8 formatted string. 10 | case text(String) 11 | 12 | public var description: String { 13 | switch self { 14 | case let .data(data): String(decoding: data.prefix(100), as: UTF8.self) 15 | case let .text(text): text 16 | } 17 | } 18 | } 19 | 20 | public extension WebSocketMessage { 21 | var stringValue: String? { 22 | switch self { 23 | case let .data(data): 24 | String(data: data, encoding: .utf8) 25 | 26 | case let .text(text): 27 | text 28 | } 29 | } 30 | } 31 | 32 | extension WebSocketMessage { 33 | init(_ message: URLSessionWebSocketTask.Message) { 34 | switch message { 35 | case let .data(data): 36 | self = .data(data) 37 | 38 | case let .string(text): 39 | self = .text(text) 40 | 41 | @unknown default: 42 | fatalError("Unhandled message: \(message)") 43 | } 44 | } 45 | 46 | var wsMessage: URLSessionWebSocketTask.Message { 47 | switch self { 48 | case let .data(data): .data(data) 49 | case let .text(text): .string(text) 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /Sources/WebSocket/WebSocketOptions.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | public struct WebSocketOptions: Hashable, Sendable { 4 | public var maximumMessageSize: Int 5 | public var timeoutIntervalForRequest: TimeInterval 6 | public var timeoutIntervalForResource: TimeInterval 7 | 8 | public init( 9 | maximumMessageSize: Int = 1024 * 1024, // 1 MiB 10 | timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds 11 | timeoutIntervalForResource: TimeInterval = 604_800 // 7 days 12 | ) { 13 | self.maximumMessageSize = maximumMessageSize 14 | self.timeoutIntervalForRequest = timeoutIntervalForRequest 15 | self.timeoutIntervalForResource = timeoutIntervalForResource 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /Tests/LinuxMain.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | import WebSocketTests 4 | 5 | var tests = [XCTestCaseEntry]() 6 | tests += WebSocketTests.allTests() 7 | XCTMain(tests) 8 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/Server/ServerBootstrap+WebSocketTests.swift: -------------------------------------------------------------------------------- 1 | // MARK: - WebSocketKit 2 | 3 | // 4 | // The following is borrowed from WebSocketKit. 5 | // https://github.com/vapor/websocket-kit/blob/main/Tests/WebSocketKitTests/Utilities.swift 6 | // 7 | 8 | import Atomics 9 | import NIO 10 | import NIOExtras 11 | import NIOHTTP1 12 | import NIOSSL 13 | import NIOWebSocket 14 | import WebSocketKit 15 | import XCTest 16 | 17 | extension ServerBootstrap { 18 | static func webSocket( 19 | on eventLoopGroup: EventLoopGroup, 20 | onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> Void 21 | ) -> ServerBootstrap { 22 | ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in 23 | let webSocket = NIOWebSocketServerUpgrader( 24 | shouldUpgrade: { channel, _ in 25 | channel.eventLoop.makeSucceededFuture([:]) 26 | }, 27 | upgradePipelineHandler: { channel, req in 28 | WebSocket.server(on: channel) { ws in 29 | onUpgrade(req, ws) 30 | } 31 | } 32 | ) 33 | return channel.pipeline.configureHTTPServerPipeline( 34 | withServerUpgrade: ( 35 | upgraders: [webSocket], 36 | completionHandler: { _ in 37 | // complete 38 | } 39 | ) 40 | ) 41 | } 42 | } 43 | } 44 | 45 | final class WebsocketBin { 46 | enum BindTarget { 47 | case unixDomainSocket(String) 48 | case localhostIPv4RandomPort 49 | case localhostIPv6RandomPort 50 | } 51 | 52 | enum Mode { 53 | // refuses all connections 54 | case refuse 55 | // supports http1.1 connections only, which can be either plain text or encrypted 56 | case http1_1(ssl: Bool = false) 57 | } 58 | 59 | enum Proxy { 60 | case none 61 | case simulate(config: ProxyConfig, authorization: String?) 62 | } 63 | 64 | struct ProxyConfig { 65 | var tls: Bool 66 | let headVerification: (ChannelHandlerContext, HTTPRequestHead) -> Void 67 | } 68 | 69 | let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) 70 | 71 | var port: Int { 72 | Int(serverChannel.localAddress!.port!) 73 | } 74 | 75 | private let mode: Mode 76 | private let sslContext: NIOSSLContext? 77 | private var serverChannel: Channel! 78 | private let isShutdown = ManagedAtomic(false) 79 | 80 | init( 81 | _ mode: Mode = .http1_1(ssl: false), 82 | proxy: Proxy = .none, 83 | bindTarget: BindTarget = .localhostIPv4RandomPort, 84 | sslContext: NIOSSLContext?, 85 | onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> Void 86 | ) { 87 | self.mode = mode 88 | self.sslContext = sslContext 89 | 90 | let socketAddress: SocketAddress = switch bindTarget { 91 | case .localhostIPv4RandomPort: 92 | try! SocketAddress(ipAddress: "127.0.0.1", port: 0) 93 | case .localhostIPv6RandomPort: 94 | try! SocketAddress(ipAddress: "::1", port: 0) 95 | case let .unixDomainSocket(path): 96 | try! SocketAddress(unixDomainSocketPath: path) 97 | } 98 | 99 | serverChannel = try! ServerBootstrap(group: group) 100 | .serverChannelOption( 101 | ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), 102 | value: 1 103 | ) 104 | .childChannelInitializer { channel in 105 | do { 106 | if case .refuse = mode { 107 | throw HTTPBinError.refusedConnection 108 | } 109 | 110 | let webSocket = NIOWebSocketServerUpgrader( 111 | shouldUpgrade: { channel, _ in 112 | channel.eventLoop.makeSucceededFuture([:]) 113 | }, 114 | upgradePipelineHandler: { channel, req in 115 | WebSocket.server(on: channel) { ws in 116 | onUpgrade(req, ws) 117 | } 118 | } 119 | ) 120 | 121 | // if we need to simulate a proxy, we need to add those handlers first 122 | if case let .simulate( 123 | config: config, 124 | authorization: expectedAuthorization 125 | ) = proxy { 126 | if config.tls { 127 | try self.syncAddTLSHTTPProxyHandlers( 128 | to: channel, 129 | proxyConfig: config, 130 | expectedAuthorization: expectedAuthorization, 131 | upgraders: [webSocket] 132 | ) 133 | } else { 134 | try self.syncAddHTTPProxyHandlers( 135 | to: channel, 136 | proxyConfig: config, 137 | expectedAuthorization: expectedAuthorization, 138 | upgraders: [webSocket] 139 | ) 140 | } 141 | return channel.eventLoop.makeSucceededVoidFuture() 142 | } 143 | 144 | // if a connection has been established, we need to negotiate TLS before 145 | // anything else. Depending on the negotiation, the HTTPHandlers will be 146 | // added. 147 | if let sslContext = self.sslContext { 148 | try channel.pipeline.syncOperations 149 | .addHandler(NIOSSLServerHandler(context: sslContext)) 150 | } 151 | 152 | // if neither HTTP Proxy nor TLS are wanted, we can add HTTP1 handlers 153 | // directly 154 | try channel.pipeline.syncOperations.configureHTTPServerPipeline( 155 | withPipeliningAssistance: true, 156 | withServerUpgrade: ( 157 | upgraders: [webSocket], 158 | completionHandler: { _ in 159 | // complete 160 | } 161 | ), 162 | withErrorHandling: true 163 | ) 164 | return channel.eventLoop.makeSucceededVoidFuture() 165 | } catch { 166 | return channel.eventLoop.makeFailedFuture(error) 167 | } 168 | }.bind(to: socketAddress).wait() 169 | } 170 | 171 | // In the TLS case we must set up the 'proxy' and the 'server' handlers sequentially 172 | // rather than re-using parts because the requestDecoder stops parsing after a CONNECT 173 | // request 174 | private func syncAddTLSHTTPProxyHandlers( 175 | to channel: Channel, 176 | proxyConfig: ProxyConfig, 177 | expectedAuthorization: String?, 178 | upgraders: [HTTPServerProtocolUpgrader] 179 | ) throws { 180 | let sync = channel.pipeline.syncOperations 181 | let promise = channel.eventLoop.makePromise(of: Void.self) 182 | 183 | let responseEncoder = HTTPResponseEncoder() 184 | let requestDecoder = 185 | ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) 186 | let proxySimulator = HTTPProxySimulator( 187 | promise: promise, 188 | config: proxyConfig, 189 | expectedAuthorization: expectedAuthorization 190 | ) 191 | 192 | try sync.addHandler(responseEncoder) 193 | try sync.addHandler(requestDecoder) 194 | 195 | try sync.addHandler(proxySimulator) 196 | 197 | promise.futureResult.flatMap { _ in 198 | channel.pipeline.removeHandler(proxySimulator) 199 | }.flatMap { _ in 200 | channel.pipeline.removeHandler(responseEncoder) 201 | }.flatMap { _ in 202 | channel.pipeline.removeHandler(requestDecoder) 203 | }.whenComplete { result in 204 | switch result { 205 | case .failure: 206 | channel.close(mode: .all, promise: nil) 207 | case .success: 208 | self.httpProxyEstablished(channel, upgraders: upgraders) 209 | } 210 | } 211 | } 212 | 213 | // In the plain-text case we must set up the 'proxy' and the 'server' handlers 214 | // simultaneously 215 | // so that the combined proxy/upgrade request can be processed by the separate proxy and 216 | // upgrade handlers 217 | private func syncAddHTTPProxyHandlers( 218 | to channel: Channel, 219 | proxyConfig: ProxyConfig, 220 | expectedAuthorization: String?, 221 | upgraders: [HTTPServerProtocolUpgrader] 222 | ) throws { 223 | let sync = channel.pipeline.syncOperations 224 | let promise = channel.eventLoop.makePromise(of: Void.self) 225 | 226 | let responseEncoder = HTTPResponseEncoder() 227 | let requestDecoder = 228 | ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) 229 | let proxySimulator = HTTPProxySimulator( 230 | promise: promise, 231 | config: proxyConfig, 232 | expectedAuthorization: expectedAuthorization 233 | ) 234 | 235 | let serverPipelineHandler = HTTPServerPipelineHandler() 236 | let serverProtocolErrorHandler = HTTPServerProtocolErrorHandler() 237 | 238 | let extraHTTPHandlers: [RemovableChannelHandler] = [ 239 | requestDecoder, 240 | serverPipelineHandler, 241 | serverProtocolErrorHandler, 242 | ] 243 | 244 | try sync.addHandler(responseEncoder) 245 | try sync.addHandler(requestDecoder) 246 | 247 | try sync.addHandler(proxySimulator) 248 | 249 | try sync.addHandler(serverPipelineHandler) 250 | try sync.addHandler(serverProtocolErrorHandler) 251 | 252 | let upgrader = HTTPServerUpgradeHandler( 253 | upgraders: upgraders, 254 | httpEncoder: responseEncoder, 255 | extraHTTPHandlers: extraHTTPHandlers, 256 | upgradeCompletionHandler: { _ in 257 | // complete 258 | } 259 | ) 260 | 261 | try sync.addHandler(upgrader) 262 | 263 | promise.futureResult.flatMap { () -> EventLoopFuture in 264 | channel.pipeline.removeHandler(proxySimulator) 265 | }.whenComplete { result in 266 | switch result { 267 | case .failure: 268 | channel.close(mode: .all, promise: nil) 269 | case .success: 270 | break 271 | } 272 | } 273 | } 274 | 275 | private func httpProxyEstablished( 276 | _ channel: Channel, 277 | upgraders: [HTTPServerProtocolUpgrader] 278 | ) { 279 | do { 280 | // if a connection has been established, we need to negotiate TLS before 281 | // anything else. Depending on the negotiation, the HTTPHandlers will be added. 282 | if let sslContext { 283 | try channel.pipeline.syncOperations 284 | .addHandler(NIOSSLServerHandler(context: sslContext)) 285 | } 286 | 287 | try channel.pipeline.syncOperations.configureHTTPServerPipeline( 288 | withPipeliningAssistance: true, 289 | withServerUpgrade: ( 290 | upgraders: upgraders, 291 | completionHandler: { _ in 292 | // complete 293 | } 294 | ), 295 | withErrorHandling: true 296 | ) 297 | } catch { 298 | // in case of an while modifying the pipeline we should close the connection 299 | channel.close(mode: .all, promise: nil) 300 | } 301 | } 302 | 303 | func shutdown() throws { 304 | isShutdown.store(true, ordering: .relaxed) 305 | try group.syncShutdownGracefully() 306 | } 307 | } 308 | 309 | enum HTTPBinError: Error { 310 | case refusedConnection 311 | case invalidProxyRequest 312 | } 313 | 314 | final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { 315 | typealias InboundIn = HTTPServerRequestPart 316 | typealias InboundOut = HTTPServerResponsePart 317 | typealias OutboundOut = HTTPServerResponsePart 318 | 319 | // the promise to succeed, once the proxy connection is setup 320 | let promise: EventLoopPromise 321 | let config: WebsocketBin.ProxyConfig 322 | let expectedAuthorization: String? 323 | 324 | var head: HTTPResponseHead 325 | 326 | init( 327 | promise: EventLoopPromise, 328 | config: WebsocketBin.ProxyConfig, 329 | expectedAuthorization: String? 330 | ) { 331 | self.promise = promise 332 | self.config = config 333 | self.expectedAuthorization = expectedAuthorization 334 | head = HTTPResponseHead( 335 | version: .init(major: 1, minor: 1), 336 | status: .ok, 337 | headers: .init([("Content-Length", "0")]) 338 | ) 339 | } 340 | 341 | func channelRead(context: ChannelHandlerContext, data: NIOAny) { 342 | let request = unwrapInboundIn(data) 343 | switch request { 344 | case let .head(head): 345 | if config.tls { 346 | guard head.method == .CONNECT else { 347 | self.head.status = .badRequest 348 | return 349 | } 350 | } else { 351 | guard head.method == .GET else { 352 | self.head.status = .badRequest 353 | return 354 | } 355 | } 356 | 357 | config.headVerification(context, head) 358 | 359 | if let expectedAuthorization { 360 | guard let authorization = head.headers["proxy-authorization"].first, 361 | expectedAuthorization == authorization 362 | else { 363 | self.head.status = .proxyAuthenticationRequired 364 | return 365 | } 366 | } 367 | if !config.tls { 368 | context.fireChannelRead(data) 369 | } 370 | 371 | case .body: 372 | () 373 | 374 | case .end: 375 | if config.tls { 376 | context.write(wrapOutboundOut(.head(head)), promise: nil) 377 | context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil) 378 | } 379 | if head.status == .ok { 380 | if !config.tls { 381 | context.fireChannelRead(data) 382 | } 383 | promise.succeed(()) 384 | } else { 385 | promise.fail(HTTPBinError.invalidProxyRequest) 386 | } 387 | } 388 | } 389 | } 390 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/Server/WebSocketServer.swift: -------------------------------------------------------------------------------- 1 | import Combine 2 | import Foundation 3 | import NIO 4 | import WebSocket 5 | import WebSocketKit 6 | 7 | enum WebSocketServerOutput: Hashable { 8 | case message(WebSocketMessage) 9 | case remoteClose 10 | } 11 | 12 | final class WebSocketServer { 13 | var port: Int { channel!.localAddress!.port! } 14 | 15 | let maximumMessageSize: Int 16 | 17 | // Publisher provided by consumers of `WebSocketServer` to provide the output 18 | // `WebSocketServer` should send to its clients. 19 | private let outputPublisher: AnyPublisher 20 | private var outputPublisherSubscription: AnyCancellable? 21 | 22 | // Publisher that repeats everything sent to it by clients. 23 | private let inputSubject = PassthroughSubject() 24 | 25 | private let eventLoopGroup: EventLoopGroup 26 | private var channel: Channel? 27 | 28 | init( 29 | outputPublisher: P, 30 | maximumMessageSize: Int = 1024 * 1024 31 | ) throws where P.Output == WebSocketServerOutput, P.Failure == Error { 32 | self.outputPublisher = outputPublisher.eraseToAnyPublisher() 33 | self.maximumMessageSize = maximumMessageSize 34 | 35 | eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) 36 | 37 | channel = try ServerBootstrap 38 | .webSocket(on: eventLoopGroup) { [weak self] _, ws in 39 | guard let self else { return } 40 | subscribeToOutputPublisher(ws) 41 | 42 | ws.onText { [weak self] _, text in 43 | self?.inputSubject.send(.text(text)) 44 | } 45 | ws.onBinary { [weak self] _, binary in 46 | var binary = binary 47 | guard let data = binary.readData( 48 | length: binary.readableBytes, 49 | byteTransferStrategy: .copy 50 | ) else { return } 51 | self?.inputSubject.send(.data(data)) 52 | } 53 | }.bind(host: "127.0.0.1", port: 0).wait() 54 | } 55 | 56 | private func subscribeToOutputPublisher(_ ws: WebSocketKit.WebSocket) { 57 | outputPublisherSubscription = outputPublisher 58 | .sink( 59 | receiveCompletion: { completion in 60 | switch completion { 61 | case .finished: 62 | _ = ws.close(code:) 63 | 64 | case .failure: 65 | _ = ws.close(code: .unexpectedServerError) 66 | } 67 | }, 68 | receiveValue: { output in 69 | switch output { 70 | case .remoteClose: 71 | do { try ws.close(code: .goingAway).wait() } 72 | catch {} 73 | 74 | case let .message(message): 75 | switch message { 76 | case let .data(data): 77 | ws.send(raw: data, opcode: .binary) 78 | 79 | case let .text(text): 80 | ws.send(text) 81 | } 82 | } 83 | } 84 | ) 85 | } 86 | 87 | func shutDown() { 88 | try? channel?.close(mode: .all).wait() 89 | try? eventLoopGroup.syncShutdownGracefully() 90 | } 91 | 92 | var inputPublisher: AnyPublisher { 93 | inputSubject.eraseToAnyPublisher() 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/SystemWebSocketTests.swift: -------------------------------------------------------------------------------- 1 | import Combine 2 | import Synchronized 3 | @testable import WebSocket 4 | import XCTest 5 | 6 | class SystemWebSocketTests: XCTestCase { 7 | var subject: PassthroughSubject! 8 | 9 | override func setUp() async throws { 10 | try await super.setUp() 11 | subject = .init() 12 | } 13 | 14 | func testCanConnectToAndDisconnectFromServer() async throws { 15 | let openEx = expectation(description: "Should have opened") 16 | let closeEx = expectation(description: "Should have closed") 17 | let (server, client) = try await makeServerAndClient( 18 | onOpen: { openEx.fulfill() }, 19 | onClose: { close in 20 | XCTAssertEqual(.normalClosure, close.code) 21 | XCTAssertNil(close.reason) 22 | closeEx.fulfill() 23 | } 24 | ) 25 | defer { server.shutDown() } 26 | 27 | try await client.open() 28 | await fulfillment(of: [openEx], timeout: 2) 29 | 30 | let isOpen = await client.isOpen 31 | XCTAssertTrue(isOpen) 32 | 33 | try await client.close() 34 | await fulfillment(of: [closeEx], timeout: 2) 35 | } 36 | 37 | func testErrorWhenServerIsUnreachable() async throws { 38 | let ex = expectation(description: "Should have errored") 39 | let (server, client) = try await makeOfflineServerAndClient( 40 | onOpen: { XCTFail("Should not have opened") }, 41 | onClose: { close in 42 | XCTAssertEqual(.abnormalClosure, close.code) 43 | XCTAssertNil(close.reason) 44 | ex.fulfill() 45 | } 46 | ) 47 | defer { server.shutDown() } 48 | 49 | await fulfillment(of: [ex], timeout: 2) 50 | 51 | let isClosed = await client.isClosed 52 | XCTAssertTrue(isClosed) 53 | } 54 | 55 | func _testErrorWhenRemoteCloses() async throws { 56 | let errorEx = expectation(description: "Should have closed") 57 | let (server, client) = try await makeServerAndClient( 58 | onClose: { close in 59 | DispatchQueue.main.async { 60 | XCTAssertTrue( 61 | close.code == .goingAway || close.code == .cancelled 62 | ) 63 | errorEx.fulfill() 64 | } 65 | } 66 | ) 67 | defer { server.shutDown() } 68 | 69 | // When running tests repeatedly (i.e., on the order of 1000s of times), 70 | // sometimes the server fails and causes `.open()` to throw. 71 | do { try await client.open() } 72 | catch {} 73 | 74 | subject.send(.remoteClose) 75 | await fulfillment(of: [errorEx], timeout: 2) 76 | } 77 | 78 | func testWebSocketCannotBeOpenedTwice() async throws { 79 | let closeCount = Locked(0) 80 | 81 | let firstCloseEx = expectation(description: "Should have closed once") 82 | let secondCloseEx = expectation(description: "Should not have closed more than once") 83 | secondCloseEx.isInverted = true 84 | 85 | let (server, client) = try await makeServerAndClient( 86 | onClose: { _ in 87 | let c = closeCount.access { count -> Int in 88 | count += 1 89 | return count 90 | } 91 | if c == 1 { 92 | firstCloseEx.fulfill() 93 | } else { 94 | secondCloseEx.fulfill() 95 | } 96 | } 97 | ) 98 | defer { server.shutDown() } 99 | 100 | try await client.open() 101 | 102 | try await client.close() 103 | await fulfillment(of: [firstCloseEx], timeout: 2) 104 | 105 | do { 106 | try await client.open() 107 | XCTFail("Should not have successfully reopened") 108 | } catch { 109 | guard let wserror = error as? WebSocketError, 110 | case .alreadyClosed = wserror.closeCode 111 | else { return XCTFail("Received wrong error: \(error)") } 112 | } 113 | 114 | await fulfillment(of: [secondCloseEx], timeout: 0.1) 115 | } 116 | 117 | func testPushAndReceiveText() async throws { 118 | let (server, client) = try await makeServerAndClient() 119 | defer { server.shutDown() } 120 | 121 | let sentEx = expectation(description: "Server should have received message") 122 | let sentSub = server.inputPublisher 123 | .sink(receiveValue: { message in 124 | guard case let .text(text) = message 125 | else { return XCTFail("Should have received text") } 126 | XCTAssertEqual("hello", text) 127 | sentEx.fulfill() 128 | }) 129 | defer { sentSub.cancel() } 130 | 131 | try await client.open() 132 | 133 | let receivedEx = expectation(description: "Should have received message") 134 | let receivedSub = client.sink { message in 135 | defer { receivedEx.fulfill() } 136 | guard case let .text(text) = message 137 | else { return XCTFail("Should have received text") } 138 | XCTAssertEqual("hi, to you too!", text) 139 | } 140 | defer { receivedSub.cancel() } 141 | 142 | try await client.send(.text("hello")) 143 | await fulfillment(of: [sentEx], timeout: 2) 144 | subject.send(.message(.text("hi, to you too!"))) 145 | await fulfillment(of: [receivedEx], timeout: 2) 146 | } 147 | 148 | @available(iOS 15.0, macOS 12.0, *) 149 | func testPushAndReceiveTextWithAsyncPublisher() async throws { 150 | let (server, client) = try await makeServerAndClient() 151 | defer { server.shutDown() } 152 | 153 | try await client.open() 154 | 155 | try await client.send(.text("hello")) 156 | subject.send(.message(.text("hi, to you too!"))) 157 | 158 | for await message in client.values { 159 | guard case let .text(text) = message else { 160 | XCTFail("Should have received text") 161 | break 162 | } 163 | XCTAssertEqual("hi, to you too!", text) 164 | break 165 | } 166 | } 167 | 168 | func testPushAndReceiveData() async throws { 169 | let (server, client) = try await makeServerAndClient() 170 | defer { server.shutDown() } 171 | 172 | let sentEx = expectation(description: "Server should have received message") 173 | let sentSub = server.inputPublisher 174 | .sink(receiveValue: { message in 175 | guard case let .data(data) = message 176 | else { return XCTFail("Should have received data") } 177 | XCTAssertEqual(Data("hello".utf8), data) 178 | sentEx.fulfill() 179 | }) 180 | defer { sentSub.cancel() } 181 | 182 | try await client.open() 183 | 184 | let receivedEx = expectation(description: "Should have received message") 185 | let receivedSub = client.sink { message in 186 | defer { receivedEx.fulfill() } 187 | guard case let .data(data) = message 188 | else { return XCTFail("Should have received data") } 189 | XCTAssertEqual(Data("hi, to you too!".utf8), data) 190 | } 191 | defer { receivedSub.cancel() } 192 | 193 | try await client.send(.data(Data("hello".utf8))) 194 | await fulfillment(of: [sentEx], timeout: 2) 195 | subject.send(.message(.data(Data("hi, to you too!".utf8)))) 196 | await fulfillment(of: [receivedEx], timeout: 2) 197 | } 198 | 199 | @available(iOS 15.0, macOS 12.0, *) 200 | func testPushAndReceiveDataWithAsyncPublisher() async throws { 201 | let (server, client) = try await makeServerAndClient() 202 | defer { server.shutDown() } 203 | 204 | try await client.open() 205 | 206 | try await client.send(.data(Data("hello bytes".utf8))) 207 | subject.send(.message(.data(Data("howdy".utf8)))) 208 | 209 | for await message in client.values { 210 | guard case let .data(data) = message else { 211 | XCTFail("Should have received data") 212 | break 213 | } 214 | XCTAssertEqual("howdy", String(data: data, encoding: .utf8)) 215 | break 216 | } 217 | } 218 | 219 | @available(iOS 15.0, macOS 12.0, *) 220 | func testPublisherFinishesOnClose() async throws { 221 | let (server, client) = try await makeServerAndClient() 222 | defer { server.shutDown() } 223 | 224 | try await client.open() 225 | 226 | let task = Task.detached { 227 | var count = 1 228 | repeat { 229 | self.subject.send(.message(.text(String(count)))) 230 | count += 1 231 | try await Task.sleep(nanoseconds: 20 * NSEC_PER_MSEC) 232 | } while !Task.isCancelled 233 | } 234 | 235 | var receivedMessages = 0 236 | for await message in client.values { 237 | guard let _ = message.stringValue else { return XCTFail() } 238 | receivedMessages += 1 239 | if receivedMessages == 3 { 240 | try await client.close() 241 | } 242 | } 243 | 244 | XCTAssertEqual(3, receivedMessages) 245 | 246 | task.cancel() 247 | } 248 | 249 | @available(iOS 15.0, macOS 12.0, *) 250 | func testPublisherFinishesOnCloseFromServer() async throws { 251 | let (server, client) = try await makeServerAndClient() 252 | defer { server.shutDown() } 253 | 254 | try await client.open() 255 | 256 | let task = Task.detached { 257 | var count = 1 258 | repeat { 259 | self.subject.send(.message(.text(String(count)))) 260 | count += 1 261 | try await Task.sleep(nanoseconds: 20 * NSEC_PER_MSEC) 262 | } while !Task.isCancelled 263 | } 264 | 265 | var receivedMessages = 0 266 | for await message in client.values { 267 | guard let _ = message.stringValue else { return XCTFail() } 268 | receivedMessages += 1 269 | if receivedMessages == 3 { 270 | subject.send(.remoteClose) 271 | } 272 | } 273 | 274 | XCTAssertEqual(3, receivedMessages) 275 | 276 | task.cancel() 277 | } 278 | 279 | func testWrappedSystemWebSocket() async throws { 280 | let openEx = expectation(description: "Should have opened") 281 | let closeEx = expectation(description: "Should have closed") 282 | let (server, client) = try await makeServerAndWrappedClient( 283 | onOpen: { openEx.fulfill() }, 284 | onClose: { close in 285 | XCTAssertEqual(.normalClosure, close.code) 286 | XCTAssertNil(close.reason) 287 | closeEx.fulfill() 288 | } 289 | ) 290 | defer { server.shutDown() } 291 | 292 | let messagesToSendToServer: [WebSocketMessage] = [ 293 | .text("client: one"), 294 | .data(Data("client: two".utf8)), 295 | .text("client: three"), 296 | ] 297 | 298 | let messagesToReceiveFromServer: [WebSocketMessage] = [ 299 | .text("server: one"), 300 | .data(Data("server: two".utf8)), 301 | .text("server: three"), 302 | ] 303 | 304 | var messagesReceivedByServer = 0 305 | let sentSub = server.inputPublisher 306 | .sink(receiveValue: { message in 307 | let i = messagesReceivedByServer 308 | defer { messagesReceivedByServer += 1 } 309 | XCTAssertEqual(messagesToSendToServer[i], message) 310 | }) 311 | defer { sentSub.cancel() } 312 | 313 | // These two lines are redundant, but the goal 314 | // is to test everything in `WebSocket`. 315 | try await client.open() 316 | await fulfillment(of: [openEx], timeout: 2) 317 | 318 | // This message has to be sent after the `AsyncStream` is 319 | // subscribed to below. 320 | let messageToReceiveFromServer = messagesToReceiveFromServer[0] 321 | Task.detached { 322 | try await Task.sleep(nanoseconds: 10_000_000) // 10ms 323 | self.subject.send(.message(messageToReceiveFromServer)) 324 | } 325 | 326 | var messagesReceivedByClient = 0 327 | for await message in client.messages { 328 | let i = messagesReceivedByClient 329 | defer { messagesReceivedByClient += 1 } 330 | 331 | XCTAssertEqual(messagesToReceiveFromServer[i], message) 332 | try await client.send(messagesToSendToServer[i]) 333 | 334 | if i < 2 { 335 | subject.send(.message(messagesToReceiveFromServer[i + 1])) 336 | } else { 337 | try await client.close() 338 | } 339 | } 340 | 341 | XCTAssertEqual(3, messagesReceivedByClient) 342 | XCTAssertEqual(3, messagesReceivedByServer) 343 | 344 | await fulfillment(of: [closeEx], timeout: 2) 345 | } 346 | } 347 | 348 | private let empty: Empty = Empty( 349 | completeImmediately: false, 350 | outputType: WebSocketServerOutput.self, 351 | failureType: Error.self 352 | ) 353 | 354 | private extension SystemWebSocketTests { 355 | func request(_ port: Int) -> URLRequest { 356 | URLRequest( 357 | url: URL(string: "ws://127.0.0.1:\(port)/socket")! 358 | ) 359 | } 360 | 361 | func makeServerAndClient( 362 | onOpen: @escaping @Sendable () -> Void = {}, 363 | onClose: @escaping @Sendable (WebSocketClose) -> Void = { _ in } 364 | ) async throws -> (WebSocketServer, SystemWebSocket) { 365 | let server = try WebSocketServer(outputPublisher: subject) 366 | let client = try! await SystemWebSocket( 367 | request: request(server.port), 368 | options: .init(timeoutIntervalForRequest: 2), 369 | onOpen: onOpen, 370 | onClose: onClose 371 | ) 372 | return (server, client) 373 | } 374 | 375 | func makeOfflineServerAndClient( 376 | onOpen: @escaping @Sendable () -> Void = {}, 377 | onClose: @escaping @Sendable (WebSocketClose) -> Void = { _ in } 378 | ) async throws -> (WebSocketServer, SystemWebSocket) { 379 | let server = try WebSocketServer(outputPublisher: empty) 380 | let client = try! await SystemWebSocket( 381 | request: request(19), 382 | options: .init(timeoutIntervalForRequest: 2), 383 | onOpen: onOpen, 384 | onClose: onClose 385 | ) 386 | return (server, client) 387 | } 388 | 389 | func makeServerAndWrappedClient( 390 | onOpen: @escaping @Sendable () -> Void = {}, 391 | onClose: @escaping @Sendable (WebSocketClose) -> Void = { _ in } 392 | ) async throws -> (WebSocketServer, WebSocket) { 393 | let server = try WebSocketServer(outputPublisher: subject) 394 | let client = try! await SystemWebSocket( 395 | request: request(server.port), 396 | options: .init(timeoutIntervalForRequest: 2), 397 | onOpen: onOpen, 398 | onClose: onClose 399 | ) 400 | return (server, try! await .system(client)) 401 | } 402 | } 403 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift: -------------------------------------------------------------------------------- 1 | @testable import WebSocket 2 | import XCTest 3 | 4 | class URLSessionWebSocketTaskCloseCodeTests: XCTestCase { 5 | func testCanInitializeURLSessionWebSocketTaskCloseCode() throws { 6 | let urlSessionCloseCodes: [URLSessionWebSocketTask.CloseCode] = [ 7 | .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, 8 | .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, 9 | .messageTooBig, .mandatoryExtensionMissing, .internalServerError, 10 | .tlsHandshakeFailure, 11 | ] 12 | 13 | let closeCodes: [WebSocketCloseCode] = [ 14 | .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, 15 | .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, 16 | .messageTooBig, .mandatoryExtensionMissing, .internalServerError, 17 | .tlsHandshakeFailure, 18 | ] 19 | 20 | for (urlSessionCloseCode, closeCode) in zip(urlSessionCloseCodes, closeCodes) { 21 | XCTAssertEqual(urlSessionCloseCode, closeCode.wsCloseCode) 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /Tests/WebSocketTests/XCTestManifests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | #if !canImport(ObjectiveC) 4 | public func allTests() -> [XCTestCaseEntry] { 5 | [ 6 | testCase(WebSocketTests.allTests), 7 | ] 8 | } 9 | #endif 10 | -------------------------------------------------------------------------------- /bin/format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | SELF=`realpath $0` 6 | DIR=`dirname $SELF` 7 | DEV_DIR=`echo ${DIR%/*}` 8 | 9 | pushd "$DEV_DIR" &>/dev/null 10 | 11 | if command -v swiftformat >/dev/null 2>&1; then 12 | swiftformat --quiet --config .swiftformat . 13 | else 14 | echo "warning: Install swiftformat by running 'brew install swiftformat'" 15 | fi 16 | 17 | popd &>/dev/null 18 | --------------------------------------------------------------------------------