├── .envrc ├── .eslintrc.cjs ├── .gitattributes ├── .github ├── CODEOWNERS ├── pull_request_template.md ├── release-drafter.yml └── workflows │ ├── ci.yaml │ ├── npmjs-publish.yml │ └── release-drafter.yml ├── .gitignore ├── .prettierignore ├── .prettierrc ├── .replit ├── LICENSE ├── PROTOCOL.md ├── README.md ├── __tests__ ├── __snapshots__ │ └── serialize.test.ts.snap ├── allocation.test.ts ├── bandwidth.bench.ts ├── cancellation.test.ts ├── cleanup.test.ts ├── context.test.ts ├── disconnects.test.ts ├── e2e.test.ts ├── globalSetup.ts ├── handler.test.ts ├── invalid-request.test.ts ├── middleware.test.ts ├── negative.test.ts ├── serialize.test.ts ├── streams.test.ts └── typescript-stress.test.ts ├── codec ├── adapter.ts ├── binary.ts ├── codec.test.ts ├── index.ts ├── json.ts └── types.ts ├── flake.lock ├── flake.nix ├── flake.sh ├── logging ├── index.ts └── log.ts ├── package-lock.json ├── package.json ├── replit.nix ├── router ├── client.ts ├── context.ts ├── errors.ts ├── handshake.ts ├── index.ts ├── procedures.ts ├── result.ts ├── server.ts ├── services.ts └── streams.ts ├── testUtil ├── duplex │ ├── duplexPair.test.ts │ └── duplexPair.ts ├── fixtures │ ├── cleanup.ts │ ├── codec.ts │ ├── matrix.ts │ ├── mockTransport.ts │ ├── services.ts │ └── transports.ts ├── index.ts └── observable │ ├── observable.test.ts │ └── observable.ts ├── tracing ├── index.ts └── tracing.test.ts ├── transport ├── client.ts ├── connection.ts ├── events.test.ts ├── events.ts ├── id.ts ├── impls │ └── ws │ │ ├── client.ts │ │ ├── connection.ts │ │ ├── server.ts │ │ ├── ws.test.ts │ │ └── wslike.ts ├── index.ts ├── message.test.ts ├── message.ts ├── options.ts ├── rateLimit.test.ts ├── rateLimit.ts ├── results.ts ├── server.ts ├── sessionStateMachine │ ├── SessionBackingOff.ts │ ├── SessionConnected.ts │ ├── SessionConnecting.ts │ ├── SessionHandshaking.ts │ ├── SessionNoConnection.ts │ ├── SessionWaitingForHandshake.ts │ ├── common.ts │ ├── index.ts │ ├── stateMachine.test.ts │ └── transitions.ts ├── stringifyError.ts ├── transport.test.ts └── transport.ts ├── tsconfig.json ├── tsup.config.ts └── vitest.config.ts /.envrc: -------------------------------------------------------------------------------- 1 | use flake 2 | dotenv_if_exists 3 | PATH_add ./node_modules/.bin 4 | -------------------------------------------------------------------------------- /.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | env: { 3 | node: true, 4 | es2021: true, 5 | }, 6 | extends: [ 7 | 'eslint:recommended', 8 | 'plugin:@typescript-eslint/strict-type-checked', 9 | 'plugin:@typescript-eslint/stylistic-type-checked', 10 | 'plugin:prettier/recommended', 11 | ], 12 | parser: '@typescript-eslint/parser', 13 | parserOptions: { 14 | ecmaVersion: 'latest', 15 | sourceType: 'module', 16 | project: ['./tsconfig.json'], 17 | }, 18 | plugins: ['@typescript-eslint', '@stylistic/js', '@stylistic/ts'], 19 | rules: { 20 | 'linebreak-style': ['error', 'unix'], 21 | '@typescript-eslint/no-confusing-void-expression': [ 22 | 'error', 23 | { ignoreArrowShorthand: true }, 24 | ], 25 | '@typescript-eslint/no-unused-vars': [ 26 | 'error', 27 | { 28 | args: 'all', 29 | argsIgnorePattern: '^_', 30 | caughtErrors: 'all', 31 | caughtErrorsIgnorePattern: '^_', 32 | destructuredArrayIgnorePattern: '^_', 33 | varsIgnorePattern: '^_', 34 | ignoreRestSiblings: true, 35 | }, 36 | ], 37 | '@typescript-eslint/require-await': 'off', 38 | '@typescript-eslint/array-type': ['error', { default: 'generic' }], 39 | '@typescript-eslint/no-invalid-void-type': 'off', 40 | '@typescript-eslint/restrict-template-expressions': [ 41 | 'error', 42 | { 43 | allowNullish: true, 44 | allowNumber: true, 45 | }, 46 | ], 47 | '@stylistic/ts/lines-between-class-members': [ 48 | 'error', 49 | 'always', 50 | { exceptAfterSingleLine: true }, 51 | ], 52 | '@stylistic/js/no-multiple-empty-lines': ['error', { max: 1 }], 53 | '@stylistic/ts/padding-line-between-statements': [ 54 | 'error', 55 | { blankLine: 'always', prev: '*', next: 'return' }, 56 | { 57 | blankLine: 'always', 58 | prev: '*', 59 | next: ['enum', 'interface', 'type'], 60 | }, 61 | ], 62 | }, 63 | ignorePatterns: ['dist/**/*'], 64 | }; 65 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @replit/workspace-infra 2 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Why 2 | 3 | 4 | 5 | ## What changed 6 | 7 | 8 | 9 | ## Versioning 10 | 11 | - [ ] Breaking protocol change 12 | - [ ] Breaking ts/js API change 13 | 14 | 15 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$RESOLVED_VERSION' 2 | tag-template: 'v$RESOLVED_VERSION' 3 | categories: 4 | - title: '🚀 Features' 5 | labels: 6 | - 'feature' 7 | - 'enhancement' 8 | - title: '🐛 Bug Fixes' 9 | labels: 10 | - 'fix' 11 | - 'bugfix' 12 | - 'bug' 13 | - title: '🧰 Maintenance' 14 | label: 'chore' 15 | - title: '🤖 Dependencies' 16 | label: 'dependencies' 17 | change-template: '- $TITLE @$AUTHOR (#$NUMBER)' 18 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. 19 | version-resolver: 20 | major: 21 | labels: 22 | - 'major' 23 | minor: 24 | labels: 25 | - 'minor' 26 | patch: 27 | labels: 28 | - 'patch' 29 | default: patch 30 | template: | 31 | ## Changes 32 | 33 | $CHANGES 34 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - '**' 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build-and-test: 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [macos-latest, ubuntu-latest] 17 | runs-on: ${{ matrix.os }} 18 | permissions: 19 | contents: write 20 | actions: read 21 | checks: write 22 | steps: 23 | - uses: actions/checkout@v3 24 | with: 25 | fetch-depth: 0 26 | 27 | - name: Setup Node 28 | uses: actions/setup-node@v3 29 | with: 30 | node-version: 18 31 | 32 | - name: Cache dependencies 33 | uses: actions/cache@v3 34 | with: 35 | path: ~/.npm 36 | key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} 37 | restore-keys: | 38 | ${{ runner.os }}-node- 39 | 40 | - run: npm ci 41 | 42 | - name: Check types and style 43 | run: npm run check 44 | 45 | - name: Test 46 | run: npm test -- --outputFile.junit=./test-results.xml 47 | 48 | - name: Test Report 49 | uses: dorny/test-reporter@v1 50 | if: success() || failure() 51 | with: 52 | name: Test Report (${{ matrix.os }}) 53 | path: ./test-results.xml 54 | reporter: java-junit 55 | -------------------------------------------------------------------------------- /.github/workflows/npmjs-publish.yml: -------------------------------------------------------------------------------- 1 | name: Build and Upload Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | inputs: 8 | version: 9 | description: 'What version to use for the release' 10 | required: true 11 | 12 | jobs: 13 | deploy: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - name: Setup Node 20 | uses: actions/setup-node@v3 21 | with: 22 | node-version: 18 23 | 24 | - name: Set release version 25 | run: | 26 | tag="${{ github.event.inputs.version }}" 27 | if [ -z "$tag" ]; then 28 | tag="${GITHUB_REF_NAME}" 29 | fi 30 | version="${tag#v}" # Strip leading v 31 | 32 | # Bump library tag 33 | npm version --no-git-tag-version "$version" 34 | 35 | git config user.name 'GitHub Actions' 36 | git config user.email eng+github@repl.it 37 | 38 | git commit -m 'Setting version' package.json 39 | 40 | - name: Build and publish 41 | run: | 42 | npm set "//registry.npmjs.org/:_authToken" "${{ secrets.NPMJS_AUTH_TOKEN }}" 43 | npm install --frozen-lockfile 44 | npm run publish 45 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | workflow_dispatch: {} 5 | push: 6 | # branches to consider in the event; optional, defaults to all 7 | branches: 8 | - main 9 | # pull_request event is required only for autolabeler 10 | pull_request: 11 | # Only following types are handled by the action, but one can default to all as well 12 | types: [opened, reopened, synchronize] 13 | # pull_request_target event is required for autolabeler to support PRs from forks 14 | pull_request_target: 15 | types: [opened, reopened, synchronize] 16 | 17 | permissions: 18 | contents: read 19 | 20 | jobs: 21 | update_release_draft: 22 | permissions: 23 | # write permission is required to create a github release 24 | contents: write 25 | # write permission is required for autolabeler 26 | # otherwise, read permission is required at least 27 | pull-requests: write 28 | runs-on: ubuntu-latest 29 | steps: 30 | # Drafts your next Release notes as Pull Requests are merged into "master" 31 | - uses: release-drafter/release-drafter@v5 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | node_modules 3 | example 4 | 5 | # Nix 6 | /.direnv/ 7 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | .cache 2 | node_modules 3 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "quoteProps": "as-needed", 3 | "trailingComma": "all", 4 | "tabWidth": 2, 5 | "semi": true, 6 | "singleQuote": true, 7 | "bracketSpacing": true, 8 | "useTabs": false, 9 | "arrowParens": "always" 10 | } 11 | -------------------------------------------------------------------------------- /.replit: -------------------------------------------------------------------------------- 1 | run = "npm run test" 2 | modules = ["nodejs-20:v8-20230920-bd784b9"] 3 | hidden = [".config", "package-lock.json"] 4 | 5 | disableGuessImports = true 6 | disableInstallBeforeRun = true 7 | 8 | [nix] 9 | channel = "stable-23_05" 10 | 11 | [[ports]] 12 | localPort = 3000 13 | externalPort = 80 14 | 15 | [languages.eslint] 16 | pattern = "**{*.ts,*.js,*.tsx,*.jsx}" 17 | [languages.eslint.languageServer] 18 | start = "vscode-eslint-language-server --stdio" 19 | [languages.eslint.languageServer.configuration] 20 | nodePath = "node" # this should resolve to nvm 21 | validate = "probe" 22 | useESLintClass = false 23 | format = false 24 | quiet = false 25 | run = "onType" 26 | packageManager = "npm" 27 | rulesCustomizations = [] 28 | onIgnoredFiles = "off" 29 | [languages.eslint.languageServer.configuration.codeActionOnSave] 30 | mode = "auto" 31 | [languages.eslint.languageServer.configuration.workspaceFolder] 32 | name = "river" 33 | # we seem to not be able to use ${REPL_HOME} here as the vscode package does 34 | # not evaluate the environment variable, and we need a `/` prefix so it 35 | # knows we gave it an absolute path 36 | uri = "file:///home/runner/${REPL_SLUG}" 37 | [languages.eslint.languageServer.configuration.experimental] 38 | useFlatConfig = false 39 | [languages.eslint.languageServer.configuration.problems] 40 | shortenToSingleLine = false 41 | [languages.eslint.languageServer.configuration.codeAction.disableRuleComment] 42 | enable = true 43 | location = "separateLine" 44 | commentStyle = "line" 45 | [languages.eslint.languageServer.configuration.codeAction.showDocumentation] 46 | enable = true 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Repl.it 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # River 2 | 3 | ⚠️ Not production ready, while Replit is using parts of River in production, we are still going through rapid breaking changes. First production ready version will be `1.x.x` ⚠️ 4 | 5 | River allows multiple clients to connect to and make remote procedure calls to a remote server as if they were local procedures. 6 | 7 | ## Long-lived streaming remote procedure calls 8 | 9 | River provides a framework for long-lived streaming Remote Procedure Calls (RPCs) in modern web applications, featuring advanced error handling and customizable retry policies to ensure seamless communication between clients and servers. 10 | 11 | River provides a framework similar to [tRPC](https://trpc.io/) and [gRPC](https://grpc.io/) but with additional features: 12 | 13 | - JSON Schema Support + run-time schema validation 14 | - full-duplex streaming 15 | - service multiplexing 16 | - result types and error handling 17 | - snappy DX (no code generation) 18 | - transparent reconnect support for long-lived sessions 19 | - over any transport (WebSockets and Unix Domain Socket out of the box) 20 | 21 | See [PROTOCOL.md](./PROTOCOL.md) for more information on the protocol. 22 | 23 | ### Prerequisites 24 | 25 | Before proceeding, ensure you have TypeScript 5 installed and configured appropriately: 26 | 27 | 1. **Ensure your `tsconfig.json` is configured correctly**: 28 | 29 | You must verify that: 30 | 31 | - `compilerOptions.moduleResolution` is set to `"bundler"` 32 | - `compilerOptions.strictFunctionTypes` is set to `true` 33 | - `compilerOptions.strictNullChecks` is set to `true` 34 | 35 | or, preferably, that: 36 | 37 | - `compilerOptions.moduleResolution` is set to `"bundler"` 38 | - `compilerOptions.strict` is set to `true` 39 | 40 | Like so: 41 | 42 | ```jsonc 43 | { 44 | "compilerOptions": { 45 | "moduleResolution": "bundler", 46 | "strict": true 47 | // Other compiler options... 48 | } 49 | } 50 | ``` 51 | 52 | If these options already exist in your `tsconfig.json` and don't match what is shown above, modify them. River is designed for `"strict": true`, but technically only `strictFunctionTypes` and `strictNullChecks` being set to `true` is required. Failing to set these will cause unresolvable type errors when defining services. 53 | 54 | 2. Install River and Dependencies: 55 | 56 | To use River, install the required packages using npm: 57 | 58 | ```bash 59 | npm i @replit/river @sinclair/typebox 60 | ``` 61 | 62 | ## Writing services 63 | 64 | ### Concepts 65 | 66 | - Router: a collection of services, namespaced by service name. 67 | - Service: a collection of procedures with a shared state. 68 | - Procedure: a single procedure. A procedure declares its type, a request data type, a response data type, optionally a response error type, and the associated handler. Valid types are: 69 | - `rpc`, single request, single response 70 | - `upload`, multiple requests, single response 71 | - `subscription`, single request, multiple responses 72 | - `stream`, multiple requests, multiple response 73 | - Transport: manages the lifecycle (creation/deletion) of connections and multiplexing read/writes from clients. Both the client and the server must be passed in a subclass of `Transport` to work. 74 | - Connection: the actual raw underlying transport connection 75 | - Session: a higher-level abstraction that operates over the span of potentially multiple transport-level connections 76 | - Codec: encodes messages between clients/servers before the transport sends it across the wire. 77 | 78 | ### A basic router 79 | 80 | First, we create a service using `ServiceSchema`: 81 | 82 | ```ts 83 | import { ServiceSchema, Procedure, Ok } from '@replit/river'; 84 | import { Type } from '@sinclair/typebox'; 85 | 86 | export const ExampleService = ServiceSchema.define( 87 | // configuration 88 | { 89 | // initializer for shared state 90 | initializeState: () => ({ count: 0 }), 91 | }, 92 | // procedures 93 | { 94 | add: Procedure.rpc({ 95 | requestInit: Type.Object({ n: Type.Number() }), 96 | responseData: Type.Object({ result: Type.Number() }), 97 | requestErrors: Type.Never(), 98 | // note that a handler is unique per user RPC 99 | async handler({ ctx, reqInit: { n } }) { 100 | // access and mutate shared state 101 | ctx.state.count += n; 102 | return Ok({ result: ctx.state.count }); 103 | }, 104 | }), 105 | }, 106 | ); 107 | ``` 108 | 109 | Then, we create the server: 110 | 111 | ```ts 112 | import http from 'http'; 113 | import { WebSocketServer } from 'ws'; 114 | import { WebSocketServerTransport } from '@replit/river/transport/ws/server'; 115 | import { createServer } from '@replit/river'; 116 | 117 | // start websocket server on port 3000 118 | const httpServer = http.createServer(); 119 | const port = 3000; 120 | const wss = new WebSocketServer({ server: httpServer }); 121 | const transport = new WebSocketServerTransport(wss, 'SERVER'); 122 | 123 | export const server = createServer(transport, { 124 | example: ExampleService, 125 | }); 126 | 127 | export type ServiceSurface = typeof server; 128 | 129 | httpServer.listen(port); 130 | ``` 131 | 132 | In another file for the client (to create a separate entrypoint), 133 | 134 | ```ts 135 | import { WebSocketClientTransport } from '@replit/river/transport/ws/client'; 136 | import { createClient } from '@replit/river'; 137 | import { WebSocket } from 'ws'; 138 | 139 | const transport = new WebSocketClientTransport( 140 | async () => new WebSocket('ws://localhost:3000'), 141 | 'my-client-id', 142 | ); 143 | 144 | const client = createClient( 145 | transport, 146 | 'SERVER', // transport id of the server in the previous step 147 | { eagerlyConnect: true }, // whether to eagerly connect to the server on creation (optional argument) 148 | ); 149 | 150 | // we get full type safety on `client` 151 | // client...() 152 | // e.g. 153 | const result = await client.example.add.rpc({ n: 3 }); 154 | if (result.ok) { 155 | const msg = result.payload; 156 | console.log(msg.result); // 0 + 3 = 3 157 | } 158 | ``` 159 | 160 | ### Logging 161 | 162 | To add logging, you can bind a logging function to a transport. 163 | 164 | ```ts 165 | import { coloredStringLogger } from '@replit/river/logging'; 166 | 167 | const transport = new WebSocketClientTransport( 168 | async () => new WebSocket('ws://localhost:3000'), 169 | 'my-client-id', 170 | ); 171 | 172 | transport.bindLogger(console.log); 173 | // or 174 | transport.bindLogger(coloredStringLogger); 175 | ``` 176 | 177 | You can define your own logging functions that satisfy the `LogFn` type. 178 | 179 | ### Connection status 180 | 181 | River defines two types of reconnects: 182 | 183 | 1. **Transparent reconnects:** These occur when the connection is temporarily lost and reestablished without losing any messages. From the application's perspective, this process is seamless and does not disrupt ongoing operations. 184 | 2. **Hard reconnect:** This occurs when all server state is lost, requiring the client to reinitialize anything stateful (e.g. subscriptions). 185 | 186 | Hard reconnects are signaled via `sessionStatus` events. 187 | 188 | If your application is stateful on either the server or the client, the service consumer _should_ wrap all the client-side setup with `transport.addEventListener('sessionStatus', (evt) => ...)` to do appropriate setup and teardown. 189 | 190 | ```ts 191 | transport.addEventListener('sessionStatus', (evt) => { 192 | if (evt.status === 'created') { 193 | // do something 194 | } else if (evt.status === 'closing') { 195 | // do other things 196 | } else if (evt.status === 'closed') { 197 | // note that evt.session only has id + to 198 | // this is useful for doing things like creating a new session if 199 | // a session just got yanked 200 | } 201 | }); 202 | 203 | // or, listen for specific session states 204 | transport.addEventListener('sessionTransition', (evt) => { 205 | if (evt.state === SessionState.Connected) { 206 | // switch on various transition states 207 | } else if (evt.state === SessionState.NoConnection) { 208 | // do something 209 | } 210 | }); 211 | ``` 212 | 213 | ### Custom Handshake 214 | 215 | River allows you to extend the protocol-level handshake so you can add additional logic to 216 | validate incoming connections. 217 | 218 | You can do this by passing extra options to `createClient` and `createServer` and extending the `ParsedMetadata` interface: 219 | 220 | ```ts 221 | declare module '@replit/river' { 222 | interface ParsedMetadata { 223 | userId: number; 224 | } 225 | } 226 | 227 | const schema = Type.Object({ token: Type.String() }); 228 | createClient(new MockClientTransport('client'), 'SERVER', { 229 | eagerlyConnect: false, 230 | handshakeOptions: createClientHandshakeOptions(schema, async () => ({ 231 | // the type of this function is 232 | // () => Static | Promise> 233 | token: '123', 234 | })), 235 | }); 236 | 237 | createServer(new MockServerTransport('SERVER'), services, { 238 | handshakeOptions: createServerHandshakeOptions( 239 | schema, 240 | (metadata, previousMetadata) => { 241 | // the type of this function is 242 | // (metadata: Static, previousMetadata?: ParsedMetadata) => 243 | // | false | Promise (if you reject it) 244 | // | ParsedMetadata | Promise (if you allow it) 245 | // next time a connection happens on the same session, previousMetadata will 246 | // be populated with the last returned value 247 | }, 248 | ), 249 | }); 250 | ``` 251 | 252 | You can then access the `ParsedMetadata` in your procedure handlers: 253 | 254 | ```ts 255 | async handler(ctx, ...args) { 256 | // this contains the parsed metadata 257 | console.log(ctx.metadata) 258 | } 259 | ``` 260 | 261 | ### Further examples 262 | 263 | We've also provided an end-to-end testing environment using `Next.js`, and a simple backend connected with the WebSocket transport that you can [play with on Replit](https://replit.com/@jzhao-replit/riverbed). 264 | 265 | You can find more service examples in the [E2E test fixtures](https://github.com/replit/river/blob/main/__tests__/fixtures/services.ts) 266 | 267 | ## Developing 268 | 269 | [![Run on Repl.it](https://replit.com/badge/github/replit/river)](https://replit.com/new/github/replit/river) 270 | 271 | - `npm i` -- install dependencies 272 | - `npm run check` -- lint 273 | - `npm run format` -- format 274 | - `npm run test` -- run tests 275 | - `npm run publish` -- cut a new release (should bump version in package.json first) 276 | -------------------------------------------------------------------------------- /__tests__/allocation.test.ts: -------------------------------------------------------------------------------- 1 | import { beforeEach, describe, test, expect, vi, assert } from 'vitest'; 2 | import { TestSetupHelpers, transports } from '../testUtil/fixtures/transports'; 3 | import { BinaryCodec, Codec } from '../codec'; 4 | import { 5 | advanceFakeTimersByHeartbeat, 6 | createPostTestCleanups, 7 | } from '../testUtil/fixtures/cleanup'; 8 | import { createServer } from '../router/server'; 9 | import { createClient } from '../router/client'; 10 | import { TestServiceSchema } from '../testUtil/fixtures/services'; 11 | import { waitFor } from '../testUtil/fixtures/cleanup'; 12 | import { numberOfConnections, closeAllConnections } from '../testUtil'; 13 | import { cleanupTransports } from '../testUtil/fixtures/cleanup'; 14 | import { testFinishesCleanly } from '../testUtil/fixtures/cleanup'; 15 | import { ProtocolError } from '../transport/events'; 16 | 17 | let isOom = false; 18 | // simulate RangeError: Array buffer allocation failed 19 | const OomableCodec: Codec = { 20 | toBuffer(obj) { 21 | if (isOom) { 22 | throw new RangeError('failed allocation'); 23 | } 24 | 25 | return BinaryCodec.toBuffer(obj); 26 | }, 27 | fromBuffer: (buff: Uint8Array) => { 28 | return BinaryCodec.fromBuffer(buff); 29 | }, 30 | }; 31 | 32 | describe.each(transports)( 33 | 'failed allocation test ($name transport)', 34 | async (transport) => { 35 | const clientOpts = { codec: OomableCodec }; 36 | const serverOpts = { codec: BinaryCodec }; 37 | 38 | const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); 39 | let getClientTransport: TestSetupHelpers['getClientTransport']; 40 | let getServerTransport: TestSetupHelpers['getServerTransport']; 41 | beforeEach(async () => { 42 | // only allow client to oom, server has sane oom handling already 43 | const setup = await transport.setup({ 44 | client: clientOpts, 45 | server: serverOpts, 46 | }); 47 | getClientTransport = setup.getClientTransport; 48 | getServerTransport = setup.getServerTransport; 49 | isOom = false; 50 | 51 | return async () => { 52 | await postTestCleanup(); 53 | await setup.cleanup(); 54 | }; 55 | }); 56 | 57 | test('oom during heartbeat kills the session, client starts new session', async () => { 58 | // setup 59 | const clientTransport = getClientTransport('client'); 60 | const serverTransport = getServerTransport(); 61 | const services = { test: TestServiceSchema }; 62 | const server = createServer(serverTransport, services); 63 | const client = createClient( 64 | clientTransport, 65 | serverTransport.clientId, 66 | ); 67 | 68 | const errMock = vi.fn(); 69 | clientTransport.addEventListener('protocolError', errMock); 70 | addPostTestCleanup(async () => { 71 | clientTransport.removeEventListener('protocolError', errMock); 72 | await cleanupTransports([clientTransport, serverTransport]); 73 | }); 74 | 75 | // establish initial connection 76 | const result = await client.test.add.rpc({ n: 1 }); 77 | expect(result).toStrictEqual({ ok: true, payload: { result: 1 } }); 78 | 79 | await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); 80 | await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); 81 | const oldClientSession = serverTransport.sessions.get('client'); 82 | const oldServerSession = clientTransport.sessions.get('SERVER'); 83 | assert(oldClientSession); 84 | assert(oldServerSession); 85 | 86 | // simulate some OOM during heartbeat 87 | for (let i = 0; i < 5; i++) { 88 | isOom = i % 2 === 0; 89 | await advanceFakeTimersByHeartbeat(); 90 | } 91 | 92 | // verify session on client is dead 93 | await waitFor(() => expect(clientTransport.sessions.size).toBe(0)); 94 | 95 | // verify we got MessageSendFailure errors 96 | await waitFor(() => { 97 | expect(errMock).toHaveBeenCalledWith( 98 | expect.objectContaining({ 99 | type: ProtocolError.MessageSendFailure, 100 | }), 101 | ); 102 | }); 103 | 104 | // client should be able to reconnect and make new calls 105 | isOom = false; 106 | const result2 = await client.test.add.rpc({ n: 2 }); 107 | expect(result2).toStrictEqual({ ok: true, payload: { result: 3 } }); 108 | 109 | // verify new session IDs are different from old ones 110 | const newClientSession = serverTransport.sessions.get('client'); 111 | const newServerSession = clientTransport.sessions.get('SERVER'); 112 | assert(newClientSession); 113 | assert(newServerSession); 114 | expect(newClientSession.id).not.toBe(oldClientSession.id); 115 | expect(newServerSession.id).not.toBe(oldServerSession.id); 116 | 117 | await testFinishesCleanly({ 118 | clientTransports: [clientTransport], 119 | serverTransport, 120 | server, 121 | }); 122 | }); 123 | 124 | test('oom during handshake kills the session, client starts new session', async () => { 125 | // setup 126 | const clientTransport = getClientTransport('client'); 127 | const serverTransport = getServerTransport(); 128 | const services = { test: TestServiceSchema }; 129 | const server = createServer(serverTransport, services); 130 | const client = createClient( 131 | clientTransport, 132 | serverTransport.clientId, 133 | ); 134 | const errMock = vi.fn(); 135 | clientTransport.addEventListener('protocolError', errMock); 136 | addPostTestCleanup(async () => { 137 | clientTransport.removeEventListener('protocolError', errMock); 138 | await cleanupTransports([clientTransport, serverTransport]); 139 | }); 140 | 141 | // establish initial connection 142 | await client.test.add.rpc({ n: 1 }); 143 | await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); 144 | await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(1)); 145 | 146 | // close connection to force reconnection 147 | closeAllConnections(clientTransport); 148 | await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); 149 | await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); 150 | 151 | // simulate OOM during handshake 152 | isOom = true; 153 | clientTransport.connect('SERVER'); 154 | await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(0)); 155 | await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); 156 | 157 | await waitFor(() => { 158 | expect(errMock).toHaveBeenCalledWith( 159 | expect.objectContaining({ 160 | type: ProtocolError.MessageSendFailure, 161 | }), 162 | ); 163 | }); 164 | 165 | // client should be able to reconnect and make new calls 166 | isOom = false; 167 | const result = await client.test.add.rpc({ n: 2 }); 168 | expect(result).toStrictEqual({ ok: true, payload: { result: 3 } }); 169 | 170 | await testFinishesCleanly({ 171 | clientTransports: [clientTransport], 172 | serverTransport, 173 | server, 174 | }); 175 | }); 176 | }, 177 | ); 178 | -------------------------------------------------------------------------------- /__tests__/bandwidth.bench.ts: -------------------------------------------------------------------------------- 1 | import { afterAll, assert, bench, describe } from 'vitest'; 2 | import { getClientSendFn, waitForMessage } from '../testUtil'; 3 | import { TestServiceSchema } from '../testUtil/fixtures/services'; 4 | import { createServer } from '../router/server'; 5 | import { createClient } from '../router/client'; 6 | import { transports } from '../testUtil/fixtures/transports'; 7 | import { nanoid } from 'nanoid'; 8 | 9 | let n = 0; 10 | const dummyPayloadSmall = () => ({ 11 | streamId: 'test', 12 | controlFlags: 0, 13 | payload: { 14 | msg: 'cool', 15 | n: n++, 16 | }, 17 | }); 18 | 19 | // give time for v8 to warm up 20 | const BENCH_DURATION = 10_000; 21 | describe('bandwidth', async () => { 22 | for (const { name, setup } of transports.filter((t) => t.name !== 'mock')) { 23 | const { getClientTransport, getServerTransport, cleanup } = await setup(); 24 | afterAll(cleanup); 25 | 26 | const serverTransport = getServerTransport(); 27 | const clientTransport = getClientTransport('client'); 28 | clientTransport.connect(serverTransport.clientId); 29 | 30 | const services = { test: TestServiceSchema }; 31 | createServer(serverTransport, services); 32 | const client = createClient( 33 | clientTransport, 34 | serverTransport.clientId, 35 | ); 36 | 37 | const sendClosure = getClientSendFn(clientTransport, serverTransport); 38 | bench( 39 | `${name} -- raw transport send and recv`, 40 | async () => { 41 | const msg = dummyPayloadSmall(); 42 | const id = sendClosure(msg); 43 | await waitForMessage(serverTransport, (msg) => msg.id === id); 44 | 45 | return; 46 | }, 47 | { time: BENCH_DURATION }, 48 | ); 49 | 50 | bench( 51 | `${name} -- rpc`, 52 | async () => { 53 | const result = await client.test.add.rpc({ n: Math.random() }); 54 | assert(result.ok); 55 | }, 56 | { time: BENCH_DURATION }, 57 | ); 58 | 59 | const { reqWritable, resReadable } = client.test.echo.stream({}); 60 | const resIter = resReadable[Symbol.asyncIterator](); 61 | bench( 62 | `${name} -- stream`, 63 | async () => { 64 | reqWritable.write({ msg: nanoid(), ignore: false }); 65 | const result = await resIter.next(); 66 | assert(result.value?.ok); 67 | }, 68 | { time: BENCH_DURATION }, 69 | ); 70 | } 71 | }); 72 | -------------------------------------------------------------------------------- /__tests__/context.test.ts: -------------------------------------------------------------------------------- 1 | import { beforeEach, describe, expect, test } from 'vitest'; 2 | import { 3 | cleanupTransports, 4 | createPostTestCleanups, 5 | } from '../testUtil/fixtures/cleanup'; 6 | import { testMatrix } from '../testUtil/fixtures/matrix'; 7 | import { TestSetupHelpers } from '../testUtil/fixtures/transports'; 8 | import { 9 | Ok, 10 | Procedure, 11 | ServiceSchema, 12 | createClient, 13 | createServer, 14 | } from '../router'; 15 | import { Type } from '@sinclair/typebox'; 16 | 17 | describe('should handle incompatabilities', async () => { 18 | const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); 19 | let getClientTransport: TestSetupHelpers['getClientTransport']; 20 | let getServerTransport: TestSetupHelpers['getServerTransport']; 21 | beforeEach(async () => { 22 | const { 23 | codec: { codec }, 24 | transport, 25 | } = testMatrix()[0]; 26 | const setup = await transport.setup({ 27 | client: { codec }, 28 | server: { codec }, 29 | }); 30 | getClientTransport = setup.getClientTransport; 31 | getServerTransport = setup.getServerTransport; 32 | 33 | return async () => { 34 | await postTestCleanup(); 35 | await setup.cleanup(); 36 | }; 37 | }); 38 | 39 | test('should pass extended context to procedure', async () => { 40 | // setup 41 | const clientTransport = getClientTransport('client'); 42 | const serverTransport = getServerTransport(); 43 | const services = { 44 | testservice: ServiceSchema.define({ 45 | testrpc: Procedure.rpc({ 46 | requestInit: Type.Object({}), 47 | responseData: Type.String(), 48 | handler: async ({ ctx }) => { 49 | return Ok((ctx as unknown as typeof extendedContext).testctx); 50 | }, 51 | }), 52 | }), 53 | }; 54 | 55 | const extendedContext = { testctx: Math.random().toString() }; 56 | createServer(serverTransport, services, { 57 | extendedContext, 58 | }); 59 | const client = createClient( 60 | clientTransport, 61 | serverTransport.clientId, 62 | ); 63 | addPostTestCleanup(async () => { 64 | await cleanupTransports([clientTransport, serverTransport]); 65 | }); 66 | 67 | const res = await client.testservice.testrpc.rpc({}); 68 | 69 | expect(res).toEqual({ ok: true, payload: extendedContext.testctx }); 70 | }); 71 | 72 | test('should pass extended context to initializeState', async () => { 73 | // setup 74 | const clientTransport = getClientTransport('client'); 75 | const serverTransport = getServerTransport(); 76 | 77 | const TestServiceScaffold = ServiceSchema.scaffold({ 78 | initializeState: (ctx) => ({ 79 | fromctx: (ctx as unknown as typeof extendedContext).testctx, 80 | }), 81 | }); 82 | const services = { 83 | testservice: TestServiceScaffold.finalize({ 84 | ...TestServiceScaffold.procedures({ 85 | testrpc: Procedure.rpc({ 86 | requestInit: Type.Object({}), 87 | responseData: Type.String(), 88 | handler: async ({ ctx }) => { 89 | return Ok(ctx.state.fromctx); 90 | }, 91 | }), 92 | }), 93 | }), 94 | }; 95 | 96 | const extendedContext = { testctx: Math.random().toString() }; 97 | createServer(serverTransport, services, { 98 | extendedContext, 99 | }); 100 | const client = createClient( 101 | clientTransport, 102 | serverTransport.clientId, 103 | ); 104 | addPostTestCleanup(async () => { 105 | await cleanupTransports([clientTransport, serverTransport]); 106 | }); 107 | 108 | const res = await client.testservice.testrpc.rpc({}); 109 | 110 | expect(res).toEqual({ ok: true, payload: extendedContext.testctx }); 111 | }); 112 | }); 113 | -------------------------------------------------------------------------------- /__tests__/globalSetup.ts: -------------------------------------------------------------------------------- 1 | import { vi } from 'vitest'; 2 | 3 | vi.useFakeTimers({ shouldAdvanceTime: true }); 4 | -------------------------------------------------------------------------------- /__tests__/handler.test.ts: -------------------------------------------------------------------------------- 1 | import { isReadableDone, readNextResult } from '../testUtil'; 2 | import { afterEach, beforeEach, describe, expect, test } from 'vitest'; 3 | import { 4 | DIV_BY_ZERO, 5 | FallibleServiceSchema, 6 | STREAM_ERROR, 7 | TestServiceSchema, 8 | SubscribableServiceSchema, 9 | UploadableServiceSchema, 10 | } from '../testUtil/fixtures/services'; 11 | import { createClient, createServer, UNCAUGHT_ERROR_CODE } from '../router'; 12 | import { createMockTransportNetwork } from '../testUtil/fixtures/mockTransport'; 13 | 14 | describe('server-side test', () => { 15 | let mockTransportNetwork: ReturnType; 16 | 17 | beforeEach(async () => { 18 | mockTransportNetwork = createMockTransportNetwork(); 19 | }); 20 | 21 | afterEach(async () => { 22 | await mockTransportNetwork.cleanup(); 23 | }); 24 | 25 | test('rpc basic', async () => { 26 | const services = { test: TestServiceSchema }; 27 | createServer(mockTransportNetwork.getServerTransport(), services); 28 | const client = createClient( 29 | mockTransportNetwork.getClientTransport('client'), 30 | 'SERVER', 31 | ); 32 | 33 | const result = await client.test.add.rpc({ n: 3 }); 34 | expect(result).toStrictEqual({ ok: true, payload: { result: 3 } }); 35 | }); 36 | 37 | test('fallible rpc', async () => { 38 | const services = { test: FallibleServiceSchema }; 39 | createServer(mockTransportNetwork.getServerTransport(), services); 40 | const client = createClient( 41 | mockTransportNetwork.getClientTransport('client'), 42 | 'SERVER', 43 | ); 44 | 45 | const result = await client.test.divide.rpc({ a: 10, b: 2 }); 46 | expect(result).toStrictEqual({ ok: true, payload: { result: 5 } }); 47 | 48 | const result2 = await client.test.divide.rpc({ a: 10, b: 0 }); 49 | expect(result2).toStrictEqual({ 50 | ok: false, 51 | payload: { 52 | code: DIV_BY_ZERO, 53 | message: 'Cannot divide by zero', 54 | extras: { test: 'abc' }, 55 | }, 56 | }); 57 | }); 58 | 59 | test('stream basic', async () => { 60 | const services = { test: TestServiceSchema }; 61 | createServer(mockTransportNetwork.getServerTransport(), services); 62 | const client = createClient( 63 | mockTransportNetwork.getClientTransport('client'), 64 | 'SERVER', 65 | ); 66 | 67 | const { reqWritable, resReadable } = client.test.echo.stream({}); 68 | 69 | reqWritable.write({ msg: 'abc', ignore: false }); 70 | reqWritable.write({ msg: 'def', ignore: true }); 71 | reqWritable.write({ msg: 'ghi', ignore: false }); 72 | reqWritable.close(); 73 | 74 | const result1 = await readNextResult(resReadable); 75 | expect(result1).toStrictEqual({ ok: true, payload: { response: 'abc' } }); 76 | 77 | const result2 = await readNextResult(resReadable); 78 | expect(result2).toStrictEqual({ ok: true, payload: { response: 'ghi' } }); 79 | 80 | expect(await isReadableDone(resReadable)).toEqual(true); 81 | }); 82 | 83 | test('stream empty', async () => { 84 | const services = { test: TestServiceSchema }; 85 | createServer(mockTransportNetwork.getServerTransport(), services); 86 | const client = createClient( 87 | mockTransportNetwork.getClientTransport('client'), 88 | 'SERVER', 89 | ); 90 | 91 | const { reqWritable, resReadable } = client.test.echo.stream({}); 92 | reqWritable.close(); 93 | 94 | expect(await isReadableDone(resReadable)).toEqual(true); 95 | }); 96 | 97 | test('stream with initialization', async () => { 98 | const services = { test: TestServiceSchema }; 99 | createServer(mockTransportNetwork.getServerTransport(), services); 100 | const client = createClient( 101 | mockTransportNetwork.getClientTransport('client'), 102 | 'SERVER', 103 | ); 104 | 105 | const { reqWritable, resReadable } = client.test.echoWithPrefix.stream({ 106 | prefix: 'test', 107 | }); 108 | 109 | reqWritable.write({ msg: 'abc', ignore: false }); 110 | reqWritable.write({ msg: 'def', ignore: true }); 111 | reqWritable.write({ msg: 'ghi', ignore: false }); 112 | reqWritable.close(); 113 | 114 | const result1 = await readNextResult(resReadable); 115 | expect(result1).toStrictEqual({ 116 | ok: true, 117 | payload: { response: 'test abc' }, 118 | }); 119 | 120 | const result2 = await readNextResult(resReadable); 121 | expect(result2).toStrictEqual({ 122 | ok: true, 123 | payload: { response: 'test ghi' }, 124 | }); 125 | 126 | expect(await isReadableDone(resReadable)).toEqual(true); 127 | }); 128 | 129 | test('fallible stream', async () => { 130 | const services = { test: FallibleServiceSchema }; 131 | createServer(mockTransportNetwork.getServerTransport(), services); 132 | const client = createClient( 133 | mockTransportNetwork.getClientTransport('client'), 134 | 'SERVER', 135 | ); 136 | 137 | const { reqWritable, resReadable } = client.test.echo.stream({}); 138 | reqWritable.write({ msg: 'abc', throwResult: false, throwError: false }); 139 | 140 | const result1 = await readNextResult(resReadable); 141 | expect(result1).toStrictEqual({ ok: true, payload: { response: 'abc' } }); 142 | 143 | reqWritable.write({ msg: 'def', throwResult: true, throwError: false }); 144 | const result2 = await readNextResult(resReadable); 145 | expect(result2).toStrictEqual({ 146 | ok: false, 147 | payload: { 148 | code: STREAM_ERROR, 149 | message: 'field throwResult was set to true', 150 | }, 151 | }); 152 | 153 | reqWritable.write({ msg: 'ghi', throwResult: false, throwError: true }); 154 | const result3 = await readNextResult(resReadable); 155 | expect(result3).toStrictEqual({ 156 | ok: false, 157 | payload: { 158 | code: UNCAUGHT_ERROR_CODE, 159 | message: 'some message', 160 | }, 161 | }); 162 | 163 | reqWritable.close(); 164 | }); 165 | 166 | test('subscriptions', async () => { 167 | const services = { test: SubscribableServiceSchema }; 168 | createServer(mockTransportNetwork.getServerTransport(), services); 169 | const client = createClient( 170 | mockTransportNetwork.getClientTransport('client'), 171 | 'SERVER', 172 | ); 173 | 174 | const { resReadable } = client.test.value.subscribe({}); 175 | 176 | const streamResult1 = await readNextResult(resReadable); 177 | expect(streamResult1).toStrictEqual({ ok: true, payload: { result: 0 } }); 178 | 179 | const result = await client.test.add.rpc({ n: 3 }); 180 | expect(result).toStrictEqual({ ok: true, payload: { result: 3 } }); 181 | 182 | const streamResult2 = await readNextResult(resReadable); 183 | expect(streamResult2).toStrictEqual({ ok: true, payload: { result: 3 } }); 184 | }); 185 | 186 | test('uploads', async () => { 187 | const services = { test: UploadableServiceSchema }; 188 | createServer(mockTransportNetwork.getServerTransport(), services); 189 | const client = createClient( 190 | mockTransportNetwork.getClientTransport('client'), 191 | 'SERVER', 192 | ); 193 | 194 | const { reqWritable, finalize } = client.test.addMultiple.upload({}); 195 | 196 | reqWritable.write({ n: 1 }); 197 | reqWritable.write({ n: 2 }); 198 | reqWritable.close(); 199 | expect(await finalize()).toStrictEqual({ 200 | ok: true, 201 | payload: { result: 3 }, 202 | }); 203 | }); 204 | 205 | test('uploads empty', async () => { 206 | const services = { test: UploadableServiceSchema }; 207 | createServer(mockTransportNetwork.getServerTransport(), services); 208 | const client = createClient( 209 | mockTransportNetwork.getClientTransport('client'), 210 | 'SERVER', 211 | ); 212 | 213 | const { reqWritable, finalize } = client.test.addMultiple.upload({}); 214 | reqWritable.close(); 215 | expect(await finalize()).toStrictEqual({ 216 | ok: true, 217 | payload: { result: 0 }, 218 | }); 219 | }); 220 | 221 | test('uploads with initialization', async () => { 222 | const services = { test: UploadableServiceSchema }; 223 | createServer(mockTransportNetwork.getServerTransport(), services); 224 | const client = createClient( 225 | mockTransportNetwork.getClientTransport('client'), 226 | 'SERVER', 227 | ); 228 | 229 | const { reqWritable, finalize } = client.test.addMultipleWithPrefix.upload({ 230 | prefix: 'test', 231 | }); 232 | 233 | reqWritable.write({ n: 1 }); 234 | reqWritable.write({ n: 2 }); 235 | reqWritable.close(); 236 | expect(await finalize()).toStrictEqual({ 237 | ok: true, 238 | payload: { result: 'test 3' }, 239 | }); 240 | }); 241 | }); 242 | -------------------------------------------------------------------------------- /__tests__/negative.test.ts: -------------------------------------------------------------------------------- 1 | import { assert, beforeEach, describe, expect, test, vi } from 'vitest'; 2 | import http from 'node:http'; 3 | import { 4 | cleanupTransports, 5 | testFinishesCleanly, 6 | waitFor, 7 | } from '../testUtil/fixtures/cleanup'; 8 | import { 9 | createLocalWebSocketClient, 10 | createWebSocketServer, 11 | numberOfConnections, 12 | onWsServerReady, 13 | } from '../testUtil'; 14 | import { WebSocketServerTransport } from '../transport/impls/ws/server'; 15 | import { 16 | ControlFlags, 17 | ControlMessageHandshakeRequestSchema, 18 | OpaqueTransportMessage, 19 | handshakeRequestMessage, 20 | } from '../transport/message'; 21 | import { NaiveJsonCodec } from '../codec'; 22 | import { Static } from '@sinclair/typebox'; 23 | import { WebSocketClientTransport } from '../transport/impls/ws/client'; 24 | import { ProtocolError } from '../transport/events'; 25 | import NodeWs from 'ws'; 26 | import { createPostTestCleanups } from '../testUtil/fixtures/cleanup'; 27 | import { generateId } from '../transport/id'; 28 | 29 | describe('should handle incompatabilities', async () => { 30 | let server: http.Server; 31 | let port: number; 32 | let wss: NodeWs.Server; 33 | 34 | const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); 35 | beforeEach(async () => { 36 | server = http.createServer(); 37 | port = await onWsServerReady(server); 38 | wss = createWebSocketServer(server); 39 | 40 | return async () => { 41 | await postTestCleanup(); 42 | wss.close(); 43 | server.close(); 44 | }; 45 | }); 46 | 47 | test('cannot get a bound send function on a closed transport', async () => { 48 | const clientTransport = new WebSocketClientTransport( 49 | () => Promise.resolve(createLocalWebSocketClient(port)), 50 | 'client', 51 | ); 52 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 53 | addPostTestCleanup(async () => { 54 | await cleanupTransports([clientTransport, serverTransport]); 55 | }); 56 | 57 | clientTransport.connect(serverTransport.clientId); 58 | const clientSession = clientTransport.sessions.get( 59 | serverTransport.clientId, 60 | ); 61 | assert(clientSession); 62 | 63 | clientTransport.close(); 64 | expect(() => 65 | clientTransport.getSessionBoundSendFn( 66 | serverTransport.clientId, 67 | clientSession.id, 68 | ), 69 | ).toThrow(); 70 | 71 | await testFinishesCleanly({ 72 | clientTransports: [clientTransport], 73 | serverTransport, 74 | }); 75 | }); 76 | 77 | test('retrying single connection attempt should hit retry limit reached', async () => { 78 | const clientTransport = new WebSocketClientTransport( 79 | () => Promise.reject(new Error('fake connection failure')), 80 | 'client', 81 | ); 82 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 83 | const errMock = vi.fn(); 84 | clientTransport.addEventListener('protocolError', errMock); 85 | addPostTestCleanup(async () => { 86 | clientTransport.removeEventListener('protocolError', errMock); 87 | await cleanupTransports([clientTransport, serverTransport]); 88 | }); 89 | 90 | // try connecting and make sure we get the fake connection failure 91 | expect(errMock).toHaveBeenCalledTimes(0); 92 | clientTransport.connect(serverTransport.clientId); 93 | await vi.runAllTimersAsync(); 94 | 95 | await waitFor(() => expect(errMock).toHaveBeenCalledTimes(1)); 96 | expect(errMock).toHaveBeenCalledWith( 97 | expect.objectContaining({ 98 | type: ProtocolError.RetriesExceeded, 99 | }), 100 | ); 101 | 102 | await testFinishesCleanly({ 103 | clientTransports: [clientTransport], 104 | serverTransport, 105 | }); 106 | }); 107 | 108 | test('calling connect consecutively should reuse the same connection', async () => { 109 | let connectCalls = 0; 110 | const clientTransport = new WebSocketClientTransport( 111 | () => { 112 | connectCalls++; 113 | 114 | return Promise.resolve(createLocalWebSocketClient(port)); 115 | }, 116 | 'client', 117 | { attemptBudgetCapacity: 3 }, 118 | ); 119 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 120 | const errMock = vi.fn(); 121 | clientTransport.addEventListener('protocolError', errMock); 122 | addPostTestCleanup(async () => { 123 | clientTransport.removeEventListener('protocolError', errMock); 124 | await cleanupTransports([clientTransport, serverTransport]); 125 | }); 126 | 127 | for (let i = 0; i < 3; i++) { 128 | clientTransport.connect(serverTransport.clientId); 129 | } 130 | 131 | expect(errMock).toHaveBeenCalledTimes(0); 132 | await waitFor(() => expect(numberOfConnections(serverTransport)).toBe(1)); 133 | expect(connectCalls).toBe(1); 134 | 135 | await testFinishesCleanly({ 136 | clientTransports: [clientTransport], 137 | serverTransport, 138 | }); 139 | }); 140 | 141 | test('incorrect client handshake', async () => { 142 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 143 | // add listeners 144 | const spy = vi.fn(); 145 | const errMock = vi.fn(); 146 | serverTransport.addEventListener('sessionStatus', spy); 147 | serverTransport.addEventListener('protocolError', errMock); 148 | addPostTestCleanup(async () => { 149 | serverTransport.removeEventListener('sessionStatus', spy); 150 | serverTransport.removeEventListener('protocolError', errMock); 151 | await cleanupTransports([serverTransport]); 152 | }); 153 | 154 | const ws = createLocalWebSocketClient(port); 155 | await new Promise((resolve) => (ws.onopen = resolve)); 156 | ws.send(Buffer.from('bad handshake')); 157 | 158 | // should never connect 159 | // ws should be closed 160 | await waitFor(() => expect(ws.readyState).toBe(ws.CLOSED)); 161 | expect(numberOfConnections(serverTransport)).toBe(0); 162 | expect(spy).toHaveBeenCalledTimes(0); 163 | expect(errMock).toHaveBeenCalledTimes(1); 164 | expect(errMock).toHaveBeenCalledWith( 165 | expect.objectContaining({ 166 | type: ProtocolError.HandshakeFailed, 167 | }), 168 | ); 169 | 170 | await testFinishesCleanly({ 171 | clientTransports: [], 172 | serverTransport, 173 | }); 174 | }); 175 | 176 | test('seq number in the future should close connection', async () => { 177 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 178 | 179 | // add listeners 180 | const spy = vi.fn(); 181 | const errMock = vi.fn(); 182 | serverTransport.addEventListener('sessionStatus', spy); 183 | serverTransport.addEventListener('protocolError', errMock); 184 | addPostTestCleanup(async () => { 185 | serverTransport.removeEventListener('sessionStatus', spy); 186 | serverTransport.removeEventListener('protocolError', errMock); 187 | await cleanupTransports([serverTransport]); 188 | }); 189 | 190 | const ws = createLocalWebSocketClient(port); 191 | await new Promise((resolve) => (ws.onopen = resolve)); 192 | const requestMsg = handshakeRequestMessage({ 193 | from: 'client', 194 | to: 'SERVER', 195 | expectedSessionState: { 196 | nextExpectedSeq: 0, 197 | nextSentSeq: 0, 198 | }, 199 | sessionId: 'sessionId', 200 | }); 201 | ws.send(NaiveJsonCodec.toBuffer(requestMsg)); 202 | 203 | // wait for both sides to be happy 204 | await waitFor(() => expect(spy).toHaveBeenCalledTimes(1)); 205 | expect(errMock).toHaveBeenCalledTimes(0); 206 | expect(spy).toHaveBeenCalledWith( 207 | expect.objectContaining({ 208 | status: 'created', 209 | }), 210 | ); 211 | 212 | // send one with bad sequence number 213 | const msg: OpaqueTransportMessage = { 214 | id: 'msgid', 215 | to: 'SERVER', 216 | from: 'client', 217 | seq: 50, 218 | ack: 0, 219 | controlFlags: ControlFlags.StreamOpenBit, 220 | streamId: 'streamid', 221 | payload: {}, 222 | }; 223 | ws.send(NaiveJsonCodec.toBuffer(msg)); 224 | 225 | await waitFor(() => ws.readyState === ws.CLOSED); 226 | expect(serverTransport.sessions.size).toBe(1); 227 | 228 | await testFinishesCleanly({ 229 | clientTransports: [], 230 | serverTransport, 231 | }); 232 | }); 233 | 234 | test('mismatched protocol version', async () => { 235 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 236 | // add listeners 237 | const spy = vi.fn(); 238 | const errMock = vi.fn(); 239 | serverTransport.addEventListener('sessionStatus', spy); 240 | serverTransport.addEventListener('protocolError', errMock); 241 | addPostTestCleanup(async () => { 242 | serverTransport.removeEventListener('protocolError', errMock); 243 | serverTransport.removeEventListener('sessionStatus', spy); 244 | await cleanupTransports([serverTransport]); 245 | }); 246 | 247 | const ws = createLocalWebSocketClient(port); 248 | await new Promise((resolve) => (ws.onopen = resolve)); 249 | 250 | const requestMsg = { 251 | id: generateId(), 252 | from: 'client', 253 | to: 'SERVER', 254 | seq: 0, 255 | ack: 0, 256 | streamId: generateId(), 257 | controlFlags: 0, 258 | payload: { 259 | type: 'HANDSHAKE_REQ', 260 | protocolVersion: 'v0', 261 | sessionId: 'sessionId', 262 | expectedSessionState: { 263 | nextExpectedSeq: 0, 264 | nextSentSeq: 0, 265 | }, 266 | } satisfies Static, 267 | }; 268 | ws.send(NaiveJsonCodec.toBuffer(requestMsg)); 269 | 270 | // should never connect 271 | // ws should be closed 272 | await waitFor(() => expect(ws.readyState).toBe(ws.CLOSED)); 273 | expect(numberOfConnections(serverTransport)).toBe(0); 274 | expect(spy).toHaveBeenCalledTimes(0); 275 | expect(errMock).toHaveBeenCalledTimes(1); 276 | expect(errMock).toHaveBeenCalledWith( 277 | expect.objectContaining({ 278 | type: ProtocolError.HandshakeFailed, 279 | }), 280 | ); 281 | 282 | await testFinishesCleanly({ 283 | clientTransports: [], 284 | serverTransport, 285 | }); 286 | }); 287 | }); 288 | -------------------------------------------------------------------------------- /__tests__/serialize.test.ts: -------------------------------------------------------------------------------- 1 | import { expect, describe, test } from 'vitest'; 2 | import { 3 | BinaryFileServiceSchema, 4 | FallibleServiceSchema, 5 | TestServiceSchema, 6 | } from '../testUtil/fixtures/services'; 7 | import { serializeSchema } from '../router'; 8 | import { Type } from '@sinclair/typebox'; 9 | 10 | describe('serialize server to jsonschema', () => { 11 | test('serialize entire service schema', () => { 12 | const schema = { test: TestServiceSchema }; 13 | const handshakeSchema = Type.Object({ 14 | token: Type.String(), 15 | }); 16 | 17 | expect(serializeSchema(schema, handshakeSchema)).toMatchSnapshot(); 18 | }); 19 | }); 20 | 21 | describe('serialize service to jsonschema', () => { 22 | test('serialize basic service', () => { 23 | expect(TestServiceSchema.serialize()).toMatchSnapshot(); 24 | }); 25 | 26 | test('serialize service with binary', () => { 27 | expect(BinaryFileServiceSchema.serialize()).toMatchSnapshot(); 28 | }); 29 | 30 | test('serialize service with errors', () => { 31 | expect(FallibleServiceSchema.serialize()).toMatchSnapshot(); 32 | }); 33 | 34 | test('serialize backwards compatible with v1', () => { 35 | expect(TestServiceSchema.serializeV1Compat()).toMatchSnapshot(); 36 | }); 37 | }); 38 | -------------------------------------------------------------------------------- /codec/adapter.ts: -------------------------------------------------------------------------------- 1 | import { Value } from '@sinclair/typebox/value'; 2 | import { 3 | OpaqueTransportMessage, 4 | OpaqueTransportMessageSchema, 5 | } from '../transport'; 6 | import { Codec } from './types'; 7 | import { DeserializeResult, SerializeResult } from '../transport/results'; 8 | import { coerceErrorString } from '../transport/stringifyError'; 9 | 10 | /** 11 | * Adapts a {@link Codec} to the {@link OpaqueTransportMessage} format, 12 | * accounting for fallibility of toBuffer and fromBuffer and wrapping 13 | * it with a Result type. 14 | */ 15 | export class CodecMessageAdapter { 16 | constructor(private readonly codec: Codec) {} 17 | 18 | toBuffer(msg: OpaqueTransportMessage): SerializeResult { 19 | try { 20 | return { 21 | ok: true, 22 | value: this.codec.toBuffer(msg), 23 | }; 24 | } catch (e) { 25 | return { 26 | ok: false, 27 | reason: coerceErrorString(e), 28 | }; 29 | } 30 | } 31 | 32 | fromBuffer(buf: Uint8Array): DeserializeResult { 33 | try { 34 | const parsedMsg = this.codec.fromBuffer(buf); 35 | if (!Value.Check(OpaqueTransportMessageSchema, parsedMsg)) { 36 | return { 37 | ok: false, 38 | reason: 'transport message schema mismatch', 39 | }; 40 | } 41 | 42 | return { 43 | ok: true, 44 | value: parsedMsg, 45 | }; 46 | } catch (e) { 47 | return { 48 | ok: false, 49 | reason: coerceErrorString(e), 50 | }; 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /codec/binary.ts: -------------------------------------------------------------------------------- 1 | import { decode, encode } from '@msgpack/msgpack'; 2 | import { Codec } from './types'; 3 | 4 | /** 5 | * Binary codec, uses [msgpack](https://www.npmjs.com/package/@msgpack/msgpack) under the hood 6 | * @type {Codec} 7 | */ 8 | export const BinaryCodec: Codec = { 9 | toBuffer(obj) { 10 | return encode(obj, { ignoreUndefined: true }); 11 | }, 12 | fromBuffer: (buff: Uint8Array) => { 13 | const res = decode(buff); 14 | if (typeof res !== 'object' || res === null) { 15 | throw new Error('unpacked msg is not an object'); 16 | } 17 | 18 | return res; 19 | }, 20 | }; 21 | -------------------------------------------------------------------------------- /codec/codec.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, test, expect } from 'vitest'; 2 | import { codecs } from '../testUtil/fixtures/codec'; 3 | 4 | describe.each(codecs)('codec -- $name', ({ codec }) => { 5 | test('empty object', () => { 6 | const msg = {}; 7 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 8 | }); 9 | 10 | test('simple test', () => { 11 | const msg = { abc: 123, def: 'cool' }; 12 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 13 | }); 14 | 15 | test('encodes null properly', () => { 16 | const msg = { test: null }; 17 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 18 | }); 19 | 20 | test('encodes the empty buffer properly', () => { 21 | const msg = { test: new Uint8Array(0) }; 22 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 23 | }); 24 | 25 | test('skips optional fields', () => { 26 | const msg = { test: undefined }; 27 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual({}); 28 | }); 29 | 30 | test('deeply nested test', () => { 31 | const msg = { 32 | array: [{ object: true }], 33 | deeply: { 34 | nested: { 35 | nice: null, 36 | }, 37 | }, 38 | }; 39 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 40 | }); 41 | 42 | test('buffer test', () => { 43 | const msg = { 44 | buff: Uint8Array.from([0, 42, 100, 255]), 45 | }; 46 | expect(codec.fromBuffer(codec.toBuffer(msg))).toStrictEqual(msg); 47 | }); 48 | 49 | test('invalid json throws', () => { 50 | expect(() => codec.fromBuffer(Buffer.from(''))).toThrow(); 51 | expect(() => codec.fromBuffer(Buffer.from('['))).toThrow(); 52 | expect(() => codec.fromBuffer(Buffer.from('[{}'))).toThrow(); 53 | expect(() => codec.fromBuffer(Buffer.from('{"a":1}[]'))).toThrow(); 54 | }); 55 | }); 56 | -------------------------------------------------------------------------------- /codec/index.ts: -------------------------------------------------------------------------------- 1 | export { BinaryCodec } from './binary'; 2 | export { NaiveJsonCodec } from './json'; 3 | export type { Codec } from './types'; 4 | export { CodecMessageAdapter } from './adapter'; 5 | -------------------------------------------------------------------------------- /codec/json.ts: -------------------------------------------------------------------------------- 1 | import { Codec } from './types'; 2 | 3 | const encoder = new TextEncoder(); 4 | const decoder = new TextDecoder(); 5 | 6 | // Convert Uint8Array to base64 7 | function uint8ArrayToBase64(uint8Array: Uint8Array) { 8 | let binary = ''; 9 | uint8Array.forEach((byte) => { 10 | binary += String.fromCharCode(byte); 11 | }); 12 | 13 | return btoa(binary); 14 | } 15 | 16 | // Convert base64 to Uint8Array 17 | function base64ToUint8Array(base64: string) { 18 | const binaryString = atob(base64); 19 | const uint8Array = new Uint8Array(binaryString.length); 20 | for (let i = 0; i < binaryString.length; i++) { 21 | uint8Array[i] = binaryString.charCodeAt(i); 22 | } 23 | 24 | return uint8Array; 25 | } 26 | 27 | interface Base64EncodedValue { 28 | $t: string; 29 | } 30 | 31 | /** 32 | * Naive JSON codec implementation using JSON.stringify and JSON.parse. 33 | * @type {Codec} 34 | */ 35 | export const NaiveJsonCodec: Codec = { 36 | toBuffer: (obj: object) => { 37 | return encoder.encode( 38 | JSON.stringify(obj, function replacer< 39 | T extends object, 40 | >(this: T, key: keyof T) { 41 | const val = this[key]; 42 | if (val instanceof Uint8Array) { 43 | return { $t: uint8ArrayToBase64(val) } satisfies Base64EncodedValue; 44 | } else { 45 | return val; 46 | } 47 | }), 48 | ); 49 | }, 50 | fromBuffer: (buff: Uint8Array) => { 51 | const parsed = JSON.parse( 52 | decoder.decode(buff), 53 | function reviver(_key, val: unknown) { 54 | if ((val as Base64EncodedValue | undefined)?.$t !== undefined) { 55 | return base64ToUint8Array((val as Base64EncodedValue).$t); 56 | } else { 57 | return val; 58 | } 59 | }, 60 | ) as unknown; 61 | 62 | if (typeof parsed !== 'object' || parsed === null) { 63 | throw new Error('unpacked msg is not an object'); 64 | } 65 | 66 | return parsed; 67 | }, 68 | }; 69 | -------------------------------------------------------------------------------- /codec/types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Codec interface for encoding and decoding objects to and from Uint8 buffers. 3 | * Used to prepare messages for use by the transport layer. 4 | */ 5 | export interface Codec { 6 | /** 7 | * Encodes an object to a Uint8 buffer. 8 | * @param obj - The object to encode. 9 | * @returns The encoded Uint8 buffer. 10 | */ 11 | toBuffer(obj: object): Uint8Array; 12 | /** 13 | * Decodes an object from a Uint8 buffer. 14 | * @param buf - The Uint8 buffer to decode. 15 | * @returns The decoded object, or null if decoding failed. 16 | */ 17 | fromBuffer(buf: Uint8Array): object; 18 | } 19 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "nixpkgs": { 4 | "locked": { 5 | "lastModified": 1716479278, 6 | "narHash": "sha256-2eh7rYxQOntkUjFXtlPH7lBuUDd4isu/YHRjNJW7u1Q=", 7 | "owner": "nixos", 8 | "repo": "nixpkgs", 9 | "rev": "2ee89d5a0167a8aa0f2a5615d2b8aefb1f299cd4", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "nixos", 14 | "repo": "nixpkgs", 15 | "type": "github" 16 | } 17 | }, 18 | "root": { 19 | "inputs": { 20 | "nixpkgs": "nixpkgs" 21 | } 22 | } 23 | }, 24 | "root": "root", 25 | "version": 7 26 | } 27 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!"; 3 | 4 | inputs.nixpkgs.url = "github:nixos/nixpkgs"; 5 | 6 | outputs = { self, nixpkgs }: 7 | let 8 | mkDevShell = system: 9 | let 10 | pkgs = nixpkgs.legacyPackages.${system}; 11 | in 12 | pkgs.mkShell { 13 | nativeBuildInputs = with pkgs; [ 14 | nodejs 15 | nodePackages.typescript-language-server 16 | ]; 17 | }; 18 | in 19 | { 20 | devShells.aarch64-linux.default = mkDevShell "aarch64-linux"; 21 | devShells.aarch64-darwin.default = mkDevShell "aarch64-darwin"; 22 | devShells.x86_64-linux.default = mkDevShell "x86_64-linux"; 23 | devShells.x86_64-darwin.default = mkDevShell "x86_64-darwin"; 24 | }; 25 | } 26 | -------------------------------------------------------------------------------- /flake.sh: -------------------------------------------------------------------------------- 1 | for run in {1..50}; do 2 | npm run test:single || { echo 'flake detected :((' ; exit 1; }; 3 | sleep 0.5; 4 | done 5 | -------------------------------------------------------------------------------- /logging/index.ts: -------------------------------------------------------------------------------- 1 | export { stringLogger, coloredStringLogger, jsonLogger } from './log'; 2 | export type { Logger, LogFn, MessageMetadata } from './log'; 3 | -------------------------------------------------------------------------------- /logging/log.ts: -------------------------------------------------------------------------------- 1 | import { OpaqueTransportMessage, ProtocolVersion } from '../transport/message'; 2 | import { context, trace } from '@opentelemetry/api'; 3 | 4 | const LoggingLevels = { 5 | debug: -1, 6 | info: 0, 7 | warn: 1, 8 | error: 2, 9 | } as const; 10 | export type LoggingLevel = keyof typeof LoggingLevels; 11 | 12 | export type LogFn = ( 13 | msg: string, 14 | ctx?: MessageMetadata, 15 | level?: LoggingLevel, 16 | ) => void; 17 | export type Logger = { 18 | [key in LoggingLevel]: (msg: string, metadata?: MessageMetadata) => void; 19 | }; 20 | 21 | export type Tags = 22 | | 'invariant-violation' 23 | | 'state-transition' 24 | | 'invalid-request' 25 | | 'unhealthy-session' 26 | | 'uncaught-handler-error'; 27 | 28 | const cleanedLogFn = (log: LogFn) => { 29 | return (msg: string, metadata?: MessageMetadata) => { 30 | // try to infer telemetry 31 | if (metadata && !metadata.telemetry) { 32 | const span = trace.getSpan(context.active()); 33 | if (span) { 34 | metadata.telemetry = { 35 | traceId: span.spanContext().traceId, 36 | spanId: span.spanContext().spanId, 37 | }; 38 | } 39 | } 40 | 41 | // skip cloning object if metadata has no transportMessage 42 | if (!metadata?.transportMessage) { 43 | log(msg, metadata); 44 | 45 | return; 46 | } 47 | 48 | // clone metadata and clean transportMessage 49 | const { payload, ...rest } = metadata.transportMessage; 50 | metadata.transportMessage = rest; 51 | 52 | log(msg, metadata); 53 | }; 54 | }; 55 | 56 | export type MessageMetadata = Partial<{ 57 | protocolVersion: ProtocolVersion; 58 | clientId: string; 59 | connectedTo: string; 60 | sessionId: string; 61 | connId: string; 62 | transportMessage: Partial; 63 | validationErrors: Array<{ path: string; message: string }>; 64 | tags: Array; 65 | telemetry: { 66 | traceId: string; 67 | spanId: string; 68 | }; 69 | extras?: Record; 70 | }>; 71 | 72 | export class BaseLogger implements Logger { 73 | minLevel: LoggingLevel; 74 | private output: LogFn; 75 | 76 | constructor(output: LogFn, minLevel: LoggingLevel = 'info') { 77 | this.minLevel = minLevel; 78 | this.output = output; 79 | } 80 | 81 | debug(msg: string, metadata?: MessageMetadata) { 82 | if (LoggingLevels[this.minLevel] <= LoggingLevels.debug) { 83 | this.output(msg, metadata ?? {}, 'debug'); 84 | } 85 | } 86 | 87 | info(msg: string, metadata?: MessageMetadata) { 88 | if (LoggingLevels[this.minLevel] <= LoggingLevels.info) { 89 | this.output(msg, metadata ?? {}, 'info'); 90 | } 91 | } 92 | 93 | warn(msg: string, metadata?: MessageMetadata) { 94 | if (LoggingLevels[this.minLevel] <= LoggingLevels.warn) { 95 | this.output(msg, metadata ?? {}, 'warn'); 96 | } 97 | } 98 | 99 | error(msg: string, metadata?: MessageMetadata) { 100 | if (LoggingLevels[this.minLevel] <= LoggingLevels.error) { 101 | this.output(msg, metadata ?? {}, 'error'); 102 | } 103 | } 104 | } 105 | 106 | export const stringLogger: LogFn = (msg, ctx, level = 'info') => { 107 | const from = ctx?.clientId ? `${ctx.clientId} -- ` : ''; 108 | console.log(`[river:${level}] ${from}${msg}`); 109 | }; 110 | 111 | const colorMap = { 112 | debug: '\u001b[34m', 113 | info: '\u001b[32m', 114 | warn: '\u001b[33m', 115 | error: '\u001b[31m', 116 | }; 117 | 118 | export const coloredStringLogger: LogFn = (msg, ctx, level = 'info') => { 119 | const color = colorMap[level]; 120 | const from = ctx?.clientId ? `${ctx.clientId} -- ` : ''; 121 | console.log(`[river:${color}${level}\u001b[0m] ${from}${msg}`); 122 | }; 123 | 124 | export const jsonLogger: LogFn = (msg, ctx, level) => { 125 | console.log(JSON.stringify({ msg, ctx, level })); 126 | }; 127 | 128 | export const createLogProxy = (log: Logger) => ({ 129 | debug: cleanedLogFn(log.debug.bind(log)), 130 | info: cleanedLogFn(log.info.bind(log)), 131 | warn: cleanedLogFn(log.warn.bind(log)), 132 | error: cleanedLogFn(log.error.bind(log)), 133 | }); 134 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@replit/river", 3 | "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", 4 | "version": "0.208.3", 5 | "type": "module", 6 | "exports": { 7 | ".": { 8 | "import": "./dist/router/index.js", 9 | "require": "./dist/router/index.cjs" 10 | }, 11 | "./logging": { 12 | "import": "./dist/logging/index.js", 13 | "require": "./dist/logging/index.cjs" 14 | }, 15 | "./codec": { 16 | "import": "./dist/codec/index.js", 17 | "require": "./dist/codec/index.cjs" 18 | }, 19 | "./transport": { 20 | "import": "./dist/transport/index.js", 21 | "require": "./dist/transport/index.cjs" 22 | }, 23 | "./transport/ws/client": { 24 | "import": "./dist/transport/impls/ws/client.js", 25 | "require": "./dist/transport/impls/ws/client.cjs" 26 | }, 27 | "./transport/ws/server": { 28 | "import": "./dist/transport/impls/ws/server.js", 29 | "require": "./dist/transport/impls/ws/server.cjs" 30 | }, 31 | "./transport/uds/client": { 32 | "import": "./dist/transport/impls/uds/client.js", 33 | "require": "./dist/transport/impls/uds/client.cjs" 34 | }, 35 | "./transport/uds/server": { 36 | "import": "./dist/transport/impls/uds/server.js", 37 | "require": "./dist/transport/impls/uds/server.cjs" 38 | }, 39 | "./test-util": { 40 | "import": "./dist/testUtil/index.js", 41 | "require": "./dist/testUtil/index.cjs" 42 | } 43 | }, 44 | "sideEffects": [ 45 | "./dist/logging/index.js" 46 | ], 47 | "files": [ 48 | "dist" 49 | ], 50 | "dependencies": { 51 | "@msgpack/msgpack": "^3.0.0-beta2", 52 | "nanoid": "^5.0.9", 53 | "ws": "^8.17.0" 54 | }, 55 | "peerDependencies": { 56 | "@opentelemetry/api": "^1.7.0", 57 | "@sinclair/typebox": "~0.34.0" 58 | }, 59 | "devDependencies": { 60 | "@opentelemetry/api": "^1.7.0", 61 | "@opentelemetry/context-async-hooks": "^1.26.0", 62 | "@opentelemetry/core": "^1.7.0", 63 | "@opentelemetry/sdk-trace-base": "^1.24.1", 64 | "@sinclair/typebox": "~0.34.0", 65 | "@stylistic/eslint-plugin": "^2.6.4", 66 | "@types/ws": "^8.5.5", 67 | "@typescript-eslint/eslint-plugin": "^7.8.0", 68 | "@typescript-eslint/parser": "^7.8.0", 69 | "@vitest/ui": "^3.1.1", 70 | "eslint": "^8.57.0", 71 | "eslint-config-prettier": "^9.1.0", 72 | "eslint-plugin-prettier": "^5.1.3", 73 | "prettier": "^3.0.0", 74 | "tsup": "^8.4.0", 75 | "typescript": "^5.4.5", 76 | "vitest": "^3.1.1" 77 | }, 78 | "scripts": { 79 | "check": "tsc --noEmit && npm run format && npm run lint", 80 | "format": "npx prettier . --check", 81 | "format:fix": "npx prettier . --write", 82 | "lint": "eslint .", 83 | "lint:fix": "eslint . --fix", 84 | "fix": "npm run format:fix && npm run lint:fix", 85 | "build": "rm -rf dist && tsup && du -sh dist", 86 | "prepack": "npm run build", 87 | "release": "npm publish --access public", 88 | "test:ui": "echo \"remember to go to /__vitest__ in the webview\" && vitest --ui --api.host 0.0.0.0 --api.port 3000", 89 | "test": "vitest", 90 | "test:single": "vitest run --reporter=dot", 91 | "test:flake": "./flake.sh", 92 | "bench": "vitest bench" 93 | }, 94 | "engines": { 95 | "node": ">=16" 96 | }, 97 | "keywords": [ 98 | "rpc", 99 | "websockets", 100 | "jsonschema" 101 | ], 102 | "author": "Jacky Zhao", 103 | "license": "MIT" 104 | } 105 | -------------------------------------------------------------------------------- /replit.nix: -------------------------------------------------------------------------------- 1 | { pkgs }: { 2 | deps = [ 3 | pkgs.nodePackages.vscode-langservers-extracted 4 | ]; 5 | } 6 | -------------------------------------------------------------------------------- /router/context.ts: -------------------------------------------------------------------------------- 1 | import { Span } from '@opentelemetry/api'; 2 | import { TransportClientId } from '../transport/message'; 3 | import { SessionId } from '../transport/sessionStateMachine/common'; 4 | import { ErrResult } from './result'; 5 | import { CancelErrorSchema } from './errors'; 6 | import { Static } from '@sinclair/typebox'; 7 | 8 | /** 9 | * ServiceContext exist for the purpose of declaration merging 10 | * to extend the context with additional properties. 11 | * 12 | * For example: 13 | * 14 | * ```ts 15 | * declare module '@replit/river' { 16 | * interface ServiceContext { 17 | * db: Database; 18 | * } 19 | * } 20 | * 21 | * createServer(someTransport, myServices, { extendedContext: { db: myDb } }); 22 | * ``` 23 | * 24 | * Once you do this, your {@link ProcedureHandlerContext} will have `db` property on it. 25 | */ 26 | /* eslint-disable-next-line @typescript-eslint/no-empty-interface */ 27 | export interface ServiceContext {} 28 | 29 | /** 30 | * The parsed metadata schema for a service. This is the 31 | * return value of the {@link ServerHandshakeOptions.validate} 32 | * if the handshake extension is used. 33 | * 34 | * You should use declaration merging to extend this interface 35 | * with the sanitized metadata. 36 | * 37 | * ```ts 38 | * declare module '@replit/river' { 39 | * interface ParsedMetadata { 40 | * userId: number; 41 | * } 42 | * } 43 | * ``` 44 | */ 45 | /* eslint-disable-next-line @typescript-eslint/no-empty-interface */ 46 | export interface ParsedMetadata extends Record {} 47 | 48 | /** 49 | * This is passed to every procedure handler and contains various context-level 50 | * information and utilities. This may be extended, see {@link ServiceContext} 51 | */ 52 | export type ProcedureHandlerContext = ServiceContext & { 53 | /** 54 | * State for this service as defined by the service definition. 55 | */ 56 | state: State; 57 | /** 58 | * The span for this procedure call. You can use this to add attributes, events, and 59 | * links to the span. 60 | */ 61 | span: Span; 62 | /** 63 | * Metadata parsed on the server. See {@link ParsedMetadata} 64 | */ 65 | metadata: ParsedMetadata; 66 | /** 67 | * The ID of the session that sent this request. 68 | */ 69 | sessionId: SessionId; 70 | /** 71 | * The ID of the client that sent this request. There may be multiple sessions per client. 72 | */ 73 | from: TransportClientId; 74 | /** 75 | * This is used to cancel the procedure call from the handler and notify the client that the 76 | * call was cancelled. 77 | * 78 | * Cancelling is not the same as closing procedure calls gracefully, please refer to 79 | * the river documentation to understand the difference between the two concepts. 80 | */ 81 | cancel: (message?: string) => ErrResult>; 82 | /** 83 | * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) 84 | * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing 85 | * for _any_ reason, for example: 86 | * - client explicit cancellation 87 | * - procedure handler explicit cancellation via {@link cancel} 88 | * - client session disconnect 89 | * - server cancellation due to client invalid payload 90 | * - invocation finishes cleanly, this depends on the type of the procedure (i.e. rpc handler return, or in a stream after the client-side has closed the request writable and the server-side has closed the response writable) 91 | * 92 | * You can use this to pass it on to asynchronous operations (such as fetch). 93 | * 94 | * You may also want to explicitly register callbacks on the 95 | * ['abort' event](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/abort_event) 96 | * as a way to cleanup after the request is finished. 97 | * 98 | * Note that (per standard AbortSignals) callbacks registered _after_ the procedure invocation 99 | * is done are not triggered. In such cases, you can check the "aborted" property and cleanup 100 | * immediately if needed. 101 | */ 102 | signal: AbortSignal; 103 | }; 104 | -------------------------------------------------------------------------------- /router/errors.ts: -------------------------------------------------------------------------------- 1 | import { 2 | Kind, 3 | Static, 4 | TLiteral, 5 | TNever, 6 | TObject, 7 | TSchema, 8 | TString, 9 | TUnion, 10 | Type, 11 | } from '@sinclair/typebox'; 12 | import { ValueErrorIterator } from '@sinclair/typebox/errors'; 13 | 14 | /** 15 | * {@link UNCAUGHT_ERROR_CODE} is the code that is used when an error is thrown 16 | * inside a procedure handler that's not required. 17 | */ 18 | export const UNCAUGHT_ERROR_CODE = 'UNCAUGHT_ERROR'; 19 | /** 20 | * {@link UNEXPECTED_DISCONNECT_CODE} is the code used the stream's session 21 | * disconnect unexpetedly. 22 | */ 23 | export const UNEXPECTED_DISCONNECT_CODE = 'UNEXPECTED_DISCONNECT'; 24 | /** 25 | * {@link INVALID_REQUEST_CODE} is the code used when a client's request is invalid. 26 | */ 27 | export const INVALID_REQUEST_CODE = 'INVALID_REQUEST'; 28 | /** 29 | * {@link CANCEL_CODE} is the code used when either server or client cancels the stream. 30 | */ 31 | export const CANCEL_CODE = 'CANCEL'; 32 | 33 | type TLiteralString = TLiteral; 34 | 35 | export type BaseErrorSchemaType = 36 | | TObject<{ 37 | code: TLiteralString; 38 | message: TLiteralString | TString; 39 | }> 40 | | TObject<{ 41 | code: TLiteralString; 42 | message: TLiteralString | TString; 43 | extras: TSchema; 44 | }>; 45 | 46 | /** 47 | * Takes in a specific error schema and returns a result schema the error 48 | */ 49 | export const ErrResultSchema = (t: T) => 50 | Type.Object({ 51 | ok: Type.Literal(false), 52 | payload: t, 53 | }); 54 | 55 | const ValidationErrorDetails = Type.Object({ 56 | path: Type.String(), 57 | message: Type.String(), 58 | }); 59 | 60 | export const ValidationErrors = Type.Array(ValidationErrorDetails); 61 | export function castTypeboxValueErrors( 62 | errors: ValueErrorIterator, 63 | ): Static { 64 | const result = []; 65 | for (const error of errors) { 66 | result.push({ 67 | path: error.path, 68 | message: error.message, 69 | }); 70 | } 71 | 72 | return result; 73 | } 74 | 75 | /** 76 | * A schema for cancel payloads sent from the client 77 | */ 78 | export const CancelErrorSchema = Type.Object({ 79 | code: Type.Literal(CANCEL_CODE), 80 | message: Type.String(), 81 | }); 82 | 83 | export const CancelResultSchema = ErrResultSchema(CancelErrorSchema); 84 | 85 | /** 86 | * {@link ReaderErrorSchema} is the schema for all the built-in river errors that 87 | * can be emitted to a reader (request reader on the server, and response reader 88 | * on the client). 89 | */ 90 | export const ReaderErrorSchema = Type.Union([ 91 | Type.Object({ 92 | code: Type.Literal(UNCAUGHT_ERROR_CODE), 93 | message: Type.String(), 94 | }), 95 | Type.Object({ 96 | code: Type.Literal(UNEXPECTED_DISCONNECT_CODE), 97 | message: Type.String(), 98 | }), 99 | Type.Object({ 100 | code: Type.Literal(INVALID_REQUEST_CODE), 101 | message: Type.String(), 102 | extras: Type.Optional( 103 | Type.Object({ 104 | firstValidationErrors: Type.Array(ValidationErrorDetails), 105 | totalErrors: Type.Number(), 106 | }), 107 | ), 108 | }), 109 | CancelErrorSchema, 110 | ]) satisfies ProcedureErrorSchemaType; 111 | 112 | export const ReaderErrorResultSchema = ErrResultSchema(ReaderErrorSchema); 113 | 114 | /** 115 | * Represents an acceptable schema to pass to a procedure. 116 | * Just a type of a schema, not an actual schema. 117 | * 118 | */ 119 | export type ProcedureErrorSchemaType = 120 | | TNever 121 | | BaseErrorSchemaType 122 | | TUnion>; 123 | 124 | // arbitrarily nested unions 125 | // river doesn't accept this by default, use the `flattenErrorType` helper 126 | type NestableProcedureErrorSchemaType = 127 | | BaseErrorSchemaType 128 | | TUnion; 129 | 130 | // use an interface to defer the type definition to be evaluated lazily 131 | // eslint-disable-next-line @typescript-eslint/no-empty-interface 132 | interface NestableProcedureErrorSchemaTypeArray 133 | extends Array {} 134 | 135 | function isUnion(schema: TSchema): schema is TUnion { 136 | return schema[Kind] === 'Union'; 137 | } 138 | 139 | type Flatten = T extends BaseErrorSchemaType 140 | ? T 141 | : T extends TUnion> 142 | ? Flatten 143 | : unknown; 144 | 145 | /** 146 | * In the case where API consumers for some god-forsaken reason want to use 147 | * arbitrarily nested unions, this helper flattens them to a single level. 148 | * 149 | * Note that loses some metadata information on the nested unions like 150 | * nested description fields, etc. 151 | * 152 | * @param errType - An arbitrarily union-nested error schema. 153 | * @returns The flattened error schema. 154 | */ 155 | export function flattenErrorType( 156 | errType: T, 157 | ): Flatten; 158 | export function flattenErrorType( 159 | errType: NestableProcedureErrorSchemaType, 160 | ): ProcedureErrorSchemaType { 161 | if (!isUnion(errType)) { 162 | return errType; 163 | } 164 | 165 | const flattenedTypes: Array = []; 166 | function flatten(type: NestableProcedureErrorSchemaType) { 167 | if (isUnion(type)) { 168 | for (const t of type.anyOf) { 169 | flatten(t); 170 | } 171 | } else { 172 | flattenedTypes.push(type); 173 | } 174 | } 175 | 176 | flatten(errType); 177 | 178 | return Type.Union(flattenedTypes); 179 | } 180 | -------------------------------------------------------------------------------- /router/handshake.ts: -------------------------------------------------------------------------------- 1 | import { Static, TSchema } from '@sinclair/typebox'; 2 | import { ParsedMetadata } from './context'; 3 | import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message'; 4 | 5 | type ConstructHandshake = () => 6 | | Static 7 | | Promise>; 8 | 9 | type ValidateHandshake = ( 10 | metadata: Static, 11 | previousParsedMetadata?: ParsedMetadata, 12 | ) => 13 | | Static 14 | | ParsedMetadata 15 | | Promise< 16 | | Static 17 | | ParsedMetadata 18 | >; 19 | 20 | export interface ClientHandshakeOptions< 21 | MetadataSchema extends TSchema = TSchema, 22 | > { 23 | /** 24 | * Schema for the metadata that the client sends to the server 25 | * during the handshake. 26 | */ 27 | schema: MetadataSchema; 28 | 29 | /** 30 | * Gets the {@link HandshakeRequestMetadata} to send to the server. 31 | */ 32 | construct: ConstructHandshake; 33 | } 34 | 35 | export interface ServerHandshakeOptions< 36 | MetadataSchema extends TSchema = TSchema, 37 | > { 38 | /** 39 | * Schema for the metadata that the server receives from the client 40 | * during the handshake. 41 | */ 42 | schema: MetadataSchema; 43 | 44 | /** 45 | * Parses the {@link HandshakeRequestMetadata} sent by the client, transforming 46 | * it into {@link ParsedHandshakeMetadata}. 47 | * 48 | * May return `false` if the client should be rejected. 49 | * 50 | * @param metadata - The metadata sent by the client. 51 | * @param session - The session that the client would be associated with. 52 | * @param isReconnect - Whether the client is reconnecting to the session, 53 | * or if this is a new session. 54 | */ 55 | validate: ValidateHandshake; 56 | } 57 | 58 | export function createClientHandshakeOptions< 59 | MetadataSchema extends TSchema = TSchema, 60 | >( 61 | schema: MetadataSchema, 62 | construct: ConstructHandshake, 63 | ): ClientHandshakeOptions { 64 | return { schema, construct }; 65 | } 66 | 67 | export function createServerHandshakeOptions< 68 | MetadataSchema extends TSchema = TSchema, 69 | >( 70 | schema: MetadataSchema, 71 | validate: ValidateHandshake, 72 | ): ServerHandshakeOptions { 73 | return { schema, validate: validate as ValidateHandshake }; 74 | } 75 | -------------------------------------------------------------------------------- /router/index.ts: -------------------------------------------------------------------------------- 1 | export type { 2 | Service, 3 | ServiceConfiguration, 4 | ProcHandler, 5 | ProcInit, 6 | ProcRequest, 7 | ProcResponse, 8 | ProcErrors, 9 | ProcType, 10 | } from './services'; 11 | export { 12 | ServiceSchema, 13 | serializeSchema, 14 | SerializedServerSchema, 15 | SerializedServiceSchema, 16 | SerializedProcedureSchema, 17 | serializeSchemaV1Compat, 18 | SerializedServerSchemaProtocolv1, 19 | SerializedServiceSchemaProtocolv1, 20 | SerializedProcedureSchemaProtocolv1, 21 | } from './services'; 22 | export type { 23 | ValidProcType, 24 | PayloadType, 25 | ProcedureMap, 26 | RpcProcedure as RPCProcedure, 27 | UploadProcedure, 28 | SubscriptionProcedure, 29 | StreamProcedure, 30 | } from './procedures'; 31 | export type { Writable, Readable } from './streams'; 32 | export { Procedure } from './procedures'; 33 | export { 34 | ProcedureErrorSchemaType, 35 | flattenErrorType, 36 | UNCAUGHT_ERROR_CODE, 37 | UNEXPECTED_DISCONNECT_CODE, 38 | INVALID_REQUEST_CODE, 39 | CANCEL_CODE, 40 | ReaderErrorSchema, 41 | BaseErrorSchemaType, 42 | } from './errors'; 43 | export { createClient } from './client'; 44 | export type { Client } from './client'; 45 | export { createServer } from './server'; 46 | export type { 47 | Server, 48 | Middleware, 49 | MiddlewareParam, 50 | MiddlewareContext, 51 | } from './server'; 52 | export type { 53 | ParsedMetadata, 54 | ServiceContext, 55 | ProcedureHandlerContext, 56 | } from './context'; 57 | export { Ok, Err } from './result'; 58 | export type { 59 | Result, 60 | ErrResult, 61 | OkResult, 62 | ResultUnwrapOk, 63 | ResultUnwrapErr, 64 | ResponseData, 65 | } from './result'; 66 | export { 67 | createClientHandshakeOptions, 68 | createServerHandshakeOptions, 69 | } from './handshake'; 70 | export { version as RIVER_VERSION } from '../package.json'; 71 | -------------------------------------------------------------------------------- /router/result.ts: -------------------------------------------------------------------------------- 1 | import { Static, Type } from '@sinclair/typebox'; 2 | import { Client } from './client'; 3 | import { Readable } from './streams'; 4 | import { BaseErrorSchemaType } from './errors'; 5 | 6 | /** 7 | * AnyResultSchema is a schema to validate any result. 8 | */ 9 | export const AnyResultSchema = Type.Union([ 10 | Type.Object({ 11 | ok: Type.Literal(false), 12 | payload: Type.Object({ 13 | code: Type.String(), 14 | message: Type.String(), 15 | extras: Type.Optional(Type.Unknown()), 16 | }), 17 | }), 18 | 19 | Type.Object({ 20 | ok: Type.Literal(true), 21 | payload: Type.Unknown(), 22 | }), 23 | ]); 24 | 25 | export interface OkResult { 26 | ok: true; 27 | payload: T; 28 | } 29 | export interface ErrResult> { 30 | ok: false; 31 | payload: Err; 32 | } 33 | export type Result> = 34 | | OkResult 35 | | ErrResult; 36 | 37 | export function Ok>(p: T): OkResult; 38 | export function Ok>(p: T): OkResult; 39 | export function Ok(payload: T): OkResult; 40 | export function Ok(payload: T): OkResult { 41 | return { 42 | ok: true, 43 | payload, 44 | }; 45 | } 46 | 47 | export function Err>( 48 | error: Err, 49 | ): ErrResult { 50 | return { 51 | ok: false, 52 | payload: error, 53 | }; 54 | } 55 | 56 | /** 57 | * Refine a {@link Result} type to its returned payload. 58 | */ 59 | export type ResultUnwrapOk = R extends Result 60 | ? T 61 | : never; 62 | 63 | /** 64 | * Unwrap a {@link Result} type and return the payload if successful, 65 | * otherwise throws an error. 66 | * @param result - The result to unwrap. 67 | * @throws Will throw an error if the result is not ok. 68 | */ 69 | export function unwrapOrThrow>( 70 | result: Result, 71 | ): T { 72 | if (result.ok) { 73 | return result.payload; 74 | } 75 | 76 | throw new Error( 77 | `Cannot non-ok result, got: ${result.payload.code} - ${result.payload.message}`, 78 | ); 79 | } 80 | 81 | /** 82 | * Refine a {@link Result} type to its error payload. 83 | */ 84 | export type ResultUnwrapErr = R extends Result 85 | ? Err 86 | : never; 87 | 88 | /** 89 | * Retrieve the response type for a procedure, represented as a {@link Result} 90 | * type. 91 | * Example: 92 | * ``` 93 | * type Message = ResponseData 94 | * ``` 95 | */ 96 | export type ResponseData< 97 | RiverClient, 98 | ServiceName extends keyof RiverClient, 99 | ProcedureName extends keyof RiverClient[ServiceName], 100 | Procedure = RiverClient[ServiceName][ProcedureName], 101 | Fn extends (...args: never) => unknown = (...args: never) => unknown, 102 | > = RiverClient extends Client 103 | ? Procedure extends object 104 | ? Procedure extends object & { rpc: infer RpcFn extends Fn } 105 | ? Awaited> 106 | : Procedure extends object & { upload: infer UploadFn extends Fn } 107 | ? ReturnType extends { 108 | finalize: (...args: never) => Promise; 109 | } 110 | ? UploadOutputMessage 111 | : never 112 | : Procedure extends object & { stream: infer StreamFn extends Fn } 113 | ? ReturnType extends { 114 | resReadable: Readable< 115 | infer StreamOutputMessage, 116 | Static 117 | >; 118 | } 119 | ? StreamOutputMessage 120 | : never 121 | : Procedure extends object & { 122 | subscribe: infer SubscriptionFn extends Fn; 123 | } 124 | ? Awaited> extends { 125 | resReadable: Readable< 126 | infer SubscriptionOutputMessage, 127 | Static 128 | >; 129 | } 130 | ? SubscriptionOutputMessage 131 | : never 132 | : never 133 | : never 134 | : never; 135 | -------------------------------------------------------------------------------- /testUtil/duplex/duplexPair.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, expect, test } from 'vitest'; 2 | import { duplexPair } from './duplexPair'; 3 | 4 | describe('duplexPair', () => { 5 | test('should create a pair of duplex streams', () => { 6 | const [a, b] = duplexPair(); 7 | expect(a).toBeDefined(); 8 | expect(b).toBeDefined(); 9 | 10 | a.write(Uint8Array.from([0x00, 0x01, 0x02])); 11 | expect(b.read()).toStrictEqual(Buffer.from([0x00, 0x01, 0x02])); 12 | 13 | b.write(Uint8Array.from([0x03, 0x04, 0x05])); 14 | expect(a.read()).toStrictEqual(Buffer.from([0x03, 0x04, 0x05])); 15 | }); 16 | }); 17 | -------------------------------------------------------------------------------- /testUtil/duplex/duplexPair.ts: -------------------------------------------------------------------------------- 1 | import { Duplex } from 'node:stream'; 2 | import assert from 'assert'; 3 | 4 | const kCallback = Symbol('Callback'); 5 | const kInitOtherSide = Symbol('InitOtherSide'); 6 | 7 | // yoinked from https://github.com/nodejs/node/blob/c3a7b29e56a5ada6327ebb622ba746d022685742/lib/internal/streams/duplexpair.js#L55 8 | // but with types 9 | class DuplexSide extends Duplex { 10 | private otherSide: DuplexSide | null; 11 | private [kCallback]: (() => void) | null; 12 | 13 | constructor() { 14 | super(); 15 | this[kCallback] = null; 16 | this.otherSide = null; 17 | } 18 | 19 | [kInitOtherSide](otherSide: DuplexSide) { 20 | if (this.otherSide === null) { 21 | this.otherSide = otherSide; 22 | } 23 | } 24 | 25 | _read() { 26 | const callback = this[kCallback]; 27 | if (callback) { 28 | this[kCallback] = null; 29 | callback(); 30 | } 31 | } 32 | 33 | _write( 34 | chunk: Uint8Array, 35 | _encoding: BufferEncoding, 36 | callback: (error?: Error | null) => void, 37 | ) { 38 | assert(this.otherSide !== null); 39 | assert(this.otherSide[kCallback] === null); 40 | if (chunk.length === 0) { 41 | process.nextTick(callback); 42 | } else { 43 | this.otherSide.push(chunk); 44 | this.otherSide[kCallback] = callback; 45 | } 46 | } 47 | 48 | _final(callback: (error?: Error | null) => void) { 49 | this.otherSide?.on('end', callback); 50 | this.otherSide?.push(null); 51 | } 52 | } 53 | 54 | export function duplexPair(): [DuplexSide, DuplexSide] { 55 | const side0 = new DuplexSide(); 56 | const side1 = new DuplexSide(); 57 | side0[kInitOtherSide](side1); 58 | side1[kInitOtherSide](side0); 59 | side0.on('close', () => { 60 | setImmediate(() => { 61 | side1.destroy(); 62 | }); 63 | }); 64 | 65 | side1.on('close', () => { 66 | setImmediate(() => { 67 | side0.destroy(); 68 | }); 69 | }); 70 | 71 | return [side0, side1]; 72 | } 73 | -------------------------------------------------------------------------------- /testUtil/fixtures/cleanup.ts: -------------------------------------------------------------------------------- 1 | import { expect, vi } from 'vitest'; 2 | import { 3 | ClientTransport, 4 | Connection, 5 | OpaqueTransportMessage, 6 | ServerTransport, 7 | Transport, 8 | } from '../../transport'; 9 | import { Server } from '../../router'; 10 | import { AnyServiceSchemaMap } from '../../router/services'; 11 | import { numberOfConnections, testingSessionOptions } from '..'; 12 | import { Value } from '@sinclair/typebox/value'; 13 | import { ControlMessageAckSchema } from '../../transport/message'; 14 | 15 | const waitUntilOptions = { 16 | timeout: 500, // account for possibility of conn backoff 17 | interval: 5, // check every 5ms 18 | }; 19 | 20 | export async function advanceFakeTimersByHeartbeat() { 21 | await vi.advanceTimersByTimeAsync(testingSessionOptions.heartbeatIntervalMs); 22 | } 23 | 24 | export async function advanceFakeTimersByDisconnectGrace() { 25 | for (let i = 0; i < testingSessionOptions.heartbeatsUntilDead + 1; i++) { 26 | await advanceFakeTimersByHeartbeat(); 27 | } 28 | } 29 | 30 | export async function advanceFakeTimersBySessionGrace() { 31 | await vi.advanceTimersByTimeAsync( 32 | testingSessionOptions.sessionDisconnectGraceMs, 33 | ); 34 | } 35 | 36 | export async function advanceFakeTimersByConnectionBackoff() { 37 | await vi.advanceTimersByTimeAsync(500); 38 | } 39 | 40 | export async function ensureTransportIsClean(t: Transport) { 41 | await advanceFakeTimersBySessionGrace(); 42 | await waitFor(() => 43 | expect( 44 | t.sessions, 45 | `[post-test cleanup] transport ${t.clientId} should not have open sessions after the test`, 46 | ).toStrictEqual(new Map()), 47 | ); 48 | await waitFor(() => 49 | expect( 50 | numberOfConnections(t), 51 | `[post-test cleanup] transport ${t.clientId} should not have open connections after the test`, 52 | ).toBe(0), 53 | ); 54 | } 55 | 56 | export function waitFor(cb: () => T | Promise) { 57 | return vi.waitFor(cb, waitUntilOptions); 58 | } 59 | 60 | export async function ensureTransportBuffersAreEventuallyEmpty( 61 | t: Transport, 62 | ) { 63 | // wait for send buffers to be flushed 64 | // ignore heartbeat messages 65 | await waitFor(() => 66 | expect( 67 | new Map( 68 | [...t.sessions] 69 | .map(([client, sess]) => { 70 | // get all messages that are not heartbeats 71 | const buff = sess.sendBuffer.filter((msg) => { 72 | return !Value.Check(ControlMessageAckSchema, msg.payload); 73 | }); 74 | 75 | return [client, buff] as [ 76 | string, 77 | ReadonlyArray, 78 | ]; 79 | }) 80 | .filter((entry) => entry[1].length > 0), 81 | ), 82 | `[post-test cleanup] transport ${t.clientId} should not have any messages waiting to send after the test`, 83 | ).toStrictEqual(new Map()), 84 | ); 85 | } 86 | 87 | export async function ensureServerIsClean(s: Server) { 88 | return waitFor(() => 89 | expect( 90 | s.streams, 91 | `[post-test cleanup] server should not have any open streams after the test`, 92 | ).toStrictEqual(new Map()), 93 | ); 94 | } 95 | 96 | export async function cleanupTransports( 97 | transports: Array>, 98 | ) { 99 | for (const t of transports) { 100 | if (t.getStatus() !== 'closed') { 101 | t.log?.info('*** end of test cleanup ***', { clientId: t.clientId }); 102 | t.close(); 103 | } 104 | } 105 | } 106 | 107 | export async function testFinishesCleanly({ 108 | clientTransports, 109 | serverTransport, 110 | server, 111 | }: Partial<{ 112 | clientTransports: Array>; 113 | serverTransport: ServerTransport; 114 | server: Server; 115 | }>) { 116 | // pre-close invariants 117 | // invariant check servers first as heartbeats are authoritative on their side 118 | const allTransports = [ 119 | ...(serverTransport ? [serverTransport] : []), 120 | ...(clientTransports ?? []), 121 | ]; 122 | 123 | for (const t of allTransports) { 124 | t.log?.info('*** end of test invariant checks ***', { 125 | clientId: t.clientId, 126 | }); 127 | } 128 | 129 | // wait for one round of heartbeats to propagate 130 | await advanceFakeTimersByHeartbeat(); 131 | 132 | // make sure clients have sent everything 133 | for (const t of clientTransports ?? []) { 134 | await ensureTransportBuffersAreEventuallyEmpty(t); 135 | } 136 | 137 | // wait for one round of heartbeats to propagate 138 | await advanceFakeTimersByHeartbeat(); 139 | 140 | // make sure servers finally received everything 141 | if (serverTransport) { 142 | await ensureTransportBuffersAreEventuallyEmpty(serverTransport); 143 | } 144 | 145 | if (server) { 146 | await ensureServerIsClean(server); 147 | } 148 | 149 | // close all the things 150 | await cleanupTransports(allTransports); 151 | 152 | // post-close invariants 153 | for (const t of allTransports) { 154 | await ensureTransportIsClean(t); 155 | } 156 | } 157 | 158 | export const createPostTestCleanups = () => { 159 | const cleanupFns: Array<() => Promise> = []; 160 | 161 | return { 162 | addPostTestCleanup: (fn: () => Promise) => { 163 | cleanupFns.push(fn); 164 | }, 165 | postTestCleanup: async () => { 166 | while (cleanupFns.length > 0) { 167 | await cleanupFns.pop()?.(); 168 | } 169 | }, 170 | }; 171 | }; 172 | -------------------------------------------------------------------------------- /testUtil/fixtures/codec.ts: -------------------------------------------------------------------------------- 1 | import { BinaryCodec, Codec, NaiveJsonCodec } from '../../codec'; 2 | 3 | export type ValidCodecs = 'naive' | 'binary'; 4 | export const codecs: Array<{ 5 | name: ValidCodecs; 6 | codec: Codec; 7 | }> = [ 8 | { name: 'naive', codec: NaiveJsonCodec }, 9 | { name: 'binary', codec: BinaryCodec }, 10 | ]; 11 | -------------------------------------------------------------------------------- /testUtil/fixtures/matrix.ts: -------------------------------------------------------------------------------- 1 | import { Codec } from '../../codec'; 2 | import { ValidCodecs, codecs } from './codec'; 3 | import { 4 | TransportMatrixEntry, 5 | ValidTransports, 6 | transports, 7 | } from './transports'; 8 | 9 | interface TestMatrixEntry { 10 | transport: TransportMatrixEntry; 11 | codec: { 12 | name: string; 13 | codec: Codec; 14 | }; 15 | } 16 | 17 | /** 18 | * Defines a selector type that pairs a valid transport with a valid codec. 19 | */ 20 | type Selector = [ValidTransports | 'all', ValidCodecs | 'all']; 21 | 22 | /** 23 | * Generates a matrix of test entries for each combination of transport and codec. 24 | * If a selector is provided, it filters the matrix to only include the specified transport and codec combination. 25 | * 26 | * @param selector An optional tuple specifying a transport and codec to filter the matrix. 27 | * @returns An array of TestMatrixEntry objects representing the combinations of transport and codec. 28 | */ 29 | export const testMatrix = ( 30 | [transportSelector, codecSelector]: Selector = ['all', 'all'], 31 | ): Array => { 32 | const filteredTransports = transports.filter( 33 | (t) => transportSelector === 'all' || t.name === transportSelector, 34 | ); 35 | 36 | const filteredCodecs = codecs.filter( 37 | (c) => codecSelector === 'all' || c.name === codecSelector, 38 | ); 39 | 40 | return filteredTransports 41 | .map((transport) => 42 | filteredCodecs.map((codec) => ({ 43 | transport, 44 | codec, 45 | })), 46 | ) 47 | .flat(); 48 | }; 49 | -------------------------------------------------------------------------------- /testUtil/fixtures/mockTransport.ts: -------------------------------------------------------------------------------- 1 | import { TransportClientId } from '../../transport'; 2 | import { ClientTransport } from '../../transport/client'; 3 | import { Connection } from '../../transport/connection'; 4 | import { ServerTransport } from '../../transport/server'; 5 | import { Observable } from '../observable/observable'; 6 | import { ProvidedServerTransportOptions } from '../../transport/options'; 7 | import { TestSetupHelpers, TestTransportOptions } from './transports'; 8 | import { Duplex } from 'node:stream'; 9 | import { duplexPair } from '../duplex/duplexPair'; 10 | import { nanoid } from 'nanoid'; 11 | 12 | export class InMemoryConnection extends Connection { 13 | conn: Duplex; 14 | 15 | constructor(pipe: Duplex) { 16 | super(); 17 | this.conn = pipe; 18 | this.conn.allowHalfOpen = false; 19 | 20 | this.conn.on('data', (data: Uint8Array) => { 21 | this.dataListener?.(data); 22 | }); 23 | 24 | this.conn.on('close', () => { 25 | this.closeListener?.(); 26 | }); 27 | 28 | this.conn.on('error', (err) => { 29 | this.errorListener?.(err); 30 | }); 31 | } 32 | 33 | send(payload: Uint8Array): boolean { 34 | setImmediate(() => { 35 | this.conn.write(payload); 36 | }); 37 | 38 | return true; 39 | } 40 | 41 | close(): void { 42 | setImmediate(() => { 43 | this.conn.end(); 44 | this.conn.emit('close'); 45 | }); 46 | } 47 | } 48 | 49 | interface BidiConnection { 50 | id: string; 51 | clientToServer: Duplex; 52 | serverToClient: Duplex; 53 | clientId: TransportClientId; 54 | serverId: TransportClientId; 55 | handled: boolean; 56 | } 57 | 58 | // we construct a network of transports connected by node streams here 59 | // so that we can test the transport layer without needing to actually 60 | // use real network/websocket connections 61 | // this is useful for testing the transport layer in isolation 62 | // and allows us to control network conditions in a way that would be 63 | // difficult with real network connections (e.g. simulating a phantom 64 | // disconnect, .pause() vs .removeAllListeners('data'), congestion, 65 | // latency, differences in ws implementations between node and browsers, etc.) 66 | export function createMockTransportNetwork( 67 | opts?: TestTransportOptions, 68 | ): TestSetupHelpers { 69 | // conn id -> [client->server, server->client] 70 | const connections = new Observable>({}); 71 | 72 | const transports: Array = []; 73 | class MockClientTransport extends ClientTransport { 74 | async createNewOutgoingConnection( 75 | to: TransportClientId, 76 | ): Promise { 77 | const [clientToServer, serverToClient] = duplexPair(); 78 | await new Promise((resolve) => setImmediate(resolve)); 79 | 80 | const connId = nanoid(); 81 | connections.set((prev) => ({ 82 | ...prev, 83 | [connId]: { 84 | id: connId, 85 | clientToServer, 86 | serverToClient, 87 | clientId: this.clientId, 88 | serverId: to, 89 | handled: false, 90 | }, 91 | })); 92 | 93 | return new InMemoryConnection(clientToServer); 94 | } 95 | } 96 | 97 | class MockServerTransport extends ServerTransport { 98 | subscribeCleanup: () => void; 99 | 100 | constructor( 101 | clientId: TransportClientId, 102 | options?: ProvidedServerTransportOptions, 103 | ) { 104 | super(clientId, options); 105 | 106 | this.subscribeCleanup = connections.observe((conns) => { 107 | // look for any unhandled connections 108 | for (const conn of Object.values(conns)) { 109 | // if we've already handled this connection, skip it 110 | // or if it's not for us, skip it 111 | if (conn.handled || conn.serverId !== this.clientId) { 112 | continue; 113 | } 114 | 115 | conn.handled = true; 116 | const connection = new InMemoryConnection(conn.serverToClient); 117 | this.handleConnection(connection); 118 | } 119 | }); 120 | } 121 | 122 | close() { 123 | this.subscribeCleanup(); 124 | super.close(); 125 | } 126 | } 127 | 128 | return { 129 | getClientTransport: (id, handshakeOptions) => { 130 | const clientTransport = new MockClientTransport(id, opts?.client); 131 | if (handshakeOptions) { 132 | clientTransport.extendHandshake(handshakeOptions); 133 | } 134 | 135 | transports.push(clientTransport); 136 | 137 | return clientTransport; 138 | }, 139 | getServerTransport: (id = 'SERVER', handshakeOptions) => { 140 | const serverTransport = new MockServerTransport(id, opts?.server); 141 | if (handshakeOptions) { 142 | serverTransport.extendHandshake(handshakeOptions); 143 | } 144 | 145 | transports.push(serverTransport); 146 | 147 | return serverTransport; 148 | }, 149 | simulatePhantomDisconnect() { 150 | for (const conn of Object.values(connections.get())) { 151 | conn.serverToClient.pause(); 152 | conn.clientToServer.pause(); 153 | } 154 | }, 155 | async restartServer() { 156 | for (const transport of transports) { 157 | if (transport.clientId !== 'SERVER') continue; 158 | transport.close(); 159 | } 160 | 161 | // kill all connections while we're at it 162 | for (const conn of Object.values(connections.get())) { 163 | conn.serverToClient.destroy(); 164 | conn.clientToServer.destroy(); 165 | } 166 | }, 167 | cleanup() { 168 | for (const conn of Object.values(connections.get())) { 169 | conn.serverToClient.destroy(); 170 | conn.clientToServer.destroy(); 171 | } 172 | }, 173 | }; 174 | } 175 | -------------------------------------------------------------------------------- /testUtil/fixtures/transports.ts: -------------------------------------------------------------------------------- 1 | import http from 'node:http'; 2 | import { 3 | createLocalWebSocketClient, 4 | createWebSocketServer, 5 | getTransportConnections, 6 | onWsServerReady, 7 | } from '..'; 8 | import { WebSocketClientTransport } from '../../transport/impls/ws/client'; 9 | import { WebSocketServerTransport } from '../../transport/impls/ws/server'; 10 | import { 11 | ClientHandshakeOptions, 12 | ServerHandshakeOptions, 13 | } from '../../router/handshake'; 14 | import { createMockTransportNetwork } from './mockTransport'; 15 | import { 16 | ProvidedClientTransportOptions, 17 | ProvidedServerTransportOptions, 18 | } from '../../transport/options'; 19 | import { TransportClientId } from '../../transport/message'; 20 | import { ClientTransport } from '../../transport/client'; 21 | import { Connection } from '../../transport/connection'; 22 | import { ServerTransport } from '../../transport/server'; 23 | 24 | export type ValidTransports = 'ws' | 'mock'; 25 | 26 | export interface TestTransportOptions { 27 | client?: ProvidedClientTransportOptions; 28 | server?: ProvidedServerTransportOptions; 29 | } 30 | 31 | export interface TestSetupHelpers { 32 | getClientTransport: ( 33 | id: TransportClientId, 34 | handshakeOptions?: ClientHandshakeOptions, 35 | ) => ClientTransport; 36 | getServerTransport: ( 37 | id?: TransportClientId, 38 | handshakeOptions?: ServerHandshakeOptions, 39 | ) => ServerTransport; 40 | simulatePhantomDisconnect: () => void; 41 | restartServer: () => Promise; 42 | cleanup: () => Promise | void; 43 | } 44 | 45 | export interface TransportMatrixEntry { 46 | name: ValidTransports; 47 | setup: (opts?: TestTransportOptions) => Promise; 48 | } 49 | 50 | export const transports: Array = [ 51 | { 52 | name: 'ws', 53 | setup: async (opts) => { 54 | let server = http.createServer(); 55 | const port = await onWsServerReady(server); 56 | let wss = createWebSocketServer(server); 57 | 58 | const transports: Array< 59 | WebSocketClientTransport | WebSocketServerTransport 60 | > = []; 61 | 62 | return { 63 | simulatePhantomDisconnect() { 64 | for (const transport of transports) { 65 | for (const conn of getTransportConnections(transport)) { 66 | conn.ws.onmessage = null; 67 | } 68 | } 69 | }, 70 | getClientTransport: (id, handshakeOptions) => { 71 | const clientTransport = new WebSocketClientTransport( 72 | () => Promise.resolve(createLocalWebSocketClient(port)), 73 | id, 74 | opts?.client, 75 | ); 76 | 77 | if (handshakeOptions) { 78 | clientTransport.extendHandshake(handshakeOptions); 79 | } 80 | 81 | clientTransport.bindLogger((msg, ctx, level) => { 82 | if (ctx?.tags?.includes('invariant-violation')) { 83 | console.error('invariant violation', { msg, ctx, level }); 84 | throw new Error( 85 | `Invariant violation encountered: [${level}] ${msg}`, 86 | ); 87 | } 88 | }, 'debug'); 89 | 90 | transports.push(clientTransport); 91 | 92 | return clientTransport; 93 | }, 94 | getServerTransport(id = 'SERVER', handshakeOptions) { 95 | const serverTransport = new WebSocketServerTransport( 96 | wss, 97 | id, 98 | opts?.server, 99 | ); 100 | 101 | serverTransport.bindLogger((msg, ctx, level) => { 102 | if (ctx?.tags?.includes('invariant-violation')) { 103 | console.error('invariant violation', { msg, ctx, level }); 104 | throw new Error( 105 | `Invariant violation encountered: [${level}] ${msg}`, 106 | ); 107 | } 108 | }, 'debug'); 109 | 110 | if (handshakeOptions) { 111 | serverTransport.extendHandshake(handshakeOptions); 112 | } 113 | 114 | transports.push(serverTransport); 115 | 116 | return serverTransport as ServerTransport; 117 | }, 118 | async restartServer() { 119 | for (const transport of transports) { 120 | if (transport.clientId !== 'SERVER') continue; 121 | transport.close(); 122 | } 123 | 124 | await new Promise((resolve) => { 125 | server.close(() => resolve()); 126 | }); 127 | server = http.createServer(); 128 | await new Promise((resolve) => { 129 | server.listen(port, resolve); 130 | }); 131 | wss = createWebSocketServer(server); 132 | }, 133 | cleanup: async () => { 134 | wss.close(); 135 | server.close(); 136 | }, 137 | }; 138 | }, 139 | }, 140 | { 141 | name: 'mock', 142 | setup: async (opts) => { 143 | const network = createMockTransportNetwork(opts); 144 | 145 | return network; 146 | }, 147 | }, 148 | ]; 149 | -------------------------------------------------------------------------------- /testUtil/index.ts: -------------------------------------------------------------------------------- 1 | import NodeWs, { WebSocketServer } from 'ws'; 2 | import http from 'node:http'; 3 | import { Static } from '@sinclair/typebox'; 4 | import { 5 | OpaqueTransportMessage, 6 | PartialTransportMessage, 7 | currentProtocolVersion, 8 | } from '../transport/message'; 9 | import { Transport } from '../transport/transport'; 10 | import { Readable, ReadableResult, ReadableIterator } from '../router/streams'; 11 | import { WsLike } from '../transport/impls/ws/wslike'; 12 | import { 13 | defaultClientTransportOptions, 14 | defaultTransportOptions, 15 | } from '../transport/options'; 16 | import { Connection } from '../transport/connection'; 17 | import { SessionState } from '../transport/sessionStateMachine/common'; 18 | import { SessionStateGraph } from '../transport/sessionStateMachine/transitions'; 19 | import { BaseErrorSchemaType } from '../router/errors'; 20 | import { ClientTransport } from '../transport/client'; 21 | import { ServerTransport } from '../transport/server'; 22 | import { getTracer } from '../tracing'; 23 | 24 | export { 25 | createMockTransportNetwork, 26 | InMemoryConnection, 27 | } from './fixtures/mockTransport'; 28 | 29 | /** 30 | * Creates a WebSocket client that connects to a local server at the specified port. 31 | * This should only be used for testing. 32 | * @param port - The port number to connect to. 33 | * @returns A Promise that resolves to a WebSocket instance. 34 | */ 35 | export function createLocalWebSocketClient(port: number): WsLike { 36 | const sock = new NodeWs(`ws://localhost:${port}`); 37 | sock.binaryType = 'arraybuffer'; 38 | 39 | return sock; 40 | } 41 | 42 | /** 43 | * Creates a WebSocket server instance using the provided HTTP server. 44 | * Only used as helper for testing. 45 | * @param server - The HTTP server instance to use for the WebSocket server. 46 | * @returns A Promise that resolves to the created WebSocket server instance. 47 | */ 48 | export function createWebSocketServer(server: http.Server) { 49 | return new WebSocketServer({ server }); 50 | } 51 | 52 | /** 53 | * Starts listening on the given server and returns the automatically allocated port number. 54 | * This should only be used for testing. 55 | * @param server - The http server to listen on. 56 | * @returns A promise that resolves with the allocated port number. 57 | * @throws An error if a port cannot be allocated. 58 | */ 59 | export function onWsServerReady(server: http.Server): Promise { 60 | return new Promise((resolve, reject) => { 61 | server.listen(() => { 62 | const addr = server.address(); 63 | if (typeof addr === 'object' && addr) { 64 | resolve(addr.port); 65 | } else { 66 | reject(new Error("couldn't find a port to allocate")); 67 | } 68 | }); 69 | }); 70 | } 71 | 72 | const readableIterators = new WeakMap< 73 | Readable>, 74 | ReadableIterator> 75 | >(); 76 | 77 | /** 78 | * A safe way to access {@link Readble}'s iterator multiple times in test helpers. 79 | * 80 | * If there are other iteration attempts outside of the test helpers 81 | * (this function, {@link readNextResult}, and {@link isReadableDone}) 82 | * it will throw an error. 83 | */ 84 | export function getReadableIterator>( 85 | readable: Readable, 86 | ): ReadableIterator { 87 | let iter = readableIterators.get(readable) as 88 | | ReadableIterator 89 | | undefined; 90 | 91 | if (!iter) { 92 | iter = readable[Symbol.asyncIterator](); 93 | readableIterators.set(readable, iter); 94 | } 95 | 96 | return iter; 97 | } 98 | 99 | /** 100 | * Retrieves the next value from {@link Readable}, or throws an error if the Readable is done. 101 | * 102 | * Calling semantics are similar to {@link getReadableIterator} 103 | */ 104 | export async function readNextResult>( 105 | readable: Readable, 106 | ): Promise> { 107 | const res = await getReadableIterator(readable).next(); 108 | 109 | if (res.done) { 110 | throw new Error('readNext from a done Readable'); 111 | } 112 | 113 | return res.value; 114 | } 115 | 116 | /** 117 | * Checks if the readable is done iterating, it consumes an iteration in the process. 118 | * 119 | * Calling semantics are similar to {@link getReadableIterator} 120 | */ 121 | export async function isReadableDone>( 122 | readable: Readable, 123 | ) { 124 | const res = await getReadableIterator(readable).next(); 125 | 126 | return res.done; 127 | } 128 | 129 | export function payloadToTransportMessage( 130 | payload: Payload, 131 | ): PartialTransportMessage { 132 | return { 133 | streamId: 'stream', 134 | controlFlags: 0, 135 | payload, 136 | }; 137 | } 138 | 139 | export function createDummyTransportMessage() { 140 | return payloadToTransportMessage({ 141 | msg: 'cool', 142 | test: Math.random(), 143 | }); 144 | } 145 | 146 | /** 147 | * Waits for a message on the transport. 148 | * @param {Transport} t - The transport to listen to. 149 | * @param filter - An optional filter function to apply to the received messages. 150 | * @returns A promise that resolves with the payload of the first message that passes the filter. 151 | */ 152 | export async function waitForMessage( 153 | t: Transport, 154 | filter?: (msg: OpaqueTransportMessage) => boolean, 155 | rejectMismatch?: boolean, 156 | ) { 157 | return new Promise((resolve, reject) => { 158 | function cleanup() { 159 | t.removeEventListener('message', onMessage); 160 | } 161 | 162 | function onMessage(msg: OpaqueTransportMessage) { 163 | if (!filter || filter(msg)) { 164 | cleanup(); 165 | resolve(msg.payload); 166 | } else if (rejectMismatch) { 167 | cleanup(); 168 | reject(new Error('message didnt match the filter')); 169 | } 170 | } 171 | 172 | t.addEventListener('message', onMessage); 173 | }); 174 | } 175 | 176 | export const testingSessionOptions = defaultTransportOptions; 177 | export const testingClientSessionOptions = defaultClientTransportOptions; 178 | 179 | export function dummySession() { 180 | return SessionStateGraph.entrypoints.NoConnection( 181 | 'client', 182 | 'server', 183 | { 184 | onSessionGracePeriodElapsed: () => { 185 | /* noop */ 186 | }, 187 | }, 188 | testingSessionOptions, 189 | currentProtocolVersion, 190 | getTracer(), 191 | ); 192 | } 193 | 194 | export function getClientSendFn( 195 | clientTransport: ClientTransport, 196 | serverTransport: ServerTransport, 197 | ) { 198 | const session = 199 | clientTransport.sessions.get(serverTransport.clientId) ?? 200 | clientTransport.createUnconnectedSession(serverTransport.clientId); 201 | 202 | return clientTransport.getSessionBoundSendFn( 203 | serverTransport.clientId, 204 | session.id, 205 | ); 206 | } 207 | 208 | export function getServerSendFn( 209 | serverTransport: ServerTransport, 210 | clientTransport: ClientTransport, 211 | ) { 212 | const session = serverTransport.sessions.get(clientTransport.clientId); 213 | if (!session) { 214 | throw new Error('session not found'); 215 | } 216 | 217 | return serverTransport.getSessionBoundSendFn( 218 | clientTransport.clientId, 219 | session.id, 220 | ); 221 | } 222 | 223 | export function getTransportConnections( 224 | transport: Transport, 225 | ): Array { 226 | const connections = []; 227 | for (const session of transport.sessions.values()) { 228 | if (session.state === SessionState.Connected) { 229 | connections.push(session.conn); 230 | } 231 | } 232 | 233 | return connections; 234 | } 235 | 236 | export function numberOfConnections( 237 | transport: Transport, 238 | ): number { 239 | return getTransportConnections(transport).length; 240 | } 241 | 242 | export function closeAllConnections( 243 | transport: Transport, 244 | ) { 245 | for (const conn of getTransportConnections(transport)) { 246 | conn.close(); 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /testUtil/observable/observable.test.ts: -------------------------------------------------------------------------------- 1 | import { Observable } from './observable'; 2 | import { describe, expect, test, vitest } from 'vitest'; 3 | 4 | describe('Observable', () => { 5 | test('should set initial value correctly', () => { 6 | const initialValue = 10; 7 | const observable = new Observable(initialValue); 8 | expect(observable.value).toBe(initialValue); 9 | }); 10 | 11 | test('should update value correctly', () => { 12 | const observable = new Observable(10); 13 | const newValue = 20; 14 | observable.set(() => newValue); 15 | expect(observable.value).toBe(newValue); 16 | }); 17 | 18 | test('should notify listeners when value changes', () => { 19 | const observable = new Observable(10); 20 | const listener = vitest.fn(); 21 | observable.observe(listener); 22 | expect(listener).toHaveBeenCalledTimes(1); 23 | 24 | const newValue = 20; 25 | observable.set(() => newValue); 26 | 27 | expect(listener).toHaveBeenCalledTimes(2); 28 | expect(listener).toHaveBeenCalledWith(newValue); 29 | }); 30 | 31 | test('should unsubscribe from notifications', () => { 32 | const observable = new Observable(10); 33 | const listener = vitest.fn(); 34 | const unsubscribe = observable.observe(listener); 35 | expect(listener).toHaveBeenCalledTimes(1); 36 | 37 | const newValue = 20; 38 | observable.set(() => newValue); 39 | 40 | expect(listener).toHaveBeenCalledTimes(2); 41 | expect(listener).toHaveBeenCalledWith(newValue); 42 | 43 | unsubscribe(); 44 | 45 | const anotherValue = 30; 46 | observable.set(() => anotherValue); 47 | 48 | expect(listener).toHaveBeenCalledTimes(2); // should not be called again after unsubscribing 49 | }); 50 | }); 51 | -------------------------------------------------------------------------------- /testUtil/observable/observable.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Represents an observable value that can be subscribed to for changes. 3 | * This should only be used in tests 4 | * @template T - The type of the value being observed. 5 | */ 6 | export class Observable { 7 | value: T; 8 | private listeners: Set<(val: T) => void>; 9 | 10 | constructor(initialValue: T) { 11 | this.value = initialValue; 12 | this.listeners = new Set(); 13 | } 14 | 15 | /** 16 | * Gets the current value of the observable. 17 | */ 18 | get() { 19 | return this.value; 20 | } 21 | 22 | /** 23 | * Sets the current value of the observable. All listeners will get an update with this value. 24 | * @param newValue - The new value to set. 25 | */ 26 | set(tx: (preValue: T) => T) { 27 | const newValue = tx(this.value); 28 | this.value = newValue; 29 | this.listeners.forEach((listener) => listener(newValue)); 30 | } 31 | 32 | /** 33 | * Subscribes to changes in the observable value. 34 | * @param listener - A callback function that will be called when the value changes. 35 | * @returns A function that can be called to unsubscribe from further notifications. 36 | */ 37 | observe(listener: (val: T) => void) { 38 | this.listeners.add(listener); 39 | listener(this.get()); 40 | 41 | return () => this.listeners.delete(listener); 42 | } 43 | 44 | /** 45 | * Returns the number of listeners currently observing the observable 46 | */ 47 | get listenerCount(): number { 48 | return this.listeners.size; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tracing/index.ts: -------------------------------------------------------------------------------- 1 | import { 2 | Context, 3 | Span, 4 | SpanKind, 5 | SpanStatusCode, 6 | context, 7 | propagation, 8 | trace, 9 | Tracer, 10 | } from '@opentelemetry/api'; 11 | import { BaseErrorSchemaType, RIVER_VERSION, ValidProcType } from '../router'; 12 | import { Connection } from '../transport'; 13 | import { MessageMetadata } from '../logging'; 14 | import { ClientSession } from '../transport/sessionStateMachine/transitions'; 15 | import { IdentifiedSession } from '../transport/sessionStateMachine/common'; 16 | import { Static } from '@sinclair/typebox'; 17 | 18 | export interface PropagationContext { 19 | traceparent: string; 20 | tracestate: string; 21 | } 22 | 23 | export interface TelemetryInfo { 24 | span: Span; 25 | ctx: Context; 26 | } 27 | 28 | export function getPropagationContext( 29 | ctx: Context, 30 | ): PropagationContext | undefined { 31 | const tracing = { 32 | traceparent: '', 33 | tracestate: '', 34 | }; 35 | propagation.inject(ctx, tracing); 36 | 37 | return tracing; 38 | } 39 | 40 | export function createSessionTelemetryInfo( 41 | tracer: Tracer, 42 | sessionId: string, 43 | to: string, 44 | from: string, 45 | propagationCtx?: PropagationContext, 46 | ): TelemetryInfo { 47 | const parentCtx = propagationCtx 48 | ? propagation.extract(context.active(), propagationCtx) 49 | : context.active(); 50 | 51 | const span = tracer.startSpan( 52 | `river.session`, 53 | { 54 | attributes: { 55 | component: 'river', 56 | 'river.session.id': sessionId, 57 | 'river.session.to': to, 58 | 'river.session.from': from, 59 | }, 60 | }, 61 | parentCtx, 62 | ); 63 | 64 | const ctx = trace.setSpan(parentCtx, span); 65 | 66 | return { span, ctx }; 67 | } 68 | 69 | export function createConnectionTelemetryInfo( 70 | tracer: Tracer, 71 | connection: Connection, 72 | info: TelemetryInfo, 73 | ): TelemetryInfo { 74 | const span = tracer.startSpan( 75 | `river.connection`, 76 | { 77 | attributes: { 78 | component: 'river', 79 | 'river.connection.id': connection.id, 80 | }, 81 | links: [{ context: info.span.spanContext() }], 82 | }, 83 | info.ctx, 84 | ); 85 | 86 | const ctx = trace.setSpan(info.ctx, span); 87 | 88 | return { span, ctx }; 89 | } 90 | 91 | export function createProcTelemetryInfo( 92 | tracer: Tracer, 93 | session: ClientSession, 94 | kind: ValidProcType, 95 | serviceName: string, 96 | procedureName: string, 97 | streamId: string, 98 | ): TelemetryInfo { 99 | const baseCtx = context.active(); 100 | const span = tracer.startSpan( 101 | `river.client.${serviceName}.${procedureName}`, 102 | { 103 | attributes: { 104 | component: 'river', 105 | 'river.method.kind': kind, 106 | 'river.method.service': serviceName, 107 | 'river.method.name': procedureName, 108 | 'river.streamId': streamId, 109 | 'span.kind': 'client', 110 | }, 111 | links: [{ context: session.telemetry.span.spanContext() }], 112 | kind: SpanKind.CLIENT, 113 | }, 114 | baseCtx, 115 | ); 116 | 117 | const ctx = trace.setSpan(baseCtx, span); 118 | const metadata: MessageMetadata = { 119 | ...session.loggingMetadata, 120 | transportMessage: { 121 | procedureName, 122 | serviceName, 123 | }, 124 | }; 125 | 126 | if (span.isRecording()) { 127 | metadata.telemetry = { 128 | traceId: span.spanContext().traceId, 129 | spanId: span.spanContext().spanId, 130 | }; 131 | } 132 | 133 | session.log?.info(`invoked ${serviceName}.${procedureName}`, metadata); 134 | 135 | return { span, ctx }; 136 | } 137 | 138 | export function createHandlerSpan unknown>( 139 | tracer: Tracer, 140 | session: IdentifiedSession, 141 | kind: ValidProcType, 142 | serviceName: string, 143 | procedureName: string, 144 | streamId: string, 145 | tracing: PropagationContext | undefined, 146 | fn: Fn, 147 | ): ReturnType { 148 | const ctx = tracing 149 | ? propagation.extract(context.active(), tracing) 150 | : context.active(); 151 | 152 | return tracer.startActiveSpan( 153 | `river.server.${serviceName}.${procedureName}`, 154 | { 155 | attributes: { 156 | component: 'river', 157 | 'river.method.kind': kind, 158 | 'river.method.service': serviceName, 159 | 'river.method.name': procedureName, 160 | 'river.streamId': streamId, 161 | 'span.kind': 'server', 162 | }, 163 | links: [{ context: session.telemetry.span.spanContext() }], 164 | kind: SpanKind.SERVER, 165 | }, 166 | ctx, 167 | fn, 168 | ); 169 | } 170 | 171 | export function recordRiverError( 172 | span: Span, 173 | error: Static, 174 | ): void { 175 | span.setStatus({ 176 | code: SpanStatusCode.ERROR, 177 | message: error.message, 178 | }); 179 | span.setAttributes({ 180 | 'river.error_code': error.code, 181 | 'river.error_message': error.message, 182 | }); 183 | } 184 | 185 | export function getTracer(): Tracer { 186 | return trace.getTracer('river', RIVER_VERSION); 187 | } 188 | -------------------------------------------------------------------------------- /tracing/tracing.test.ts: -------------------------------------------------------------------------------- 1 | import { 2 | trace, 3 | context, 4 | propagation, 5 | Span, 6 | SpanStatusCode, 7 | } from '@opentelemetry/api'; 8 | import { describe, test, expect, vi, assert, beforeEach } from 'vitest'; 9 | import { dummySession, readNextResult } from '../testUtil'; 10 | 11 | import { 12 | BasicTracerProvider, 13 | InMemorySpanExporter, 14 | SimpleSpanProcessor, 15 | } from '@opentelemetry/sdk-trace-base'; 16 | import { W3CTraceContextPropagator } from '@opentelemetry/core'; 17 | import { AsyncHooksContextManager } from '@opentelemetry/context-async-hooks'; 18 | import { createSessionTelemetryInfo, getPropagationContext } from './index'; 19 | import { testMatrix } from '../testUtil/fixtures/matrix'; 20 | import { 21 | cleanupTransports, 22 | testFinishesCleanly, 23 | waitFor, 24 | } from '../testUtil/fixtures/cleanup'; 25 | import { TestSetupHelpers } from '../testUtil/fixtures/transports'; 26 | import { createPostTestCleanups } from '../testUtil/fixtures/cleanup'; 27 | import { FallibleServiceSchema } from '../testUtil/fixtures/services'; 28 | import { createServer } from '../router/server'; 29 | import { createClient } from '../router/client'; 30 | import { UNCAUGHT_ERROR_CODE } from '../router'; 31 | import { LogFn } from '../logging'; 32 | 33 | const provider = new BasicTracerProvider(); 34 | const spanExporter = new InMemorySpanExporter(); 35 | provider.addSpanProcessor(new SimpleSpanProcessor(spanExporter)); 36 | const contextManager = new AsyncHooksContextManager(); 37 | contextManager.enable(); 38 | trace.setGlobalTracerProvider(provider); 39 | context.setGlobalContextManager(contextManager); 40 | propagation.setGlobalPropagator(new W3CTraceContextPropagator()); 41 | 42 | describe('Basic tracing tests', () => { 43 | test('createSessionTelemetryInfo', () => { 44 | const parentCtx = context.active(); 45 | const tracer = trace.getTracer('test'); 46 | const span = tracer.startSpan('empty span', {}, parentCtx); 47 | const ctx = trace.setSpan(parentCtx, span); 48 | 49 | const propCtx = getPropagationContext(ctx); 50 | expect(propCtx?.traceparent).toBeTruthy(); 51 | const session = dummySession(); 52 | const teleInfo = createSessionTelemetryInfo( 53 | tracer, 54 | session.id, 55 | session.to, 56 | session.from, 57 | propCtx, 58 | ); 59 | 60 | // @ts-expect-error: hacking to get parentSpanId 61 | expect(propCtx?.traceparent).toContain(teleInfo.span.parentSpanId); 62 | expect( 63 | teleInfo.ctx.getValue( 64 | Symbol.for('OpenTelemetry Context Key SPAN'), 65 | ) as Span, 66 | ).toBeTruthy(); 67 | }); 68 | }); 69 | 70 | describe.each(testMatrix())( 71 | 'Integrated tracing tests ($transport.name transport, $codec.name codec)', 72 | async ({ transport, codec }) => { 73 | const opts = { codec: codec.codec }; 74 | 75 | const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); 76 | let getClientTransport: TestSetupHelpers['getClientTransport']; 77 | let getServerTransport: TestSetupHelpers['getServerTransport']; 78 | beforeEach(async () => { 79 | const setup = await transport.setup({ client: opts, server: opts }); 80 | getClientTransport = setup.getClientTransport; 81 | getServerTransport = setup.getServerTransport; 82 | spanExporter.reset(); 83 | 84 | return async () => { 85 | await postTestCleanup(); 86 | await setup.cleanup(); 87 | }; 88 | }); 89 | 90 | test('Traces sessions and connections across network boundary', async () => { 91 | const clientTransport = getClientTransport('client'); 92 | const serverTransport = getServerTransport(); 93 | clientTransport.connect(serverTransport.clientId); 94 | addPostTestCleanup(async () => { 95 | await cleanupTransports([clientTransport, serverTransport]); 96 | }); 97 | 98 | await waitFor(() => { 99 | expect(clientTransport.sessions.size).toBe(1); 100 | expect(serverTransport.sessions.size).toBe(1); 101 | }); 102 | 103 | const clientSession = clientTransport.sessions.get( 104 | serverTransport.clientId, 105 | ); 106 | const serverSession = serverTransport.sessions.get( 107 | clientTransport.clientId, 108 | ); 109 | 110 | assert(clientSession); 111 | assert(serverSession); 112 | 113 | const clientSpan = clientSession.telemetry.span; 114 | const serverSpan = serverSession.telemetry.span; 115 | 116 | // ensure server span is a child of client span 117 | // @ts-expect-error: hacking to get parentSpanId 118 | expect(serverSpan.parentSpanId).toBe(clientSpan.spanContext().spanId); 119 | await testFinishesCleanly({ 120 | clientTransports: [clientTransport], 121 | serverTransport, 122 | }); 123 | }); 124 | 125 | test('implicit telemetry gets picked up from handlers', async () => { 126 | // setup 127 | const clientTransport = getClientTransport('client'); 128 | const clientMockLogger = vi.fn(); 129 | clientTransport.bindLogger(clientMockLogger); 130 | const serverTransport = getServerTransport(); 131 | const serverMockLogger = vi.fn(); 132 | serverTransport.bindLogger(serverMockLogger); 133 | const services = { 134 | fallible: FallibleServiceSchema, 135 | }; 136 | const server = createServer(serverTransport, services); 137 | const client = createClient( 138 | clientTransport, 139 | serverTransport.clientId, 140 | ); 141 | addPostTestCleanup(async () => { 142 | await cleanupTransports([clientTransport, serverTransport]); 143 | }); 144 | 145 | // test 146 | const { reqWritable, resReadable } = client.fallible.echo.stream({}); 147 | 148 | reqWritable.write({ 149 | msg: 'abc', 150 | throwResult: false, 151 | throwError: false, 152 | }); 153 | let result = await readNextResult(resReadable); 154 | expect(result).toStrictEqual({ 155 | ok: true, 156 | payload: { 157 | response: 'abc', 158 | }, 159 | }); 160 | 161 | // this isn't the first message so doesn't have telemetry info on the message itself 162 | reqWritable.write({ 163 | msg: 'def', 164 | throwResult: false, 165 | throwError: true, 166 | }); 167 | 168 | result = await readNextResult(resReadable); 169 | expect(result).toStrictEqual({ 170 | ok: false, 171 | payload: { 172 | code: UNCAUGHT_ERROR_CODE, 173 | message: 'some message', 174 | }, 175 | }); 176 | 177 | // expect that both client and server loggers logged the uncaught error with the correct telemetry info 178 | const clientInvokeCall = clientMockLogger.mock.calls.find( 179 | (call) => call[0] === 'invoked fallible.echo', 180 | ); 181 | const serverInvokeFail = serverMockLogger.mock.calls.find( 182 | (call) => call[0] === 'fallible.echo handler threw an uncaught error', 183 | ); 184 | expect(clientInvokeCall?.[1]).toBeTruthy(); 185 | expect(serverInvokeFail?.[1]).toBeTruthy(); 186 | expect(clientInvokeCall?.[1]?.telemetry?.traceId).toStrictEqual( 187 | serverInvokeFail?.[1]?.telemetry?.traceId, 188 | ); 189 | 190 | reqWritable.close(); 191 | await testFinishesCleanly({ 192 | clientTransports: [clientTransport], 193 | serverTransport, 194 | server, 195 | }); 196 | }); 197 | 198 | test('river errors are recorded on handler spans', async () => { 199 | // setup 200 | const clientTransport = getClientTransport('client'); 201 | const clientMockLogger = vi.fn(); 202 | clientTransport.bindLogger(clientMockLogger); 203 | const serverTransport = getServerTransport(); 204 | const serverMockLogger = vi.fn(); 205 | serverTransport.bindLogger(serverMockLogger); 206 | const services = { 207 | fallible: FallibleServiceSchema, 208 | }; 209 | const server = createServer(serverTransport, services); 210 | const client = createClient( 211 | clientTransport, 212 | serverTransport.clientId, 213 | ); 214 | addPostTestCleanup(async () => { 215 | await cleanupTransports([clientTransport, serverTransport]); 216 | }); 217 | 218 | const { reqWritable, resReadable } = client.fallible.echo.stream({}); 219 | 220 | reqWritable.write({ 221 | msg: 'abc', 222 | throwResult: false, 223 | throwError: false, 224 | }); 225 | let result = await readNextResult(resReadable); 226 | expect(result).toStrictEqual({ 227 | ok: true, 228 | payload: { 229 | response: 'abc', 230 | }, 231 | }); 232 | 233 | // this isn't the first message so doesn't have telemetry info on the message itself 234 | reqWritable.write({ 235 | msg: 'def', 236 | throwResult: false, 237 | throwError: true, 238 | }); 239 | 240 | result = await readNextResult(resReadable); 241 | expect(result).toStrictEqual({ 242 | ok: false, 243 | payload: { 244 | code: UNCAUGHT_ERROR_CODE, 245 | message: 'some message', 246 | }, 247 | }); 248 | 249 | const spans = spanExporter.getFinishedSpans(); 250 | 251 | const errSpan = spans.find( 252 | (span) => 253 | span.name === 'river.server.fallible.echo' && 254 | span.status.code === SpanStatusCode.ERROR, 255 | ); 256 | expect(errSpan).toBeTruthy(); 257 | expect(errSpan?.attributes['river.error_code']).toBe(UNCAUGHT_ERROR_CODE); 258 | expect(errSpan?.attributes['river.error_message']).toBe('some message'); 259 | 260 | await testFinishesCleanly({ 261 | clientTransports: [clientTransport], 262 | serverTransport, 263 | server, 264 | }); 265 | }); 266 | }, 267 | ); 268 | -------------------------------------------------------------------------------- /transport/connection.ts: -------------------------------------------------------------------------------- 1 | import { TelemetryInfo } from '../tracing'; 2 | import { MessageMetadata } from '../logging'; 3 | import { generateId } from './id'; 4 | 5 | /** 6 | * A connection is the actual raw underlying transport connection. 7 | * It's responsible for dispatching to/from the actual connection itself 8 | * This should be instantiated as soon as the client/server has a connection 9 | * It's tied to the lifecycle of the underlying transport connection (i.e. if the WS drops, this connection should be deleted) 10 | */ 11 | export abstract class Connection { 12 | id: string; 13 | telemetry?: TelemetryInfo; 14 | 15 | constructor() { 16 | this.id = `conn-${generateId()}`; // for debugging, no collision safety needed 17 | } 18 | 19 | get loggingMetadata(): MessageMetadata { 20 | const metadata: MessageMetadata = { connId: this.id }; 21 | 22 | if (this.telemetry?.span.isRecording()) { 23 | const spanContext = this.telemetry.span.spanContext(); 24 | metadata.telemetry = { 25 | traceId: spanContext.traceId, 26 | spanId: spanContext.spanId, 27 | }; 28 | } 29 | 30 | return metadata; 31 | } 32 | 33 | dataListener?: (msg: Uint8Array) => void; 34 | closeListener?: () => void; 35 | errorListener?: (err: Error) => void; 36 | 37 | onData(msg: Uint8Array) { 38 | this.dataListener?.(msg); 39 | } 40 | 41 | onError(err: Error) { 42 | this.errorListener?.(err); 43 | } 44 | 45 | onClose() { 46 | this.closeListener?.(); 47 | this.telemetry?.span.end(); 48 | } 49 | 50 | /** 51 | * Set the callback for when a message is received. 52 | * @param cb The message handler callback. 53 | */ 54 | setDataListener(cb: (msg: Uint8Array) => void) { 55 | this.dataListener = cb; 56 | } 57 | 58 | removeDataListener() { 59 | this.dataListener = undefined; 60 | } 61 | 62 | /** 63 | * Set the callback for when the connection is closed. 64 | * This should also be called if an error happens and after notifying the error listener. 65 | * @param cb The callback to call when the connection is closed. 66 | */ 67 | setCloseListener(cb: () => void): void { 68 | this.closeListener = cb; 69 | } 70 | 71 | removeCloseListener(): void { 72 | this.closeListener = undefined; 73 | } 74 | 75 | /** 76 | * Set the callback for when an error is received. 77 | * This should only be used for logging errors, all cleanup 78 | * should be delegated to setCloseListener. 79 | * 80 | * The implementer should take care such that the implemented 81 | * connection will call both the close and error callbacks 82 | * on an error. 83 | * 84 | * @param cb The callback to call when an error is received. 85 | */ 86 | setErrorListener(cb: (err: Error) => void): void { 87 | this.errorListener = cb; 88 | } 89 | 90 | removeErrorListener(): void { 91 | this.errorListener = undefined; 92 | } 93 | 94 | /** 95 | * Sends a message over the connection. 96 | * @param msg The message to send. 97 | * @returns true if the message was sent, false otherwise. 98 | */ 99 | abstract send(msg: Uint8Array): boolean; 100 | 101 | /** 102 | * Closes the connection. 103 | */ 104 | abstract close(): void; 105 | } 106 | -------------------------------------------------------------------------------- /transport/events.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, expect, test, vitest } from 'vitest'; 2 | import { EventDispatcher } from './events'; 3 | import { OpaqueTransportMessage } from '.'; 4 | import { generateId } from './id'; 5 | 6 | function dummyMessage(): OpaqueTransportMessage { 7 | return { 8 | id: generateId(), 9 | from: generateId(), 10 | to: generateId(), 11 | seq: 0, 12 | ack: 0, 13 | streamId: generateId(), 14 | controlFlags: 0, 15 | payload: generateId(), 16 | }; 17 | } 18 | 19 | describe('EventDispatcher', () => { 20 | test('notifies all handlers in order they were registered', () => { 21 | const dispatcher = new EventDispatcher(); 22 | 23 | const handler1 = vitest.fn(); 24 | const handler2 = vitest.fn(); 25 | const sessionStatusHandler = vitest.fn(); 26 | 27 | dispatcher.addEventListener('message', handler1); 28 | dispatcher.addEventListener('message', handler2); 29 | dispatcher.addEventListener('sessionStatus', sessionStatusHandler); 30 | 31 | expect(dispatcher.numberOfListeners('message')).toEqual(2); 32 | 33 | const message = dummyMessage(); 34 | 35 | dispatcher.dispatchEvent('message', message); 36 | 37 | expect(handler1).toHaveBeenCalledTimes(1); 38 | expect(handler2).toHaveBeenCalledTimes(1); 39 | expect(handler1).toHaveBeenCalledWith(message); 40 | expect(handler2).toHaveBeenCalledWith(message); 41 | expect(handler1.mock.invocationCallOrder[0]).toBeLessThan( 42 | handler2.mock.invocationCallOrder[0], 43 | ); 44 | expect(sessionStatusHandler).toHaveBeenCalledTimes(0); 45 | }); 46 | 47 | test('does not notify removed handlers', () => { 48 | const dispatcher = new EventDispatcher(); 49 | 50 | const handler1 = vitest.fn(); 51 | const handler2 = vitest.fn(); 52 | 53 | dispatcher.addEventListener('message', handler1); 54 | dispatcher.addEventListener('message', handler2); 55 | 56 | dispatcher.removeEventListener('message', handler1); 57 | dispatcher.removeEventListener('message', function neverRegistered() { 58 | /** */ 59 | }); 60 | 61 | expect(dispatcher.numberOfListeners('message')).toEqual(1); 62 | 63 | const message = dummyMessage(); 64 | 65 | dispatcher.dispatchEvent('message', message); 66 | 67 | expect(handler1).toHaveBeenCalledTimes(0); 68 | expect(handler2).toHaveBeenCalledTimes(1); 69 | expect(handler2).toHaveBeenCalledWith(message); 70 | }); 71 | 72 | test('does not notify handlers added while notifying another handler', () => { 73 | const dispatcher = new EventDispatcher(); 74 | 75 | const handler1 = vitest.fn(() => { 76 | dispatcher.addEventListener('message', handler2); 77 | }); 78 | const handler2 = vitest.fn(); 79 | 80 | dispatcher.addEventListener('message', handler1); 81 | 82 | const message = dummyMessage(); 83 | 84 | dispatcher.dispatchEvent('message', message); 85 | 86 | expect(handler1).toHaveBeenCalledTimes(1); 87 | expect(handler2).toHaveBeenCalledTimes(0); 88 | 89 | dispatcher.dispatchEvent('message', message); 90 | 91 | expect(handler1).toHaveBeenCalledTimes(2); 92 | expect(handler2).toHaveBeenCalledTimes(1); 93 | }); 94 | 95 | test('does notify handlers removed while notifying another handler', () => { 96 | const dispatcher = new EventDispatcher(); 97 | 98 | const handler1 = vitest.fn(); 99 | const handler2 = vitest.fn(() => { 100 | dispatcher.removeEventListener('message', handler1); 101 | }); 102 | 103 | dispatcher.addEventListener('message', handler1); 104 | dispatcher.addEventListener('message', handler2); 105 | 106 | const message = dummyMessage(); 107 | 108 | dispatcher.dispatchEvent('message', message); 109 | 110 | expect(handler1).toHaveBeenCalledTimes(1); 111 | expect(handler2).toHaveBeenCalledTimes(1); 112 | 113 | dispatcher.dispatchEvent('message', message); 114 | 115 | expect(handler1).toHaveBeenCalledTimes(1); 116 | expect(handler2).toHaveBeenCalledTimes(2); 117 | }); 118 | 119 | test('removes all listeners', () => { 120 | const dispatcher = new EventDispatcher(); 121 | 122 | const handler = vitest.fn(); 123 | dispatcher.addEventListener('message', handler); 124 | dispatcher.addEventListener('protocolError', handler); 125 | dispatcher.addEventListener('sessionStatus', handler); 126 | dispatcher.addEventListener('transportStatus', handler); 127 | 128 | dispatcher.removeAllListeners(); 129 | expect(dispatcher.numberOfListeners('message')).toEqual(0); 130 | expect(dispatcher.numberOfListeners('protocolError')).toEqual(0); 131 | expect(dispatcher.numberOfListeners('sessionStatus')).toEqual(0); 132 | expect(dispatcher.numberOfListeners('transportStatus')).toEqual(0); 133 | 134 | // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-argument 135 | dispatcher.dispatchEvent('message', {} as any); 136 | // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-argument 137 | dispatcher.dispatchEvent('protocolError', {} as any); 138 | // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-argument 139 | dispatcher.dispatchEvent('sessionStatus', {} as any); 140 | // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-argument 141 | dispatcher.dispatchEvent('transportStatus', {} as any); 142 | 143 | expect(handler).toHaveBeenCalledTimes(0); 144 | }); 145 | }); 146 | -------------------------------------------------------------------------------- /transport/events.ts: -------------------------------------------------------------------------------- 1 | import { type Static } from '@sinclair/typebox'; 2 | import { Connection } from './connection'; 3 | import { OpaqueTransportMessage, HandshakeErrorResponseCodes } from './message'; 4 | import { Session, SessionState } from './sessionStateMachine'; 5 | import { SessionId } from './sessionStateMachine/common'; 6 | import { TransportStatus } from './transport'; 7 | 8 | export const ProtocolError = { 9 | RetriesExceeded: 'conn_retry_exceeded', 10 | HandshakeFailed: 'handshake_failed', 11 | MessageOrderingViolated: 'message_ordering_violated', 12 | InvalidMessage: 'invalid_message', 13 | MessageSendFailure: 'message_send_failure', 14 | } as const; 15 | 16 | export type ProtocolErrorType = 17 | (typeof ProtocolError)[keyof typeof ProtocolError]; 18 | 19 | export interface EventMap { 20 | message: OpaqueTransportMessage; 21 | sessionStatus: 22 | | { 23 | status: 'created' | 'closing'; 24 | session: Session; 25 | } 26 | | { 27 | status: 'closed'; 28 | session: Pick, 'id' | 'to'>; 29 | }; 30 | sessionTransition: 31 | | { state: SessionState.Connected; id: SessionId } 32 | | { state: SessionState.Handshaking; id: SessionId } 33 | | { state: SessionState.Connecting; id: SessionId } 34 | | { state: SessionState.BackingOff; id: SessionId } 35 | | { state: SessionState.NoConnection; id: SessionId }; 36 | protocolError: 37 | | { 38 | type: (typeof ProtocolError)['HandshakeFailed']; 39 | code: Static; 40 | message: string; 41 | } 42 | | { 43 | type: Omit< 44 | ProtocolErrorType, 45 | (typeof ProtocolError)['HandshakeFailed'] 46 | >; 47 | message: string; 48 | }; 49 | transportStatus: { 50 | status: TransportStatus; 51 | }; 52 | } 53 | 54 | export type EventTypes = keyof EventMap; 55 | export type EventHandler = ( 56 | event: EventMap[K], 57 | ) => unknown; 58 | 59 | export class EventDispatcher { 60 | private eventListeners: { [K in T]?: Set> } = {}; 61 | 62 | removeAllListeners() { 63 | this.eventListeners = {}; 64 | } 65 | 66 | numberOfListeners(eventType: K) { 67 | return this.eventListeners[eventType]?.size ?? 0; 68 | } 69 | 70 | addEventListener(eventType: K, handler: EventHandler) { 71 | if (!this.eventListeners[eventType]) { 72 | this.eventListeners[eventType] = new Set(); 73 | } 74 | 75 | this.eventListeners[eventType]?.add(handler); 76 | } 77 | 78 | removeEventListener(eventType: K, handler: EventHandler) { 79 | const handlers = this.eventListeners[eventType]; 80 | if (handlers) { 81 | this.eventListeners[eventType]?.delete(handler); 82 | } 83 | } 84 | 85 | dispatchEvent(eventType: K, event: EventMap[K]) { 86 | const handlers = this.eventListeners[eventType]; 87 | if (handlers) { 88 | // copying ensures that adding more listeners in a handler doesn't 89 | // affect the current dispatch. 90 | const copy = [...handlers]; 91 | for (const handler of copy) { 92 | handler(event); 93 | } 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /transport/id.ts: -------------------------------------------------------------------------------- 1 | import { customAlphabet } from 'nanoid'; 2 | 3 | const alphabet = customAlphabet( 4 | '1234567890abcdefghijklmnopqrstuvxyzABCDEFGHIJKLMNOPQRSTUVXYZ', 5 | ); 6 | export const generateId = () => alphabet(12); 7 | -------------------------------------------------------------------------------- /transport/impls/ws/client.ts: -------------------------------------------------------------------------------- 1 | import { ClientTransport } from '../../client'; 2 | import { TransportClientId } from '../../message'; 3 | import { ProvidedClientTransportOptions } from '../../options'; 4 | import { WebSocketConnection } from './connection'; 5 | import { WsLike } from './wslike'; 6 | 7 | /** 8 | * A transport implementation that uses a WebSocket connection with automatic reconnection. 9 | * @class 10 | * @extends Transport 11 | */ 12 | export class WebSocketClientTransport extends ClientTransport { 13 | /** 14 | * A function that returns a Promise that resolves to a websocket URL. 15 | */ 16 | wsGetter: (to: TransportClientId) => Promise | WsLike; 17 | 18 | /** 19 | * Creates a new WebSocketClientTransport instance. 20 | * @param wsGetter A function that returns a Promise that resolves to a WebSocket instance. 21 | * @param clientId The ID of the client using the transport. This should be unique per session. 22 | * @param serverId The ID of the server this transport is connecting to. 23 | * @param providedOptions An optional object containing configuration options for the transport. 24 | */ 25 | constructor( 26 | wsGetter: (to: TransportClientId) => Promise | WsLike, 27 | clientId: TransportClientId, 28 | providedOptions?: ProvidedClientTransportOptions, 29 | ) { 30 | super(clientId, providedOptions); 31 | this.wsGetter = wsGetter; 32 | } 33 | 34 | async createNewOutgoingConnection(to: string) { 35 | this.log?.info(`establishing a new websocket to ${to}`, { 36 | clientId: this.clientId, 37 | connectedTo: to, 38 | }); 39 | 40 | const ws = await this.wsGetter(to); 41 | 42 | await new Promise((resolve, reject) => { 43 | if (ws.readyState === ws.OPEN) { 44 | resolve(); 45 | 46 | return; 47 | } 48 | 49 | if (ws.readyState === ws.CLOSING || ws.readyState === ws.CLOSED) { 50 | reject(new Error('ws is closing or closed')); 51 | 52 | return; 53 | } 54 | 55 | ws.onopen = () => { 56 | resolve(); 57 | }; 58 | 59 | ws.onclose = (evt) => { 60 | reject(new Error(evt.reason)); 61 | }; 62 | 63 | ws.onerror = (err) => { 64 | reject(new Error(err.message)); 65 | }; 66 | }); 67 | 68 | const conn = new WebSocketConnection(ws); 69 | this.log?.info(`raw websocket to ${to} ok`, { 70 | clientId: this.clientId, 71 | connectedTo: to, 72 | ...conn.loggingMetadata, 73 | }); 74 | 75 | return conn; 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /transport/impls/ws/connection.ts: -------------------------------------------------------------------------------- 1 | import { Connection } from '../../connection'; 2 | import { WsLike } from './wslike'; 3 | 4 | interface ConnectionInfoExtras extends Record { 5 | headers: Record; 6 | } 7 | 8 | const WS_HEALTHY_CLOSE_CODE = 1000; 9 | 10 | export class WebSocketConnection extends Connection { 11 | ws: WsLike; 12 | extras?: ConnectionInfoExtras; 13 | 14 | get loggingMetadata() { 15 | const metadata = super.loggingMetadata; 16 | if (this.extras) { 17 | metadata.extras = this.extras; 18 | } 19 | 20 | return metadata; 21 | } 22 | 23 | constructor(ws: WsLike, extras?: ConnectionInfoExtras) { 24 | super(); 25 | this.ws = ws; 26 | this.extras = extras; 27 | this.ws.binaryType = 'arraybuffer'; 28 | 29 | // Websockets are kinda shitty, they emit error events with no 30 | // information other than it errored, so we have to do some extra 31 | // work to figure out what happened. 32 | let didError = false; 33 | this.ws.onerror = () => { 34 | didError = true; 35 | }; 36 | 37 | this.ws.onclose = ({ code, reason }) => { 38 | if (didError) { 39 | const err = new Error( 40 | `websocket closed with code and reason: ${code} - ${reason}`, 41 | ); 42 | 43 | this.onError(err); 44 | } 45 | 46 | this.onClose(); 47 | }; 48 | 49 | this.ws.onmessage = (msg) => { 50 | this.onData(msg.data as Uint8Array); 51 | }; 52 | } 53 | 54 | send(payload: Uint8Array) { 55 | try { 56 | this.ws.send(payload); 57 | 58 | return true; 59 | } catch { 60 | return false; 61 | } 62 | } 63 | 64 | close() { 65 | // we close with 1000 normal even if its not really healthy at the river level 66 | // if we don't specify this, it defaults to 1005 which 67 | // some proxies/loggers detect as an error 68 | this.ws.close(WS_HEALTHY_CLOSE_CODE); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /transport/impls/ws/server.ts: -------------------------------------------------------------------------------- 1 | import { TransportClientId } from '../../message'; 2 | import { WebSocketServer } from 'ws'; 3 | import { WebSocketConnection } from './connection'; 4 | import { WsLike } from './wslike'; 5 | import { ServerTransport } from '../../server'; 6 | import { ProvidedServerTransportOptions } from '../../options'; 7 | import { type IncomingMessage } from 'http'; 8 | 9 | function cleanHeaders( 10 | headers: IncomingMessage['headers'], 11 | ): Record { 12 | const cleanedHeaders: Record = {}; 13 | 14 | for (const [key, value] of Object.entries(headers)) { 15 | if (!key.startsWith('sec-') && value) { 16 | const cleanedValue = Array.isArray(value) ? value[0] : value; 17 | cleanedHeaders[key] = cleanedValue; 18 | } 19 | } 20 | 21 | return cleanedHeaders; 22 | } 23 | 24 | export class WebSocketServerTransport extends ServerTransport { 25 | wss: WebSocketServer; 26 | 27 | constructor( 28 | wss: WebSocketServer, 29 | clientId: TransportClientId, 30 | providedOptions?: ProvidedServerTransportOptions, 31 | ) { 32 | super(clientId, providedOptions); 33 | this.wss = wss; 34 | this.wss.on('connection', this.connectionHandler); 35 | } 36 | 37 | connectionHandler = (ws: WsLike, req: IncomingMessage) => { 38 | const conn = new WebSocketConnection(ws, { 39 | headers: cleanHeaders(req.headersDistinct), 40 | }); 41 | 42 | this.handleConnection(conn); 43 | }; 44 | 45 | close() { 46 | super.close(); 47 | this.wss.off('connection', this.connectionHandler); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /transport/impls/ws/ws.test.ts: -------------------------------------------------------------------------------- 1 | import http from 'node:http'; 2 | import { describe, test, expect, beforeEach } from 'vitest'; 3 | import { 4 | createWebSocketServer, 5 | onWsServerReady, 6 | waitForMessage, 7 | createDummyTransportMessage, 8 | payloadToTransportMessage, 9 | createLocalWebSocketClient, 10 | numberOfConnections, 11 | getTransportConnections, 12 | getClientSendFn, 13 | getServerSendFn, 14 | } from '../../../testUtil'; 15 | import { WebSocketServerTransport } from './server'; 16 | import { WebSocketClientTransport } from './client'; 17 | import { 18 | advanceFakeTimersBySessionGrace, 19 | cleanupTransports, 20 | testFinishesCleanly, 21 | waitFor, 22 | } from '../../../testUtil/fixtures/cleanup'; 23 | import { PartialTransportMessage } from '../../message'; 24 | import type NodeWs from 'ws'; 25 | import { createPostTestCleanups } from '../../../testUtil/fixtures/cleanup'; 26 | 27 | describe('sending and receiving across websockets works', async () => { 28 | let server: http.Server; 29 | let port: number; 30 | let wss: NodeWs.Server; 31 | 32 | const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); 33 | beforeEach(async () => { 34 | server = http.createServer(); 35 | port = await onWsServerReady(server); 36 | wss = createWebSocketServer(server); 37 | 38 | return async () => { 39 | await postTestCleanup(); 40 | wss.close(); 41 | server.close(); 42 | }; 43 | }); 44 | 45 | test('basic send/receive', async () => { 46 | const clientTransport = new WebSocketClientTransport( 47 | () => Promise.resolve(createLocalWebSocketClient(port)), 48 | 'client', 49 | ); 50 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 51 | clientTransport.connect(serverTransport.clientId); 52 | const clientSendFn = getClientSendFn(clientTransport, serverTransport); 53 | 54 | addPostTestCleanup(async () => { 55 | await cleanupTransports([clientTransport, serverTransport]); 56 | }); 57 | 58 | const msg = createDummyTransportMessage(); 59 | const msgId = clientSendFn(msg); 60 | await expect( 61 | waitForMessage(serverTransport, (recv) => recv.id === msgId), 62 | ).resolves.toStrictEqual(msg.payload); 63 | 64 | await testFinishesCleanly({ 65 | clientTransports: [clientTransport], 66 | serverTransport, 67 | }); 68 | }); 69 | 70 | test('sending respects to/from fields', async () => { 71 | const makeDummyMessage = (message: string): PartialTransportMessage => { 72 | return payloadToTransportMessage({ message }); 73 | }; 74 | 75 | const clientId1 = 'client1'; 76 | const clientId2 = 'client2'; 77 | const serverId = 'SERVER'; 78 | const serverTransport = new WebSocketServerTransport(wss, serverId); 79 | 80 | const initClient = async (id: string) => { 81 | const client = new WebSocketClientTransport( 82 | () => Promise.resolve(createLocalWebSocketClient(port)), 83 | id, 84 | ); 85 | 86 | // client to server 87 | client.connect(serverTransport.clientId); 88 | const clientSendFn = getClientSendFn(client, serverTransport); 89 | const initMsg = makeDummyMessage('hello server'); 90 | const initMsgId = clientSendFn(initMsg); 91 | await expect( 92 | waitForMessage(serverTransport, (recv) => recv.id === initMsgId), 93 | ).resolves.toStrictEqual(initMsg.payload); 94 | 95 | return client; 96 | }; 97 | 98 | const client1 = await initClient(clientId1); 99 | const client2 = await initClient(clientId2); 100 | addPostTestCleanup(async () => { 101 | await cleanupTransports([client1, client2, serverTransport]); 102 | }); 103 | 104 | // sending messages from server to client shouldn't leak between clients 105 | const msg1 = makeDummyMessage('hello client1'); 106 | const msg2 = makeDummyMessage('hello client2'); 107 | const msg1Id = getServerSendFn(serverTransport, client1)(msg1); 108 | const msg2Id = getServerSendFn(serverTransport, client2)(msg2); 109 | const promises = Promise.all([ 110 | // true means reject if we receive any message that isn't the one we are expecting 111 | waitForMessage(client2, (recv) => recv.id === msg2Id, true), 112 | waitForMessage(client1, (recv) => recv.id === msg1Id, true), 113 | ]); 114 | await expect(promises).resolves.toStrictEqual( 115 | expect.arrayContaining([msg1.payload, msg2.payload]), 116 | ); 117 | 118 | await testFinishesCleanly({ 119 | clientTransports: [client1, client2], 120 | serverTransport, 121 | }); 122 | }); 123 | 124 | test('hanging ws connection with no handshake is cleaned up after grace', async () => { 125 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 126 | addPostTestCleanup(async () => { 127 | await cleanupTransports([serverTransport]); 128 | }); 129 | 130 | const ws = createLocalWebSocketClient(port); 131 | 132 | // wait for ws to be open 133 | await new Promise((resolve) => (ws.onopen = resolve)); 134 | 135 | // we never sent a handshake so there should be no connections or sessions 136 | expect(numberOfConnections(serverTransport)).toBe(0); 137 | expect(serverTransport.sessions.size).toBe(0); 138 | 139 | // advance time past the grace period 140 | await advanceFakeTimersBySessionGrace(); 141 | 142 | // the connection should have been cleaned up 143 | await waitFor(() => { 144 | expect(numberOfConnections(serverTransport)).toBe(0); 145 | expect(serverTransport.sessions.size).toBe(0); 146 | expect(ws.readyState).toBe(ws.CLOSED); 147 | }); 148 | 149 | await testFinishesCleanly({ 150 | clientTransports: [], 151 | serverTransport, 152 | }); 153 | }); 154 | 155 | test('ws connection is recreated after unclean disconnect', async () => { 156 | const clientTransport = new WebSocketClientTransport( 157 | () => Promise.resolve(createLocalWebSocketClient(port)), 158 | 'client', 159 | ); 160 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 161 | 162 | clientTransport.connect(serverTransport.clientId); 163 | const clientSendFn = getClientSendFn(clientTransport, serverTransport); 164 | 165 | addPostTestCleanup(async () => { 166 | await cleanupTransports([clientTransport, serverTransport]); 167 | }); 168 | 169 | const msg1 = createDummyTransportMessage(); 170 | const msg2 = createDummyTransportMessage(); 171 | 172 | const msg1Id = clientSendFn(msg1); 173 | await expect( 174 | waitForMessage(serverTransport, (recv) => recv.id === msg1Id), 175 | ).resolves.toStrictEqual(msg1.payload); 176 | 177 | // unclean client disconnect 178 | for (const conn of getTransportConnections(clientTransport)) { 179 | (conn.ws as NodeWs).terminate(); 180 | } 181 | 182 | // by this point the client should have reconnected 183 | const msg2Id = clientSendFn(msg2); 184 | await expect( 185 | waitForMessage(serverTransport, (recv) => recv.id === msg2Id), 186 | ).resolves.toStrictEqual(msg2.payload); 187 | 188 | await testFinishesCleanly({ 189 | clientTransports: [clientTransport], 190 | serverTransport, 191 | }); 192 | }); 193 | 194 | test('ws connection always calls the close callback', async () => { 195 | const clientTransport = new WebSocketClientTransport( 196 | () => Promise.resolve(createLocalWebSocketClient(port)), 197 | 'client', 198 | ); 199 | const serverTransport = new WebSocketServerTransport(wss, 'SERVER'); 200 | clientTransport.connect(serverTransport.clientId); 201 | const clientSendFn = getClientSendFn(clientTransport, serverTransport); 202 | 203 | addPostTestCleanup(async () => { 204 | await cleanupTransports([clientTransport, serverTransport]); 205 | }); 206 | 207 | const msg1 = createDummyTransportMessage(); 208 | const msg2 = createDummyTransportMessage(); 209 | 210 | const msg1Id = clientSendFn(msg1); 211 | await expect( 212 | waitForMessage(serverTransport, (recv) => recv.id === msg1Id), 213 | ).resolves.toStrictEqual(msg1.payload); 214 | 215 | // unclean server disconnect. Note that the Node implementation sends the reason on the 216 | // `onclose`, but (some?) browsers call the `onerror` handler before it since it was an unclean 217 | // exit. 218 | for (const conn of getTransportConnections(clientTransport)) { 219 | (conn.ws as NodeWs).terminate(); 220 | } 221 | 222 | // by this point the client should have reconnected 223 | const msg2Id = clientSendFn(msg2); 224 | await expect( 225 | waitForMessage(serverTransport, (recv) => recv.id === msg2Id), 226 | ).resolves.toStrictEqual(msg2.payload); 227 | 228 | await testFinishesCleanly({ 229 | clientTransports: [clientTransport], 230 | serverTransport, 231 | }); 232 | }); 233 | }); 234 | -------------------------------------------------------------------------------- /transport/impls/ws/wslike.ts: -------------------------------------------------------------------------------- 1 | interface WsEvent extends Event { 2 | type: string; 3 | // we don't care about the target 4 | // because we never use it -- we need to just 5 | // give it any to suppress the underlying type 6 | // see: https://www.typescriptlang.org/docs/handbook/type-compatibility.html#any-unknown-object-void-undefined-null-and-never-assignability 7 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 8 | target: any; 9 | } 10 | 11 | interface ErrorEvent extends WsEvent { 12 | error: unknown; 13 | message: string; 14 | } 15 | 16 | interface CloseEvent extends WsEvent { 17 | wasClean: boolean; 18 | code: number; 19 | reason: string; 20 | } 21 | 22 | interface MessageEvent extends WsEvent { 23 | // same here: we don't know the underlying type of data so we 24 | // need to just give it any to suppress the underlying type 25 | // eslint-disable-next-line @typescript-eslint/no-explicit-any 26 | data: any; 27 | } 28 | 29 | export interface WsLikeWithHandlers { 30 | readonly CONNECTING: 0; 31 | readonly OPEN: 1; 32 | readonly CLOSING: 2; 33 | readonly CLOSED: 3; 34 | 35 | binaryType: string; 36 | readonly readyState: number; 37 | 38 | onclose(ev: CloseEvent): unknown; 39 | onmessage(ev: MessageEvent): unknown; 40 | onopen(ev: WsEvent): unknown; 41 | onerror(ev: ErrorEvent): unknown; 42 | 43 | send(data: unknown): void; 44 | close(code?: number, reason?: string): void; 45 | } 46 | 47 | // null specific fields 48 | // to my knowledge, this is the only way to get nullable interface methods 49 | // instead of function types 50 | // variance is different for methods and properties 51 | // https://www.typescriptlang.org/docs/handbook/type-compatibility.html#function-parameter-bivariance 52 | type Nullable = { 53 | [_K in keyof T]: _K extends K ? T[_K] | null : T[_K]; 54 | }; 55 | 56 | /** 57 | * A websocket-like interface that has all we need, this matches 58 | * "lib.dom.d.ts" and npm's "ws" websocket interfaces. 59 | */ 60 | export type WsLike = Nullable< 61 | WsLikeWithHandlers, 62 | 'onclose' | 'onmessage' | 'onopen' | 'onerror' 63 | >; 64 | -------------------------------------------------------------------------------- /transport/index.ts: -------------------------------------------------------------------------------- 1 | export { Transport } from './transport'; 2 | export { ClientTransport } from './client'; 3 | export { ServerTransport } from './server'; 4 | export type { TransportStatus } from './transport'; 5 | export type { 6 | ProvidedTransportOptions as TransportOptions, 7 | ProvidedClientTransportOptions as ClientTransportOptions, 8 | ProvidedServerTransportOptions as ServerTransportOptions, 9 | } from './options'; 10 | export { 11 | Session, 12 | SessionState, 13 | type SessionNoConnection, 14 | type SessionConnecting, 15 | type SessionHandshaking, 16 | type SessionConnected, 17 | type SessionWaitingForHandshake, 18 | } from './sessionStateMachine'; 19 | export { Connection } from './connection'; 20 | export { 21 | TransportMessageSchema, 22 | OpaqueTransportMessageSchema, 23 | } from './message'; 24 | export type { 25 | TransportMessage, 26 | OpaqueTransportMessage, 27 | TransportClientId, 28 | isStreamOpen, 29 | isStreamClose, 30 | } from './message'; 31 | export { 32 | EventMap, 33 | EventTypes, 34 | EventHandler, 35 | ProtocolError, 36 | type ProtocolErrorType, 37 | } from './events'; 38 | -------------------------------------------------------------------------------- /transport/message.test.ts: -------------------------------------------------------------------------------- 1 | import { TransportMessage } from '.'; 2 | import { 3 | ControlFlags, 4 | handshakeRequestMessage, 5 | handshakeResponseMessage, 6 | isAck, 7 | isStreamClose, 8 | isStreamOpen, 9 | } from './message'; 10 | import { describe, test, expect } from 'vitest'; 11 | 12 | const msg = ( 13 | to: string, 14 | from: string, 15 | streamId: string, 16 | payload: unknown, 17 | serviceName: string, 18 | procedureName: string, 19 | ): TransportMessage => ({ 20 | id: 'abc', 21 | to, 22 | from, 23 | streamId, 24 | payload, 25 | serviceName, 26 | procedureName, 27 | controlFlags: 0, 28 | seq: 0, 29 | ack: 0, 30 | }); 31 | 32 | describe('message helpers', () => { 33 | test('ack', () => { 34 | const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); 35 | m.controlFlags |= ControlFlags.AckBit; 36 | 37 | expect(isAck(m.controlFlags)).toBe(true); 38 | expect(isStreamOpen(m.controlFlags)).toBe(false); 39 | expect(isStreamClose(m.controlFlags)).toBe(false); 40 | }); 41 | 42 | test('streamOpen', () => { 43 | const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); 44 | m.controlFlags |= ControlFlags.StreamOpenBit; 45 | 46 | expect(isAck(m.controlFlags)).toBe(false); 47 | expect(isStreamOpen(m.controlFlags)).toBe(true); 48 | expect(isStreamClose(m.controlFlags)).toBe(false); 49 | }); 50 | 51 | test('streamClose', () => { 52 | const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); 53 | m.controlFlags |= ControlFlags.StreamClosedBit; 54 | 55 | expect(isAck(m.controlFlags)).toBe(false); 56 | expect(isStreamOpen(m.controlFlags)).toBe(false); 57 | expect(isStreamClose(m.controlFlags)).toBe(true); 58 | }); 59 | 60 | test('handshakeRequestMessage', () => { 61 | const m = handshakeRequestMessage({ 62 | from: 'a', 63 | to: 'b', 64 | expectedSessionState: { 65 | nextExpectedSeq: 0, 66 | nextSentSeq: 0, 67 | }, 68 | sessionId: 'sess', 69 | }); 70 | 71 | expect(m).toMatchObject({ 72 | from: 'a', 73 | to: 'b', 74 | payload: { 75 | sessionId: 'sess', 76 | }, 77 | }); 78 | }); 79 | 80 | test('handshakeResponseMessage', () => { 81 | const mSuccess = handshakeResponseMessage({ 82 | from: 'a', 83 | to: 'b', 84 | status: { 85 | ok: true, 86 | sessionId: 'sess', 87 | }, 88 | }); 89 | const mFail = handshakeResponseMessage({ 90 | from: 'a', 91 | to: 'b', 92 | status: { 93 | ok: false, 94 | reason: 'bad', 95 | code: 'SESSION_STATE_MISMATCH', 96 | }, 97 | }); 98 | 99 | expect(mSuccess.from).toBe('a'); 100 | expect(mSuccess.to).toBe('b'); 101 | expect(mSuccess.payload.status.ok).toBe(true); 102 | 103 | expect(mFail.from).toBe('a'); 104 | expect(mFail.to).toBe('b'); 105 | expect(mFail.payload.status.ok).toBe(false); 106 | }); 107 | 108 | test('default message has no control flags set', () => { 109 | const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); 110 | 111 | expect(isAck(m.controlFlags)).toBe(false); 112 | expect(isStreamOpen(m.controlFlags)).toBe(false); 113 | expect(isStreamClose(m.controlFlags)).toBe(false); 114 | }); 115 | 116 | test('combining control flags works', () => { 117 | const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); 118 | m.controlFlags |= ControlFlags.StreamOpenBit; 119 | 120 | expect(isStreamOpen(m.controlFlags)).toBe(true); 121 | expect(isStreamClose(m.controlFlags)).toBe(false); 122 | 123 | m.controlFlags |= ControlFlags.StreamClosedBit; 124 | expect(isStreamOpen(m.controlFlags)).toBe(true); 125 | expect(isStreamClose(m.controlFlags)).toBe(true); 126 | }); 127 | }); 128 | -------------------------------------------------------------------------------- /transport/options.ts: -------------------------------------------------------------------------------- 1 | import { NaiveJsonCodec } from '../codec/json'; 2 | import { ConnectionRetryOptions } from './rateLimit'; 3 | import { SessionOptions } from './sessionStateMachine/common'; 4 | 5 | export type TransportOptions = SessionOptions; 6 | 7 | export type ProvidedTransportOptions = Partial; 8 | 9 | export const defaultTransportOptions: TransportOptions = { 10 | heartbeatIntervalMs: 1_000, 11 | heartbeatsUntilDead: 2, 12 | sessionDisconnectGraceMs: 5_000, 13 | connectionTimeoutMs: 2_000, 14 | handshakeTimeoutMs: 1_000, 15 | enableTransparentSessionReconnects: true, 16 | codec: NaiveJsonCodec, 17 | }; 18 | 19 | export type ClientTransportOptions = TransportOptions & ConnectionRetryOptions; 20 | 21 | export type ProvidedClientTransportOptions = Partial; 22 | 23 | const defaultConnectionRetryOptions: ConnectionRetryOptions = { 24 | baseIntervalMs: 150, 25 | maxJitterMs: 200, 26 | maxBackoffMs: 32_000, 27 | attemptBudgetCapacity: 5, 28 | budgetRestoreIntervalMs: 200, 29 | }; 30 | 31 | export const defaultClientTransportOptions: ClientTransportOptions = { 32 | ...defaultTransportOptions, 33 | ...defaultConnectionRetryOptions, 34 | }; 35 | 36 | export type ServerTransportOptions = TransportOptions; 37 | 38 | export type ProvidedServerTransportOptions = Partial; 39 | 40 | export const defaultServerTransportOptions: ServerTransportOptions = { 41 | ...defaultTransportOptions, 42 | }; 43 | -------------------------------------------------------------------------------- /transport/rateLimit.test.ts: -------------------------------------------------------------------------------- 1 | import { 2 | LeakyBucketRateLimit, 3 | ConnectionRetryOptions, 4 | } from '../transport/rateLimit'; 5 | import { describe, test, expect, vi } from 'vitest'; 6 | 7 | describe('LeakyBucketRateLimit', () => { 8 | const options: ConnectionRetryOptions = { 9 | attemptBudgetCapacity: 10, 10 | budgetRestoreIntervalMs: 1000, 11 | baseIntervalMs: 100, 12 | maxJitterMs: 50, 13 | maxBackoffMs: 5000, 14 | }; 15 | 16 | test('should return 0 backoff time for new user', () => { 17 | const rateLimit = new LeakyBucketRateLimit(options); 18 | const backoffMs = rateLimit.getBackoffMs(); 19 | expect(backoffMs).toBe(0); 20 | }); 21 | 22 | test('should return 0 budget consumed for new user', () => { 23 | const rateLimit = new LeakyBucketRateLimit(options); 24 | const budgetConsumed = rateLimit.getBudgetConsumed(); 25 | expect(budgetConsumed).toBe(0); 26 | }); 27 | 28 | test('should consume budget correctly', () => { 29 | const rateLimit = new LeakyBucketRateLimit(options); 30 | rateLimit.consumeBudget(); 31 | expect(rateLimit.getBudgetConsumed()).toBe(1); 32 | }); 33 | 34 | test('keeps growing until startRestoringBudget', () => { 35 | const rateLimit = new LeakyBucketRateLimit(options); 36 | rateLimit.consumeBudget(); 37 | rateLimit.consumeBudget(); 38 | expect(rateLimit.getBudgetConsumed()).toBe(2); 39 | 40 | // Advanding time before startRestoringBudget should be noop 41 | vi.advanceTimersByTime(options.budgetRestoreIntervalMs); 42 | expect(rateLimit.getBudgetConsumed()).toBe(2); 43 | 44 | rateLimit.startRestoringBudget(); 45 | expect(rateLimit.getBudgetConsumed()).toBe(2); 46 | vi.advanceTimersByTime(options.budgetRestoreIntervalMs); 47 | expect(rateLimit.getBudgetConsumed()).toBe(1); 48 | }); 49 | 50 | test('stops restoring budget when we consume budget again', () => { 51 | const rateLimit = new LeakyBucketRateLimit(options); 52 | rateLimit.consumeBudget(); 53 | rateLimit.consumeBudget(); 54 | expect(rateLimit.getBudgetConsumed()).toBe(2); 55 | 56 | rateLimit.startRestoringBudget(); 57 | expect(rateLimit.getBudgetConsumed()).toBe(2); 58 | 59 | rateLimit.consumeBudget(); 60 | expect(rateLimit.getBudgetConsumed()).toBe(3); 61 | vi.advanceTimersByTime(options.budgetRestoreIntervalMs); 62 | expect(rateLimit.getBudgetConsumed()).toBe(3); 63 | }); 64 | 65 | test('respects maximum backoff time', () => { 66 | const maxBackoffMs = 50; 67 | const rateLimit = new LeakyBucketRateLimit({ ...options, maxBackoffMs }); 68 | 69 | rateLimit.consumeBudget(); 70 | 71 | expect(rateLimit.getBackoffMs()).toBeLessThanOrEqual( 72 | maxBackoffMs + options.maxJitterMs, 73 | ); 74 | expect(rateLimit.getBackoffMs()).toBeGreaterThanOrEqual(maxBackoffMs); 75 | }); 76 | 77 | test('backoff increases', () => { 78 | const rateLimit = new LeakyBucketRateLimit(options); 79 | 80 | rateLimit.consumeBudget(); 81 | const backoffMs1 = rateLimit.getBackoffMs(); 82 | rateLimit.consumeBudget(); 83 | const backoffMs2 = rateLimit.getBackoffMs(); 84 | expect(backoffMs2).toBeGreaterThan(backoffMs1); 85 | rateLimit.consumeBudget(); 86 | const backoffMs3 = rateLimit.getBackoffMs(); 87 | expect(backoffMs3).toBeGreaterThan(backoffMs2); 88 | }); 89 | 90 | test('reports remaining budget correctly', () => { 91 | const maxAttempts = 3; 92 | const rateLimit = new LeakyBucketRateLimit({ 93 | ...options, 94 | attemptBudgetCapacity: maxAttempts, 95 | }); 96 | 97 | for (let i = 0; i < maxAttempts; i++) { 98 | expect(rateLimit.hasBudget()).toBe(true); 99 | rateLimit.consumeBudget(); 100 | } 101 | 102 | expect(rateLimit.hasBudget()).toBe(false); 103 | rateLimit.consumeBudget(); 104 | }); 105 | }); 106 | -------------------------------------------------------------------------------- /transport/rateLimit.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * Options to control the backoff and retry behavior of the client transport's connection behaviour. 3 | * 4 | * River implements exponential backoff with jitter to prevent flooding the server 5 | * when there's an issue with connection establishment. 6 | * 7 | * The backoff is calculated via the following: 8 | * backOff = min(jitter + {@link baseIntervalMs} * 2 ^ budget_consumed, {@link maxBackoffMs}) 9 | * 10 | * We use a leaky bucket rate limit with a budget of {@link attemptBudgetCapacity} reconnection attempts. 11 | * Budget only starts to restore after a successful handshake at a rate of one budget per {@link budgetRestoreIntervalMs}. 12 | */ 13 | export interface ConnectionRetryOptions { 14 | /** 15 | * The base interval to wait before retrying a connection. 16 | */ 17 | baseIntervalMs: number; 18 | 19 | /** 20 | * The maximum random jitter to add to the total backoff time. 21 | */ 22 | maxJitterMs: number; 23 | 24 | /** 25 | * The maximum amount of time to wait before retrying a connection. 26 | * This does not include the jitter. 27 | */ 28 | maxBackoffMs: number; 29 | 30 | /** 31 | * The max number of times to attempt a connection before a successful handshake. 32 | * This persists across connections but starts restoring budget after a successful handshake. 33 | * The restoration interval depends on {@link budgetRestoreIntervalMs} 34 | */ 35 | attemptBudgetCapacity: number; 36 | 37 | /** 38 | * After a successful connection attempt, how long to wait before we restore a single budget. 39 | */ 40 | budgetRestoreIntervalMs: number; 41 | } 42 | 43 | export class LeakyBucketRateLimit { 44 | private budgetConsumed: number; 45 | private intervalHandle?: ReturnType; 46 | private readonly options: ConnectionRetryOptions; 47 | 48 | constructor(options: ConnectionRetryOptions) { 49 | this.options = options; 50 | this.budgetConsumed = 0; 51 | } 52 | 53 | getBackoffMs() { 54 | if (this.getBudgetConsumed() === 0) { 55 | return 0; 56 | } 57 | 58 | const exponent = Math.max(0, this.getBudgetConsumed() - 1); 59 | const jitter = Math.floor(Math.random() * this.options.maxJitterMs); 60 | const backoffMs = Math.min( 61 | this.options.baseIntervalMs * 2 ** exponent, 62 | this.options.maxBackoffMs, 63 | ); 64 | 65 | return backoffMs + jitter; 66 | } 67 | 68 | get totalBudgetRestoreTime() { 69 | return ( 70 | this.options.budgetRestoreIntervalMs * this.options.attemptBudgetCapacity 71 | ); 72 | } 73 | 74 | consumeBudget() { 75 | // If we're consuming again, let's ensure that we're not leaking 76 | this.stopLeak(); 77 | this.budgetConsumed = this.getBudgetConsumed() + 1; 78 | } 79 | 80 | getBudgetConsumed() { 81 | return this.budgetConsumed; 82 | } 83 | 84 | hasBudget() { 85 | return this.getBudgetConsumed() < this.options.attemptBudgetCapacity; 86 | } 87 | 88 | startRestoringBudget() { 89 | if (this.intervalHandle) { 90 | return; 91 | } 92 | 93 | const restoreBudgetForUser = () => { 94 | const currentBudget = this.budgetConsumed; 95 | if (!currentBudget) { 96 | this.stopLeak(); 97 | 98 | return; 99 | } 100 | 101 | const newBudget = currentBudget - 1; 102 | if (newBudget === 0) { 103 | return; 104 | } 105 | 106 | this.budgetConsumed = newBudget; 107 | }; 108 | 109 | this.intervalHandle = setInterval( 110 | restoreBudgetForUser, 111 | this.options.budgetRestoreIntervalMs, 112 | ); 113 | } 114 | 115 | private stopLeak() { 116 | if (!this.intervalHandle) { 117 | return; 118 | } 119 | 120 | clearInterval(this.intervalHandle); 121 | this.intervalHandle = undefined; 122 | } 123 | 124 | close() { 125 | this.stopLeak(); 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /transport/results.ts: -------------------------------------------------------------------------------- 1 | import { OpaqueTransportMessage } from './message'; 2 | 3 | // internal use only, not to be used in public API 4 | type SessionApiResult = 5 | | { 6 | ok: true; 7 | value: T; 8 | } 9 | | { 10 | ok: false; 11 | reason: string; 12 | }; 13 | 14 | export type SendResult = SessionApiResult; 15 | export type SendBufferResult = SessionApiResult; 16 | export type SerializeResult = SessionApiResult; 17 | export type DeserializeResult = SessionApiResult; 18 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionBackingOff.ts: -------------------------------------------------------------------------------- 1 | import { 2 | IdentifiedSessionWithGracePeriod, 3 | IdentifiedSessionWithGracePeriodListeners, 4 | IdentifiedSessionWithGracePeriodProps, 5 | SessionState, 6 | } from './common'; 7 | 8 | export interface SessionBackingOffListeners 9 | extends IdentifiedSessionWithGracePeriodListeners { 10 | onBackoffFinished: () => void; 11 | } 12 | 13 | export interface SessionBackingOffProps 14 | extends IdentifiedSessionWithGracePeriodProps { 15 | backoffMs: number; 16 | listeners: SessionBackingOffListeners; 17 | } 18 | 19 | /* 20 | * A session that is backing off before attempting to connect. 21 | * See transitions.ts for valid transitions. 22 | */ 23 | export class SessionBackingOff extends IdentifiedSessionWithGracePeriod { 24 | readonly state = SessionState.BackingOff as const; 25 | listeners: SessionBackingOffListeners; 26 | 27 | backoffTimeout?: ReturnType; 28 | 29 | constructor(props: SessionBackingOffProps) { 30 | super(props); 31 | this.listeners = props.listeners; 32 | 33 | this.backoffTimeout = setTimeout(() => { 34 | this.listeners.onBackoffFinished(); 35 | }, props.backoffMs); 36 | } 37 | 38 | _handleClose(): void { 39 | super._handleClose(); 40 | } 41 | 42 | _handleStateExit(): void { 43 | super._handleStateExit(); 44 | 45 | if (this.backoffTimeout) { 46 | clearTimeout(this.backoffTimeout); 47 | this.backoffTimeout = undefined; 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionConnected.ts: -------------------------------------------------------------------------------- 1 | import { Static } from '@sinclair/typebox'; 2 | import { 3 | ControlFlags, 4 | ControlMessageAckSchema, 5 | OpaqueTransportMessage, 6 | PartialTransportMessage, 7 | TransportMessage, 8 | isAck, 9 | } from '../message'; 10 | import { 11 | IdentifiedSession, 12 | IdentifiedSessionProps, 13 | sendMessage, 14 | SessionState, 15 | } from './common'; 16 | import { Connection } from '../connection'; 17 | import { SpanStatusCode } from '@opentelemetry/api'; 18 | import { SendBufferResult, SendResult } from '../results'; 19 | 20 | export interface SessionConnectedListeners { 21 | onConnectionErrored: (err: unknown) => void; 22 | onConnectionClosed: () => void; 23 | onMessage: (msg: OpaqueTransportMessage) => void; 24 | onMessageSendFailure: (msg: PartialTransportMessage, reason: string) => void; 25 | onInvalidMessage: (reason: string) => void; 26 | } 27 | 28 | export interface SessionConnectedProps 29 | extends IdentifiedSessionProps { 30 | conn: ConnType; 31 | listeners: SessionConnectedListeners; 32 | } 33 | 34 | /* 35 | * A session that is connected and can send and receive messages. 36 | * See transitions.ts for valid transitions. 37 | */ 38 | export class SessionConnected< 39 | ConnType extends Connection, 40 | > extends IdentifiedSession { 41 | readonly state = SessionState.Connected as const; 42 | conn: ConnType; 43 | listeners: SessionConnectedListeners; 44 | 45 | private heartbeatHandle?: ReturnType | undefined; 46 | private heartbeatMissTimeout?: ReturnType | undefined; 47 | private isActivelyHeartbeating = false; 48 | 49 | updateBookkeeping(ack: number, seq: number) { 50 | this.sendBuffer = this.sendBuffer.filter((unacked) => unacked.seq >= ack); 51 | this.ack = seq + 1; 52 | 53 | if (this.heartbeatMissTimeout) { 54 | clearTimeout(this.heartbeatMissTimeout); 55 | } 56 | 57 | this.startMissingHeartbeatTimeout(); 58 | } 59 | 60 | private assertSendOrdering(constructedMsg: TransportMessage) { 61 | if (constructedMsg.seq > this.seqSent + 1) { 62 | const msg = `invariant violation: would have sent out of order msg (seq: ${constructedMsg.seq}, expected: ${this.seqSent} + 1)`; 63 | this.log?.error(msg, { 64 | ...this.loggingMetadata, 65 | transportMessage: constructedMsg, 66 | tags: ['invariant-violation'], 67 | }); 68 | 69 | throw new Error(msg); 70 | } 71 | } 72 | 73 | send(msg: PartialTransportMessage): SendResult { 74 | const constructedMsg = this.constructMsg(msg); 75 | this.assertSendOrdering(constructedMsg); 76 | this.sendBuffer.push(constructedMsg); 77 | const res = sendMessage(this.conn, this.codec, constructedMsg); 78 | if (!res.ok) { 79 | this.listeners.onMessageSendFailure(constructedMsg, res.reason); 80 | 81 | return res; 82 | } 83 | 84 | this.seqSent = constructedMsg.seq; 85 | 86 | return res; 87 | } 88 | 89 | constructor(props: SessionConnectedProps) { 90 | super(props); 91 | this.conn = props.conn; 92 | this.listeners = props.listeners; 93 | 94 | this.conn.setDataListener(this.onMessageData); 95 | this.conn.setCloseListener(this.listeners.onConnectionClosed); 96 | this.conn.setErrorListener(this.listeners.onConnectionErrored); 97 | } 98 | 99 | sendBufferedMessages(): SendBufferResult { 100 | // send any buffered messages 101 | // dont explicity clear the buffer, we'll just filter out old messages 102 | // when we receive an ack 103 | if (this.sendBuffer.length > 0) { 104 | this.log?.info( 105 | `sending ${ 106 | this.sendBuffer.length 107 | } buffered messages, starting at seq ${this.nextSeq()}`, 108 | this.loggingMetadata, 109 | ); 110 | 111 | for (const msg of this.sendBuffer) { 112 | this.assertSendOrdering(msg); 113 | const res = sendMessage(this.conn, this.codec, msg); 114 | if (!res.ok) { 115 | this.listeners.onMessageSendFailure(msg, res.reason); 116 | 117 | return res; 118 | } 119 | 120 | this.seqSent = msg.seq; 121 | } 122 | } 123 | 124 | return { ok: true, value: undefined }; 125 | } 126 | 127 | get loggingMetadata() { 128 | return { 129 | ...super.loggingMetadata, 130 | ...this.conn.loggingMetadata, 131 | }; 132 | } 133 | 134 | startMissingHeartbeatTimeout() { 135 | const maxMisses = this.options.heartbeatsUntilDead; 136 | const missDuration = maxMisses * this.options.heartbeatIntervalMs; 137 | this.heartbeatMissTimeout = setTimeout(() => { 138 | this.log?.info( 139 | `closing connection to ${this.to} due to inactivity (missed ${maxMisses} heartbeats which is ${missDuration}ms)`, 140 | this.loggingMetadata, 141 | ); 142 | this.telemetry.span.addEvent( 143 | 'closing connection due to missing heartbeat', 144 | ); 145 | 146 | this.conn.close(); 147 | }, missDuration); 148 | } 149 | 150 | startActiveHeartbeat() { 151 | this.isActivelyHeartbeating = true; 152 | this.heartbeatHandle = setInterval(() => { 153 | this.sendHeartbeat(); 154 | }, this.options.heartbeatIntervalMs); 155 | } 156 | 157 | private sendHeartbeat(): void { 158 | this.log?.debug('sending heartbeat', this.loggingMetadata); 159 | const heartbeat = { 160 | streamId: 'heartbeat', 161 | controlFlags: ControlFlags.AckBit, 162 | payload: { 163 | type: 'ACK', 164 | } satisfies Static, 165 | } satisfies PartialTransportMessage; 166 | 167 | this.send(heartbeat); 168 | } 169 | 170 | onMessageData = (msg: Uint8Array) => { 171 | const parsedMsgRes = this.codec.fromBuffer(msg); 172 | if (!parsedMsgRes.ok) { 173 | this.listeners.onInvalidMessage( 174 | `could not parse message: ${parsedMsgRes.reason}`, 175 | ); 176 | 177 | return; 178 | } 179 | 180 | const parsedMsg = parsedMsgRes.value; 181 | 182 | // check message ordering here 183 | if (parsedMsg.seq !== this.ack) { 184 | if (parsedMsg.seq < this.ack) { 185 | this.log?.debug( 186 | `received duplicate msg (got seq: ${parsedMsg.seq}, wanted seq: ${this.ack}), discarding`, 187 | { 188 | ...this.loggingMetadata, 189 | transportMessage: parsedMsg, 190 | }, 191 | ); 192 | } else { 193 | const reason = `received out-of-order msg, closing connection (got seq: ${parsedMsg.seq}, wanted seq: ${this.ack})`; 194 | this.log?.error(reason, { 195 | ...this.loggingMetadata, 196 | transportMessage: parsedMsg, 197 | tags: ['invariant-violation'], 198 | }); 199 | 200 | this.telemetry.span.setStatus({ 201 | code: SpanStatusCode.ERROR, 202 | message: reason, 203 | }); 204 | 205 | // try to recover by closing the connection and re-handshaking 206 | // with the session intact 207 | this.conn.close(); 208 | } 209 | 210 | return; 211 | } 212 | 213 | // message is ok to update bookkeeping with 214 | this.log?.debug(`received msg`, { 215 | ...this.loggingMetadata, 216 | transportMessage: parsedMsg, 217 | }); 218 | 219 | this.updateBookkeeping(parsedMsg.ack, parsedMsg.seq); 220 | 221 | // dispatch directly if its not an explicit ack 222 | if (!isAck(parsedMsg.controlFlags)) { 223 | this.listeners.onMessage(parsedMsg); 224 | 225 | return; 226 | } 227 | 228 | // discard acks (unless we aren't heartbeating in which case just respond) 229 | this.log?.debug(`discarding msg (ack bit set)`, { 230 | ...this.loggingMetadata, 231 | transportMessage: parsedMsg, 232 | }); 233 | 234 | // if we are not actively heartbeating, we are in passive 235 | // heartbeat mode and should send a response to the ack 236 | if (!this.isActivelyHeartbeating) { 237 | this.sendHeartbeat(); 238 | } 239 | }; 240 | 241 | _handleStateExit(): void { 242 | super._handleStateExit(); 243 | this.conn.removeDataListener(); 244 | this.conn.removeCloseListener(); 245 | this.conn.removeErrorListener(); 246 | 247 | if (this.heartbeatHandle) { 248 | clearInterval(this.heartbeatHandle); 249 | this.heartbeatHandle = undefined; 250 | } 251 | 252 | if (this.heartbeatMissTimeout) { 253 | clearTimeout(this.heartbeatMissTimeout); 254 | this.heartbeatMissTimeout = undefined; 255 | } 256 | } 257 | 258 | _handleClose(): void { 259 | super._handleClose(); 260 | this.conn.close(); 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionConnecting.ts: -------------------------------------------------------------------------------- 1 | import { Connection } from '../connection'; 2 | import { 3 | IdentifiedSessionWithGracePeriod, 4 | IdentifiedSessionWithGracePeriodListeners, 5 | IdentifiedSessionWithGracePeriodProps, 6 | SessionState, 7 | } from './common'; 8 | 9 | export interface SessionConnectingListeners 10 | extends IdentifiedSessionWithGracePeriodListeners { 11 | onConnectionEstablished: (conn: Connection) => void; 12 | onConnectionFailed: (err: unknown) => void; 13 | 14 | // timeout related 15 | onConnectionTimeout: () => void; 16 | } 17 | 18 | export interface SessionConnectingProps 19 | extends IdentifiedSessionWithGracePeriodProps { 20 | connPromise: Promise; 21 | listeners: SessionConnectingListeners; 22 | } 23 | 24 | /* 25 | * A session that is connecting but we don't have access to the raw connection yet. 26 | * See transitions.ts for valid transitions. 27 | */ 28 | export class SessionConnecting< 29 | ConnType extends Connection, 30 | > extends IdentifiedSessionWithGracePeriod { 31 | readonly state = SessionState.Connecting as const; 32 | connPromise: Promise; 33 | listeners: SessionConnectingListeners; 34 | 35 | connectionTimeout?: ReturnType; 36 | 37 | constructor(props: SessionConnectingProps) { 38 | super(props); 39 | this.connPromise = props.connPromise; 40 | this.listeners = props.listeners; 41 | 42 | this.connPromise.then( 43 | (conn) => { 44 | if (this._isConsumed) return; 45 | this.listeners.onConnectionEstablished(conn); 46 | }, 47 | (err) => { 48 | if (this._isConsumed) return; 49 | this.listeners.onConnectionFailed(err); 50 | }, 51 | ); 52 | 53 | this.connectionTimeout = setTimeout(() => { 54 | this.listeners.onConnectionTimeout(); 55 | }, this.options.connectionTimeoutMs); 56 | } 57 | 58 | // close a pending connection if it resolves, ignore errors if the promise 59 | // ends up rejected anyways 60 | bestEffortClose() { 61 | // these can technically be stale if the connPromise resolves after the 62 | // state has transitioned, but that's fine, this is best effort anyways 63 | // we pull these out so even if the state has transitioned, we can still log 64 | // without erroring out 65 | const logger = this.log; 66 | const metadata = this.loggingMetadata; 67 | 68 | this.connPromise 69 | .then((conn) => { 70 | conn.close(); 71 | logger?.info( 72 | 'connection eventually resolved but session has transitioned, closed connection', 73 | { 74 | ...metadata, 75 | ...conn.loggingMetadata, 76 | }, 77 | ); 78 | }) 79 | .catch(() => { 80 | // ignore errors 81 | }); 82 | } 83 | 84 | _handleStateExit(): void { 85 | super._handleStateExit(); 86 | if (this.connectionTimeout) { 87 | clearTimeout(this.connectionTimeout); 88 | this.connectionTimeout = undefined; 89 | } 90 | } 91 | 92 | _handleClose(): void { 93 | super._handleClose(); 94 | 95 | // close the pending connection if it resolves 96 | this.bestEffortClose(); 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionHandshaking.ts: -------------------------------------------------------------------------------- 1 | import { Static } from '@sinclair/typebox'; 2 | import { Connection } from '../connection'; 3 | import { 4 | OpaqueTransportMessage, 5 | TransportMessage, 6 | HandshakeErrorResponseCodes, 7 | } from '../message'; 8 | import { 9 | IdentifiedSessionWithGracePeriod, 10 | IdentifiedSessionWithGracePeriodListeners, 11 | IdentifiedSessionWithGracePeriodProps, 12 | sendMessage, 13 | SessionState, 14 | } from './common'; 15 | import { SendResult } from '../results'; 16 | 17 | export interface SessionHandshakingListeners 18 | extends IdentifiedSessionWithGracePeriodListeners { 19 | onConnectionErrored: (err: unknown) => void; 20 | onConnectionClosed: () => void; 21 | onHandshake: (msg: OpaqueTransportMessage) => void; 22 | onInvalidHandshake: ( 23 | reason: string, 24 | code: Static, 25 | ) => void; 26 | 27 | // timeout related 28 | onHandshakeTimeout: () => void; 29 | } 30 | 31 | export interface SessionHandshakingProps 32 | extends IdentifiedSessionWithGracePeriodProps { 33 | conn: ConnType; 34 | listeners: SessionHandshakingListeners; 35 | } 36 | 37 | /* 38 | * A session that is handshaking and waiting for the other side to identify itself. 39 | * See transitions.ts for valid transitions. 40 | */ 41 | export class SessionHandshaking< 42 | ConnType extends Connection, 43 | > extends IdentifiedSessionWithGracePeriod { 44 | readonly state = SessionState.Handshaking as const; 45 | conn: ConnType; 46 | listeners: SessionHandshakingListeners; 47 | 48 | handshakeTimeout?: ReturnType; 49 | 50 | constructor(props: SessionHandshakingProps) { 51 | super(props); 52 | this.conn = props.conn; 53 | this.listeners = props.listeners; 54 | 55 | this.handshakeTimeout = setTimeout(() => { 56 | this.listeners.onHandshakeTimeout(); 57 | }, this.options.handshakeTimeoutMs); 58 | 59 | this.conn.setDataListener(this.onHandshakeData); 60 | this.conn.setErrorListener(this.listeners.onConnectionErrored); 61 | this.conn.setCloseListener(this.listeners.onConnectionClosed); 62 | } 63 | 64 | get loggingMetadata() { 65 | return { 66 | ...super.loggingMetadata, 67 | ...this.conn.loggingMetadata, 68 | }; 69 | } 70 | 71 | onHandshakeData = (msg: Uint8Array) => { 72 | const parsedMsgRes = this.codec.fromBuffer(msg); 73 | if (!parsedMsgRes.ok) { 74 | this.listeners.onInvalidHandshake( 75 | `could not parse handshake message: ${parsedMsgRes.reason}`, 76 | 'MALFORMED_HANDSHAKE', 77 | ); 78 | 79 | return; 80 | } 81 | 82 | this.listeners.onHandshake(parsedMsgRes.value); 83 | }; 84 | 85 | sendHandshake(msg: TransportMessage): SendResult { 86 | return sendMessage(this.conn, this.codec, msg); 87 | } 88 | 89 | _handleStateExit(): void { 90 | super._handleStateExit(); 91 | this.conn.removeDataListener(); 92 | this.conn.removeErrorListener(); 93 | this.conn.removeCloseListener(); 94 | 95 | if (this.handshakeTimeout) { 96 | clearTimeout(this.handshakeTimeout); 97 | this.handshakeTimeout = undefined; 98 | } 99 | } 100 | 101 | _handleClose(): void { 102 | super._handleClose(); 103 | this.conn.close(); 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionNoConnection.ts: -------------------------------------------------------------------------------- 1 | import { 2 | IdentifiedSessionWithGracePeriod, 3 | IdentifiedSessionWithGracePeriodListeners, 4 | IdentifiedSessionWithGracePeriodProps, 5 | SessionState, 6 | } from './common'; 7 | 8 | export type SessionNoConnectionListeners = 9 | IdentifiedSessionWithGracePeriodListeners; 10 | 11 | export type SessionNoConnectionProps = IdentifiedSessionWithGracePeriodProps; 12 | 13 | /* 14 | * A session that is not connected and cannot send or receive messages. 15 | * See transitions.ts for valid transitions. 16 | */ 17 | export class SessionNoConnection extends IdentifiedSessionWithGracePeriod { 18 | readonly state = SessionState.NoConnection as const; 19 | 20 | _handleClose(): void { 21 | super._handleClose(); 22 | } 23 | 24 | _handleStateExit(): void { 25 | super._handleStateExit(); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/SessionWaitingForHandshake.ts: -------------------------------------------------------------------------------- 1 | import { Static } from '@sinclair/typebox'; 2 | import { Connection } from '../connection'; 3 | import { 4 | HandshakeErrorResponseCodes, 5 | OpaqueTransportMessage, 6 | TransportMessage, 7 | } from '../message'; 8 | import { 9 | CommonSession, 10 | CommonSessionProps, 11 | sendMessage, 12 | SessionState, 13 | } from './common'; 14 | import { SendResult } from '../results'; 15 | 16 | export interface SessionWaitingForHandshakeListeners { 17 | onConnectionErrored: (err: unknown) => void; 18 | onConnectionClosed: () => void; 19 | onHandshake: (msg: OpaqueTransportMessage) => void; 20 | onInvalidHandshake: ( 21 | reason: string, 22 | code: Static, 23 | ) => void; 24 | 25 | // timeout related 26 | onHandshakeTimeout: () => void; 27 | } 28 | 29 | export interface SessionWaitingForHandshakeProps 30 | extends CommonSessionProps { 31 | conn: ConnType; 32 | listeners: SessionWaitingForHandshakeListeners; 33 | } 34 | 35 | /* 36 | * Server-side session that has a connection but is waiting for the client to identify itself. 37 | * See transitions.ts for valid transitions. 38 | */ 39 | export class SessionWaitingForHandshake< 40 | ConnType extends Connection, 41 | > extends CommonSession { 42 | readonly state = SessionState.WaitingForHandshake as const; 43 | conn: ConnType; 44 | listeners: SessionWaitingForHandshakeListeners; 45 | 46 | handshakeTimeout?: ReturnType; 47 | 48 | constructor(props: SessionWaitingForHandshakeProps) { 49 | super(props); 50 | this.conn = props.conn; 51 | this.listeners = props.listeners; 52 | 53 | this.handshakeTimeout = setTimeout(() => { 54 | this.listeners.onHandshakeTimeout(); 55 | }, this.options.handshakeTimeoutMs); 56 | 57 | this.conn.setDataListener(this.onHandshakeData); 58 | this.conn.setErrorListener(this.listeners.onConnectionErrored); 59 | this.conn.setCloseListener(this.listeners.onConnectionClosed); 60 | } 61 | 62 | get loggingMetadata() { 63 | return { 64 | clientId: this.from, 65 | connId: this.conn.id, 66 | ...this.conn.loggingMetadata, 67 | }; 68 | } 69 | 70 | onHandshakeData = (msg: Uint8Array) => { 71 | const parsedMsgRes = this.codec.fromBuffer(msg); 72 | if (!parsedMsgRes.ok) { 73 | this.listeners.onInvalidHandshake( 74 | `could not parse handshake message: ${parsedMsgRes.reason}`, 75 | 'MALFORMED_HANDSHAKE', 76 | ); 77 | 78 | return; 79 | } 80 | 81 | // after this fires, the listener is responsible for transitioning the session 82 | // and thus removing the handshake timeout 83 | this.listeners.onHandshake(parsedMsgRes.value); 84 | }; 85 | 86 | sendHandshake(msg: TransportMessage): SendResult { 87 | return sendMessage(this.conn, this.codec, msg); 88 | } 89 | 90 | _handleStateExit(): void { 91 | this.conn.removeDataListener(); 92 | this.conn.removeErrorListener(); 93 | this.conn.removeCloseListener(); 94 | clearTimeout(this.handshakeTimeout); 95 | this.handshakeTimeout = undefined; 96 | } 97 | 98 | _handleClose(): void { 99 | this.conn.close(); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/common.ts: -------------------------------------------------------------------------------- 1 | import { Logger, MessageMetadata } from '../../logging'; 2 | import { TelemetryInfo } from '../../tracing'; 3 | import { 4 | OpaqueTransportMessage, 5 | PartialTransportMessage, 6 | ProtocolVersion, 7 | TransportClientId, 8 | TransportMessage, 9 | } from '../message'; 10 | import { Codec, CodecMessageAdapter } from '../../codec'; 11 | import { generateId } from '../id'; 12 | import { Tracer } from '@opentelemetry/api'; 13 | import { SendResult } from '../results'; 14 | import { Connection } from '../connection'; 15 | 16 | export const enum SessionState { 17 | NoConnection = 'NoConnection', 18 | BackingOff = 'BackingOff', 19 | Connecting = 'Connecting', 20 | Handshaking = 'Handshaking', 21 | Connected = 'Connected', 22 | WaitingForHandshake = 'WaitingForHandshake', 23 | } 24 | 25 | export const ERR_CONSUMED = `session state has been consumed and is no longer valid`; 26 | 27 | abstract class StateMachineState { 28 | abstract readonly state: SessionState; 29 | 30 | /* 31 | * Whether this state has been consumed 32 | * and we've moved on to another state 33 | */ 34 | _isConsumed: boolean; 35 | 36 | // called when we're transitioning to another state 37 | // note that this is internal and should not be called directly 38 | // by consumers, the proxy will call this when the state is consumed 39 | // and we're transitioning to another state 40 | abstract _handleStateExit(): void; 41 | 42 | // called when we exit the state machine entirely 43 | // note that this is internal and should not be called directly 44 | // by consumers, the proxy will call this when .close is closed 45 | abstract _handleClose(): void; 46 | 47 | /** 48 | * Cleanup this state machine state and mark it as consumed. 49 | * After calling close, it is an error to access any properties on the state. 50 | * You should never need to call this as a consumer. 51 | * 52 | * If you're looking to close the session from the client, 53 | * use `.hardDisconnect` on the client transport. 54 | */ 55 | close(): void { 56 | this._handleClose(); 57 | } 58 | 59 | constructor() { 60 | this._isConsumed = false; 61 | 62 | // proxy helps us prevent access to properties after the state has been consumed 63 | // e.g. if we hold a reference to a state and try to access it after it's been consumed 64 | // we intercept the access and throw an error to help catch bugs 65 | return new Proxy(this, { 66 | get(target, prop) { 67 | // always allow access to _isConsumed, id, and state 68 | if (prop === '_isConsumed' || prop === 'id' || prop === 'state') { 69 | return Reflect.get(target, prop); 70 | } 71 | 72 | // modify _handleStateExit 73 | if (prop === '_handleStateExit') { 74 | return () => { 75 | target._isConsumed = true; 76 | target._handleStateExit(); 77 | }; 78 | } 79 | 80 | // modify _handleClose 81 | if (prop === '_handleClose') { 82 | return () => { 83 | // target is the non-proxied object, we need to set _isConsumed again 84 | target._isConsumed = true; 85 | target._handleStateExit(); 86 | target._handleClose(); 87 | }; 88 | } 89 | 90 | if (target._isConsumed) { 91 | throw new Error( 92 | `${ERR_CONSUMED}: getting ${prop.toString()} on consumed state`, 93 | ); 94 | } 95 | 96 | return Reflect.get(target, prop); 97 | }, 98 | set(target, prop, value) { 99 | if (target._isConsumed) { 100 | throw new Error( 101 | `${ERR_CONSUMED}: setting ${prop.toString()} on consumed state`, 102 | ); 103 | } 104 | 105 | return Reflect.set(target, prop, value); 106 | }, 107 | }); 108 | } 109 | } 110 | 111 | export interface SessionOptions { 112 | /** 113 | * Frequency at which to send heartbeat acknowledgements 114 | */ 115 | heartbeatIntervalMs: number; 116 | /** 117 | * Number of elapsed heartbeats without a response message before we consider 118 | * the connection dead. 119 | */ 120 | heartbeatsUntilDead: number; 121 | /** 122 | * Max duration that a session can be without a connection before we consider 123 | * it dead. This deadline is carried between states and is used to determine 124 | * when to consider the session a lost cause and delete it entirely. 125 | * Generally, this should be strictly greater than the sum of 126 | * {@link connectionTimeoutMs} and {@link handshakeTimeoutMs}. 127 | */ 128 | sessionDisconnectGraceMs: number; 129 | /** 130 | * Connection timeout in milliseconds 131 | */ 132 | connectionTimeoutMs: number; 133 | /** 134 | * Handshake timeout in milliseconds 135 | */ 136 | handshakeTimeoutMs: number; 137 | /** 138 | * Whether to enable transparent session reconnects 139 | */ 140 | enableTransparentSessionReconnects: boolean; 141 | /** 142 | * The codec to use for encoding/decoding messages over the wire 143 | */ 144 | codec: Codec; 145 | } 146 | 147 | // all session states have a from and options 148 | export interface CommonSessionProps { 149 | from: TransportClientId; 150 | options: SessionOptions; 151 | codec: CodecMessageAdapter; 152 | tracer: Tracer; 153 | log: Logger | undefined; 154 | } 155 | 156 | export abstract class CommonSession extends StateMachineState { 157 | readonly from: TransportClientId; 158 | readonly options: SessionOptions; 159 | 160 | readonly codec: CodecMessageAdapter; 161 | tracer: Tracer; 162 | log?: Logger; 163 | abstract get loggingMetadata(): MessageMetadata; 164 | 165 | constructor({ from, options, log, tracer, codec }: CommonSessionProps) { 166 | super(); 167 | this.from = from; 168 | this.options = options; 169 | this.log = log; 170 | this.tracer = tracer; 171 | this.codec = codec; 172 | } 173 | } 174 | 175 | export type InheritedProperties = Pick< 176 | IdentifiedSession, 177 | 'id' | 'from' | 'to' | 'seq' | 'ack' | 'sendBuffer' | 'telemetry' | 'options' 178 | >; 179 | 180 | export type SessionId = string; 181 | 182 | // all sessions where we know the other side's client id 183 | export interface IdentifiedSessionProps extends CommonSessionProps { 184 | id: SessionId; 185 | to: TransportClientId; 186 | seq: number; 187 | ack: number; 188 | seqSent: number; 189 | sendBuffer: Array; 190 | telemetry: TelemetryInfo; 191 | protocolVersion: ProtocolVersion; 192 | } 193 | 194 | export abstract class IdentifiedSession extends CommonSession { 195 | readonly id: SessionId; 196 | readonly telemetry: TelemetryInfo; 197 | readonly to: TransportClientId; 198 | readonly protocolVersion: ProtocolVersion; 199 | 200 | /** 201 | * Index of the message we will send next (excluding handshake) 202 | */ 203 | seq: number; 204 | 205 | /** 206 | * Last seq we sent over the wire this session (excluding handshake) and retransmissions 207 | */ 208 | seqSent: number; 209 | 210 | /** 211 | * Number of unique messages we've received this session (excluding handshake) 212 | */ 213 | ack: number; 214 | sendBuffer: Array; 215 | 216 | constructor(props: IdentifiedSessionProps) { 217 | const { 218 | id, 219 | to, 220 | seq, 221 | ack, 222 | sendBuffer, 223 | telemetry, 224 | log, 225 | protocolVersion, 226 | seqSent: messagesSent, 227 | } = props; 228 | super(props); 229 | this.id = id; 230 | this.to = to; 231 | this.seq = seq; 232 | this.ack = ack; 233 | this.sendBuffer = sendBuffer; 234 | this.telemetry = telemetry; 235 | this.log = log; 236 | this.protocolVersion = protocolVersion; 237 | this.seqSent = messagesSent; 238 | } 239 | 240 | get loggingMetadata(): MessageMetadata { 241 | const metadata: MessageMetadata = { 242 | clientId: this.from, 243 | connectedTo: this.to, 244 | sessionId: this.id, 245 | }; 246 | 247 | if (this.telemetry.span.isRecording()) { 248 | const spanContext = this.telemetry.span.spanContext(); 249 | metadata.telemetry = { 250 | traceId: spanContext.traceId, 251 | spanId: spanContext.spanId, 252 | }; 253 | } 254 | 255 | return metadata; 256 | } 257 | 258 | constructMsg( 259 | partialMsg: PartialTransportMessage, 260 | ): TransportMessage { 261 | const msg = { 262 | ...partialMsg, 263 | id: generateId(), 264 | to: this.to, 265 | from: this.from, 266 | seq: this.seq, 267 | ack: this.ack, 268 | }; 269 | 270 | this.seq++; 271 | 272 | return msg; 273 | } 274 | 275 | nextSeq(): number { 276 | return this.sendBuffer.length > 0 ? this.sendBuffer[0].seq : this.seq; 277 | } 278 | 279 | send(msg: PartialTransportMessage): SendResult { 280 | const constructedMsg = this.constructMsg(msg); 281 | this.sendBuffer.push(constructedMsg); 282 | 283 | return { 284 | ok: true, 285 | value: constructedMsg.id, 286 | }; 287 | } 288 | 289 | _handleStateExit(): void { 290 | // noop 291 | } 292 | 293 | _handleClose(): void { 294 | // zero out the buffer 295 | this.sendBuffer.length = 0; 296 | this.telemetry.span.end(); 297 | } 298 | } 299 | 300 | export interface IdentifiedSessionWithGracePeriodListeners { 301 | onSessionGracePeriodElapsed: () => void; 302 | } 303 | 304 | export interface IdentifiedSessionWithGracePeriodProps 305 | extends IdentifiedSessionProps { 306 | graceExpiryTime: number; 307 | listeners: IdentifiedSessionWithGracePeriodListeners; 308 | } 309 | 310 | export abstract class IdentifiedSessionWithGracePeriod extends IdentifiedSession { 311 | graceExpiryTime: number; 312 | protected gracePeriodTimeout?: ReturnType; 313 | 314 | listeners: IdentifiedSessionWithGracePeriodListeners; 315 | 316 | constructor(props: IdentifiedSessionWithGracePeriodProps) { 317 | super(props); 318 | this.listeners = props.listeners; 319 | 320 | this.graceExpiryTime = props.graceExpiryTime; 321 | this.gracePeriodTimeout = setTimeout(() => { 322 | this.listeners.onSessionGracePeriodElapsed(); 323 | }, this.graceExpiryTime - Date.now()); 324 | } 325 | 326 | _handleStateExit(): void { 327 | super._handleStateExit(); 328 | 329 | if (this.gracePeriodTimeout) { 330 | clearTimeout(this.gracePeriodTimeout); 331 | this.gracePeriodTimeout = undefined; 332 | } 333 | } 334 | 335 | _handleClose(): void { 336 | super._handleClose(); 337 | } 338 | } 339 | 340 | export function sendMessage( 341 | conn: Connection, 342 | codec: CodecMessageAdapter, 343 | msg: TransportMessage, 344 | ): SendResult { 345 | const buff = codec.toBuffer(msg); 346 | if (!buff.ok) { 347 | return buff; 348 | } 349 | 350 | const sent = conn.send(buff.value); 351 | if (!sent) { 352 | return { 353 | ok: false, 354 | reason: 'failed to send message', 355 | }; 356 | } 357 | 358 | return { 359 | ok: true, 360 | value: msg.id, 361 | }; 362 | } 363 | -------------------------------------------------------------------------------- /transport/sessionStateMachine/index.ts: -------------------------------------------------------------------------------- 1 | export { SessionState } from './common'; 2 | export { type SessionWaitingForHandshake } from './SessionWaitingForHandshake'; 3 | export { type SessionConnecting } from './SessionConnecting'; 4 | export { type SessionNoConnection } from './SessionNoConnection'; 5 | export { type SessionHandshaking } from './SessionHandshaking'; 6 | export { type SessionConnected } from './SessionConnected'; 7 | export { 8 | ClientSessionStateGraph, 9 | ServerSessionStateGraph, 10 | type Session, 11 | } from './transitions'; 12 | -------------------------------------------------------------------------------- /transport/stringifyError.ts: -------------------------------------------------------------------------------- 1 | export function coerceErrorString(err: unknown): string { 2 | if (err instanceof Error) { 3 | return err.message || 'unknown reason'; 4 | } 5 | 6 | return `[coerced to error] ${String(err)}`; 7 | } 8 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "esnext", 4 | "module": "esnext", 5 | "lib": [], 6 | "declaration": true, 7 | "declarationDir": "dist", 8 | "declarationMap": true, 9 | "outDir": "./dist", 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "strictFunctionTypes": true, 14 | "strictBindCallApply": true, 15 | "strictPropertyInitialization": true, 16 | "noImplicitThis": true, 17 | "alwaysStrict": true, 18 | "noUnusedLocals": true, 19 | "noUnusedParameters": true, 20 | "noImplicitReturns": true, 21 | "resolveJsonModule": true, 22 | "moduleResolution": "bundler", 23 | "esModuleInterop": true, 24 | "skipLibCheck": true, 25 | "forceConsistentCasingInFileNames": true 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tsup.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'tsup'; 2 | 3 | export default defineConfig({ 4 | entry: [ 5 | 'router/index.ts', 6 | 'logging/index.ts', 7 | 'codec/index.ts', 8 | 'testUtil/index.ts', 9 | 'transport/index.ts', 10 | 'transport/impls/ws/client.ts', 11 | 'transport/impls/ws/server.ts', 12 | 'transport/impls/uds/client.ts', 13 | 'transport/impls/uds/server.ts', 14 | ], 15 | format: ['esm', 'cjs'], 16 | sourcemap: true, 17 | clean: true, 18 | dts: true, 19 | }); 20 | -------------------------------------------------------------------------------- /vitest.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vite'; 2 | import { configDefaults, coverageConfigDefaults } from 'vitest/config'; 3 | 4 | export default defineConfig({ 5 | test: { 6 | exclude: [...configDefaults.exclude, '**/.direnv/**'], 7 | coverage: { 8 | exclude: [...coverageConfigDefaults.exclude, '**/.direnv/**'], 9 | }, 10 | sequence: { 11 | hooks: 'stack', 12 | }, 13 | reporters: process.env.GITHUB_ACTIONS 14 | ? ['basic', 'github-actions', 'junit'] 15 | : ['default'], 16 | pool: 'forks', 17 | testTimeout: 1000, 18 | setupFiles: './__tests__/globalSetup.ts', 19 | }, 20 | }); 21 | --------------------------------------------------------------------------------