├── .github ├── FUNDING.yml └── workflows │ └── go.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── client.go ├── client_connection.go ├── client_test.go ├── cmd ├── complex │ ├── main.go │ └── metrics.go ├── complex_client │ └── main.go ├── helloworld │ └── main.go ├── helloworld_client │ └── client.go └── llm │ └── main.go ├── codecov.yml ├── event.go ├── event_test.go ├── go.mod ├── go.sum ├── internal ├── parser │ ├── chunk.go │ ├── chunk_test.go │ ├── field.go │ ├── field_parser.go │ ├── field_parser_test.go │ ├── parser.go │ ├── parser_test.go │ ├── split_func_test.go │ └── test_helpers_test.go └── tests │ ├── expect.go │ └── time.go ├── joe.go ├── joe_test.go ├── message.go ├── message_fields.go ├── message_fields_test.go ├── message_test.go ├── replay.go ├── replay_test.go ├── server.go ├── server_test.go ├── session.go └── session_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: tmaxmax -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | paths-ignore: 7 | - "**.md" 8 | - "cmd/**" 9 | pull_request: 10 | branches: [master] 11 | paths-ignore: 12 | - "**.md" 13 | - "cmd/**" 14 | 15 | jobs: 16 | lint: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | - uses: golangci/golangci-lint-action@v2 21 | test: 22 | name: Test (latest major) 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v3 26 | - uses: actions/setup-go@v4 27 | with: 28 | go-version: "~1.23.4" 29 | - name: Test 30 | run: go test -v -timeout=1s -coverprofile=coverage.txt -covermode=atomic ./... 31 | - name: Test (race) 32 | run: go test -v -timeout=1s -race ./... 33 | - name: Coverage 34 | uses: codecov/codecov-action@v3 35 | if: github.ref == 'refs/head/master' 36 | with: 37 | token: ${{ secrets.CODECOV_TOKEN }} 38 | files: ./coverage.txt 39 | test-old: 40 | name: Test (previous major) 41 | runs-on: ubuntu-latest 42 | steps: 43 | - uses: actions/checkout@v3 44 | - uses: actions/setup-go@v4 45 | with: 46 | go-version: "~1.22.10" 47 | - name: Test 48 | run: go test -v -timeout=1s ./... 49 | - name: Test (race) 50 | run: go test -v -timeout=1s -race ./... 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - errcheck 4 | - gosimple 5 | - govet 6 | - ineffassign 7 | - staticcheck 8 | - typecheck 9 | - dogsled 10 | - dupl 11 | - errorlint 12 | - exhaustive 13 | - nestif 14 | - goconst 15 | - gocritic 16 | - gocyclo 17 | - godot 18 | - godox 19 | - gofmt 20 | - gofumpt 21 | - goheader 22 | - goimports 23 | - gomoddirectives 24 | - gomodguard 25 | - gosec 26 | - importas 27 | - makezero 28 | - misspell 29 | - prealloc 30 | - promlinter 31 | - predeclared 32 | - nolintlint 33 | - revive 34 | - stylecheck 35 | - tagliatelle 36 | - thelper 37 | - unparam 38 | - unused 39 | - whitespace 40 | linters-settings: 41 | gosec: 42 | excludes: 43 | - G404 44 | gocritic: 45 | disabled-checks: 46 | - ifElseChain 47 | - unnamedResult 48 | - hugeParam 49 | enabled-tags: 50 | - performance 51 | - diagnostic 52 | - experimental 53 | - opinionated 54 | nestif: 55 | min-complexity: 8 56 | govet: 57 | enable: 58 | - fieldalignment 59 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | This file tracks changes to this project. It follows the [Keep a Changelog format](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 4 | 5 | ## [0.11.0] - 2025-05-14 6 | 7 | The `sse.Server` logging and session handling were revamped to have more familiar, more flexible and less error prone interfaces for users. 8 | 9 | ### Removed 10 | 11 | - `Logger` and `LogLevel` enum have been removed. `Server.Logger` has transitioned to the standard `slog` library for better compatibility with the ecosystem 12 | 13 | ### Changed 14 | 15 | - `Server.Logger` is now of type `func(r *http.Request) *slog.Logger` instead of `sse.Logger` – it is possible to customize the logger on a per-request basis, by for example retrieving it from the context. 16 | - `Server.OnSession` signature changed from `func(s *Session) (Subscription, bool)` to `func(w http.ResponseWriter, r *http.Request) (topics []string, accepted bool)` – its initial role was to essentially just provide the topics, so the need to fiddle with `Session` and `Subscription` was redundant anyway 17 | - `Joe.Subscribe` now always returns `ErrProviderClosed` when a `Joe` instance is closed while subscriptions are active. Previously it would return it only if `Joe` was already shut down before subscribing. 18 | - `Joe` will print a stack trace for `Replayer` panics. 19 | 20 | ### Fixed 21 | 22 | - `sse.Session` doesn't write the header explicitly anymore. This would cause a `http: superfluous response.WriteHeader call` warning being logged when `sse.Server.OnSession` writes a response code itself when accepting a session. The change was initially introduced to remove the warning for users of certain external libraries (see #41) but this is the issue of the external library, not of `go-sse`. If you encounter this warning when using an external library, write the response code yourself in the HTTP handler before subscribing the `sse.Session`, as described in the linked discussion. 23 | - An insidious synchronization issue in `Joe` causing a channel double close in an edge case scenario (see #50, see code for details) 24 | 25 | ## [0.10.0] - 2024-12-29 26 | 27 | If you're working with LLMs in Go this update will make you happy! `sse.Read` is now a thing – it just parses all events from an `io.Reader`. Use it with your response bodies and forget about any `sse.Client` configuration. It also makes use of the new Go 1.23 iterators to keep your code neat and tidy. 28 | 29 | ### Added 30 | 31 | - `Read` and `ReadConfig` 32 | 33 | ## [0.9.0] - 2024-12-26 34 | 35 | This is the replayer update. Oh, what is a "replayer"? It's how we call replay providers starting with this version! Anyway, besides renaming, this update removes many replaying bugs, improves performance, robustness and error handling and better defines expected behavior for `ReplayProviders`... err, `Replayers`. 36 | 37 | More such overhauls are planned. I'm leaving it up to you to guess which comes next – the server or the client? ;) 38 | 39 | ### Removed 40 | 41 | - `FiniteReplayer.{Count, AutoIDs}` – use the constructor instead. 42 | - `ValidReplayer.{TTL, AutoIDs}` – use the constructor instead. 43 | 44 | ### Changed 45 | 46 | - The `ReplayProvider` and related entities are renamed to just `Replayer`. `go-sse` strives to have a minimal and expressive API, and minimal and expressive names are an important step in that direction. The changelog will use the new names onwards. 47 | - Due to a change in the internal implementation, the `FiniteReplayer` is now able to replay events only if the event with the LastEventID provided by the client is still buffered. Previously if the LastEventID was that of the latest removed event, events would still be replayed. This detail added complexity to the implementation without an apparent significant win, so it was dropped. 48 | - `FiniteReplayer.GCInterval` should be set to `0` now in order to disable GC. 49 | - Automatic ID generation for both replayers does not overwrite already existing message IDs and errors instead. Ensure that your events do not have IDs when using replayers configured to generate IDs. 50 | - `Replayer.Put` now returns an error instead of being required to panic. Read the method documentation for more info. `Joe` also propagates this error through `Joe.Publish`. 51 | - Replayers are now required to not overwrite message IDs and return errors instead. Sending unsupported messages to replayers is a bug which should not go unnoticed. Both replayers in this library now implement this behavior. 52 | - `Joe` does not log replayer panics to the console anymore. Handle these panics inside the replay provider itself. 53 | 54 | ### Added 55 | 56 | - `NewFiniteReplayer` constructor 57 | - `NewValidReplayer` constructor 58 | - `Connection.Buffer` 59 | 60 | ### Fixed 61 | 62 | - `FiniteReplayer` doesn't leak memory anymore and respects the stored messages count it was given. Previously when a new message was put after the messages count was reached and some other messages were removed, the total messages count would grow unexpectedly and `FiniteReplayer` would store and replay more events than it was configured to. 63 | - `ValidReplayer` was also susceptible to a similar memory leak, which is also fixed now. 64 | - #41 – `sse.Session` now writes the header explicitly when upgrading. 65 | 66 | ## [0.8.0] - 2024-01-30 67 | 68 | This version removes all external dependencies of `go-sse`. All our bugs are belong to us! It also does some API and documentation cleanups. 69 | 70 | ### Removed 71 | 72 | - `Client.DefaultReconnectionTime`, `Client.MaxRetries` have been replaced with the new `Client.Backoff` configuration field. See the Added section for more info. 73 | - `ErrReplayFailed` is removed from the public API. 74 | - `ReplayProviderWithGC` and `Joe.ReplayGCInterval` are no more. The responsibility for garbage collection is assigned to the replay providers. 75 | 76 | ### Changed 77 | 78 | - `Server.Logger` is now of a new type: the `Logger` interface. The dependency on x/exp/slog is removed. This opens up the possibility to adapt any existing logger to be usable with `Server`. 79 | - The default backoff behavior has changed. The _previous_ defaults map to the new `Backoff` configuration as follows: 80 | ```go 81 | sse.Backoff{ 82 | InitialInterval: 5 * time.Second, // currently 500ms 83 | Multiplier: 1.5, // currently the same 84 | Jitter: 0.5, // currently the same 85 | MaxInterval: 60 * time.Second, // currently unbounded 86 | MaxElapsedDuration: 15 * time.Minute, // currently unbounded 87 | MaxRetries: -1, // previously no retries by default, currently unbounded 88 | } 89 | ``` 90 | - `Joe` now accepts new subscriptions even if replay providers panic (previously `ErrReplayFailed` would be returned). 91 | - `Server.ServeHTTP` panics if a custom `OnSession` handler returns a `Subscription` with 0 topics 92 | 93 | ### Added 94 | 95 | - The `Logger` interface, `LogLevel` type, and `LogLevel(Info|Warn|Error)` values. 96 | - `Backoff` and `Client.Backoff` – the backoff strategy is now fully configurable. See the code documentation for info. 97 | - `ValidReplayProvider.GCInterval`, to configure at which interval expired events should be cleaned up. 98 | 99 | ## [0.7.0] - 2023-11-19 100 | 101 | This version overhauls connection retry and fixes the connection event dispatch order issue. Some internal changes to Joe were also made, which makes it faster and more resilient. 102 | 103 | ### Removed 104 | 105 | - `ConnectionError.Temporary` 106 | - `ConnectionError.Timeout` 107 | 108 | ### Changed 109 | 110 | - Go's `Timeout` and `Temporary` interfaces are not used anymore – the client makes no assumptions and retries on every network or response read error. The only cases when `Connection.Connect` returns now are either when there are no more retries left (when the number is not infinite), or when the request context was cancelled. 111 | - `*url.Error`s that occur on the HTTP request are now unwrapped and their cause is put inside a `ConnectionError`. 112 | - `Connection.Connect` doesn't suppress any errors anymore: the request context errors are returned as is, all other errors are wrapped inside `ConnectionError`. 113 | - On reconnection attempt, the response reset error is now wrapped inside `ConnectionError`. With this change, all errors other than the context errors are wrapped inside `ConnectionError`. 114 | - Subscription callbacks are no longer called in individual goroutines. This caused messages to be received in an indeterminate order. Make sure that your callbacks do not block for too long! 115 | 116 | ### Changed 117 | 118 | - If a `ReplayProvider` method panics when called by `Joe`, instead of closing itself completely it just stops replaying, putting or GC-ing messages to upcoming clients. `Joe` continues to function as if no replay provider was given. A stack trace is printed to stderr when such a panic occurs. 119 | 120 | ## [0.6.0] - 2023-07-22 121 | 122 | This version brings a number of refactors to the server-side tooling the library offers. Constructors and construction related types are removed, for ease of use and reduced API size, concerns regarding topics and expiry were separated from `Message`, logging of the `Server` is upgraded to structured logging and messages can be now published to multiple topics at once. Request upgrading has also been refactored to provide a more functional API, and the `Server` logic can now be customized without having to create a distinct handler. 123 | 124 | ### Removed 125 | 126 | - `Message.ExpiresAt` is no more. 127 | - `Message.Topic` is no more. See the changes to `Server`, `Provider` and `ReplayProvider` for handling topics – you can now publish a message to multiple topics at once. 128 | - `Message.Writer` is no more. The API was redundant – one can achieve the same using `strings.Builder` and `Message.AppendData`. See the `MessageWriter` example for more. 129 | - `NewValidReplayProvider` is no more. 130 | - `NewFiniteReplayProvider` is no more. 131 | - `NewJoe` is no more. 132 | - `JoeConfig` is no more. 133 | - `Server.Subscribe` is no more – it never made sense. 134 | - `Server.Provider` is no more. 135 | - `NewServer`, `ServerOption` and friends are no more. 136 | - The `Logger` interface and the capability of the `Server` to use types that implement `Logger` as logging systems is removed. 137 | - `SubscriptionCallback` is no more (see the change to the `Subscription` type in the "Changed" section). 138 | 139 | ### Added 140 | 141 | - Because the `ValidReplayProvider` constructor was removed, the fields `ValidReplayProvider.{TTL,AutoIDs}` were added for configuration. 142 | - Because the `FiniteReplayProvider` constructor was removed, the fields `FiniteReplayProvider.{Count,AutoIDs}` were added for configuration. 143 | - Because the `Joe` constructor was removed, the fields `Joe.{ReplayProvider,ReplayGCInterval}` were added for configuration. 144 | - Because the `Server` constructor was removed, the field `Server.Provider` was added for configuration. 145 | - New `MessageWriter` interface; used by providers to send messages and implemented by `Session` (previously named `Request`). 146 | - New `ResponseWriter` interface, which is a `http.ResponseWriter` augmented with a `Flush` method. 147 | - `ValidReplayProvider` has a new field `Now` which allows providing a custom current time getter, like `time.Now`, to the provider. Enables deterministic testing of dependents on `ValidReplayProvider`. 148 | - New `Server.OnSession` field, which enables customization of `Server`'s response and subscriptions. 149 | - New `Server.Logger` field, which enables structured logging with logger retrieved from the request and customizable config of logged information. 150 | 151 | ### Changed 152 | 153 | - `ReplayProvider.Put` takes a simple `*Message` and returns a `*Message`, instead of changing the `*Message` to which the `**Message` parameter points. 154 | It also takes a slice of topics, given that the `Message` doesn't hold the topic itself anymore. If the Message cannot be put, the method must now panic – see documentation for info. 155 | - Because `Message.ExpiresAt` is removed, the `ValidReplayProvider` sets the expiry itself. 156 | - `Server.Publish` now takes a list of topics. 157 | - `Provider.Publish` now takes a non-empty slice of topics. 158 | - `ReplayProvider.Put` now takes a non-empty slice of topics. 159 | - `Provider.Stop` is now `Provider.Shutdown` and takes now a `context.Context` as a parameter. 160 | - `Server.Shutdown` takes now a `context.Context` as a parameter. 161 | - `Request` is now named `Session` and exposes the HTTP request, response writer, and the last event ID of the request. 162 | - A new method `Flush` is added to `Session`; messages are no longer flushed by default, which allows providers, replay providers to batch send messages. 163 | - `Upgrade` now takes an `*http.Request` as its second parameter. 164 | - `Subscription` now has a `Client` field of type `MessageWriter` instead of a `Callback`. 165 | - Given the `Subscription` change, `Provider.Subscribe` and `ReplayProvider.Replay` now report message sending errors. 166 | 167 | 168 | ## [0.5.2] - 2023-07-12 169 | 170 | ### Added 171 | 172 | - The new `Message.Writer` – write to the `Message` as if it is an `io.Writer`. 173 | 174 | ### Fixed 175 | 176 | - `Message.UnmarshalText` now strips the leading Unicode BOM, if it exists, as per the specification. 177 | - When parsing events client-side, BOM removal was attempted on each event input. Now the BOM is correctly removed only when parsing is started. 178 | 179 | ## [0.5.1] - 2023-07-12 180 | 181 | ### Fixed 182 | 183 | - `Message.WriteTo` now writes nothing if `Message` is empty. 184 | - `Message.WriteTo` does not attempt to write the `retry` field if `Message.Retry` is not at least 1ms. 185 | - `NewType` error message is updated to say "event type", not "event name". 186 | 187 | ## [0.5.0] - 2023-07-11 188 | 189 | This version comes with a series of internal refactorings that improve code readability and performance. It also replaces usage of `[]byte` for event data with `string` – SSE is a UTF-8 encoded text-based protocol, so raw bytes never made sense. This migration improves code safety (less `unsafe` usage and less worry about ownership) and reduces the memory footprint of some objects. 190 | 191 | Creating events on the server is also revised – fields that required getters and setters, apart from `data` and comments, are now simple public fields on the `sse.Message` struct. 192 | 193 | Across the codebase, to refer to the value of the `event` field the name "event type" is used, which is the nomenclature used in the SSE specification. 194 | 195 | Documentation and examples were also fixed and improved. 196 | 197 | ### Added 198 | 199 | - New `sse.EventName` type, which holds valid values for the `event` field, together with constructors (`sse.Name` and `sse.NewName`). 200 | 201 | ### Removed 202 | 203 | - `sse.Message`: `AppendText` was removed, as part of the migration from byte slices to strings. SSE is a UTF-8 encoded text-based protocol – raw bytes never made sense. 204 | 205 | ### Changed 206 | 207 | - Minimum supported Go version was bumped from 1.16 to 1.19. From now on, the latest two major Go versions will be supported. 208 | - `sse.Message`: `AppendData` takes `string`s instead of `[]byte`. 209 | - `sse.Message`: `Comment` is now named `AppendComment`, for consistency with `AppendData`. 210 | - `sse.Message`: The message's expiration is not reset anymore by `UnmarshalText`. 211 | - `sse.Message`: `UnmarshalText` now unmarshals comments as well. 212 | - `sse.Message`: `WriteTo` (and `MarshalText` and `String` as a result) replaces all newline sequences in data with LF. 213 | - `sse.Message`: The `Expiry` getter and `SetExpiresAt`, `SetTTL` setters are replaced by the public field `ExpiresAt`. 214 | - `sse.Message`: Event ID getter and setter are replaced by the public `ID` field. 215 | - `sse.Message`: Event type (previously named `Name`) getter and setter are replaced by the public `Type` field. 216 | - `sse.Message`: The `retry` field value is now a public field on the struct. As a byproduct, `WriteTo` will now make 1 allocation when writing events with the `retry` field set. 217 | - `sse.NewEventID` is now `sse.NewID`, and `sse.MustEventID` is `sse.ID`. 218 | - `sse.Event`: The `Data` field is now of type `string`, not `[]byte`. 219 | - `sse.Event`: The `Name` field is now named `Type`. 220 | 221 | ### Fixed 222 | 223 | - `sse.Message`: `Clone` now copies the topic of the message to the new value. 224 | - `sse.Message`: ID fields that contain NUL characters are now ignored, as required by the spec, in `UnmarshalText`. 225 | 226 | ## [0.4.3] - 2023-07-08 227 | 228 | ### Fixed 229 | 230 | - Messages longer than 4096 bytes are no longer being dropped ([#2], thanks [@aldld]) 231 | - Event parsing no longer panics on empty field with colon after name, see [test case](https://github.com/tmaxmax/go-sse/blob/4938f99db3bf7a8f057cb3e21ca88df57db3c0e0/internal/parser/field_parser_test.go#L37-L45) for example ([#5]) 232 | 233 | ## [0.4.2] - 2021-10-17 234 | 235 | ### Added 236 | 237 | - Get the event name of a Message 238 | 239 | ## [0.4.1] - 2021-10-15 240 | 241 | ### Added 242 | 243 | - Set a custom logger for Server 244 | 245 | ## [0.4.0] - 2021-10-15 246 | 247 | ### Changed 248 | 249 | - Server does not set any other headers besides `Content-Type`. 250 | - UpgradedRequest does not return a SendError anymore when Write errors. 251 | - Providers don't handle callback errors anymore. Callbacks return a flag that indicates whether the provider should keep calling it for new messages instead. 252 | 253 | ### Fixed 254 | 255 | - Client's default response validator now ignores `Content-Type` parameters when checking if the response's content type is `text/event-stream`. 256 | - Various optimizations 257 | 258 | ## [0.3.0] - 2021-09-18 259 | 260 | ### Added 261 | 262 | - ReplayProviderWithGC interface, which must be satisfied by replay providers that must be cleaned up periodically. 263 | 264 | ### Changed 265 | 266 | - Subscriptions now take a callback function instead of a channel. 267 | - Server response headers are now sent on the first Send call, not when Upgrade is called. 268 | - Providers are not required to add the default topic anymore. Callers of Subscribe should ensure at least a topic is specified. 269 | - Providers' Subscribe method now blocks until the subscriber is removed. 270 | - Server's Subscribe method automatically adds the default topic if no topic is specified. 271 | - ReplayProvider does not require for GC to be implemented. 272 | - Client connections take callback functions instead of channels as event listeners. 273 | - Client connections' Unsubscribe methods are replaced by functions returned by their Subscribe counterparts. 274 | 275 | ### Fixed 276 | 277 | - Fix replay providers not replaying the oldest message if the ID provided is of the one before that one. 278 | - Fix replay providers hanging the caller's goroutine when a write error occurs using the default ServeHTTP implementation. 279 | - Fix providers hanging when a write error occurs using the default ServeHTTP implementation. 280 | 281 | ## [0.2.0] - 2021-09-13 282 | 283 | ### Added 284 | 285 | - Text/JSON marshalers and unmarshalers, and SQL scanners and valuers for the EventID type (previously event.ID). 286 | - Check for http.NoBody before resetting the request body on client reconnect. 287 | 288 | ### Changed 289 | 290 | - Package structure. The module is now refactored into a single package with an idiomatic name. This has resulted in various name changes: 291 | - `client.Error` - `sse.ConnectionError` 292 | - `event.Event` - `sse.Message` (previous `server.Message` is removed, see next change) 293 | - `event.ID` - `sse.EventID` 294 | - `event.NewID` - `sse.NewEventID` 295 | - `event.MustID` - `sse.MustEventID` 296 | - `server.Connection` - `sse.UpgradedRequest` 297 | - `server.NewConnection` - `sse.Upgrade` 298 | - `server.ErrUnsupported` - `sse.ErrUpgradeUnsupported` 299 | - `server.New` - `sse.NewServer`. 300 | - `event.Event` is merged with `server.Message`, becoming `sse.Message`. This affects the `sse.Server.Publish` function, which doesn't take a `topic` parameter anymore. 301 | - The server's constructor doesn't take an `Provider` as a parameter. It instead takes multiple optional `ServerOptions`. The `WithProvider` option is now used to pass custom providers to the server. 302 | - The `ReplayProvider` interface's `Put` method now takes a `**Message` instead of a `*Message`. This change also affects the replay providers in this package: `ValidReplayProvider` and `FiniteReplayProvider`. 303 | - The `Provider` interface's `Publish` method now takes a `*Message` instead of a `Message`. This change also affects `Joe`, the provider in this package. 304 | - The `UpgradedRequest`'s `Send` now method takes a `*Message` as parameter. 305 | 306 | ## [0.1.0] - 2021-09-11 First release 307 | 308 | [@aldld]: https://github.com/aldld 309 | 310 | [#5]: https://github.com/tmaxmax/go-sse/pull/5 311 | [#2]: https://github.com/tmaxmax/go-sse/pull/2 312 | 313 | [0.6.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.6.0 314 | [0.5.2]: https://github.com/tmaxmax/go-sse/releases/tag/v0.5.2 315 | [0.5.1]: https://github.com/tmaxmax/go-sse/releases/tag/v0.5.1 316 | [0.5.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.5.0 317 | [0.4.3]: https://github.com/tmaxmax/go-sse/releases/tag/v0.4.3 318 | [0.4.2]: https://github.com/tmaxmax/go-sse/releases/tag/v0.4.2 319 | [0.4.1]: https://github.com/tmaxmax/go-sse/releases/tag/v0.4.1 320 | [0.4.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.4.0 321 | [0.3.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.3.0 322 | [0.2.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.2.0 323 | [0.1.0]: https://github.com/tmaxmax/go-sse/releases/tag/v0.1.0 324 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Teodor Maxim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-sse 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/tmaxmax/go-sse.svg)](https://pkg.go.dev/github.com/tmaxmax/go-sse) 4 | ![CI](https://github.com/tmaxmax/go-sse/actions/workflows/go.yml/badge.svg) 5 | [![codecov](https://codecov.io/gh/tmaxmax/go-sse/branch/master/graph/badge.svg?token=EP52XJI4RO)](https://codecov.io/gh/tmaxmax/go-sse) 6 | [![Go Report Card](https://goreportcard.com/badge/github.com/tmaxmax/go-sse)](https://goreportcard.com/report/github.com/tmaxmax/go-sse) 7 | 8 | Lightweight, fully spec-compliant HTML5 server-sent events library. 9 | 10 | ## Table of contents 11 | 12 | - [go-sse](#go-sse) 13 | - [Table of contents](#table-of-contents) 14 | - [Installation and usage](#installation-and-usage) 15 | - [Cut to the chase – how do I read my LLM's response?](#cut-to-the-chase--how-do-i-read-my-llms-response) 16 | - [Implementing a server](#implementing-a-server) 17 | - [Providers and why they are vital](#providers-and-why-they-are-vital) 18 | - [Meet Joe, the default provider](#meet-joe-the-default-provider) 19 | - [Publish your first event](#publish-your-first-event) 20 | - [The server-side "Hello world"](#the-server-side-hello-world) 21 | - [Using the client](#using-the-client) 22 | - [Creating a client](#creating-a-client) 23 | - [Initiating a connection](#initiating-a-connection) 24 | - [Subscribing to events](#subscribing-to-events) 25 | - [Establishing the connection](#establishing-the-connection) 26 | - [Connection lost?](#connection-lost) 27 | - [The "Hello world" server's client](#the-hello-world-servers-client) 28 | - [License](#license) 29 | - [Contributing](#contributing) 30 | 31 | ## Installation and usage 32 | 33 | Install the package using `go get`: 34 | 35 | ```sh 36 | go get -u github.com/tmaxmax/go-sse 37 | ``` 38 | 39 | It is strongly recommended to use tagged versions of `go-sse` in your projects. The `master` branch has tested but unreleased and maybe undocumented changes, which may break backwards compatibility - use with caution. 40 | 41 | The library provides both server-side and client-side implementations of the protocol. The implementations are completely decoupled and unopinionated: you can connect to a server created using `go-sse` from the browser and you can connect to any server that emits events using the client! 42 | 43 | If you are not familiar with the protocol or not sure how it works, read [MDN's guide for using server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events). [The spec](https://html.spec.whatwg.org/multipage/server-sent-events.html) is also useful read! 44 | 45 | `go-sse` promises to support the [Go versions supported by the Go team](https://go.dev/doc/devel/release#policy) – that is, the 2 most recent major releases. 46 | 47 | ## Cut to the chase – how do I read my LLM's response? 48 | 49 | If you're here just to read ChatGPT's, Claude's or whichever LLM's response stream, you're in the right place! Let's take a look at [`sse.Read`](https://pkg.go.dev/github.com/tmaxmax/go-sse#Read): you just make your HTTP request the same way you'd do for any other API and call it on the request body. Here's some code: 50 | 51 | ```go 52 | req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.yourllm.com/v1/chat/completions", payload) 53 | req.Header.Set("Content-Type", "application/json") 54 | req.Header.Set("Authorization", "Bearer "+yourKey) 55 | 56 | res, err := http.DefaultClient.Do(req) 57 | if err != nil { 58 | // handle error 59 | } 60 | defer res.Body.Close() // don't forget!! 61 | 62 | for ev, err := range sse.Read(res.Body, nil) { 63 | if err != nil { 64 | // handle read error 65 | break // can end the loop as Read stops on first error anyway 66 | } 67 | // Do something with the events, parse the JSON or whatever. 68 | } 69 | ``` 70 | 71 | See the [LLM example](cmd/llm/main.go) for a fully working Go program. 72 | 73 | Go 1.23 iterators (officially ["range-over-func"](https://go.dev/blog/range-functions)) are used for this feature. If you are still on Go 1.22 use the `GOEXPERIMENT=rangefunc` environment variable (e.g. `GOEXPERIMENT=rangefunc go run main.go`) or use the iterator without the syntactic sugar: 74 | ```go 75 | events(func(ev Event) bool { 76 | // do something with event 77 | return true // or false to stop iteration 78 | }) 79 | ``` 80 | 81 | `sse.Read` is also useful if you're implementing an LLM SDK: call it in your code and spare yourself time and maintenance burden by not reimplementing event stream parsing. 82 | 83 | ## Implementing a server 84 | 85 | ### Providers and why they are vital 86 | 87 | First, a server instance has to be created: 88 | 89 | ```go 90 | import "github.com/tmaxmax/go-sse" 91 | 92 | s := &sse.Server{} // zero value ready to use! 93 | ``` 94 | 95 | The `sse.Server` type also implements the `http.Handler` interface, but a server is framework-agnostic: See the [`ServeHTTP` implementation](https://github.com/tmaxmax/go-sse/blob/master/server/server.go#L156) to learn how to implement your own custom logic. It also has some additional configuration options: 96 | 97 | ```go 98 | s := &sse.Server{ 99 | Provider: /* what goes here? find out next! */, 100 | OnSession: /* see Go docs for this one */, 101 | Logger: /* see Go docs for this one, too */, 102 | } 103 | ``` 104 | 105 | What is this "provider"? A provider is an implementation of the publish-subscribe messaging system: 106 | 107 | ```go 108 | type Provider interface { 109 | // Publish a message to all subscribers of the given topics. 110 | Publish(msg *Message, topics []string) error 111 | // Add a new subscriber that is unsubscribed when the context is done. 112 | Subscribe(ctx context.Context, sub Subscription) error 113 | // Cleanup all resources and stop publishing messages or accepting subscriptions. 114 | Shutdown(ctx context.Context) error 115 | } 116 | ``` 117 | 118 | The provider is what dispatches events to clients. When you publish a message (an event), the provider distributes it to all connections (subscribers). It is the central piece of the server: it determines the maximum number of clients your server can handle, the latency between broadcasting events and receiving them client-side and the maximum message throughput supported by your server. As different use cases have different needs, `go-sse` allows to plug in your own system. Some examples of such external systems are: 119 | 120 | - [RabbitMQ streams](https://blog.rabbitmq.com/posts/2021/07/rabbitmq-streams-overview/) 121 | - [Redis pub-sub](https://redis.io/topics/pubsub) 122 | - [Apache Kafka](https://kafka.apache.org/) 123 | - Your own! For example, you can mock providers in testing. 124 | 125 | If an external system is required, an adapter that satisfies the `Provider` interface must be created so it can then be used with `go-sse`. To implement such an adapter, read [the Provider documentation][2] for implementation requirements! And maybe share them with others: `go-sse` is built with reusability in mind! 126 | 127 | But in most cases the power and scalability that these external systems bring is not necessary, so `go-sse` comes with a default provider builtin. Read further! 128 | 129 | ### Meet Joe, the default provider 130 | 131 | The server still works by default, without a provider. `go-sse` brings you Joe: the trusty, pure Go pub-sub implementation, who handles all your events by default! Befriend Joe as following: 132 | 133 | ```go 134 | import "github.com/tmaxmax/go-sse" 135 | 136 | joe := &sse.Joe{} // the zero value is ready to use! 137 | ``` 138 | 139 | and he'll dispatch events all day! By default, he has no memory of what events he has received, but you can help him remember and replay older messages to new clients using a `Replayer`: 140 | 141 | ```go 142 | type Replayer interface { 143 | // Put a new event in the provider's buffer. 144 | // If the provider automatically adds IDs as well, 145 | // the returned message will also have the ID set, 146 | // otherwise the input value is returned. 147 | Put(msg *Message, topics []string) (*Message, error) 148 | // Replay valid events to a subscriber. 149 | Replay(sub Subscription) error 150 | } 151 | ``` 152 | 153 | `go-sse` provides two replayers by default, which both hold the events in-memory: the `ValidReplayer` and `FiniteReplayer`. The first replays events that are valid, not expired, the second replays a finite number of the most recent events. For example: 154 | 155 | ```go 156 | // Let's have events expire after 5 minutes. For this example we don't enable automatic ID generation. 157 | r, err := sse.NewValidReplayer(time.Minute * 5, false) 158 | if err != nil { 159 | // TTL was 0 or negative. 160 | // Useful to have this error if the value comes from a config which happens to be faulty. 161 | } 162 | 163 | joe = &sse.Joe{Replayer: r} 164 | ``` 165 | 166 | will tell Joe to replay all valid events! Replayers can do so much more (for example, add IDs to events automatically): read the [docs][3] on how to use the existing ones and how to implement yours. 167 | 168 | You can also implement your own replayers: maybe you need persistent storage for your events? Or event validity is determined based on other criteria than expiry time? And if you think your replayer may be useful to others, you are encouraged to share it! 169 | 170 | `go-sse` created the `Replayer` interface mainly for `Joe`, but it encourages you to integrate it with your own `Provider` implementations, where suitable. 171 | 172 | ### Publish your first event 173 | 174 | To publish events from the server, we use the `sse.Message` struct: 175 | 176 | ```go 177 | import "github.com/tmaxmax/go-sse" 178 | 179 | m := &sse.Message{} 180 | m.AppendData("Hello world!", "Nice\nto see you.") 181 | ``` 182 | 183 | Now let's send it to our clients: 184 | 185 | ```go 186 | var s *sse.Server 187 | 188 | s.Publish(m) 189 | ``` 190 | 191 | This is how clients will receive our event: 192 | 193 | ```txt 194 | data: Hello world! 195 | data: Nice 196 | data: to see you. 197 | ``` 198 | 199 | You can also see that `go-sse` takes care of splitting input by lines into new fields, as required by the specification. 200 | 201 | Keep in mind that replayers, such as the `ValidReplayer` used above, will give an error for and won't replay the events without an ID (unless, of course, they give the IDs themselves). To have our event expire, as configured, we must set an ID for the event: 202 | 203 | ```go 204 | m.ID = sse.ID("unique") 205 | ``` 206 | 207 | This is how the event will look: 208 | 209 | ```txt 210 | id: unique 211 | data: Hello world! 212 | data: Nice 213 | data: to see you. 214 | ``` 215 | 216 | Now that it has an ID, the event will be considered expired 5 minutes after it's been published – it won't be replayed to clients after it expires! 217 | 218 | `sse.ID` is a function that returns an `EventID` – a special type that denotes an event's ID. An ID must not have newlines, so we must use special functions which validate the value beforehand. The `ID` constructor function we've used above panics (it is useful when creating IDs from static strings), but there's also `NewID`, which returns an error indicating whether the value was successfully converted to an ID or not: 219 | 220 | ```go 221 | id, err := sse.NewID("invalid\nID") 222 | ``` 223 | 224 | Here, `err` will be non-nil and `id` will be an unset value: no `id` field will be sent to clients if you set an event's ID using that value! 225 | 226 | Setting the event's type (the `event` field) is equally easy: 227 | 228 | ```go 229 | m.Type = sse.Type("The event's name") 230 | ``` 231 | 232 | Like IDs, types cannot have newlines. You are provided with constructors that follow the same convention: `Type` panics, `NewType` returns an error. Read the [docs][4] to find out more about messages and how to use them! 233 | 234 | ### The server-side "Hello world" 235 | 236 | Now, let's put everything that we've learned together! We'll create a server that sends a "Hello world!" message every second to all its clients, with Joe's help: 237 | 238 | ```go 239 | package main 240 | 241 | import ( 242 | "log" 243 | "net/http" 244 | "time" 245 | 246 | "github.com/tmaxmax/go-sse" 247 | ) 248 | 249 | func main() { 250 | s := &sse.Server{} 251 | 252 | go func() { 253 | m := &sse.Message{} 254 | m.AppendData("Hello world") 255 | 256 | for range time.Tick(time.Second) { 257 | _ = s.Publish(m) 258 | } 259 | }() 260 | 261 | if err := http.ListenAndServe(":8000", s); err != nil { 262 | log.Fatalln(err) 263 | } 264 | } 265 | ``` 266 | 267 | Joe is our default provider here, as no provider is given to the server constructor. The server is already an `http.Handler` so we can use it directly with `http.ListenAndServe`. 268 | 269 | [Also see a more complex example!](cmd/complex/main.go) 270 | 271 | This is by far a complete presentation, make sure to read the docs in order to use `go-sse` to its full potential! 272 | 273 | ## Using the client 274 | 275 | ### Creating a client 276 | 277 | We will use the `sse.Client` type for connecting to event streams: 278 | 279 | ```go 280 | type Client struct { 281 | HTTPClient *http.Client 282 | OnRetry backoff.Notify 283 | ResponseValidator ResponseValidator 284 | MaxRetries int 285 | DefaultReconnectionTime time.Duration 286 | } 287 | ``` 288 | 289 | As you can see, it uses a `net/http` client. It also uses the [cenkalti/backoff][1] library for implementing auto-reconnect when a connection to a server is lost. Read the [client docs][5] and the Backoff library's docs to find out how to configure the client. We'll use the default client the package provides for further examples. 290 | 291 | ### Initiating a connection 292 | 293 | We must first create an `http.Request` - yup, a fully customizable request: 294 | 295 | ```go 296 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, "host", http.NoBody) 297 | ``` 298 | 299 | Any kind of request is valid as long as your server handler supports it: you can do a GET, a POST, send a body; do whatever! The context is used as always for cancellation - to stop receiving events you will have to cancel the context. 300 | Let's initiate a connection with this request: 301 | 302 | ```go 303 | import "github.com/tmaxmax/go-sse" 304 | 305 | conn := sse.DefaultClient.NewConnection(req) 306 | // you can also do sse.NewConnection(req) 307 | // it is an utility function that calls the 308 | // NewConnection method on the default client 309 | ``` 310 | 311 | ### Subscribing to events 312 | 313 | Great! Let's imagine the event stream looks as following: 314 | 315 | ```txt 316 | data: some unnamed event 317 | 318 | event: I have a name 319 | data: some data 320 | 321 | event: Another name 322 | data: some data 323 | ``` 324 | 325 | To receive the unnamed events, we subscribe to them as following: 326 | 327 | ```go 328 | unsubscribe := conn.SubscribeMessages(func (event sse.Event) { 329 | // do something with the event 330 | }) 331 | ``` 332 | 333 | To receive the events named "I have a name": 334 | 335 | ```go 336 | unsubscribe := conn.SubscribeEvent("I have a name", func (event sse.Event) { 337 | // do something with the event 338 | }) 339 | ``` 340 | 341 | If you want to subscribe to all events, regardless of their name: 342 | 343 | ```go 344 | unsubscribe := conn.SubscribeToAll(func (event sse.Event) { 345 | // do something with the event 346 | }) 347 | ``` 348 | 349 | All `Subscribe` methods return a function that when called tells the connection to stop calling the corresponding callback. 350 | 351 | In order to work with events, the `sse.Event` type has some fields and methods exposed: 352 | 353 | ```go 354 | type Event struct { 355 | LastEventID string 356 | Name string 357 | Data string 358 | } 359 | ``` 360 | 361 | Pretty self-explanatory, but make sure to read the [docs][6]! 362 | 363 | Now, with this knowledge, let's subscribe to all unnamed events and, when the connection is established, print their data: 364 | 365 | ```go 366 | unsubscribe := conn.SubscribeMessages(func(event sse.Event) { 367 | fmt.Printf("Received an unnamed event: %s\n", event.Data) 368 | }) 369 | ``` 370 | 371 | ### Establishing the connection 372 | 373 | Great, we are subscribed now! Let's start receiving events: 374 | 375 | ```go 376 | err := conn.Connect() 377 | ``` 378 | 379 | By calling `Connect`, the request created above will be sent to the server, and if successful, the subscribed callbacks will be called when new events are received. `Connect` returns only after all callbacks have finished executing. 380 | To stop calling a certain callback, call the unsubscribe function returned when subscribing. You can also subscribe new callbacks after calling Connect from a different goroutine. 381 | When using a `context.Context` to stop the connection, the error returned will be the context error – be it `context.Canceled`, `context.DeadlineExceeded` or a custom cause (when using `context.WithCancelCause`). In other words, a successfully closed `Connection` will always return an error – if the context error is not relevant, you can ignore it. For example: 382 | 383 | ```go 384 | if err := conn.Connect(); !errors.Is(err, context.Canceled) { 385 | // handle error 386 | } 387 | ``` 388 | 389 | A context created with `context.WithCancel`, or one with `context.WithCancelCause` and cancelled with the error `context.Canceled` is assumed above. 390 | 391 | There may be situations where the connection does not have to live for indeterminately long – for example when using the OpenAI API. In those situations, configure the client to not retry the connection and ignore `io.EOF` on return: 392 | 393 | ```go 394 | client := sse.Client{ 395 | Backoff: sse.Backoff{ 396 | MaxRetries: -1, 397 | }, 398 | // other settings... 399 | } 400 | 401 | req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/...", body) 402 | conn := client.NewConnection(req) 403 | 404 | conn.SubscribeMessages(/* callback */) 405 | 406 | if err := conn.Connect(); !errors.Is(err, io.EOF) { 407 | // handle error 408 | } 409 | ``` 410 | 411 | ### Connection lost? 412 | 413 | Either way, after receiving so many events, something went wrong and the server is temporarily down. Oh no! As a last hope, it has sent us the following event: 414 | 415 | ```text 416 | retry: 60000 417 | : that's a minute in milliseconds and this 418 | : is a comment which is ignored by the client 419 | ``` 420 | 421 | Not a sweat, though! The connection will automatically be reattempted after a minute, when we'll hope the server's back up again. Canceling the request's context will cancel any reconnection attempt, too. 422 | 423 | If the server doesn't set a retry time, the client's `DefaultReconnectionTime` is used. 424 | 425 | ### The "Hello world" server's client 426 | 427 | Let's use what we know to create a client for the previous server example: 428 | 429 | ```go 430 | package main 431 | 432 | import ( 433 | "fmt" 434 | "net/http" 435 | "os" 436 | 437 | "github.com/tmaxmax/go-sse" 438 | ) 439 | 440 | func main() { 441 | r, _ := http.NewRequest(http.MethodGet, "http://localhost:8000", nil) 442 | conn := sse.NewConnection(r) 443 | 444 | conn.SubscribeMessages(func(ev sse.Event) { 445 | fmt.Printf("%s\n\n", ev.Data) 446 | }) 447 | 448 | if err := conn.Connect(); err != nil { 449 | fmt.Fprintln(os.Stderr, err) 450 | } 451 | } 452 | ``` 453 | 454 | Yup, this is it! We are using the default client to receive all the unnamed events from the server. The output will look like this, when both programs are run in parallel: 455 | 456 | ```txt 457 | Hello world! 458 | 459 | Hello world! 460 | 461 | Hello world! 462 | 463 | Hello world! 464 | 465 | ... 466 | ``` 467 | 468 | [See the complex example's client too!](cmd/complex_client/main.go) 469 | 470 | ## License 471 | 472 | This project is licensed under the [MIT license](LICENSE). 473 | 474 | ## Contributing 475 | 476 | The library's in its early stages, so contributions are vital - I'm so glad you wish to improve `go-sse`! Maybe start by opening an issue first, to describe the intended modifications and further discuss how to integrate them. Open PRs to the `master` branch and wait for CI to complete. If all is clear, your changes will soon be merged! Also, make sure your changes come with an extensive set of tests and the code is formatted. 477 | 478 | Thank you for contributing! 479 | 480 | [1]: https://github.com/cenkalti/backoff 481 | [2]: https://pkg.go.dev/github.com/tmaxmax/go-sse#Provider 482 | [3]: https://pkg.go.dev/github.com/tmaxmax/go-sse#Replayer 483 | [4]: https://pkg.go.dev/github.com/tmaxmax/go-sse#Message 484 | [5]: https://pkg.go.dev/github.com/tmaxmax/go-sse#Client 485 | [6]: https://pkg.go.dev/github.com/tmaxmax/go-sse#Event 486 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "net/http" 7 | "strings" 8 | "time" 9 | "unicode" 10 | ) 11 | 12 | // The ResponseValidator type defines the type of the function 13 | // that checks whether server responses are valid, before starting 14 | // to read events from them. See the Client's documentation for more info. 15 | // 16 | // These errors are considered permanent and thus if the client is configured 17 | // to retry on error no retry is attempted and the error is returned. 18 | type ResponseValidator func(*http.Response) error 19 | 20 | // The Client struct is used to initialize new connections to different servers. 21 | // It is safe for concurrent use. 22 | // 23 | // After connections are created, the Connect method must be called to start 24 | // receiving events. 25 | type Client struct { 26 | // The HTTP client to be used. Defaults to http.DefaultClient. 27 | HTTPClient *http.Client 28 | // A callback that's executed whenever a reconnection attempt starts. 29 | // It receives the error that caused the retry and the reconnection time. 30 | OnRetry func(error, time.Duration) 31 | // A function to check if the response from the server is valid. 32 | // Defaults to a function that checks the response's status code is 200 33 | // and the content type is text/event-stream. 34 | // 35 | // If the error type returned has a Temporary or a Timeout method, 36 | // they will be used to determine whether to reattempt the connection. 37 | // Otherwise, the error will be considered permanent and no reconnections 38 | // will be attempted. 39 | ResponseValidator ResponseValidator 40 | // Backoff configures the backoff strategy. See the documentation of 41 | // each field for more information. 42 | Backoff Backoff 43 | } 44 | 45 | // Backoff configures the reconnection strategy of a Connection. 46 | type Backoff struct { 47 | // The initial wait time before a reconnection is attempted. 48 | // Must be >0. Defaults to 500ms. 49 | InitialInterval time.Duration 50 | // How much should the reconnection time grow on subsequent attempts. 51 | // Must be >=1; 1 = constant interval. Defaults to 1.5. 52 | Multiplier float64 53 | // How much does the reconnection time vary relative to the base value. 54 | // This is useful to prevent multiple clients to reconnect at the exact 55 | // same time, as it makes the wait times distinct. 56 | // Must be in range (0, 1); -1 = no randomization. Defaults to 0.5. 57 | Jitter float64 58 | // How much can the wait time grow. 59 | // If <=0 = the wait time can infinitely grow. Defaults to infinite growth. 60 | MaxInterval time.Duration 61 | // How much time can retries be attempted. 62 | // For example, if this is 5 seconds, after 5 seconds the client 63 | // will stop retrying. 64 | // If <=0 = no limit. Defaults to no limit. 65 | MaxElapsedTime time.Duration 66 | // How many retries are allowed. 67 | // <0 = no retries, 0 = infinite. Defaults to infinite retries. 68 | MaxRetries int 69 | } 70 | 71 | // NewConnection initializes and configures a connection. On connect, the given 72 | // request is sent and if successful the connection starts receiving messages. 73 | // Use the request's context to stop the connection. 74 | // 75 | // If the request has a body, it is necessary to provide a GetBody function in order 76 | // for the connection to be reattempted, in case of an error. Using readers 77 | // such as bytes.Reader, strings.Reader or bytes.Buffer when creating a request 78 | // using http.NewRequestWithContext will ensure this function is present on the request. 79 | func (c *Client) NewConnection(r *http.Request) *Connection { 80 | if r == nil { 81 | panic("go-sse.client.NewConnection: request cannot be nil") 82 | } 83 | 84 | mergeDefaults(c) 85 | 86 | conn := &Connection{ 87 | client: *c, // we clone the client so the config cannot be modified from outside 88 | request: r.Clone(r.Context()), // we clone the request so its fields cannot be modified from outside 89 | callbacks: map[string]map[int]EventCallback{}, 90 | callbacksAll: map[int]EventCallback{}, 91 | } 92 | 93 | return conn 94 | } 95 | 96 | // DefaultValidator is the default client response validation function. As per the spec, 97 | // It checks the content type to be text/event-stream and the response status code to be 200 OK. 98 | // 99 | // If this validator fails, errors are considered permanent. No retry attempts are made. 100 | // 101 | // See https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model. 102 | var DefaultValidator ResponseValidator = func(r *http.Response) error { 103 | if r.StatusCode != http.StatusOK { 104 | return fmt.Errorf("expected status code %d %s, received %d %s", http.StatusOK, http.StatusText(http.StatusOK), r.StatusCode, http.StatusText(r.StatusCode)) 105 | } 106 | cts := r.Header.Get("Content-Type") 107 | ct := contentType(cts) 108 | if expected := "text/event-stream"; ct != expected { 109 | return fmt.Errorf("expected content type to have %q, received %q", expected, cts) 110 | } 111 | return nil 112 | } 113 | 114 | // NoopValidator is a client response validator function that treats all responses as valid. 115 | var NoopValidator ResponseValidator = func(_ *http.Response) error { 116 | return nil 117 | } 118 | 119 | // DefaultClient is the client that is used when creating a new connection using the NewConnection function. 120 | // Unset properties on new clients are replaced with the ones set for the default client. 121 | var DefaultClient = &Client{ 122 | HTTPClient: http.DefaultClient, 123 | ResponseValidator: DefaultValidator, 124 | Backoff: Backoff{ 125 | InitialInterval: time.Millisecond * 500, 126 | Multiplier: 1.5, 127 | Jitter: 0.5, 128 | }, 129 | } 130 | 131 | // NewConnection creates a connection using the default client. 132 | func NewConnection(r *http.Request) *Connection { 133 | return DefaultClient.NewConnection(r) 134 | } 135 | 136 | func mergeDefaults(c *Client) { 137 | if c.HTTPClient == nil { 138 | c.HTTPClient = DefaultClient.HTTPClient 139 | } 140 | if c.Backoff.InitialInterval <= 0 { 141 | c.Backoff.InitialInterval = DefaultClient.Backoff.InitialInterval 142 | } 143 | if c.Backoff.Multiplier < 1 { 144 | c.Backoff.Multiplier = DefaultClient.Backoff.Multiplier 145 | } 146 | if c.Backoff.Jitter <= 0 || c.Backoff.Jitter >= 1 { 147 | c.Backoff.Jitter = DefaultClient.Backoff.Jitter 148 | } 149 | if c.ResponseValidator == nil { 150 | c.ResponseValidator = DefaultClient.ResponseValidator 151 | } 152 | } 153 | 154 | func contentType(header string) string { 155 | cts := strings.FieldsFunc(header, func(r rune) bool { 156 | return unicode.IsSpace(r) || r == ';' || r == ',' 157 | }) 158 | if len(cts) == 0 { 159 | return "" 160 | } 161 | return strings.ToLower(cts[0]) 162 | } 163 | 164 | type backoffController struct { 165 | start time.Time 166 | rng *rand.Rand 167 | b *Backoff 168 | interval time.Duration 169 | numRetries int 170 | } 171 | 172 | func (b *Backoff) new() backoffController { 173 | now := time.Now() 174 | return backoffController{ 175 | start: now, 176 | rng: rand.New(rand.NewSource(now.UnixNano())), 177 | b: b, 178 | interval: b.InitialInterval, 179 | numRetries: 0, 180 | } 181 | } 182 | 183 | // reset the backoff to the initial state, i.e. as if no retries have occurred. 184 | // If newInterval is greater than 0, the initial interval is changed to it. 185 | func (c *backoffController) reset(newInterval time.Duration) { 186 | if newInterval > 0 { 187 | c.interval = newInterval 188 | } else { 189 | c.interval = c.b.InitialInterval 190 | } 191 | c.numRetries = 0 192 | c.start = time.Now() 193 | } 194 | 195 | func (c *backoffController) next() (interval time.Duration, shouldRetry bool) { 196 | if c.b.MaxRetries < 0 || (c.b.MaxRetries > 0 && c.numRetries == c.b.MaxRetries) { 197 | return 0, false 198 | } 199 | 200 | c.numRetries++ 201 | elapsed := time.Since(c.start) 202 | next := nextInterval(c.b.Jitter, c.rng, c.interval) 203 | c.interval = growInterval(c.interval, c.b.MaxInterval, c.b.Multiplier) 204 | 205 | if c.b.MaxElapsedTime > 0 && elapsed+next > c.b.MaxElapsedTime { 206 | return 0, false 207 | } 208 | 209 | return next, true 210 | } 211 | 212 | func nextInterval(jitter float64, rng *rand.Rand, current time.Duration) time.Duration { 213 | if jitter == -1 { 214 | return current 215 | } 216 | 217 | delta := jitter * float64(current) 218 | minInterval := float64(current) - delta 219 | maxInterval := float64(current) + delta 220 | 221 | return time.Duration(minInterval + (rng.Float64() * (maxInterval - minInterval + 1))) 222 | } 223 | 224 | func growInterval(current, maxInterval time.Duration, mul float64) time.Duration { 225 | if maxInterval > 0 && float64(current) >= float64(maxInterval)/mul { 226 | return maxInterval 227 | } 228 | return time.Duration(float64(current) * mul) 229 | } 230 | -------------------------------------------------------------------------------- /client_connection.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "sync" 11 | "time" 12 | 13 | "github.com/tmaxmax/go-sse/internal/parser" 14 | ) 15 | 16 | // EventCallback is a function that is used to receive events from a Connection. 17 | type EventCallback func(Event) 18 | 19 | // EventCallbackRemover is a function that removes an already registered callback 20 | // from a connection. Calling it multiple times is a no-op. 21 | type EventCallbackRemover func() 22 | 23 | // Connection is a connection to an events stream. Created using the Client struct, 24 | // a Connection processes the incoming events and calls the subscribed event callbacks. 25 | // If the connection to the server temporarily fails, the connection will be reattempted. 26 | // Retry values received from servers will be taken into account. 27 | // 28 | // Connections must not be copied after they are created. 29 | type Connection struct { //nolint:govet // The current order aids readability. 30 | mu sync.RWMutex 31 | request *http.Request 32 | callbacks map[string]map[int]EventCallback 33 | callbacksAll map[int]EventCallback 34 | lastEventID string 35 | client Client 36 | buf []byte 37 | bufMaxSize int 38 | callbackID int 39 | isRetry bool 40 | } 41 | 42 | // SubscribeMessages subscribes the given callback to all events without type (without or with empty `event` field). 43 | // Remove the callback by calling the returned function. 44 | func (c *Connection) SubscribeMessages(cb EventCallback) EventCallbackRemover { 45 | return c.SubscribeEvent("", cb) 46 | } 47 | 48 | // SubscribeEvent subscribes the given callback to all the events with the provided type 49 | // (the `event` field has the value given here). 50 | // Remove the callback by calling the returned function. 51 | func (c *Connection) SubscribeEvent(typ string, cb EventCallback) EventCallbackRemover { 52 | return c.addSubscriber(typ, cb) 53 | } 54 | 55 | // SubscribeToAll subscribes the given callback to all events, with or without type. 56 | // Remove the callback by calling the returned function. 57 | func (c *Connection) SubscribeToAll(cb EventCallback) EventCallbackRemover { 58 | return c.addSubscriberToAll(cb) 59 | } 60 | 61 | func (c *Connection) addSubscriberToAll(cb EventCallback) EventCallbackRemover { 62 | c.mu.Lock() 63 | defer c.mu.Unlock() 64 | 65 | id := c.callbackID 66 | c.callbacksAll[id] = cb 67 | c.callbackID++ 68 | 69 | return func() { 70 | c.mu.Lock() 71 | defer c.mu.Unlock() 72 | 73 | delete(c.callbacksAll, id) 74 | } 75 | } 76 | 77 | func (c *Connection) addSubscriber(event string, cb EventCallback) EventCallbackRemover { 78 | c.mu.Lock() 79 | defer c.mu.Unlock() 80 | 81 | if _, ok := c.callbacks[event]; !ok { 82 | c.callbacks[event] = map[int]EventCallback{} 83 | } 84 | 85 | id := c.callbackID 86 | c.callbacks[event][id] = cb 87 | c.callbackID++ 88 | 89 | return func() { 90 | c.mu.Lock() 91 | defer c.mu.Unlock() 92 | 93 | delete(c.callbacks[event], id) 94 | if len(c.callbacks[event]) == 0 { 95 | delete(c.callbacks, event) 96 | } 97 | } 98 | } 99 | 100 | // Buffer sets the underlying buffer to be used when scanning events. 101 | // Use this if you need to read very large events (bigger than the default 102 | // of 65K bytes). 103 | // 104 | // Read the documentation of bufio.Scanner.Buffer for more information. 105 | func (c *Connection) Buffer(buf []byte, maxSize int) { 106 | c.buf = buf 107 | c.bufMaxSize = maxSize 108 | } 109 | 110 | // ConnectionError is the type that wraps all the connection errors that occur. 111 | type ConnectionError struct { 112 | // The request for which the connection failed. 113 | Req *http.Request 114 | // The reason the operation failed. 115 | Err error 116 | // The reason why the request failed. 117 | Reason string 118 | } 119 | 120 | func (e *ConnectionError) Error() string { 121 | return fmt.Sprintf("request failed: %s: %v", e.Reason, e.Err) 122 | } 123 | 124 | func (e *ConnectionError) Unwrap() error { 125 | return e.Err 126 | } 127 | 128 | func (c *Connection) resetRequest() error { 129 | if !c.isRetry { 130 | c.isRetry = true 131 | return nil 132 | } 133 | if err := resetRequestBody(c.request); err != nil { 134 | return err 135 | } 136 | if c.lastEventID == "" { 137 | c.request.Header.Del("Last-Event-ID") 138 | } else { 139 | c.request.Header.Set("Last-Event-ID", c.lastEventID) 140 | } 141 | return nil 142 | } 143 | 144 | func (c *Connection) dispatch(ev Event) { 145 | c.mu.RLock() 146 | defer c.mu.RUnlock() 147 | 148 | cbs := c.callbacks[ev.Type] 149 | cbCount := len(cbs) + len(c.callbacksAll) 150 | if cbCount == 0 { 151 | return 152 | } 153 | 154 | for _, cb := range c.callbacks[ev.Type] { 155 | cb(ev) 156 | } 157 | for _, cb := range c.callbacksAll { 158 | cb(ev) 159 | } 160 | } 161 | 162 | func (c *Connection) read(r io.Reader, setRetry func(time.Duration)) error { 163 | pf := func() *parser.Parser { 164 | p := parser.New(r) 165 | if c.buf != nil || c.bufMaxSize > 0 { 166 | p.Buffer(c.buf, c.bufMaxSize) 167 | } 168 | return p 169 | } 170 | 171 | var readErr error 172 | read(pf, c.lastEventID, func(r int64) { setRetry(time.Duration(r) * time.Millisecond) }, false)(func(e Event, err error) bool { 173 | if err != nil { 174 | readErr = err 175 | return false 176 | } 177 | c.lastEventID = e.LastEventID 178 | c.dispatch(e) 179 | return true 180 | }) 181 | 182 | return readErr 183 | } 184 | 185 | // Connect sends the request the connection was created with to the server 186 | // and, if successful, it starts receiving events. The caller goroutine 187 | // is blocked until the request's context is done or an error occurs. 188 | // 189 | // If the request's context is cancelled, Connect returns its error. 190 | // Otherwise, if the maximum number of retries is made, the last error 191 | // that occurred is returned. Connect never returns otherwise – either 192 | // the context is cancelled, or it's done retrying. 193 | // 194 | // All errors returned other than the context errors will be wrapped 195 | // inside a *ConnectionError. 196 | func (c *Connection) Connect() error { 197 | ctx := c.request.Context() 198 | backoff := c.client.Backoff.new() 199 | 200 | c.request.Header.Set("Accept", "text/event-stream") 201 | c.request.Header.Set("Connection", "keep-alive") 202 | c.request.Header.Set("Cache", "no-cache") 203 | 204 | t := time.NewTimer(0) 205 | defer t.Stop() 206 | 207 | for { 208 | select { 209 | case <-t.C: 210 | shouldRetry, err := c.doConnect(ctx, backoff.reset) 211 | if !shouldRetry { 212 | return err 213 | } 214 | 215 | next, shouldRetry := backoff.next() 216 | if !shouldRetry { 217 | return err 218 | } 219 | 220 | if c.client.OnRetry != nil { 221 | c.client.OnRetry(err, next) 222 | } 223 | 224 | t.Reset(next) 225 | case <-ctx.Done(): 226 | return ctx.Err() 227 | } 228 | } 229 | } 230 | 231 | func (c *Connection) doConnect(ctx context.Context, setRetry func(time.Duration)) (shouldRetry bool, err error) { 232 | if err := c.resetRequest(); err != nil { 233 | return false, &ConnectionError{Req: c.request, Reason: "request reset failed", Err: err} 234 | } 235 | 236 | res, err := c.client.HTTPClient.Do(c.request) 237 | if err != nil { 238 | concrete := err.(*url.Error) //nolint:errorlint // We know the concrete type here 239 | if errors.Is(err, ctx.Err()) { 240 | return false, concrete.Err 241 | } 242 | return true, &ConnectionError{Req: c.request, Reason: "connection to server failed", Err: concrete.Err} 243 | } 244 | defer res.Body.Close() 245 | 246 | if err := c.client.ResponseValidator(res); err != nil { 247 | return false, &ConnectionError{Req: c.request, Reason: "response validation failed", Err: err} 248 | } 249 | 250 | setRetry(0) 251 | 252 | err = c.read(res.Body, setRetry) 253 | if errors.Is(err, ctx.Err()) { 254 | return false, err 255 | } 256 | 257 | return true, &ConnectionError{Req: c.request, Reason: "connection to server lost", Err: err} 258 | } 259 | 260 | // ErrNoGetBody is a sentinel error returned when the connection cannot be reattempted 261 | // due to GetBody not existing on the original request. 262 | var ErrNoGetBody = errors.New("the GetBody function doesn't exist on the request") 263 | 264 | func resetRequestBody(r *http.Request) error { 265 | if r.Body == nil || r.Body == http.NoBody { 266 | return nil 267 | } 268 | if r.GetBody == nil { 269 | return ErrNoGetBody 270 | } 271 | body, err := r.GetBody() 272 | if err != nil { 273 | return err 274 | } 275 | r.Body = body 276 | return nil 277 | } 278 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/tmaxmax/go-sse" 15 | "github.com/tmaxmax/go-sse/internal/parser" 16 | "github.com/tmaxmax/go-sse/internal/tests" 17 | ) 18 | 19 | type roundTripperFunc func(*http.Request) (*http.Response, error) 20 | 21 | func (r roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { 22 | return r(req) 23 | } 24 | 25 | func reqCtx(tb testing.TB, ctx context.Context, method, address string, body io.Reader) *http.Request { //nolint 26 | tb.Helper() 27 | 28 | r, err := http.NewRequestWithContext(ctx, method, address, body) 29 | tests.Equal(tb, err, nil, "failed to create request") 30 | 31 | return r 32 | } 33 | 34 | func req(tb testing.TB, method, address string, body io.Reader) *http.Request { //nolint 35 | tb.Helper() 36 | return reqCtx(tb, context.Background(), method, address, body) 37 | } 38 | 39 | func toEv(tb testing.TB, s string) (ev sse.Event) { 40 | tb.Helper() 41 | 42 | defer func() { 43 | if l := len(ev.Data); l > 0 { 44 | ev.Data = ev.Data[:l-1] 45 | } 46 | }() 47 | 48 | p := parser.NewFieldParser(s) 49 | 50 | for f := (parser.Field{}); p.Next(&f); { 51 | switch f.Name { //nolint:exhaustive // Comment fields are not parsed. 52 | case parser.FieldNameData: 53 | ev.Data += f.Value + "\n" 54 | case parser.FieldNameID: 55 | ev.LastEventID = string(f.Value) 56 | case parser.FieldNameEvent: 57 | ev.Type = string(f.Value) 58 | case parser.FieldNameRetry: 59 | default: 60 | return 61 | } 62 | } 63 | 64 | tests.Equal(tb, p.Err(), nil, "unexpected toEv fail") 65 | 66 | return 67 | } 68 | 69 | func TestClient_NewConnection(t *testing.T) { 70 | tests.Panics(t, func() { 71 | sse.NewConnection(nil) 72 | }, "a connection cannot be created without a request") 73 | 74 | c := sse.Client{} 75 | r := req(t, "", "", nil) 76 | _ = c.NewConnection(r) 77 | 78 | tests.Equal(t, c.HTTPClient, http.DefaultClient, "incorrect default HTTP client") 79 | } 80 | 81 | func TestConnection_Connect_retry(t *testing.T) { 82 | var firstReconnectionTime time.Duration 83 | var retryAttempts int 84 | 85 | testErr := errors.New("done") 86 | 87 | c := &sse.Client{ 88 | HTTPClient: &http.Client{ 89 | Transport: roundTripperFunc(func(_ *http.Request) (*http.Response, error) { 90 | return nil, testErr 91 | }), 92 | }, 93 | OnRetry: func(_ error, duration time.Duration) { 94 | retryAttempts++ 95 | if retryAttempts == 1 { 96 | firstReconnectionTime = duration 97 | } 98 | }, 99 | Backoff: sse.Backoff{ 100 | MaxRetries: 3, 101 | InitialInterval: time.Millisecond, 102 | }, 103 | } 104 | r := req(t, "", "", http.NoBody) 105 | err := c.NewConnection(r).Connect() 106 | 107 | tests.ErrorIs(t, err, testErr, "invalid error received from Connect") 108 | tests.Equal(t, retryAttempts, c.Backoff.MaxRetries, "connection was not retried enough times") 109 | 110 | timeDelta := time.Duration(float64(c.Backoff.InitialInterval) * sse.DefaultClient.Backoff.Jitter) 111 | tests.Expect(t, c.Backoff.InitialInterval-timeDelta <= firstReconnectionTime && firstReconnectionTime <= c.Backoff.InitialInterval+timeDelta, "reconnection time incorrectly set") 112 | } 113 | 114 | func TestConnection_Connect_noRetryCtxErr(t *testing.T) { 115 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 | ticker := time.NewTicker(time.Millisecond) 117 | defer ticker.Stop() 118 | 119 | for { 120 | select { 121 | case ts := <-ticker.C: 122 | fmt.Fprintf(w, "id: %s\n\n", ts) 123 | case <-r.Context().Done(): 124 | return 125 | } 126 | } 127 | })) 128 | t.Cleanup(ts.Close) 129 | 130 | ctx, cancel := context.WithCancel(context.Background()) 131 | t.Cleanup(cancel) 132 | 133 | c := &sse.Client{ 134 | HTTPClient: ts.Client(), 135 | ResponseValidator: sse.NoopValidator, 136 | } 137 | 138 | r := reqCtx(t, ctx, "", ts.URL, http.NoBody) 139 | go func() { 140 | time.Sleep(time.Millisecond) 141 | cancel() 142 | }() 143 | err := c.NewConnection(r).Connect() 144 | tests.ErrorIs(t, err, ctx.Err(), "invalid connect error") 145 | } 146 | 147 | type readerWrapper struct { 148 | io.Reader 149 | } 150 | 151 | func TestConnection_Connect_resetBody(t *testing.T) { 152 | type test struct { 153 | body io.Reader 154 | err error 155 | getBody func() (io.ReadCloser, error) 156 | name string 157 | } 158 | 159 | getBodyErr := errors.New("haha") 160 | 161 | tt := []test{ 162 | { 163 | name: "No body", 164 | }, 165 | { 166 | name: "Body for which GetBody is set", 167 | body: strings.NewReader("nice"), 168 | }, 169 | { 170 | name: "Body without GetBody", 171 | body: readerWrapper{strings.NewReader("haha")}, 172 | err: sse.ErrNoGetBody, 173 | }, 174 | { 175 | name: "GetBody that returns error", 176 | err: getBodyErr, 177 | body: readerWrapper{nil}, 178 | getBody: func() (io.ReadCloser, error) { 179 | return nil, getBodyErr 180 | }, 181 | }, 182 | } 183 | 184 | ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { 185 | time.Sleep(time.Millisecond * 5) 186 | })) 187 | defer ts.Close() 188 | httpClient := ts.Client() 189 | rt := httpClient.Transport 190 | 191 | c := &sse.Client{ 192 | HTTPClient: httpClient, 193 | ResponseValidator: sse.NoopValidator, 194 | Backoff: sse.Backoff{ 195 | MaxRetries: 1, 196 | InitialInterval: time.Nanosecond, 197 | }, 198 | } 199 | 200 | for _, test := range tt { 201 | t.Run(test.name, func(t *testing.T) { 202 | firstTry := true 203 | c.HTTPClient.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { 204 | if firstTry { 205 | firstTry = false 206 | return nil, errors.New("fail") 207 | } 208 | return rt.RoundTrip(r) 209 | }) 210 | 211 | ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*3) 212 | defer cancel() 213 | 214 | r := reqCtx(t, ctx, "", ts.URL, test.body) 215 | if test.getBody != nil { 216 | r.GetBody = test.getBody 217 | } 218 | 219 | err := c.NewConnection(r).Connect() 220 | if test.err != nil { 221 | tests.ErrorIs(t, err, test.err, "incorrect error received from Connect") 222 | } else { 223 | tests.Equal(t, err, ctx.Err(), "connection error should be context error") 224 | } 225 | }) 226 | } 227 | } 228 | 229 | func TestConnection_Connect_validator(t *testing.T) { 230 | validatorErr := errors.New("invalid") 231 | 232 | type test struct { 233 | err error 234 | validator sse.ResponseValidator 235 | name string 236 | } 237 | 238 | tt := []test{ 239 | { 240 | name: "No validation error", 241 | validator: sse.NoopValidator, 242 | err: io.EOF, 243 | }, 244 | { 245 | name: "Validation error", 246 | validator: func(_ *http.Response) error { 247 | return validatorErr 248 | }, 249 | err: validatorErr, 250 | }, 251 | } 252 | 253 | ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) 254 | defer ts.Close() 255 | 256 | c := &sse.Client{ 257 | HTTPClient: ts.Client(), 258 | Backoff: sse.Backoff{ 259 | MaxRetries: -1, 260 | }, 261 | } 262 | 263 | for _, test := range tt { 264 | t.Run(test.name, func(t *testing.T) { 265 | c.ResponseValidator = test.validator 266 | 267 | err := c.NewConnection(req(t, "", ts.URL, nil)).Connect() 268 | tests.ErrorIs(t, err, test.err, "incorrect error received from Connect") 269 | }) 270 | } 271 | } 272 | 273 | func TestConnection_Connect_defaultValidator(t *testing.T) { 274 | type test struct { 275 | handler http.Handler 276 | name string 277 | expectErr bool 278 | } 279 | 280 | tt := []test{ 281 | { 282 | name: "Valid request", 283 | handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 284 | w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") 285 | w.WriteHeader(http.StatusOK) 286 | }), 287 | }, 288 | { 289 | name: "Invalid content type", 290 | handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 291 | _, _ = io.WriteString(w, "plain text") 292 | }), 293 | expectErr: true, 294 | }, 295 | { 296 | name: "Empty content type", 297 | handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 298 | w.Header().Set("Content-Type", "") 299 | w.WriteHeader(http.StatusOK) 300 | }), 301 | expectErr: true, 302 | }, 303 | { 304 | name: "Invalid response status code", 305 | handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 306 | w.Header().Set("Content-Type", "text/event-stream") 307 | w.WriteHeader(http.StatusUnauthorized) 308 | }), 309 | expectErr: true, 310 | }, 311 | } 312 | 313 | for _, test := range tt { 314 | t.Run(test.name, func(t *testing.T) { 315 | ts := httptest.NewServer(test.handler) 316 | defer ts.Close() 317 | 318 | c := &sse.Client{HTTPClient: ts.Client(), Backoff: sse.Backoff{MaxRetries: -1}} 319 | err := c.NewConnection(req(t, "", ts.URL, nil)).Connect() 320 | 321 | if test.expectErr { 322 | tests.Expect(t, err != nil, "expected Connect error") 323 | } else { 324 | tests.ErrorIs(t, err, io.EOF, "should propagate EOF error") 325 | } 326 | }) 327 | } 328 | } 329 | 330 | func events(tb testing.TB, c *sse.Connection, topics ...string) (events <-chan []sse.Event, unsubscribe sse.EventCallbackRemover) { 331 | tb.Helper() 332 | 333 | ch := make(chan []sse.Event) 334 | recv := make(chan sse.Event, 1) 335 | done := make(chan struct{}) 336 | var unsub sse.EventCallbackRemover 337 | cb := func(e sse.Event) { 338 | select { 339 | case <-done: 340 | case recv <- e: 341 | } 342 | } 343 | events = ch 344 | 345 | if l := len(topics); l == 1 { 346 | if t := topics[0]; t == "" { 347 | unsub = c.SubscribeMessages(cb) // for coverage, SubscribeEvent("", recv) would be equivalent 348 | } else { 349 | unsub = c.SubscribeEvent(t, cb) 350 | } 351 | } else { 352 | if l == 0 { 353 | unsub = c.SubscribeToAll(cb) 354 | } else { 355 | unsubFns := make([]sse.EventCallbackRemover, 0, len(topics)) 356 | for _, t := range topics { 357 | unsubFns = append(unsubFns, c.SubscribeEvent(t, cb)) 358 | } 359 | 360 | unsub = func() { 361 | for _, fn := range unsubFns { 362 | fn() 363 | } 364 | } 365 | } 366 | } 367 | 368 | unsubscribe = func() { 369 | defer func() { _ = recover() }() 370 | defer close(done) 371 | unsub() 372 | } 373 | 374 | go func() { 375 | defer close(ch) 376 | 377 | var evs []sse.Event 378 | 379 | for { 380 | select { 381 | case ev := <-recv: 382 | evs = append(evs, ev) 383 | case <-done: 384 | ch <- evs 385 | return 386 | } 387 | } 388 | }() 389 | 390 | return 391 | } 392 | 393 | func TestConnection_Subscriptions(t *testing.T) { 394 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 395 | data := "retry: 1000\n\nevent: test\ndata: something\nid: 1\n\nevent: test2\ndata: something else\n\ndata: unnamed\nid: 2\n\ndata: this shouldn't be received" 396 | 397 | for _, s := range strings.SplitAfter(data, "\n\n") { 398 | _, _ = io.WriteString(w, s) 399 | w.(http.Flusher).Flush() 400 | time.Sleep(time.Millisecond) 401 | } 402 | })) 403 | defer ts.Close() 404 | 405 | c := &sse.Client{ 406 | HTTPClient: ts.Client(), 407 | ResponseValidator: sse.NoopValidator, 408 | Backoff: sse.Backoff{MaxRetries: -1}, 409 | } 410 | conn := c.NewConnection(req(t, "", ts.URL, nil)) 411 | 412 | firstEvent := sse.Event{} 413 | secondEvent := sse.Event{Type: "test", Data: "something", LastEventID: "1"} 414 | thirdEvent := sse.Event{Type: "test2", Data: "something else", LastEventID: "1"} 415 | fourthEvent := sse.Event{Data: "unnamed", LastEventID: "2"} 416 | 417 | all, unsubAll := events(t, conn) 418 | defer unsubAll() 419 | expectedAll := []sse.Event{firstEvent, secondEvent, thirdEvent, fourthEvent} 420 | 421 | test, unsubTest := events(t, conn, "test") 422 | defer unsubTest() 423 | expectedTest := []sse.Event{secondEvent} 424 | 425 | test2, unsubTest2 := events(t, conn, "test2") 426 | defer unsubTest2() 427 | expectedTest2 := []sse.Event{thirdEvent} 428 | 429 | messages, unsubMessages := events(t, conn, "") 430 | defer unsubMessages() 431 | expectedMessages := []sse.Event{firstEvent, fourthEvent} 432 | 433 | tests.ErrorIs(t, conn.Connect(), sse.ErrUnexpectedEOF, "incorrect Connect error") 434 | unsubAll() 435 | tests.DeepEqual(t, <-all, expectedAll, "unexpected events for all") 436 | unsubTest() 437 | tests.DeepEqual(t, <-test, expectedTest, "unexpected events for test") 438 | unsubTest2() 439 | tests.DeepEqual(t, <-test2, expectedTest2, "unexpected events for test2") 440 | unsubMessages() 441 | tests.DeepEqual(t, <-messages, expectedMessages, "unexpected events for messages") 442 | } 443 | 444 | func TestConnection_dispatchDirty(t *testing.T) { 445 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 446 | _, _ = io.WriteString(w, "data: hello\ndata: world\n") 447 | })) 448 | defer ts.Close() 449 | 450 | c := &sse.Client{ 451 | HTTPClient: ts.Client(), 452 | ResponseValidator: sse.NoopValidator, 453 | Backoff: sse.Backoff{ 454 | MaxRetries: -1, 455 | }, 456 | } 457 | conn := c.NewConnection(req(t, "", ts.URL, nil)) 458 | expected := sse.Event{Data: "hello\nworld"} 459 | var got sse.Event 460 | 461 | conn.SubscribeMessages(func(e sse.Event) { 462 | got = e 463 | }) 464 | 465 | tests.ErrorIs(t, conn.Connect(), io.EOF, "unexpected Connect error") 466 | tests.Equal(t, got, expected, "unexpected event received") 467 | } 468 | 469 | func TestConnection_Unsubscriptions(t *testing.T) { 470 | evs := make(chan string) 471 | 472 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 473 | flusher, ok := w.(http.Flusher) 474 | if !ok { 475 | panic(http.ErrAbortHandler) 476 | } 477 | for ev := range evs { 478 | _, _ = io.WriteString(w, ev) 479 | flusher.Flush() 480 | } 481 | })) 482 | defer ts.Close() 483 | 484 | c := &sse.Client{ 485 | HTTPClient: ts.Client(), 486 | ResponseValidator: sse.NoopValidator, 487 | Backoff: sse.Backoff{ 488 | MaxRetries: -1, 489 | }, 490 | } 491 | conn := c.NewConnection(req(t, "", ts.URL, nil)) 492 | 493 | all, unsubAll := events(t, conn) 494 | some, unsubSome := events(t, conn, "a", "b") 495 | one, unsubOne := events(t, conn, "a") 496 | messages, unsubMessages := events(t, conn, "") 497 | 498 | type action struct { 499 | unsub func() 500 | message string 501 | } 502 | 503 | actions := []action{ 504 | {message: "data: unnamed\n\n", unsub: unsubMessages}, 505 | {message: "data: for one and some\nevent: a\n\n", unsub: unsubOne}, 506 | {message: "data: for some\nevent: b\n\n", unsub: unsubSome}, 507 | {message: "data: for one and some again\nevent: a\n\n", unsub: unsubAll}, 508 | {message: "data: unnamed again\n\n"}, 509 | {message: "data: for some again\nevent: b\n\n"}, 510 | } 511 | 512 | firstEvent := toEv(t, actions[0].message) 513 | secondEvent := toEv(t, actions[1].message) 514 | thirdEvent := toEv(t, actions[2].message) 515 | fourthEvent := toEv(t, actions[3].message) 516 | 517 | expectedAll := []sse.Event{firstEvent, secondEvent, thirdEvent, fourthEvent} 518 | expectedSome := []sse.Event{secondEvent, thirdEvent} 519 | expectedOne := []sse.Event{secondEvent} 520 | expectedMessages := []sse.Event{firstEvent} 521 | 522 | go func() { 523 | defer close(evs) 524 | for _, action := range actions { 525 | evs <- action.message 526 | // we wait for the subscribers to receive the event 527 | time.Sleep(time.Millisecond * 5) 528 | if action.unsub != nil { 529 | action.unsub() 530 | } 531 | } 532 | }() 533 | 534 | tests.ErrorIs(t, conn.Connect(), io.EOF, "unexpected Connect error") 535 | tests.DeepEqual(t, <-all, expectedAll, "unexpected events for all") 536 | tests.DeepEqual(t, <-some, expectedSome, "unexpected events for some") 537 | tests.DeepEqual(t, <-one, expectedOne, "unexpected events for one") 538 | tests.DeepEqual(t, <-messages, expectedMessages, "unexpected events for messages") 539 | } 540 | 541 | func TestConnection_serverError(t *testing.T) { 542 | type action struct { 543 | message string 544 | cancel bool 545 | } 546 | evs := make(chan action) 547 | 548 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 549 | flusher, ok := w.(http.Flusher) 550 | if !ok { 551 | panic(http.ErrAbortHandler) 552 | } 553 | for ev := range evs { 554 | if ev.cancel { 555 | panic(http.ErrAbortHandler) 556 | } 557 | _, _ = io.WriteString(w, ev.message) 558 | flusher.Flush() 559 | } 560 | })) 561 | defer ts.Close() 562 | 563 | c := sse.Client{ 564 | HTTPClient: ts.Client(), 565 | ResponseValidator: sse.NoopValidator, 566 | Backoff: sse.Backoff{MaxRetries: -1}, 567 | } 568 | ctx, cancel := context.WithCancel(context.Background()) 569 | defer cancel() 570 | conn := c.NewConnection(reqCtx(t, ctx, "", ts.URL, nil)) 571 | 572 | all, unsubAll := events(t, conn) 573 | defer unsubAll() 574 | 575 | actions := []action{ 576 | {message: "data: first\n"}, 577 | {message: "data: second\n\n", cancel: true}, 578 | {message: "data: third\n\n"}, 579 | } 580 | expected := []sse.Event(nil) 581 | 582 | go func() { 583 | defer close(evs) 584 | for _, action := range actions { 585 | evs <- action 586 | if action.cancel { 587 | break 588 | } 589 | time.Sleep(time.Millisecond) 590 | } 591 | }() 592 | 593 | tests.Expect(t, conn.Connect() != nil, "expected Connect error") 594 | unsubAll() 595 | tests.DeepEqual(t, <-all, expected, "unexpected values for all") 596 | } 597 | 598 | func TestConnection_reconnect(t *testing.T) { 599 | try := 0 600 | lastEventIDs := []string(nil) 601 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 602 | lastEventIDs = append(lastEventIDs, r.Header.Get("Last-Event-Id")) 603 | try++ 604 | fmt.Fprintf(w, "id: %d\n\n", try) 605 | })) 606 | t.Cleanup(ts.Close) 607 | 608 | ctx, cancel := context.WithCancel(context.Background()) 609 | t.Cleanup(cancel) 610 | 611 | retries := 0 612 | c := sse.Client{ 613 | HTTPClient: ts.Client(), 614 | OnRetry: func(_ error, _ time.Duration) { 615 | retries++ 616 | if retries == 3 { 617 | cancel() 618 | } 619 | }, 620 | Backoff: sse.Backoff{ 621 | InitialInterval: time.Nanosecond, 622 | }, 623 | ResponseValidator: sse.NoopValidator, 624 | } 625 | 626 | r := reqCtx(t, ctx, "", ts.URL, http.NoBody) 627 | err := c.NewConnection(r).Connect() 628 | 629 | tests.Equal(t, err, ctx.Err(), "expected context error") 630 | tests.DeepEqual(t, lastEventIDs, []string{"", "1", "2"}, "incorrect last event IDs") 631 | } 632 | -------------------------------------------------------------------------------- /cmd/complex/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "log/slog" 9 | "math/rand" 10 | "net/http" 11 | "os" 12 | "os/signal" 13 | "strconv" 14 | "syscall" 15 | "time" 16 | 17 | "github.com/tmaxmax/go-sse" 18 | ) 19 | 20 | const ( 21 | topicRandomNumbers = "numbers" 22 | topicMetrics = "metrics" 23 | ) 24 | 25 | func newSSE() *sse.Server { 26 | rp, _ := sse.NewValidReplayer(time.Minute*5, true) 27 | rp.GCInterval = time.Minute 28 | 29 | return &sse.Server{ 30 | Provider: &sse.Joe{Replayer: rp}, 31 | // If you are using a 3rd party library to generate a per-request logger, this 32 | // can just be a simple wrapper over it. 33 | Logger: func(r *http.Request) *slog.Logger { 34 | return getLogger(r.Context()) 35 | }, 36 | OnSession: func(w http.ResponseWriter, r *http.Request) (topics []string, permitted bool) { 37 | topics = r.URL.Query()["topic"] 38 | for _, topic := range topics { 39 | if topic != topicRandomNumbers && topic != topicMetrics { 40 | fmt.Fprintf(w, "invalid topic %q; supported are %q, %q", topic, topicRandomNumbers, topicMetrics) 41 | 42 | // NOTE: if you are returning false to reject the subscription, we strongly recommend writing 43 | // your own response code. Clients will receive a 200 code otherwise, which may be confusing. 44 | w.WriteHeader(http.StatusBadRequest) 45 | return nil, false 46 | } 47 | } 48 | if len(topics) == 0 { 49 | // Provide default topics, if none are given. 50 | topics = []string{topicRandomNumbers, topicMetrics} 51 | } 52 | 53 | // the shutdown message is sent on the default topic 54 | return append(topics, sse.DefaultTopic), true 55 | }, 56 | } 57 | } 58 | 59 | func cors(h http.Handler) http.Handler { 60 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 | w.Header().Set("Access-Control-Allow-Origin", "*") 62 | h.ServeHTTP(w, r) 63 | }) 64 | } 65 | 66 | func main() { 67 | ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) 68 | defer cancel() 69 | 70 | handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}) 71 | logger := slog.New(handler) 72 | logMiddleware := withLogger(logger) 73 | 74 | sseHandler := newSSE() 75 | 76 | mux := http.NewServeMux() 77 | mux.HandleFunc("/stop", func(w http.ResponseWriter, _ *http.Request) { 78 | cancel() 79 | w.WriteHeader(http.StatusOK) 80 | }) 81 | mux.Handle("/", SnapshotHTTPEndpoint) 82 | mux.Handle("/events", sseHandler) 83 | 84 | httpLogger := slog.NewLogLogger(handler, slog.LevelWarn) 85 | s := &http.Server{ 86 | Addr: "0.0.0.0:8080", 87 | Handler: cors(logMiddleware(mux)), 88 | ReadHeaderTimeout: time.Second * 10, 89 | ErrorLog: httpLogger, 90 | } 91 | s.RegisterOnShutdown(func() { 92 | e := &sse.Message{Type: sse.Type("close")} 93 | // Adding data is necessary because spec-compliant clients 94 | // do not dispatch events without data. 95 | e.AppendData("bye") 96 | // Broadcast a close message so clients can gracefully disconnect. 97 | _ = sseHandler.Publish(e) 98 | 99 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 100 | defer cancel() 101 | 102 | // We use a context with a timeout so the program doesn't wait indefinitely 103 | // for connections to terminate. There may be misbehaving connections 104 | // which may hang for an unknown timespan, so we just stop waiting on Shutdown 105 | // after a certain duration. 106 | _ = sseHandler.Shutdown(ctx) 107 | }) 108 | 109 | go recordMetric(ctx, sseHandler, "ops", time.Second*2) 110 | go recordMetric(ctx, sseHandler, "cycles", time.Millisecond*500) 111 | 112 | go func() { 113 | duration := func() time.Duration { 114 | return time.Duration(2000+rand.Intn(1000)) * time.Millisecond 115 | } 116 | 117 | timer := time.NewTimer(duration()) 118 | defer timer.Stop() 119 | 120 | for { 121 | select { 122 | case <-timer.C: 123 | _ = sseHandler.Publish(generateRandomNumbers(), topicRandomNumbers) 124 | case <-ctx.Done(): 125 | return 126 | } 127 | 128 | timer.Reset(duration()) 129 | } 130 | }() 131 | 132 | if err := runServer(ctx, s); err != nil { 133 | log.Println("server closed", err) 134 | } 135 | } 136 | 137 | func recordMetric(ctx context.Context, sseHandler *sse.Server, metric string, frequency time.Duration) { 138 | ticker := time.NewTicker(frequency) 139 | defer ticker.Stop() 140 | 141 | for { 142 | select { 143 | case <-ticker.C: 144 | v := Inc(metric) 145 | 146 | e := &sse.Message{ 147 | Type: sse.Type(metric), 148 | } 149 | e.AppendData(strconv.FormatInt(v, 10)) 150 | 151 | _ = sseHandler.Publish(e, topicMetrics) 152 | case <-ctx.Done(): 153 | return 154 | } 155 | } 156 | } 157 | 158 | func runServer(ctx context.Context, s *http.Server) error { 159 | shutdownError := make(chan error) 160 | 161 | go func() { 162 | <-ctx.Done() 163 | 164 | sctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 165 | defer cancel() 166 | 167 | shutdownError <- s.Shutdown(sctx) 168 | }() 169 | 170 | if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { 171 | return err 172 | } 173 | 174 | return <-shutdownError 175 | } 176 | 177 | func generateRandomNumbers() *sse.Message { 178 | e := &sse.Message{} 179 | count := 1 + rand.Intn(5) 180 | 181 | for i := 0; i < count; i++ { 182 | e.AppendData(strconv.FormatUint(rand.Uint64(), 10)) 183 | } 184 | 185 | return e 186 | } 187 | 188 | type loggerCtxKey struct{} 189 | 190 | // withLogger is a net/http compatable middleware that generates a logger with request-specific fields 191 | // added to it and attaches it to the request context for later retrieval with getLogger(). 192 | // Third party logging packages may offer similar middlewares to add a logger to the request or maybe 193 | // just a helper to add a logger to context; in the second case you can build your own middleware 194 | // function around it, similar to this one. 195 | func withLogger(logger *slog.Logger) func(h http.Handler) http.Handler { 196 | return func(h http.Handler) http.Handler { 197 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 198 | l := logger.With( 199 | 200 | "UserAgent", r.UserAgent(), 201 | "RemoteAddr", r.RemoteAddr, 202 | "Host", r.Host, 203 | "Origin", r.Header.Get("origin"), 204 | ) 205 | r = r.WithContext(context.WithValue(r.Context(), loggerCtxKey{}, l)) 206 | h.ServeHTTP(w, r) 207 | }) 208 | } 209 | } 210 | 211 | // getLogger retrieves the request-specific logger from a request's context. This is 212 | // similar to how existing per-request http logging libraries work, just very simplified. 213 | func getLogger(ctx context.Context) *slog.Logger { 214 | logger, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger) 215 | if !ok { 216 | // We are accepting an arbitrary context object, so it's better to explicitly return 217 | // nil here since the exact behavior of getting the value of an undefined key is undefined 218 | return nil 219 | } 220 | return logger 221 | } 222 | -------------------------------------------------------------------------------- /cmd/complex/metrics.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | var metrics sync.Map 10 | 11 | // Add adds the given value to a metric. It creates the metric if it doesn't exist. 12 | func Add(metric string, value int64) int64 { 13 | prev, ok := metrics.LoadOrStore(metric, value) 14 | curr := prev.(int64) 15 | if ok { 16 | curr += value 17 | metrics.Store(metric, curr) 18 | } 19 | 20 | return curr 21 | } 22 | 23 | // Inc increments the given metric. It creates the metric if it doesn't exist. 24 | func Inc(metric string) int64 { 25 | return Add(metric, 1) 26 | } 27 | 28 | // Range loops through all metrics and calls the given function for each metric. 29 | func Range(fn func(key string, value int64) bool) { 30 | metrics.Range(func(key, value interface{}) bool { 31 | return fn(key.(string), value.(int64)) 32 | }) 33 | } 34 | 35 | // Snapshot returns a map containing all the metrics at the time of snapshotting. 36 | func Snapshot() map[string]int64 { 37 | snapshot := make(map[string]int64) 38 | 39 | Range(func(key string, value int64) bool { 40 | snapshot[key] = value 41 | 42 | return true 43 | }) 44 | 45 | return snapshot 46 | } 47 | 48 | // SnapshotHTTPEndpoint is an HTTP handler that sends a JSON representation of 49 | // the metrics snapshot to the request initiator. 50 | var SnapshotHTTPEndpoint http.HandlerFunc = func(w http.ResponseWriter, _ *http.Request) { 51 | payload, err := json.MarshalIndent(Snapshot(), "", " ") 52 | if err != nil { 53 | w.WriteHeader(http.StatusInternalServerError) 54 | } else { 55 | w.Header().Set("Content-Type", "application/json") 56 | _, _ = w.Write(payload) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /cmd/complex_client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "math/big" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "os/signal" 12 | "strings" 13 | "syscall" 14 | 15 | "github.com/tmaxmax/go-sse" 16 | ) 17 | 18 | func main() { 19 | var sub string 20 | flag.StringVar(&sub, "sub", "all", "The topics to subscribe to. Valid values are: all, numbers, metrics") 21 | flag.Parse() 22 | 23 | ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) 24 | defer cancel() 25 | 26 | r, _ := http.NewRequestWithContext(ctx, http.MethodGet, getRequestURL(sub), http.NoBody) 27 | conn := sse.NewConnection(r) 28 | 29 | conn.SubscribeToAll(func(event sse.Event) { 30 | switch event.Type { 31 | case "cycles", "ops": 32 | fmt.Printf("Metric %s: %s\n", event.Type, event.Data) 33 | case "close": 34 | fmt.Println("Server closed!") 35 | cancel() 36 | default: // no event name 37 | var sum, num big.Int 38 | for _, n := range strings.Split(event.Data, "\n") { 39 | _, _ = num.SetString(n, 10) 40 | sum.Add(&sum, &num) 41 | } 42 | 43 | fmt.Printf("Sum of random numbers: %s\n", &sum) 44 | } 45 | }) 46 | 47 | if err := conn.Connect(); err != nil { 48 | fmt.Fprintln(os.Stderr, err) 49 | } 50 | } 51 | 52 | func getRequestURL(sub string) string { 53 | q := url.Values{} 54 | switch sub { 55 | case "all": 56 | q.Add("topic", "numbers") 57 | q.Add("topic", "metrics") 58 | case "numbers", "metrics": 59 | q.Set("topic", sub) 60 | default: 61 | panic(fmt.Errorf("unexpected subscription topic %q", sub)) 62 | } 63 | 64 | return "http://localhost:8080/events?" + q.Encode() 65 | } 66 | -------------------------------------------------------------------------------- /cmd/helloworld/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/tmaxmax/go-sse" 9 | ) 10 | 11 | func main() { 12 | s := &sse.Server{} 13 | 14 | go func() { 15 | ev := &sse.Message{} 16 | ev.AppendData("Hello world") 17 | 18 | for range time.Tick(time.Second) { 19 | _ = s.Publish(ev) 20 | } 21 | }() 22 | 23 | //nolint:gosec // Use http.Server in your code instead, to be able to set timeouts. 24 | if err := http.ListenAndServe(":8000", s); err != nil { 25 | log.Fatalln(err) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /cmd/helloworld_client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | 8 | "github.com/tmaxmax/go-sse" 9 | ) 10 | 11 | func main() { 12 | r, _ := http.NewRequest(http.MethodGet, "http://localhost:8000", http.NoBody) 13 | conn := sse.NewConnection(r) 14 | 15 | conn.SubscribeMessages(func(event sse.Event) { 16 | fmt.Printf("%s\n\n", event.Data) 17 | }) 18 | 19 | if err := conn.Connect(); err != nil { 20 | fmt.Fprintln(os.Stderr, err) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /cmd/llm/main.go: -------------------------------------------------------------------------------- 1 | //go:build go1.23 2 | 3 | package main 4 | 5 | import ( 6 | "bufio" 7 | "context" 8 | "encoding/json" 9 | "fmt" 10 | "net/http" 11 | "os" 12 | "os/signal" 13 | "strconv" 14 | "strings" 15 | 16 | "github.com/tmaxmax/go-sse" 17 | ) 18 | 19 | func main() { 20 | ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) 21 | defer cancel() 22 | 23 | in := bufio.NewScanner(os.Stdin) 24 | if !in.Scan() { 25 | fmt.Fprintf(os.Stderr, "message read error: %v\n", in.Err()) 26 | return 27 | } 28 | 29 | // I've picked ChatGPT for this example just to have a working program. 30 | // I do not endorse nor have any affiliation with any LLM out there. 31 | payload := strings.NewReader(fmt.Sprintf(`{ 32 | "model": "gpt-4o", 33 | "messages": [ 34 | { 35 | "role": "developer", 36 | "content": "You are a helpful assistant." 37 | }, 38 | { 39 | "role": "user", 40 | "content": "%s" 41 | } 42 | ], 43 | "stream": true 44 | }`, strconv.Quote(in.Text()))) 45 | 46 | req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/chat/completions", payload) 47 | req.Header.Set("Content-Type", "application/json") 48 | req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY")) 49 | 50 | res, err := http.DefaultClient.Do(req) 51 | if err != nil { 52 | fmt.Fprintf(os.Stderr, "request error: %v\n", err) 53 | return 54 | } 55 | defer res.Body.Close() 56 | 57 | if res.StatusCode != 200 { 58 | fmt.Fprintf(os.Stderr, "response errored with code %s\n", res.Status) 59 | return 60 | } 61 | 62 | for ev, err := range sse.Read(res.Body, nil) { 63 | if err != nil { 64 | fmt.Fprintf(os.Stderr, "while reading response body: %v\n", err) 65 | // Can return – Read stops after first error and no subsequent events are parsed. 66 | return 67 | } 68 | 69 | var data struct { 70 | Choices []struct { 71 | Delta struct { 72 | Content string `json:"content"` 73 | } `json:"delta"` 74 | } `json:"choices"` 75 | } 76 | if err := json.Unmarshal([]byte(ev.Data), &data); err != nil { 77 | fmt.Fprintf(os.Stderr, "while unmarshalling response: %v\n", err) 78 | return 79 | } 80 | 81 | fmt.Printf("%s ", data.Choices[0].Delta.Content) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - cmd 3 | - internal/tests -------------------------------------------------------------------------------- /event.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "io" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/tmaxmax/go-sse/internal/parser" 9 | ) 10 | 11 | // The Event struct represents an event sent to the client by a server. 12 | type Event struct { 13 | // The last non-empty ID of all the events received. This may not be 14 | // the ID of the latest event! 15 | LastEventID string 16 | // The event's type. It is empty if the event is unnamed. 17 | Type string 18 | // The event's payload. 19 | Data string 20 | } 21 | 22 | // ReadConfig is used to configure how Read behaves. 23 | type ReadConfig struct { 24 | // MaxEventSize is the maximum expected length of the byte sequence 25 | // representing a single event. Parsing events longer than that 26 | // will result in an error. 27 | // 28 | // By default this limit is 64KB. You don't need to set this if it 29 | // is enough for your needs (e.g. the events you receive don't contain 30 | // larger amounts of data). 31 | MaxEventSize int 32 | } 33 | 34 | // Read parses an SSE stream and yields all incoming events, 35 | // On any encountered errors iteration stops and no further events are parsed – 36 | // the loop can safely be ended on error. If EOF is reached, the Read operation 37 | // is considered successful and no error is returned. An Event will never 38 | // be yielded together with an error. 39 | // 40 | // Read is especially useful for parsing responses from services which 41 | // communicate using SSE but not over long-lived connections – for example, 42 | // LLM APIs. 43 | // 44 | // Read handles the Event.LastEventID value just as the browser SSE client 45 | // (EventSource) would – for every event, the last encountered event ID will be given, 46 | // even if the ID is not the current event's ID. Read, unlike EventSource, does 47 | // not set Event.Type to "message" if no "event" field is received, leaving 48 | // it blank. 49 | // 50 | // Read provides no way to handle the "retry" field and doesn't handle retrying. 51 | // Use a Client and a Connection if you need to retry requests. 52 | func Read(r io.Reader, cfg *ReadConfig) func(func(Event, error) bool) { 53 | pf := func() *parser.Parser { 54 | p := parser.New(r) 55 | if cfg != nil && cfg.MaxEventSize > 0 { 56 | // NOTE(tmaxmax): we don't allow setting the buffer at the moment. 57 | // ReadConfig objects might be shared between Read calls executed in 58 | // different goroutines and having an actual []byte in it seems dangerous. 59 | // If there is demand it can be added. 60 | p.Buffer(nil, cfg.MaxEventSize) 61 | } 62 | return p 63 | } 64 | 65 | // We take a factory function for the parser so that Read can be inlined by the compiler. 66 | return read(pf, "", nil, true) 67 | } 68 | 69 | func read(pf func() *parser.Parser, lastEventID string, onRetry func(int64), ignoreEOF bool) func(func(Event, error) bool) { 70 | return func(yield func(Event, error) bool) { 71 | p := pf() 72 | 73 | typ, sb, dirty := "", strings.Builder{}, false 74 | doYield := func(data string) bool { 75 | if data != "" { 76 | data = data[:len(data)-1] 77 | } 78 | return yield(Event{LastEventID: lastEventID, Data: data, Type: typ}, nil) 79 | } 80 | 81 | for f := (parser.Field{}); p.Next(&f); { 82 | switch f.Name { //nolint:exhaustive // Comment fields are not parsed. 83 | case parser.FieldNameData: 84 | sb.WriteString(f.Value) 85 | sb.WriteByte('\n') 86 | dirty = true 87 | case parser.FieldNameEvent: 88 | typ = f.Value 89 | dirty = true 90 | case parser.FieldNameID: 91 | // empty IDs are valid, only IDs that contain the null byte must be ignored: 92 | // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation 93 | if strings.IndexByte(f.Value, 0) != -1 { 94 | break 95 | } 96 | 97 | lastEventID = f.Value 98 | dirty = true 99 | case parser.FieldNameRetry: 100 | n, err := strconv.ParseInt(f.Value, 10, 64) 101 | if err != nil { 102 | break 103 | } 104 | if n >= 0 && onRetry != nil { 105 | onRetry(n) 106 | dirty = true 107 | } 108 | default: 109 | if dirty { 110 | if !doYield(sb.String()) { 111 | return 112 | } 113 | sb.Reset() 114 | typ = "" 115 | dirty = false 116 | } 117 | } 118 | } 119 | 120 | err := p.Err() 121 | isEOF := err == io.EOF //nolint:errorlint // Our scanner returns io.EOF unwrapped 122 | 123 | if dirty && isEOF { 124 | if !doYield(sb.String()) { 125 | return 126 | } 127 | } 128 | 129 | if err != nil && !(ignoreEOF && isEOF) { 130 | yield(Event{}, err) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /event_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/tmaxmax/go-sse" 9 | "github.com/tmaxmax/go-sse/internal/tests" 10 | ) 11 | 12 | func TestRead(t *testing.T) { 13 | // NOTE(tmaxmax): Basic test for now, as the functionality 14 | // is tested through the client tests. Most tests pertaining 15 | // to client-side event parsing will be moved here when 16 | // the client will be refactored. 17 | response := strings.NewReader("id:\000\nretry:x\ndata: Hello World!\n\n") // also test null ID and invalid retry edge cases 18 | 19 | var recv []sse.Event 20 | 21 | events := sse.Read(response, nil) 22 | events(func(e sse.Event, err error) bool { 23 | tests.Equal(t, err, nil, "unexpected error") 24 | recv = append(recv, e) 25 | return true 26 | }) 27 | 28 | tests.DeepEqual(t, recv, []sse.Event{{Data: "Hello World!"}}, "incorrect result") 29 | 30 | t.Run("Buffer", func(t *testing.T) { 31 | _, _ = response.Seek(0, io.SeekStart) 32 | 33 | events := sse.Read(response, &sse.ReadConfig{MaxEventSize: 3}) 34 | var err error 35 | events(func(_ sse.Event, e error) bool { err = e; return err == nil }) 36 | tests.Expect(t, err != nil, "should fail because of too small buffer") 37 | }) 38 | 39 | t.Run("Break", func(t *testing.T) { 40 | events := sse.Read(strings.NewReader("id: a\n\nid: b\n\nid: c\n"), nil) // also test EOF edge case 41 | 42 | var recv []sse.Event 43 | events(func(e sse.Event, err error) bool { 44 | tests.Equal(t, err, nil, "unexpected error") 45 | recv = append(recv, e) 46 | return len(recv) < 2 47 | }) 48 | 49 | expected := []sse.Event{{LastEventID: "a"}, {LastEventID: "b"}} 50 | tests.DeepEqual(t, recv, expected, "iterator didn't stop") 51 | 52 | // Cover break check on EOF edge case 53 | // NOTE(tmaxmax): Should also test this with EOF return when possible. 54 | sse.Read(strings.NewReader("data: x\n"), nil)(func(e sse.Event, err error) bool { 55 | tests.Equal(t, err, nil, "unexpected error") 56 | tests.Equal(t, e, sse.Event{Data: "x"}, "unexpected event") 57 | return false 58 | }) 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tmaxmax/go-sse 2 | 3 | go 1.22 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmaxmax/go-sse/e3ddbdfdcf69aacfe83cbeafcc20e169f55fccb2/go.sum -------------------------------------------------------------------------------- /internal/parser/chunk.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | // isNewlineChar returns whether the given character is '\n' or '\r'. 4 | func isNewlineChar(b byte) bool { 5 | return b == '\n' || b == '\r' 6 | } 7 | 8 | // NewlineIndex returns the index of the first occurrence of a newline sequence (\n, \r, or \r\n). 9 | // It also returns the sequence's length. If no sequence is found, index is equal to len(s) 10 | // and length is 0. 11 | // 12 | // The newline is defined in the Event Stream standard's documentation: 13 | // https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events 14 | func NewlineIndex(s string) (index, length int) { 15 | for l := len(s); index < l; index++ { 16 | b := s[index] 17 | 18 | if isNewlineChar(b) { 19 | length++ 20 | if b == '\r' && index < l-1 && s[index+1] == '\n' { 21 | length++ 22 | } 23 | 24 | break 25 | } 26 | } 27 | 28 | return 29 | } 30 | 31 | // NextChunk retrieves the next chunk of data from the given string 32 | // along with the data remaining after the returned chunk. 33 | // A chunk is a string of data delimited by a newline. 34 | // If the returned chunk is the last one, len(remaining) will be 0. 35 | // 36 | // The newline is defined in the Event Stream standard's documentation: 37 | // https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events 38 | func NextChunk(s string) (chunk, remaining string, hasNewline bool) { 39 | index, endlineLen := NewlineIndex(s) 40 | return s[:index], s[index+endlineLen:], endlineLen != 0 41 | } 42 | -------------------------------------------------------------------------------- /internal/parser/chunk_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "unsafe" 7 | ) 8 | 9 | func TestNextChunk(t *testing.T) { 10 | t.Parallel() 11 | 12 | s := "sarmale" 13 | chunk, remaining, hasNewline := NextChunk(s) 14 | 15 | if remaining != "" { 16 | t.Fatalf("No more data should be remaining") 17 | } 18 | 19 | if unsafe.StringData(s) != unsafe.StringData(chunk) { 20 | t.Fatalf("First chunk should always have the same address as the given buffer") 21 | } 22 | 23 | if s != chunk { 24 | t.Fatalf("Expected chunk %q, got %q", s, chunk) 25 | } 26 | 27 | if hasNewline { 28 | t.Fatalf("Ends in newline flag incorrect: expected %t, got %t", false, hasNewline) 29 | } 30 | 31 | s = "sarmale cu\nghimbir\r\nsunt\rsuper\n\ngenial sincer\r\n" 32 | 33 | expected := []string{ 34 | "sarmale cu", 35 | "ghimbir", 36 | "sunt", 37 | "super", 38 | "", 39 | "genial sincer", 40 | } 41 | 42 | var got []string 43 | 44 | for s != "" { 45 | var chunk string 46 | chunk, s, _ = NextChunk(s) 47 | 48 | got = append(got, chunk) 49 | } 50 | 51 | if !reflect.DeepEqual(got, expected) { 52 | t.Fatalf("Bad result:\n\texpected %#v\n\treceived %#v", expected, got) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /internal/parser/field.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | // FieldName is the name of the field. 4 | type FieldName string 5 | 6 | // A Field represents an unprocessed field of a single event. The Name is the field's identifier, which is used to 7 | // process the fields afterwards. 8 | // 9 | // As a special case, if a parser (FieldParser or Parser) returns a field without a name, 10 | // it means that a whole event was parsed. In other words, all the fields before the one without a name 11 | // and after another such field are part of the same event. 12 | type Field struct { 13 | Name FieldName 14 | Value string 15 | } 16 | 17 | // Valid field names. 18 | const ( 19 | FieldNameData = FieldName("data") 20 | FieldNameEvent = FieldName("event") 21 | FieldNameRetry = FieldName("retry") 22 | FieldNameID = FieldName("id") 23 | // FieldNameComment is a sentinel value that indicates 24 | // comment fields. It is not a valid field name that should 25 | // be written to an SSE stream. 26 | FieldNameComment = FieldName(":") 27 | 28 | maxFieldNameLength = 5 29 | ) 30 | 31 | func getFieldName(b string) (FieldName, bool) { 32 | switch FieldName(b) { //nolint:exhaustive // Cannot have Comment here 33 | case FieldNameData: 34 | return FieldNameData, true 35 | case FieldNameEvent: 36 | return FieldNameEvent, true 37 | case FieldNameRetry: 38 | return FieldNameRetry, true 39 | case FieldNameID: 40 | return FieldNameID, true 41 | default: 42 | return "", false 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/parser/field_parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | // FieldParser extracts fields from a byte slice. 9 | type FieldParser struct { 10 | err error 11 | data string 12 | 13 | started bool 14 | 15 | keepComments bool 16 | removeBOM bool 17 | } 18 | 19 | func trimFirstSpace(c string) string { 20 | if c != "" && c[0] == ' ' { 21 | return c[1:] 22 | } 23 | return c 24 | } 25 | 26 | func (f *FieldParser) scanSegment(chunk string, out *Field) bool { 27 | colonPos, l := strings.IndexByte(chunk, ':'), len(chunk) 28 | if colonPos > maxFieldNameLength { 29 | return false 30 | } 31 | if colonPos == -1 { 32 | colonPos = l 33 | } 34 | 35 | name, ok := getFieldName(chunk[:colonPos]) 36 | if ok { 37 | out.Name = name 38 | out.Value = trimFirstSpace(chunk[min(colonPos+1, l):]) 39 | return true 40 | } else if chunk == "" { 41 | // scanSegment is called only with chunks which end with a newline in the input. 42 | // If chunk is empty, it means that this is a blank line which ends the event, 43 | // so an empty Field needs to be returned. 44 | out.Name = "" 45 | out.Value = "" 46 | return true 47 | } else if colonPos == 0 && f.keepComments { 48 | out.Name = FieldNameComment 49 | out.Value = trimFirstSpace(chunk[min(1, l):]) 50 | return true 51 | } 52 | 53 | return false 54 | } 55 | 56 | // ErrUnexpectedEOF is returned when the input is completely parsed but no complete field was found at the end. 57 | var ErrUnexpectedEOF = errors.New("go-sse: unexpected end of input") 58 | 59 | // Next parses the next available field in the remaining buffer. 60 | // It returns false if there are no more fields to parse. 61 | func (f *FieldParser) Next(r *Field) bool { 62 | for f.data != "" { 63 | f.started = true 64 | 65 | chunk, rem, hasNewline := NextChunk(f.data) 66 | if !hasNewline { 67 | f.err = ErrUnexpectedEOF 68 | return false 69 | } 70 | 71 | f.data = rem 72 | 73 | if !f.scanSegment(chunk, r) { 74 | continue 75 | } 76 | 77 | return true 78 | } 79 | 80 | return false 81 | } 82 | 83 | // Reset changes the buffer from which fields are parsed. 84 | func (f *FieldParser) Reset(data string) { 85 | f.data = data 86 | f.err = nil 87 | f.started = false 88 | f.doRemoveBOM() 89 | } 90 | 91 | // Err returns the last error encountered by the parser. It is either nil or ErrUnexpectedEOF. 92 | func (f *FieldParser) Err() error { 93 | return f.err 94 | } 95 | 96 | // Started tells whether parsing has started (a call to Next which consumed input was made 97 | // or the BOM was removed, if it existed). Started will be true if the FieldParser has advanced 98 | // through the data. 99 | func (f *FieldParser) Started() bool { 100 | return f.started 101 | } 102 | 103 | // KeepComments configures the FieldParser to parse/ignore comment fields. 104 | // By default comment fields are ignored. 105 | func (f *FieldParser) KeepComments(shouldKeep bool) { 106 | f.keepComments = shouldKeep 107 | } 108 | 109 | // RemoveBOM configures the FieldParser to try and remove the Unicode BOM 110 | // when parsing the first field, if it exists. 111 | // If, at the time this option is set, the input is untouched (no fields were parsed), 112 | // it will also be attempted to remove the BOM. 113 | func (f *FieldParser) RemoveBOM(shouldRemove bool) { 114 | f.removeBOM = shouldRemove 115 | f.doRemoveBOM() 116 | } 117 | 118 | func (f *FieldParser) doRemoveBOM() { 119 | const bom = "\xEF\xBB\xBF" 120 | if f.removeBOM && !f.started && strings.HasPrefix(f.data, bom) { 121 | f.data = f.data[len(bom):] 122 | f.started = true 123 | } 124 | } 125 | 126 | // NewFieldParser creates a parser that extracts fields from the given string. 127 | func NewFieldParser(data string) *FieldParser { 128 | return &FieldParser{data: data} 129 | } 130 | -------------------------------------------------------------------------------- /internal/parser/field_parser_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/tmaxmax/go-sse/internal/parser" 8 | ) 9 | 10 | func TestFieldParser(t *testing.T) { 11 | t.Parallel() 12 | 13 | t.Run("Empty input", func(t *testing.T) { 14 | p := parser.NewFieldParser("") 15 | if p.Next(nil) { 16 | t.Fatalf("empty input should not yield data") 17 | } 18 | if p.Started() { 19 | t.Fatalf("parsing empty input should have no effects") 20 | } 21 | p.RemoveBOM(true) 22 | if p.Started() { 23 | t.Fatalf("BOM shouldn't be removed on empty input") 24 | } 25 | }) 26 | 27 | type testCase struct { 28 | err error 29 | name string 30 | data string 31 | expected []parser.Field 32 | keepComments bool 33 | } 34 | 35 | tests := []testCase{ 36 | { 37 | name: "Normal data", 38 | data: "event: sarmale\ndata:doresc sarmale\n: comentariu\ndata: multe sarmale \r\n\n", 39 | expected: []parser.Field{ 40 | newEventField(t, "sarmale"), 41 | newDataField(t, "doresc sarmale"), 42 | newDataField(t, " multe sarmale "), 43 | {}, 44 | }, 45 | }, 46 | { 47 | name: "Normal data but no newline at the end", 48 | data: ":comment\r: another comment\ndata: whatever", 49 | err: parser.ErrUnexpectedEOF, 50 | }, 51 | { 52 | name: "Fields without data", 53 | data: "data\ndata \ndata:\n\n", 54 | expected: []parser.Field{ 55 | newDataField(t, ""), 56 | // The second `data ` should be ignored, as it is not a valid field name 57 | // (it would be valid without trailing spaces). 58 | newDataField(t, ""), 59 | {}, 60 | }, 61 | }, 62 | { 63 | name: "Invalid fields", 64 | data: "i'm an invalid field:\nlmao me too\nretry: 120\nid: 5\r\n\r\n", 65 | expected: []parser.Field{ 66 | newRetryField(t, "120"), 67 | newIDField(t, "5"), 68 | {}, 69 | }, 70 | }, 71 | { 72 | name: "Normal data, only one newline at the end", 73 | data: "data: first chunk\ndata: second chunk\r\n", 74 | expected: []parser.Field{ 75 | newDataField(t, "first chunk"), 76 | newDataField(t, "second chunk"), 77 | }, 78 | }, 79 | { 80 | name: "Normal data with comments", 81 | data: "data: hello\ndata: world\r: comm\r\n:other comm\nevent: test\n", 82 | keepComments: true, 83 | expected: []parser.Field{ 84 | newDataField(t, "hello"), 85 | newDataField(t, "world"), 86 | newCommentField(t, "comm"), 87 | newCommentField(t, "other comm"), 88 | newEventField(t, "test"), 89 | }, 90 | }, 91 | } 92 | 93 | for _, test := range tests { 94 | test := test 95 | 96 | t.Run(test.name, func(t *testing.T) { 97 | t.Parallel() 98 | 99 | p := parser.NewFieldParser(test.data) 100 | p.KeepComments(test.keepComments) 101 | 102 | var segments []parser.Field 103 | 104 | for f := (parser.Field{}); p.Next(&f); { 105 | segments = append(segments, f) 106 | } 107 | 108 | if !p.Started() { 109 | t.Fatalf("parsing should be marked as having started") 110 | } 111 | if p.Err() != test.err { //nolint 112 | t.Fatalf("invalid error: received %v, expected %v", p.Err(), test.err) 113 | } 114 | if !reflect.DeepEqual(test.expected, segments) { 115 | t.Fatalf("invalid segments for test %q:\nreceived %#v\nexpected %#v", test.name, segments, test.expected) 116 | } 117 | }) 118 | } 119 | 120 | t.Run("BOM", func(t *testing.T) { 121 | p := parser.NewFieldParser("\xEF\xBB\xBFid: 5\n") 122 | p.RemoveBOM(true) 123 | 124 | var f parser.Field 125 | if !p.Next(&f) { 126 | t.Fatalf("a field should be available (err=%v)", p.Err()) 127 | } 128 | 129 | expectedF := parser.Field{Name: parser.FieldNameID, Value: "5"} 130 | if f != expectedF { 131 | t.Fatalf("invalid field: received %v, expected %v", f, expectedF) 132 | } 133 | 134 | p.Reset("\xEF\xBB\xBF") 135 | if p.Next(&f) { 136 | t.Fatalf("no fields should be available") 137 | } 138 | if p.Err() != nil { 139 | t.Fatalf("no error is expected after BOM removal") 140 | } 141 | if !p.Started() { 142 | t.Fatalf("BOM removal should mark parsing as having started") 143 | } 144 | 145 | p.Reset("data: no BOM\n") 146 | if p.Started() { 147 | t.Fatalf("data has no BOM so no advancement should be made") 148 | } 149 | }) 150 | } 151 | 152 | func BenchmarkFieldParser(b *testing.B) { 153 | b.ReportAllocs() 154 | 155 | var f parser.Field 156 | 157 | for n := 0; n < b.N; n++ { 158 | p := parser.NewFieldParser(benchmarkText) 159 | for p.Next(&f) { 160 | } 161 | } 162 | 163 | _ = f 164 | } 165 | -------------------------------------------------------------------------------- /internal/parser/parser.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "unsafe" 7 | ) 8 | 9 | // splitFunc is a split function for a bufio.Scanner that splits a sequence of 10 | // bytes into SSE events. Each event ends with two consecutive newline sequences, 11 | // where a newline sequence is defined as either "\n", "\r", or "\r\n". 12 | func splitFunc(data []byte, atEOF bool) (advance int, token []byte, err error) { 13 | if len(data) == 0 { 14 | return 0, nil, nil 15 | } 16 | 17 | var start int 18 | for { 19 | index, endlineLen := NewlineIndex(unsafe.String(unsafe.SliceData(data), len(data))[advance:]) 20 | advance += index + endlineLen 21 | if index == 0 { 22 | // If it was a blank line, skip it. 23 | start += endlineLen 24 | } 25 | // We've reached the end of data or a second newline follows and the line isn't blank. 26 | // The latter means we have an event. 27 | if advance == len(data) || (isNewlineChar(data[advance]) && index > 0) { 28 | break 29 | } 30 | } 31 | 32 | if l := len(data); advance == l && !atEOF { 33 | // We have reached the end of the buffer but have not yet seen two consecutive 34 | // newline sequences, so we request more data. 35 | return 0, nil, nil 36 | } else if advance < l { 37 | // We have found a newline. Consume the end-of-line sequence. 38 | advance++ 39 | // Consume one more character if end-of-line is "\r\n". 40 | if advance < l && data[advance-1] == '\r' && data[advance] == '\n' { 41 | advance++ 42 | } 43 | } 44 | 45 | token = data[start:advance] 46 | 47 | return advance, token, nil 48 | } 49 | 50 | // Parser extracts fields from a reader. Reading is buffered using a bufio.Scanner. 51 | // The Parser also removes the UTF-8 BOM if it exists. 52 | type Parser struct { 53 | inputScanner *bufio.Scanner 54 | fieldScanner *FieldParser 55 | } 56 | 57 | // Next parses a single field from the reader. It returns false when there are no more fields to parse. 58 | func (r *Parser) Next(f *Field) bool { 59 | if !r.fieldScanner.Next(f) { 60 | if !r.inputScanner.Scan() { 61 | // Do this to signal EOF, which bufio.Scanner suppresses. 62 | if r.inputScanner.Err() == nil { 63 | r.inputScanner = nil 64 | } 65 | return false 66 | } 67 | 68 | if r.fieldScanner.Started() { 69 | // If scanning was started, then an event was already processed at this point and the BOM was 70 | // already removed, if it existed. We don't need to remove it anymore, so disable the option. 71 | r.fieldScanner.RemoveBOM(false) 72 | } 73 | 74 | // The allocation made inside `Text` is not an issue and should even improve performance. 75 | // If the Field returned from `Next` wouldn't own its resources, then the caller would have 76 | // to allocate new memory and copy each field value. This way, not only the caller doesn't 77 | // have to worry about allocations and ownership, but also bigger and less frequent allocations 78 | // are made, compared to the previous usage – allocations are now made per event, not per field value. 79 | r.fieldScanner.Reset(r.inputScanner.Text()) 80 | 81 | return r.fieldScanner.Next(f) 82 | } 83 | 84 | return true 85 | } 86 | 87 | // Err returns the last read error. At the end of input 88 | // it will always be equal to io.EOF. 89 | func (r *Parser) Err() error { 90 | if err := r.fieldScanner.Err(); err != nil { 91 | return err 92 | } 93 | if r.inputScanner == nil { 94 | // Recover the EOF suppressed by bufio.Scanner. 95 | // We need it inside the client, to know when to retry. 96 | return io.EOF 97 | } 98 | return r.inputScanner.Err() 99 | } 100 | 101 | // Buffer sets the buffer used to scan the input. 102 | // For more information, see the documentation on bufio.Scanner.Buffer. 103 | // Do not call this after parsing has started – the method will panic! 104 | func (r *Parser) Buffer(buf []byte, maxSize int) { 105 | r.inputScanner.Buffer(buf, maxSize) 106 | } 107 | 108 | // New returns a Parser that extracts fields from a reader. 109 | func New(r io.Reader) *Parser { 110 | sc := bufio.NewScanner(r) 111 | sc.Split(splitFunc) 112 | 113 | fsc := NewFieldParser("") 114 | fsc.RemoveBOM(true) 115 | 116 | return &Parser{inputScanner: sc, fieldScanner: fsc} 117 | } 118 | -------------------------------------------------------------------------------- /internal/parser/parser_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "io" 7 | "reflect" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/tmaxmax/go-sse/internal/parser" 12 | ) 13 | 14 | type errReader struct { 15 | r io.Reader 16 | } 17 | 18 | var errReadFailed = errors.New("haha error") 19 | 20 | func (e errReader) Read(_ []byte) (int, error) { 21 | return 0, errReadFailed 22 | } 23 | 24 | func TestParser(t *testing.T) { 25 | t.Parallel() 26 | 27 | type test struct { 28 | input io.Reader 29 | err error 30 | name string 31 | expected []parser.Field 32 | } 33 | 34 | longString := strings.Repeat("abcdefghijklmnopqrstuvwxyz", 193) 35 | 36 | tests := []test{ 37 | { 38 | name: "Valid input", 39 | input: strings.NewReader(` 40 | event: scan 41 | data: first chunk 42 | data: second chunk 43 | :comment 44 | data: third chunk 45 | id: 1 46 | 47 | : comment 48 | event: anotherScan 49 | data: nice scan 50 | id: 2 51 | retry: 15 52 | 53 | 54 | event: something glitched before why are there two newlines 55 | data: still, here's some data: you deserve it 56 | `), 57 | expected: []parser.Field{ 58 | newEventField(t, "scan"), 59 | newDataField(t, "first chunk"), 60 | newDataField(t, "second chunk"), 61 | newDataField(t, "third chunk"), 62 | newIDField(t, "1"), 63 | {}, 64 | newEventField(t, "anotherScan"), 65 | newDataField(t, "nice scan"), 66 | newIDField(t, "2"), 67 | newRetryField(t, "15"), 68 | {}, 69 | newEventField(t, "something glitched before why are there two newlines"), 70 | newDataField(t, "still, here's some data: you deserve it"), 71 | }, 72 | err: io.EOF, 73 | }, 74 | { 75 | name: "Valid input with long string", 76 | input: strings.NewReader("\nid:2\ndata:" + longString + "\n"), 77 | expected: []parser.Field{ 78 | newIDField(t, "2"), 79 | newDataField(t, longString), 80 | }, 81 | err: io.EOF, 82 | }, 83 | { 84 | name: "Error", 85 | input: errReader{nil}, 86 | err: errReadFailed, 87 | }, 88 | { 89 | name: "Error from field parser (no final newline)", 90 | input: strings.NewReader("data: lmao"), 91 | err: parser.ErrUnexpectedEOF, 92 | }, 93 | { 94 | name: "With BOM", 95 | // The second BOM should not be removed, which should result in that field being named 96 | // "\ufeffdata", which is an invalid name and thus the field ending up being ignored. 97 | input: strings.NewReader("\xEF\xBB\xBFdata: hello\n\n\xEF\xBB\xBFdata: world\n\n"), 98 | expected: []parser.Field{ 99 | newDataField(t, "hello"), 100 | {}, 101 | {}, 102 | }, 103 | err: io.EOF, 104 | }, 105 | } 106 | 107 | for _, test := range tests { 108 | test := test 109 | 110 | t.Run(test.name, func(t *testing.T) { 111 | t.Parallel() 112 | 113 | p := parser.New(test.input) 114 | var fields []parser.Field 115 | if l := len(test.expected); l > 0 { 116 | fields = make([]parser.Field, 0, l) 117 | } 118 | 119 | for f := (parser.Field{}); p.Next(&f); { 120 | fields = append(fields, f) 121 | } 122 | 123 | if err := p.Err(); err != test.err { //nolint 124 | t.Fatalf("invalid error: received %v, expected %v", err, test.err) 125 | } 126 | 127 | if !reflect.DeepEqual(test.expected, fields) { 128 | t.Fatalf("parse failed:\nreceived: %#v\nexpected: %#v", fields, test.expected) 129 | } 130 | }) 131 | } 132 | 133 | t.Run("Buffer", func(t *testing.T) { 134 | p := parser.New(strings.NewReader("sarmale")) 135 | p.Buffer(make([]byte, 0, 3), 3) 136 | 137 | if p.Next(nil) { 138 | t.Fatalf("nothing should be parsed") 139 | } 140 | if !errors.Is(p.Err(), bufio.ErrTooLong) { 141 | t.Fatalf("expected error %v, received %v", bufio.ErrTooLong, p.Err()) 142 | } 143 | }) 144 | 145 | t.Run("Separate CRLF", func(t *testing.T) { 146 | r, w := io.Pipe() 147 | t.Cleanup(func() { r.Close() }) 148 | 149 | p := parser.New(r) 150 | 151 | go func() { 152 | defer w.Close() 153 | _, _ = io.WriteString(w, "data: hello\n\r") 154 | // This LF should be ignored and yield no results. 155 | _, _ = io.WriteString(w, "\n") 156 | _, _ = io.WriteString(w, "data: world\n") 157 | }() 158 | 159 | var fields []parser.Field 160 | for f := (parser.Field{}); p.Next(&f); { 161 | fields = append(fields, f) 162 | } 163 | 164 | expected := []parser.Field{newDataField(t, "hello"), {}, newDataField(t, "world")} 165 | 166 | if !reflect.DeepEqual(fields, expected) { 167 | t.Fatalf("unexpected result:\nreceived %v\nexpected %v", fields, expected) 168 | } 169 | }) 170 | } 171 | 172 | func BenchmarkParser(b *testing.B) { 173 | b.ReportAllocs() 174 | 175 | var f parser.Field 176 | 177 | for n := 0; n < b.N; n++ { 178 | r := strings.NewReader(benchmarkText) 179 | p := parser.New(r) 180 | 181 | for p.Next(&f) { 182 | } 183 | } 184 | 185 | _ = f 186 | } 187 | -------------------------------------------------------------------------------- /internal/parser/split_func_test.go: -------------------------------------------------------------------------------- 1 | package parser 2 | 3 | import ( 4 | "bufio" 5 | "reflect" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestSplitFunc(t *testing.T) { 11 | t.Parallel() 12 | 13 | type testCase struct { 14 | name string 15 | input string 16 | expected []string 17 | } 18 | 19 | longString := strings.Repeat("abcdef\rghijklmn\nopqrstu\r\nvwxyz", 193) 20 | testCases := []testCase{ 21 | { 22 | name: "Short sample", 23 | input: "mama mea e super\nce genial\nsincer n-am ce sa zic\r\n\r\n\nmama tata bunica bunicul\nsarmale\r\n\r\r\naualeu\nce taraboi", 24 | expected: []string{ 25 | "mama mea e super\nce genial\nsincer n-am ce sa zic\r\n\r\n", 26 | "mama tata bunica bunicul\nsarmale\r\n\r", 27 | "aualeu\nce taraboi", 28 | }, 29 | }, 30 | { 31 | name: "Long sample", 32 | input: longString + "\n\n" + longString + "\r\r" + longString + "\r\n\r\n" + longString, 33 | expected: []string{ 34 | longString + "\n\n", 35 | longString + "\r\r", 36 | longString + "\r\n\r\n", 37 | longString, 38 | }, 39 | }, 40 | } 41 | 42 | for _, tc := range testCases { 43 | tc := tc 44 | 45 | t.Run(tc.name, func(t *testing.T) { 46 | t.Parallel() 47 | 48 | r := strings.NewReader(tc.input) 49 | s := bufio.NewScanner(r) 50 | s.Split(splitFunc) 51 | 52 | tokens := make([]string, 0, len(tc.expected)) 53 | 54 | for s.Scan() { 55 | tokens = append(tokens, s.Text()) 56 | } 57 | 58 | if s.Err() != nil { 59 | t.Fatalf("an error occurred: %v", s.Err()) 60 | } 61 | 62 | if !reflect.DeepEqual(tokens, tc.expected) { 63 | t.Fatalf("wrong tokens:\nreceived: %#v\nexpected: %#v", tokens, tc.expected) 64 | } 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/parser/test_helpers_test.go: -------------------------------------------------------------------------------- 1 | package parser_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tmaxmax/go-sse/internal/parser" 7 | ) 8 | 9 | func newField(tb testing.TB, name parser.FieldName, value string) parser.Field { 10 | tb.Helper() 11 | 12 | return parser.Field{ 13 | Name: name, 14 | Value: value, 15 | } 16 | } 17 | 18 | func newDataField(tb testing.TB, value string) parser.Field { 19 | tb.Helper() 20 | 21 | return newField(tb, parser.FieldNameData, value) 22 | } 23 | 24 | func newEventField(tb testing.TB, value string) parser.Field { 25 | tb.Helper() 26 | 27 | return newField(tb, parser.FieldNameEvent, value) 28 | } 29 | 30 | func newRetryField(tb testing.TB, value string) parser.Field { 31 | tb.Helper() 32 | 33 | return newField(tb, parser.FieldNameRetry, value) 34 | } 35 | 36 | func newIDField(tb testing.TB, value string) parser.Field { 37 | tb.Helper() 38 | 39 | return newField(tb, parser.FieldNameID, value) 40 | } 41 | 42 | func newCommentField(tb testing.TB, value string) parser.Field { 43 | tb.Helper() 44 | 45 | return newField(tb, parser.FieldNameComment, value) 46 | } 47 | 48 | const benchmarkText = ` 49 | event:cycles 50 | data:8 51 | id:10 52 | 53 | event:ops 54 | data:2 55 | id:11 56 | 57 | data:10667007354186551956 58 | id:12 59 | 60 | event:cycles 61 | data:9 62 | id:13 63 | 64 | event:cycles 65 | data:10 66 | id:14 67 | 68 | event:cycles 69 | data:11 70 | id:15 71 | 72 | event:cycles 73 | data:12 74 | id:16 75 | 76 | event:ops 77 | data:3 78 | id:17 79 | 80 | event:cycles 81 | data:13 82 | id:18 83 | 84 | data:4751997750760398084 85 | id:19 86 | 87 | event:cycles 88 | data:14 89 | id:20 90 | 91 | event:cycles 92 | data:15 93 | id:21 94 | 95 | event:cycles 96 | data:16 97 | id:22 98 | 99 | event:ops 100 | data:4 101 | id:23 102 | 103 | event:cycles 104 | data:17 105 | id:24 106 | 107 | event:cycles 108 | data:18 109 | id:25 110 | 111 | data:3510942875414458836 112 | data:12156940908066221323 113 | data:4324745483838182873 114 | id:26 115 | 116 | event:cycles 117 | data:19 118 | id:27 119 | 120 | event:cycles 121 | data:20 122 | id:28 123 | 124 | event:ops 125 | data:5 126 | id:29 127 | 128 | event:cycles 129 | data:21 130 | id:30 131 | 132 | event:cycles 133 | data:22 134 | id:31 135 | 136 | event:cycles 137 | data:23 138 | id:32 139 | 140 | data:6263450610539110790 141 | id:33 142 | 143 | event:ops 144 | data:6 145 | id:34 146 | 147 | event:cycles 148 | data:24 149 | id:35 150 | 151 | event:cycles 152 | data:25 153 | id:36 154 | 155 | event:cycles 156 | data:26 157 | id:37 158 | 159 | event:cycles 160 | data:27 161 | id:38 162 | 163 | data:3328451335138149956 164 | id:39 165 | 166 | event:ops 167 | data:7 168 | id:40 169 | 170 | event:cycles 171 | data:28 172 | id:41 173 | 174 | event:cycles 175 | data:29 176 | id:42 177 | 178 | event:cycles 179 | data:30 180 | id:43 181 | 182 | event:cycles 183 | data:31 184 | id:44 185 | 186 | ` 187 | -------------------------------------------------------------------------------- /internal/tests/expect.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func Equal[T comparable](tb testing.TB, got, expected T, format string, args ...any) { 11 | tb.Helper() 12 | 13 | if got != expected { 14 | output := fmt.Sprintf(format, args...) 15 | switch v := any(got).(type) { 16 | case string: 17 | output += fmt.Sprintf("\nreceived: %q\nexpected: %q", v, any(expected).(string)) 18 | case fmt.Stringer: 19 | output += fmt.Sprintf("\nreceived: %q\nexpected: %q", v.String(), any(expected).(fmt.Stringer).String()) 20 | default: 21 | a, b := fmt.Sprintf("%v", got), fmt.Sprintf("%v", expected) 22 | if len(a) >= 20 || len(b) >= 20 { 23 | output += fmt.Sprintf("\nreceived: %s\nexpected: %s", a, b) 24 | } else { 25 | output += fmt.Sprintf("\nreceived %s, expected %s", a, b) 26 | } 27 | } 28 | tb.Fatal(output) 29 | } 30 | } 31 | 32 | func DeepEqual[T any](tb testing.TB, got, expected T, format string, args ...any) { 33 | tb.Helper() 34 | 35 | if !reflect.DeepEqual(got, expected) { 36 | output := fmt.Sprintf(format, args...) 37 | output += fmt.Sprintf("\nreceived: %+v\nexpected %+v", got, expected) 38 | tb.Fatal(output) 39 | } 40 | } 41 | 42 | func ErrorIs(tb testing.TB, got, expected error, format string, args ...any) { 43 | tb.Helper() 44 | 45 | if !errors.Is(got, expected) { 46 | output := fmt.Sprintf(format, args...) 47 | output += fmt.Sprintf("\nreceived: %#v\nexpected: %#v", got, expected) 48 | tb.Fatal(output) 49 | } 50 | } 51 | 52 | func Expect(tb testing.TB, cond bool, format string, args ...any) { 53 | tb.Helper() 54 | 55 | if !cond { 56 | tb.Fatalf(format, args...) 57 | } 58 | } 59 | 60 | func NotPanics(tb testing.TB, fn func(), format string, args ...any) { 61 | tb.Helper() 62 | 63 | defer func() { 64 | tb.Helper() 65 | 66 | if r := recover(); r != nil { 67 | tb.Fatalf(format+"\npanic: %+v", append(args, r)...) 68 | } 69 | }() 70 | 71 | fn() 72 | } 73 | 74 | func Panics(tb testing.TB, fn func(), format string, args ...any) (recovered any) { 75 | tb.Helper() 76 | 77 | defer func() { 78 | tb.Helper() 79 | 80 | if recovered = recover(); recovered == nil { 81 | tb.Fatalf(format, args...) 82 | } 83 | }() 84 | 85 | fn() 86 | 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /internal/tests/time.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import "time" 4 | 5 | type Time struct { 6 | now time.Time 7 | added time.Duration 8 | } 9 | 10 | func (t *Time) Now() time.Time { 11 | if t == nil || t.now.IsZero() { 12 | return time.Now() 13 | } 14 | 15 | return t.now.Add(t.added) 16 | } 17 | 18 | func (t *Time) Fixed() (time.Time, bool) { 19 | if t == nil || t.now.IsZero() { 20 | return time.Time{}, false 21 | } 22 | 23 | return t.now.Add(t.added), true 24 | } 25 | 26 | func (t *Time) Set(ts time.Time) { 27 | if t != nil { 28 | t.now = ts 29 | t.added = 0 30 | } 31 | } 32 | 33 | func (t *Time) Add(d time.Duration) { 34 | if t == nil { 35 | return 36 | } 37 | 38 | if t.now.IsZero() { 39 | t.now = time.Now() 40 | } 41 | 42 | t.added += d 43 | } 44 | 45 | func (t *Time) Reset() { 46 | t.Set(time.Time{}) 47 | } 48 | 49 | func (t *Time) Rewind() { 50 | if t != nil { 51 | t.added = 0 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /joe.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "context" 5 | "runtime/debug" 6 | "sync" 7 | ) 8 | 9 | // A Replayer is a type that can replay older published events to new subscribers. 10 | // Replayers use event IDs, the topics the events were published and optionally 11 | // any other criteria to determine which are valid for replay. 12 | // 13 | // While replayers can require events to have IDs beforehand, they can also set the IDs themselves, 14 | // automatically - it's up to the implementation. Replayers should not overwrite or remove any existing 15 | // IDs and return an error instead. 16 | // 17 | // Replayers are not required to be thread-safe - server providers are required to ensure only 18 | // one operation is executed on the replayer at any given time. Server providers may not execute 19 | // replay operation concurrently with other operations, so make sure any action on the replayer 20 | // blocks for as little as possible. If a replayer is thread-safe, some operations may be 21 | // run in a separate goroutine - see the interface's method documentation. 22 | // 23 | // Executing actions that require waiting for a long time on I/O, such as HTTP requests or database 24 | // calls must be handled with great care, so the server provider is not blocked. Reducing them to 25 | // the minimum by using techniques such as caching or by executing them in separate goroutines is 26 | // recommended, as long as the implementation fulfills the requirements. 27 | // 28 | // If not specified otherwise, the errors returned are implementation-specific. 29 | type Replayer interface { 30 | // Put adds a new event to the replay buffer. The Message that is returned may not have the 31 | // same address, if the replayer automatically sets IDs. 32 | // 33 | // Put errors if the message couldn't be queued – if no topics are provided, 34 | // a message without an ID is put into a Replayer which does not 35 | // automatically set IDs, or a message with an ID is put into a Replayer which 36 | // does automatically set IDs. An error should be returned for other failures 37 | // related to the given message. When no topics are provided, ErrNoTopic should be 38 | // returned. 39 | // 40 | // The Put operation may be executed by the replayer in another goroutine only if 41 | // it can ensure that any Replay operation called after the Put goroutine is started 42 | // can replay the new received message. This also requires the replayer implementation 43 | // to be thread-safe. 44 | // 45 | // Replayers are not required to guarantee that immediately after Put returns 46 | // the new messages can be replayed. If an error occurs internally when putting the new message 47 | // and retrying the operation would block for too long, it can be aborted. 48 | // 49 | // To indicate a complete replayer failure (i.e. the replayer won't work after this point) 50 | // a panic should be used instead of an error. 51 | Put(message *Message, topics []string) (*Message, error) 52 | // Replay sends to a new subscriber all the valid events received by the replayer 53 | // since the event with the listener's ID. If the ID the listener provides 54 | // is invalid, the provider should not replay any events. 55 | // 56 | // Replay calls must return only after replaying is done. 57 | // Implementations should not keep references to the subscription client 58 | // after Replay returns. 59 | // 60 | // If an error is returned, then at least some messages weren't successfully replayed. 61 | // The error is nil if there were no messages to replay for the particular subscription 62 | // or if all messages were replayed successfully. 63 | // 64 | // If any messages are replayed, Client.Flush must be called by implementations. 65 | Replay(subscription Subscription) error 66 | } 67 | 68 | type ( 69 | subscriber chan<- error 70 | subscription struct { 71 | done subscriber 72 | Subscription 73 | } 74 | 75 | messageWithTopics struct { 76 | message *Message 77 | topics []string 78 | } 79 | 80 | publishedMessage struct { 81 | replayerErr chan<- error 82 | messageWithTopics 83 | } 84 | ) 85 | 86 | // Joe is a basic server provider that synchronously executes operations by queueing them in channels. 87 | // Events are also sent synchronously to subscribers, so if a subscriber's callback blocks, the others 88 | // have to wait. 89 | // 90 | // Joe optionally supports event replaying with the help of a Replayer. 91 | // 92 | // If the replayer panics, the subscription for which it panicked is considered failed 93 | // and an error is returned, and thereafter the replayer is not used anymore – no replays 94 | // will be attempted for future subscriptions. 95 | // If due to some other unexpected scenario something panics internally, Joe will remove all subscribers 96 | // and close itself, so subscribers don't end up blocked. 97 | // 98 | // He serves simple use-cases well, as he's light on resources, and does not require any external 99 | // services. Also, he is the default provider for Servers. 100 | type Joe struct { 101 | message chan publishedMessage 102 | subscription chan subscription 103 | unsubscription chan subscriber 104 | done chan struct{} 105 | closed chan struct{} 106 | subscribers map[subscriber]Subscription 107 | 108 | // An optional replayer that Joe uses to resend older messages to new subscribers. 109 | Replayer Replayer 110 | 111 | initDone sync.Once 112 | } 113 | 114 | // Subscribe tells Joe to send new messages to this subscriber. The subscription 115 | // is automatically removed when the context is done, a client error occurs 116 | // or Joe is stopped. 117 | // 118 | // Subscribe returns without error only when the unsubscription is caused 119 | // by the given context being canceled. 120 | func (j *Joe) Subscribe(ctx context.Context, sub Subscription) error { 121 | j.init() 122 | 123 | // Without a buffered channel we risk a deadlock when Subscribe 124 | // stops receiving from this channel on done context and Joe 125 | // encounters an error when sending messages or replaying. 126 | done := make(chan error, 1) 127 | 128 | select { 129 | case <-j.done: 130 | return ErrProviderClosed 131 | case j.subscription <- subscription{done: done, Subscription: sub}: 132 | } 133 | 134 | select { 135 | case err := <-done: 136 | return err 137 | case <-j.closed: 138 | return ErrProviderClosed 139 | case <-ctx.Done(): 140 | } 141 | 142 | select { 143 | case <-j.done: 144 | return ErrProviderClosed 145 | case j.unsubscription <- done: 146 | // NOTE(tmaxmax): should we return ctx.Err() instead? 147 | return nil 148 | } 149 | } 150 | 151 | // Publish tells Joe to send the given message to the subscribers. 152 | // When a message is published to multiple topics, Joe makes sure to 153 | // not send the Message multiple times to clients that are subscribed 154 | // to more than one topic that receive the given Message. Every client 155 | // receives each unique message once, regardless of how many topics it 156 | // is subscribed to or to how many topics the message is published. 157 | // 158 | // It returns ErrNoTopic if no topics are provided, eventual Replayer.Put 159 | // errors or ErrProviderClosed. If the replayer returns an error the 160 | // message will still be sent but most probably it won't be replayed to 161 | // new subscribers, depending on how the error is handled by the replay provider. 162 | func (j *Joe) Publish(msg *Message, topics []string) error { 163 | if len(topics) == 0 { 164 | return ErrNoTopic 165 | } 166 | 167 | j.init() 168 | 169 | // Buffered to prevent a deadlock when Publish doesn't 170 | // receive from errs due to Joe being shut down and the 171 | // message published causes an error after the shutdown. 172 | errs := make(chan error, 1) 173 | 174 | pub := publishedMessage{replayerErr: errs} 175 | pub.message = msg 176 | pub.topics = topics 177 | 178 | // Waiting on done ensures Publish doesn't block the caller goroutine 179 | // when Joe is stopped and implements the required Provider behavior. 180 | select { 181 | case j.message <- pub: 182 | return <-errs 183 | case <-j.done: 184 | return ErrProviderClosed 185 | } 186 | } 187 | 188 | // Shutdown signals Joe to close all subscribers and stop receiving messages. 189 | // It returns when all the subscribers are closed. 190 | // 191 | // Further calls to Stop will return ErrProviderClosed. 192 | func (j *Joe) Shutdown(ctx context.Context) (err error) { 193 | j.init() 194 | 195 | defer func() { 196 | if r := recover(); r != nil { 197 | err = ErrProviderClosed 198 | } 199 | }() 200 | 201 | close(j.done) 202 | 203 | select { 204 | case <-j.closed: 205 | case <-ctx.Done(): 206 | err = ctx.Err() 207 | } 208 | 209 | return 210 | } 211 | 212 | func (j *Joe) removeSubscriber(sub subscriber) { 213 | l := len(j.subscribers) 214 | delete(j.subscribers, sub) 215 | // We check that an element was deleted as removeSubscriber is called twice 216 | // in the following edge case: the subscriber context is done before a 217 | // published message is sent/flushed, and the send/flush returns an error. 218 | if l != len(j.subscribers) { 219 | close(sub) 220 | } 221 | } 222 | 223 | func (j *Joe) start(replay Replayer) { 224 | defer close(j.closed) 225 | 226 | for { 227 | select { 228 | case msg := <-j.message: 229 | if replay != nil { 230 | m, err := tryPut(msg.messageWithTopics, &replay) 231 | if _, isPanic := err.(replayPanic); err != nil && !isPanic { //nolint:errorlint // it's our error 232 | // NOTE(tmaxmax): We could return panic errors here but we'd have to expose 233 | // the error type in order for this error to be handled. Let's not change 234 | // the public errors for now. See also the other note below. 235 | msg.replayerErr <- err 236 | } else if m != nil { 237 | msg.message = m 238 | } 239 | } 240 | close(msg.replayerErr) 241 | 242 | for done, sub := range j.subscribers { 243 | if topicsIntersect(sub.Topics, msg.topics) { 244 | err := sub.Client.Send(msg.message) 245 | if err == nil { 246 | err = sub.Client.Flush() 247 | } 248 | 249 | if err != nil { 250 | done <- err 251 | // Technically it would be possible to just send the error, 252 | // as Subscribe would send an unsubscription signal. The problem 253 | // is that if the j.message channel is ready together with j.unsubscription 254 | // and j.message is picked we might send again to this now unsubscribed 255 | // subscriber, which will cause issues (e.g. deadlock on done). 256 | // This line here is the reason why we need to verify we actually 257 | // have this subscriber in removeSubscriber above. 258 | j.removeSubscriber(done) 259 | } 260 | } 261 | } 262 | case sub := <-j.subscription: 263 | var err error 264 | if replay != nil { 265 | err = tryReplay(sub.Subscription, &replay) 266 | } 267 | 268 | // NOTE(tmaxmax): We can't meaningfully handle replay panics in any way 269 | // other than disabling replay altogether. This ensures uptime 270 | // in the face of unexpected – returning the panic as an error 271 | // to the subscriber doesn't make sense, as it's probably not the subscriber's fault. 272 | if _, isPanic := err.(replayPanic); err != nil && !isPanic { //nolint:errorlint // it's our error 273 | sub.done <- err 274 | close(sub.done) 275 | } else { 276 | j.subscribers[sub.done] = sub.Subscription 277 | } 278 | case sub := <-j.unsubscription: 279 | j.removeSubscriber(sub) 280 | case <-j.done: 281 | return 282 | } 283 | } 284 | } 285 | 286 | func tryReplay(sub Subscription, replay *Replayer) (err error) { //nolint:gocritic // intended 287 | defer handleReplayerPanic(replay, &err) 288 | 289 | return (*replay).Replay(sub) 290 | } 291 | 292 | func tryPut(msg messageWithTopics, replay *Replayer) (m *Message, err error) { //nolint:gocritic // intended 293 | defer handleReplayerPanic(replay, &err) 294 | 295 | return (*replay).Put(msg.message, msg.topics) 296 | } 297 | 298 | type replayPanic struct{} 299 | 300 | func (replayPanic) Error() string { return "replay provider panicked" } 301 | 302 | func handleReplayerPanic(replay *Replayer, errp *error) { //nolint:gocritic // intended 303 | if r := recover(); r != nil { 304 | *replay = nil 305 | *errp = replayPanic{} 306 | // NOTE(tmaxmax): At least print a stacktrace. It's annoying when libraries recover from panics 307 | // and make them untraceable. Should we provide a way to handle these in a custom manner? 308 | debug.PrintStack() 309 | } 310 | } 311 | 312 | func (j *Joe) init() { 313 | j.initDone.Do(func() { 314 | j.message = make(chan publishedMessage) 315 | j.subscription = make(chan subscription) 316 | j.unsubscription = make(chan subscriber) 317 | j.done = make(chan struct{}) 318 | j.closed = make(chan struct{}) 319 | j.subscribers = map[subscriber]Subscription{} 320 | 321 | replay := j.Replayer 322 | if replay == nil { 323 | replay = noopReplayer{} 324 | } 325 | go j.start(replay) 326 | }) 327 | } 328 | -------------------------------------------------------------------------------- /joe_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/tmaxmax/go-sse" 11 | "github.com/tmaxmax/go-sse/internal/tests" 12 | ) 13 | 14 | type mockReplayer struct { 15 | putc chan struct{} 16 | replayc chan struct{} 17 | shouldPanic string 18 | } 19 | 20 | func (m *mockReplayer) Put(msg *sse.Message, _ []string) (*sse.Message, error) { 21 | m.putc <- struct{}{} 22 | if strings.Contains(m.shouldPanic, "put") { 23 | panic("panicked") 24 | } 25 | 26 | return msg, nil 27 | } 28 | 29 | func (m *mockReplayer) Replay(_ sse.Subscription) error { 30 | m.replayc <- struct{}{} 31 | if strings.Contains(m.shouldPanic, "replay") { 32 | panic("panicked") 33 | } 34 | 35 | return nil 36 | } 37 | 38 | func (m *mockReplayer) replays() int { 39 | return len(m.replayc) 40 | } 41 | 42 | func (m *mockReplayer) puts() int { 43 | return len(m.putc) 44 | } 45 | 46 | var _ sse.Replayer = (*mockReplayer)(nil) 47 | 48 | func newMockReplayer(shouldPanic string, numExpectedCalls int) *mockReplayer { 49 | return &mockReplayer{ 50 | shouldPanic: shouldPanic, 51 | putc: make(chan struct{}, numExpectedCalls), 52 | replayc: make(chan struct{}, numExpectedCalls), 53 | } 54 | } 55 | 56 | func msg(tb testing.TB, data, id string) *sse.Message { 57 | tb.Helper() 58 | 59 | e := &sse.Message{} 60 | e.AppendData(data) 61 | if id != "" { 62 | e.ID = sse.ID(id) 63 | } 64 | 65 | return e 66 | } 67 | 68 | type mockClient func(m *sse.Message) error 69 | 70 | func (c mockClient) Send(m *sse.Message) error { return c(m) } 71 | func (c mockClient) Flush() error { return c(nil) } 72 | 73 | func cleanupJoe(tb testing.TB, j *sse.Joe) { 74 | tb.Helper() 75 | tb.Cleanup(func() { 76 | _ = j.Shutdown(context.Background()) 77 | }) 78 | } 79 | 80 | func TestJoe_Shutdown(t *testing.T) { 81 | t.Parallel() 82 | 83 | rp := newMockReplayer("", 0) 84 | j := &sse.Joe{ 85 | Replayer: rp, 86 | } 87 | 88 | tests.Equal(t, j.Shutdown(context.Background()), nil, "joe should close successfully") 89 | tests.Equal(t, j.Shutdown(context.Background()), sse.ErrProviderClosed, "joe should already be closed") 90 | tests.Equal(t, j.Subscribe(context.Background(), sse.Subscription{}), sse.ErrProviderClosed, "no operation should be allowed on closed joe") 91 | tests.Equal(t, j.Publish(nil, nil), sse.ErrNoTopic, "parameter validation should happen first") 92 | tests.Equal(t, j.Publish(nil, []string{sse.DefaultTopic}), sse.ErrProviderClosed, "no operation should be allowed on closed joe") 93 | tests.Equal(t, rp.puts(), 0, "joe should not have used the replay provider") 94 | tests.Equal(t, rp.replays(), 0, "joe should not have used the replay provider") 95 | 96 | j = &sse.Joe{} 97 | // trigger internal initialization, so the concurrent Shutdowns aren't serialized by the internal sync.Once. 98 | _ = j.Publish(&sse.Message{}, []string{sse.DefaultTopic}) 99 | //nolint 100 | tests.NotPanics(t, func() { 101 | go j.Shutdown(context.Background()) 102 | j.Shutdown(context.Background()) 103 | }, "concurrent shutdown should work") 104 | 105 | j = &sse.Joe{} 106 | subctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*15) 107 | t.Cleanup(cancel) 108 | go j.Subscribe(subctx, sse.Subscription{ //nolint:errcheck // we don't care about this error 109 | Topics: []string{sse.DefaultTopic}, 110 | Client: mockClient(func(m *sse.Message) error { 111 | if m != nil { 112 | time.Sleep(time.Millisecond * 8) 113 | } 114 | return nil 115 | }), 116 | }) 117 | time.Sleep(time.Millisecond) 118 | 119 | _ = j.Publish(&sse.Message{}, []string{sse.DefaultTopic}) 120 | 121 | sctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5) 122 | t.Cleanup(cancel) 123 | tests.ErrorIs(t, j.Shutdown(sctx), context.DeadlineExceeded, "shutdown should stop on closed context") 124 | 125 | <-subctx.Done() 126 | } 127 | 128 | func subscribe(t testing.TB, p sse.Provider, ctx context.Context, topics ...string) <-chan []*sse.Message { //nolint 129 | t.Helper() 130 | 131 | if len(topics) == 0 { 132 | topics = []string{sse.DefaultTopic} 133 | } 134 | 135 | ch := make(chan []*sse.Message, 1) 136 | 137 | go func() { 138 | defer close(ch) 139 | 140 | var msgs []*sse.Message 141 | 142 | c := mockClient(func(m *sse.Message) error { 143 | if m != nil { 144 | msgs = append(msgs, m) 145 | } 146 | return nil 147 | }) 148 | 149 | _ = p.Subscribe(ctx, sse.Subscription{Client: c, Topics: topics}) 150 | 151 | ch <- msgs 152 | }() 153 | 154 | return ch 155 | } 156 | 157 | type mockContext struct { 158 | context.Context 159 | waitingOnDone chan struct{} 160 | } 161 | 162 | func (m *mockContext) Done() <-chan struct{} { 163 | close(m.waitingOnDone) 164 | 165 | return m.Context.Done() 166 | } 167 | 168 | func newMockContext(tb testing.TB) (*mockContext, context.CancelFunc) { 169 | tb.Helper() 170 | 171 | ctx, cancel := context.WithCancel(context.Background()) 172 | tb.Cleanup(cancel) 173 | 174 | return &mockContext{Context: ctx, waitingOnDone: make(chan struct{})}, cancel 175 | } 176 | 177 | func TestJoe_SubscribePublish(t *testing.T) { 178 | t.Parallel() 179 | 180 | rp := newMockReplayer("", 2) 181 | j := &sse.Joe{ 182 | Replayer: rp, 183 | } 184 | cleanupJoe(t, j) 185 | 186 | ctx, cancel := newMockContext(t) 187 | 188 | sub := subscribe(t, j, ctx) 189 | <-ctx.waitingOnDone 190 | tests.Equal(t, j.Publish(msg(t, "hello", ""), []string{sse.DefaultTopic}), nil, "publish should succeed") 191 | cancel() 192 | tests.Equal(t, j.Publish(msg(t, "world", ""), []string{sse.DefaultTopic}), nil, "publish should succeed") 193 | msgs := <-sub 194 | tests.Equal(t, "data: hello\n\n", msgs[0].String(), "invalid data received") 195 | 196 | ctx2, _ := newMockContext(t) 197 | 198 | sub2 := subscribe(t, j, ctx2) 199 | <-ctx2.waitingOnDone 200 | 201 | tests.Equal(t, j.Shutdown(context.Background()), nil, "shutdown should succeed") 202 | msgs = <-sub2 203 | tests.Equal(t, len(msgs), 0, "unexpected messages received") 204 | tests.Equal(t, rp.puts(), 2, "invalid put calls") 205 | tests.Equal(t, rp.puts(), 2, "invalid replay calls") 206 | } 207 | 208 | func TestJoe_Subscribe_multipleTopics(t *testing.T) { 209 | t.Parallel() 210 | 211 | j := &sse.Joe{} 212 | cleanupJoe(t, j) 213 | 214 | ctx, _ := newMockContext(t) 215 | 216 | sub := subscribe(t, j, ctx, sse.DefaultTopic, "another topic") 217 | <-ctx.waitingOnDone 218 | 219 | _ = j.Publish(msg(t, "hello", ""), []string{sse.DefaultTopic, "another topic"}) 220 | _ = j.Publish(msg(t, "world", ""), []string{"another topic"}) 221 | 222 | _ = j.Shutdown(context.Background()) 223 | 224 | msgs := <-sub 225 | 226 | expected := `data: hello 227 | 228 | data: world 229 | 230 | ` 231 | tests.Equal(t, expected, msgs[0].String()+msgs[1].String(), "unexpected data received") 232 | } 233 | 234 | func TestJoe_errors(t *testing.T) { 235 | t.Parallel() 236 | 237 | fin, err := sse.NewFiniteReplayer(2, false) 238 | tests.Equal(t, err, nil, "should create new FiniteReplayProvider") 239 | 240 | j := &sse.Joe{ 241 | Replayer: fin, 242 | } 243 | cleanupJoe(t, j) 244 | 245 | _ = j.Publish(msg(t, "hello", "0"), []string{sse.DefaultTopic}) 246 | _ = j.Publish(msg(t, "hello", "1"), []string{sse.DefaultTopic}) 247 | 248 | callErr := errors.New("artificial fail") 249 | 250 | var called int 251 | client := mockClient(func(m *sse.Message) error { 252 | if m != nil { 253 | called++ 254 | } 255 | return callErr 256 | }) 257 | 258 | err = j.Subscribe(context.Background(), sse.Subscription{ 259 | Client: client, 260 | LastEventID: sse.ID("0"), 261 | Topics: []string{sse.DefaultTopic}, 262 | }) 263 | tests.Equal(t, err, callErr, "error not received from replay") 264 | 265 | _ = j.Publish(msg(t, "world", "2"), []string{sse.DefaultTopic}) 266 | 267 | tests.Equal(t, called, 1, "callback was called after subscribe returned") 268 | 269 | called = 0 270 | ctx, _ := newMockContext(t) 271 | done := make(chan struct{}) 272 | 273 | go func() { 274 | defer close(done) 275 | 276 | <-ctx.waitingOnDone 277 | 278 | _ = j.Publish(msg(t, "", "3"), []string{sse.DefaultTopic}) 279 | _ = j.Publish(msg(t, "", "4"), []string{sse.DefaultTopic}) 280 | }() 281 | 282 | err = j.Subscribe(ctx, sse.Subscription{Client: client, Topics: []string{sse.DefaultTopic}}) 283 | tests.Equal(t, err, callErr, "error not received from send") 284 | // Only the first event should be attempted as nothing is replayed. 285 | tests.Equal(t, called, 1, "callback was called after subscribe returned") 286 | 287 | <-done 288 | } 289 | 290 | type mockMessageWriter struct { 291 | msg chan *sse.Message 292 | } 293 | 294 | func (m *mockMessageWriter) Send(msg *sse.Message) error { 295 | m.msg <- msg 296 | return nil 297 | } 298 | 299 | func (m *mockMessageWriter) Flush() error { 300 | return nil 301 | } 302 | 303 | func TestJoe_ReplayPanic(t *testing.T) { 304 | t.Parallel() 305 | 306 | rp := newMockReplayer("replay put", 1) 307 | j := &sse.Joe{Replayer: rp} 308 | wr := &mockMessageWriter{msg: make(chan *sse.Message, 1)} 309 | 310 | topics := []string{sse.DefaultTopic} 311 | suberr := make(chan error) 312 | go func() { suberr <- j.Subscribe(context.Background(), sse.Subscription{Client: wr, Topics: topics}) }() 313 | 314 | _, ok := <-rp.replayc 315 | tests.Expect(t, ok, "replay wasn't called") 316 | 317 | msg := &sse.Message{ID: sse.ID("hello")} 318 | tests.Equal(t, j.Publish(msg, topics), nil, "replay put should not be triggered by publishing anymore") 319 | tests.Equal(t, (<-wr.msg).ID, msg.ID, "message was not sent to client") 320 | 321 | go func() { _ = j.Subscribe(context.Background(), sse.Subscription{}) }() 322 | time.Sleep(time.Millisecond) 323 | tests.Equal(t, rp.replays(), 0, "replay was called") 324 | 325 | tests.Equal(t, j.Shutdown(context.Background()), nil, "shutdown should succeed") 326 | tests.Equal(t, <-suberr, sse.ErrProviderClosed, "expected subscribe error due to forceful shutdown") 327 | 328 | rp = newMockReplayer("put", 1) 329 | j = &sse.Joe{Replayer: rp} 330 | go func() { suberr <- j.Subscribe(context.Background(), sse.Subscription{Client: wr, Topics: topics}) }() 331 | 332 | _, ok = <-rp.replayc 333 | tests.Expect(t, ok, "replay was called") 334 | 335 | tests.Equal(t, j.Publish(msg, topics), nil, "replay put error should not be propagated") 336 | tests.Equal(t, (<-wr.msg).ID, msg.ID, "message was not sent to client") 337 | 338 | tests.Equal(t, j.Shutdown(context.Background()), nil, "shutdown should succeed") 339 | tests.Equal(t, <-suberr, sse.ErrProviderClosed, "expected subscribe error due to forceful shutdown") 340 | } 341 | 342 | func TestJoe_ClientContextCloseOnError(t *testing.T) { 343 | t.Parallel() 344 | 345 | ctx, cancel := newMockContext(t) 346 | 347 | mw := mockClient(func(m *sse.Message) error { 348 | if m == nil { 349 | cancel() 350 | time.Sleep(time.Millisecond) 351 | return errors.New("flush error") 352 | } 353 | 354 | return nil 355 | }) 356 | 357 | j := &sse.Joe{} 358 | t.Cleanup(func() { _ = j.Shutdown(context.Background()) }) 359 | 360 | errch := make(chan error) 361 | topics := []string{sse.DefaultTopic} 362 | 363 | go func() { 364 | errch <- j.Subscribe(ctx, sse.Subscription{Client: mw, Topics: topics}) 365 | }() 366 | 367 | <-ctx.waitingOnDone 368 | 369 | tests.Equal(t, j.Publish(&sse.Message{ID: sse.ID("trigger")}, topics), nil, "unexpected publish error") 370 | // The error above shouldn't be propagated, as the subscriber's context is done before the replay is done. 371 | // Subscription should end instantly and not wait for anything else other than the unsubscription signal 372 | // to be sent successfully. 373 | tests.Equal(t, <-errch, nil, "unexpected subscribe error") 374 | } 375 | -------------------------------------------------------------------------------- /message.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "strconv" 8 | "strings" 9 | "time" 10 | "unicode/utf8" 11 | "unsafe" 12 | 13 | "github.com/tmaxmax/go-sse/internal/parser" 14 | ) 15 | 16 | func isSingleLine(p string) bool { 17 | _, newlineLen := parser.NewlineIndex(p) 18 | return newlineLen == 0 19 | } 20 | 21 | // fieldBytes holds the byte representation of each field type along with a colon at the end. 22 | var ( 23 | fieldBytesData = []byte(parser.FieldNameData + ": ") 24 | fieldBytesEvent = []byte(parser.FieldNameEvent + ": ") 25 | fieldBytesRetry = []byte(parser.FieldNameRetry + ": ") 26 | fieldBytesID = []byte(parser.FieldNameID + ": ") 27 | fieldBytesComment = []byte(": ") 28 | ) 29 | 30 | type chunk struct { 31 | content string 32 | isComment bool 33 | } 34 | 35 | var newline = []byte{'\n'} 36 | 37 | func (c *chunk) WriteTo(w io.Writer) (int64, error) { 38 | name := fieldBytesData 39 | if c.isComment { 40 | name = fieldBytesComment 41 | } 42 | n, err := w.Write(name) 43 | if err != nil { 44 | return int64(n), err 45 | } 46 | m, err := writeString(w, c.content) 47 | n += m 48 | if err != nil { 49 | return int64(n), err 50 | } 51 | m, err = w.Write(newline) 52 | return int64(n + m), err 53 | } 54 | 55 | // Message is the representation of an event sent from the server to its clients. 56 | type Message struct { 57 | chunks []chunk 58 | 59 | ID EventID 60 | Type EventType 61 | Retry time.Duration 62 | } 63 | 64 | func (e *Message) appendText(isComment bool, chunks ...string) { 65 | for _, c := range chunks { 66 | var content string 67 | 68 | for c != "" { 69 | content, c, _ = parser.NextChunk(c) 70 | e.chunks = append(e.chunks, chunk{content: content, isComment: isComment}) 71 | } 72 | } 73 | } 74 | 75 | // AppendData adds multiple data fields on the message's event from the given strings. 76 | // Each string will be a distinct data field, and if the strings themselves span multiple lines 77 | // they will be broken into multiple fields. 78 | // 79 | // Server-sent events are not suited for binary data: the event fields are delimited by newlines, 80 | // where a newline can be a LF, CR or CRLF sequence. When the client interprets the fields, 81 | // it joins multiple data fields using LF, so information is altered. Here's an example: 82 | // 83 | // initial payload: This is a\r\nmultiline\rtext.\nIt has multiple\nnewline\r\nvariations. 84 | // data sent over the wire: 85 | // data: This is a 86 | // data: multiline 87 | // data: text. 88 | // data: It has multiple 89 | // data: newline 90 | // data: variations 91 | // data received by client: This is a\nmultiline\ntext.\nIt has multiple\nnewline\nvariations. 92 | // 93 | // Each line prepended with "data:" is a field; multiple data fields are joined together using LF as the delimiter. 94 | // If you attempted to send the same payload without prepending the "data:" prefix, like so: 95 | // 96 | // data: This is a 97 | // multiline 98 | // text. 99 | // It has multiple 100 | // newline 101 | // variations 102 | // 103 | // there would be only one data field (the first one). The rest would be different fields, named "multiline", "text.", 104 | // "It has multiple" etc., which are invalid fields according to the protocol. 105 | // 106 | // Besides, the protocol explicitly states that event streams must always be UTF-8 encoded: 107 | // https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream. 108 | // 109 | // If you need to send binary data, you can use a Base64 encoder or any other encoder that does not output 110 | // any newline characters (\r or \n) and then append the resulted data. 111 | // 112 | // Given that clients treat all newlines the same and replace the original newlines with LF, 113 | // for internal code simplicity AppendData replaces them as well. 114 | func (e *Message) AppendData(chunks ...string) { 115 | e.appendText(false, chunks...) 116 | } 117 | 118 | // AppendComment adds comment fields to the message's event. 119 | // If the comments span multiple lines, they are broken into multiple comment fields. 120 | func (e *Message) AppendComment(comments ...string) { 121 | e.appendText(true, comments...) 122 | } 123 | 124 | func (e *Message) writeMessageField(w io.Writer, f messageField, fieldBytes []byte) (int64, error) { 125 | if !f.IsSet() { 126 | return 0, nil 127 | } 128 | 129 | n, err := w.Write(fieldBytes) 130 | if err != nil { 131 | return int64(n), err 132 | } 133 | m, err := writeString(w, f.String()) 134 | n += m 135 | if err != nil { 136 | return int64(n), err 137 | } 138 | m, err = w.Write(newline) 139 | return int64(n + m), err 140 | } 141 | 142 | func (e *Message) writeID(w io.Writer) (int64, error) { 143 | return e.writeMessageField(w, e.ID.messageField, fieldBytesID) 144 | } 145 | 146 | func (e *Message) writeType(w io.Writer) (int64, error) { 147 | return e.writeMessageField(w, e.Type.messageField, fieldBytesEvent) 148 | } 149 | 150 | func (e *Message) writeRetry(w io.Writer) (int64, error) { 151 | millis := e.Retry.Milliseconds() 152 | if millis <= 0 { 153 | return 0, nil 154 | } 155 | 156 | n, err := w.Write(fieldBytesRetry) 157 | if err != nil { 158 | return int64(n), err 159 | } 160 | 161 | var buf [13]byte // log10(INT64_MAX / 1e6) ~= 13 162 | 163 | i := len(buf) - 1 164 | for millis != 0 { 165 | buf[i] = '0' + byte(millis%10) 166 | i-- 167 | millis /= 10 168 | } 169 | 170 | m, err := w.Write(buf[i+1:]) 171 | n += m 172 | if err != nil { 173 | return int64(n), err 174 | } 175 | m, err = w.Write(newline) 176 | return int64(n + m), err 177 | } 178 | 179 | // WriteTo writes the standard textual representation of the message's event to an io.Writer. 180 | // This operation is heavily optimized, so it is strongly preferred over MarshalText or String. 181 | func (e *Message) WriteTo(w io.Writer) (int64, error) { 182 | n, err := e.writeID(w) 183 | if err != nil { 184 | return n, err 185 | } 186 | m, err := e.writeType(w) 187 | n += m 188 | if err != nil { 189 | return n, err 190 | } 191 | m, err = e.writeRetry(w) 192 | n += m 193 | if err != nil { 194 | return n, err 195 | } 196 | for i := range e.chunks { 197 | m, err = e.chunks[i].WriteTo(w) 198 | n += m 199 | if err != nil { 200 | return n, err 201 | } 202 | } 203 | if n == 0 { 204 | return 0, nil 205 | } 206 | o, err := w.Write(newline) 207 | return int64(o) + n, err 208 | } 209 | 210 | // MarshalText writes the standard textual representation of the message's event. Marshalling and unmarshalling will 211 | // result in a message with an event that has the same fields; topic will be lost. 212 | // 213 | // If you want to preserve everything, create your own custom marshalling logic. 214 | // For an example using encoding/json, see the top-level MessageCustomJSONMarshal example. 215 | // 216 | // Use the WriteTo method if you don't need the byte representation. 217 | // 218 | // The representation is written to a bytes.Buffer, which means the error is always nil. 219 | // If the buffer grows to a size bigger than the maximum allowed, MarshalText will panic. 220 | // See the bytes.Buffer documentation for more info. 221 | func (e *Message) MarshalText() ([]byte, error) { 222 | b := bytes.Buffer{} 223 | _, err := e.WriteTo(&b) 224 | return b.Bytes(), err 225 | } 226 | 227 | // String writes the message's event standard textual representation to a strings.Builder and returns the resulted string. 228 | // It may panic if the representation is too long to be buffered. 229 | // 230 | // Use the WriteTo method if you don't actually need the string representation. 231 | func (e *Message) String() string { 232 | s := strings.Builder{} 233 | _, _ = e.WriteTo(&s) 234 | return s.String() 235 | } 236 | 237 | // UnmarshalError is the error returned by the Message's UnmarshalText method. 238 | // If the error is related to a specific field, FieldName will be a non-empty string. 239 | // If no fields were found in the target text or any other errors occurred, only 240 | // a Reason will be provided. Reason is always present. 241 | type UnmarshalError struct { 242 | Reason error 243 | FieldName string 244 | // The value of the invalid field. 245 | FieldValue string 246 | } 247 | 248 | func (u *UnmarshalError) Error() string { 249 | if u.FieldName == "" { 250 | return fmt.Sprintf("unmarshal event error: %s", u.Reason.Error()) 251 | } 252 | return fmt.Sprintf("unmarshal event error, %s field invalid: %s. contents: %s", u.FieldName, u.Reason.Error(), u.FieldValue) 253 | } 254 | 255 | func (u *UnmarshalError) Unwrap() error { 256 | return u.Reason 257 | } 258 | 259 | // ErrUnexpectedEOF is returned when unmarshaling a Message from an input that doesn't end in a newline. 260 | // 261 | // If it returned from a Connection, it means that the data from the server has reached EOF 262 | // in the middle of an incomplete event and retries are disabled (normally the client retries 263 | // the connection in this situation). 264 | var ErrUnexpectedEOF = parser.ErrUnexpectedEOF 265 | 266 | func (e *Message) reset() { 267 | e.chunks = nil 268 | e.Type = EventType{} 269 | e.ID = EventID{} 270 | e.Retry = 0 271 | } 272 | 273 | // UnmarshalText extracts the first event found in the given byte slice into the 274 | // receiver. The input is expected to be a wire format event, as defined by the spec. 275 | // Therefore, previous fields present on the Message will be overwritten 276 | // (i.e. event, ID, comments, data, retry). 277 | // 278 | // Unmarshaling ignores fields with invalid names. If no valid fields are found, 279 | // an error is returned. For a field to be valid it must end in a newline - if the last 280 | // field of the event doesn't end in one, an error is returned. 281 | // 282 | // All returned errors are of type UnmarshalError. 283 | func (e *Message) UnmarshalText(p []byte) error { 284 | e.reset() 285 | 286 | s := parser.NewFieldParser(string(p)) 287 | s.KeepComments(true) 288 | s.RemoveBOM(true) 289 | 290 | loop: 291 | for f := (parser.Field{}); s.Next(&f); { 292 | switch f.Name { 293 | case parser.FieldNameRetry: 294 | if i := strings.IndexFunc(f.Value, func(r rune) bool { 295 | return r < '0' || r > '9' 296 | }); i != -1 { 297 | r, _ := utf8.DecodeRuneInString(f.Value[i:]) 298 | 299 | return &UnmarshalError{ 300 | FieldName: string(f.Name), 301 | FieldValue: f.Value, 302 | Reason: fmt.Errorf("contains character %q, which is not an ASCII digit", r), 303 | } 304 | } 305 | 306 | milli, err := strconv.ParseInt(f.Value, 10, 64) 307 | if err != nil { 308 | return &UnmarshalError{ 309 | FieldName: string(f.Name), 310 | FieldValue: f.Value, 311 | Reason: fmt.Errorf("invalid retry value: %w", err), 312 | } 313 | } 314 | 315 | e.Retry = time.Duration(milli) * time.Millisecond 316 | case parser.FieldNameData, parser.FieldNameComment: 317 | e.chunks = append(e.chunks, chunk{content: f.Value, isComment: f.Name == parser.FieldNameComment}) 318 | case parser.FieldNameEvent: 319 | e.Type.value = f.Value 320 | e.Type.set = true 321 | case parser.FieldNameID: 322 | if strings.IndexByte(f.Value, 0) != -1 { 323 | break 324 | } 325 | 326 | e.ID.value = f.Value 327 | e.ID.set = true 328 | default: // event end 329 | break loop 330 | } 331 | } 332 | 333 | if len(e.chunks) == 0 && !e.Type.IsSet() && e.Retry == 0 && !e.ID.IsSet() || s.Err() != nil { 334 | e.reset() 335 | return &UnmarshalError{Reason: ErrUnexpectedEOF} 336 | } 337 | return nil 338 | } 339 | 340 | // Clone returns a copy of the message. 341 | func (e *Message) Clone() *Message { 342 | return &Message{ 343 | // The first AppendData will trigger a reallocation. 344 | // Already appended chunks cannot be modified/removed, so this is safe. 345 | chunks: e.chunks[:len(e.chunks):len(e.chunks)], 346 | Retry: e.Retry, 347 | Type: e.Type, 348 | ID: e.ID, 349 | } 350 | } 351 | 352 | func writeString(w io.Writer, s string) (int, error) { 353 | return w.Write(unsafe.Slice(unsafe.StringData(s), len(s))) 354 | } 355 | -------------------------------------------------------------------------------- /message_fields.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | // EventID is a value of the "id" field. 11 | // It must have a single line. 12 | type EventID struct { 13 | messageField 14 | } 15 | 16 | // NewID creates an event ID value. A valid ID must not have any newlines. 17 | // If the input is not valid, an unset (invalid) ID is returned. 18 | func NewID(value string) (EventID, error) { 19 | f, err := newMessageField(value) 20 | if err != nil { 21 | return EventID{}, fmt.Errorf("invalid event ID: %w", err) 22 | } 23 | 24 | return EventID{f}, nil 25 | } 26 | 27 | // ID creates an event ID and assumes it is valid. 28 | // If it is not valid, it panics. 29 | func ID(value string) EventID { 30 | return must(NewID(value)) 31 | } 32 | 33 | // EventType is a value of the "event" field. 34 | // It must have a single line. 35 | type EventType struct { 36 | messageField 37 | } 38 | 39 | // NewType creates a value for the "event" field. 40 | // It is valid if it does not have any newlines. 41 | // If the input is not valid, an unset (invalid) ID is returned. 42 | func NewType(value string) (EventType, error) { 43 | f, err := newMessageField(value) 44 | if err != nil { 45 | return EventType{}, fmt.Errorf("invalid event type: %w", err) 46 | } 47 | 48 | return EventType{f}, nil 49 | } 50 | 51 | // Type creates an EventType and assumes it is valid. 52 | // If it is not valid, it panics. 53 | func Type(value string) EventType { 54 | return must(NewType(value)) 55 | } 56 | 57 | func must[T any](v T, err error) T { 58 | if err != nil { 59 | panic(err) 60 | } 61 | return v 62 | } 63 | 64 | // The messageField struct represents any valid field value 65 | // i.e. single line strings. 66 | // Must be passed by value and are comparable. 67 | type messageField struct { 68 | value string 69 | set bool 70 | } 71 | 72 | func newMessageField(value string) (messageField, error) { 73 | if !isSingleLine(value) { 74 | return messageField{}, errors.New("input is multiline") 75 | } 76 | return messageField{value: value, set: true}, nil 77 | } 78 | 79 | // IsSet returns true if the receiver is a valid (set) value. 80 | func (i messageField) IsSet() bool { 81 | return i.set 82 | } 83 | 84 | // String returns the underlying value. The value may be an empty string, 85 | // make sure to check if the value is set before using it. 86 | func (i messageField) String() string { 87 | return i.value 88 | } 89 | 90 | // UnmarshalText sets the underlying value to the given string, if valid. 91 | // If the input is invalid, no changes are made to the receiver. 92 | func (i *messageField) UnmarshalText(data []byte) error { 93 | *i = messageField{} 94 | 95 | id, err := newMessageField(string(data)) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | *i = id 101 | 102 | return nil 103 | } 104 | 105 | // UnmarshalJSON sets the underlying value to the given JSON value 106 | // if the value is a string. The previous value is discarded if the operation fails. 107 | func (i *messageField) UnmarshalJSON(data []byte) error { 108 | *i = messageField{} 109 | 110 | if string(data) == "null" { 111 | return nil 112 | } 113 | 114 | var input string 115 | 116 | if err := json.Unmarshal(data, &input); err != nil { 117 | return err 118 | } 119 | 120 | id, err := newMessageField(input) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | *i = id 126 | 127 | return nil 128 | } 129 | 130 | // MarshalText returns a copy of the underlying value if it is set. 131 | // It returns an error when trying to marshal an unset value. 132 | func (i *messageField) MarshalText() ([]byte, error) { 133 | if i.IsSet() { 134 | return []byte(i.String()), nil 135 | } 136 | 137 | return nil, fmt.Errorf("can't marshal unset string to text") 138 | } 139 | 140 | // MarshalJSON returns a JSON representation of the underlying value if it is set. 141 | // It otherwise returns the representation of the JSON null value. 142 | func (i *messageField) MarshalJSON() ([]byte, error) { 143 | if i.IsSet() { 144 | return json.Marshal(i.String()) 145 | } 146 | 147 | return json.Marshal(nil) 148 | } 149 | 150 | // Scan implements the sql.Scanner interface. Values can be scanned from: 151 | // - nil interfaces (result: unset value) 152 | // - byte slice 153 | // - string 154 | func (i *messageField) Scan(src interface{}) error { 155 | *i = messageField{} 156 | 157 | if src == nil { 158 | return nil 159 | } 160 | 161 | switch v := src.(type) { 162 | case []byte: 163 | i.value = string(v) 164 | case string: 165 | i.value = string([]byte(v)) 166 | default: 167 | return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, *i) 168 | } 169 | 170 | i.set = true 171 | 172 | return nil 173 | } 174 | 175 | // Value implements the driver.Valuer interface. 176 | func (i messageField) Value() (driver.Value, error) { 177 | if i.IsSet() { 178 | return i.String(), nil 179 | } 180 | return nil, nil 181 | } 182 | -------------------------------------------------------------------------------- /message_fields_test.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tmaxmax/go-sse/internal/tests" 7 | ) 8 | 9 | func mustMessageField(tb testing.TB, value string) messageField { //nolint:unparam // May receive other values. 10 | tb.Helper() 11 | 12 | f, err := newMessageField(value) 13 | if err != nil { 14 | panic(err) 15 | } 16 | 17 | return f 18 | } 19 | 20 | func TestNewMessageField(t *testing.T) { 21 | t.Parallel() 22 | 23 | id, err := newMessageField("") 24 | tests.Equal(t, err, nil, "field evaluated as invalid") 25 | tests.Expect(t, id.IsSet(), "field is not set") 26 | tests.Equal(t, id.String(), "", "field incorrectly set") 27 | 28 | id, err = newMessageField("in\nvalid") 29 | tests.Expect(t, err != nil, "field evaluated as valid") 30 | tests.Expect(t, !id.IsSet() && id.String() == "", "field isn't unset") 31 | } 32 | 33 | func TestMessageField_UnmarshalJSON(t *testing.T) { 34 | t.Parallel() 35 | 36 | type test struct { 37 | name string 38 | input []byte 39 | output messageField 40 | expectErr bool 41 | } 42 | 43 | tt := []test{ 44 | {name: "Valid input", input: []byte("\"\""), output: mustMessageField(t, "")}, 45 | {name: "Null input", input: []byte("null")}, 46 | {name: "Invalid JSON value", input: []byte("525482"), expectErr: true}, 47 | {name: "Invalid input", input: []byte("\"multi\\nline\""), expectErr: true}, 48 | } 49 | 50 | for _, test := range tt { 51 | test := test 52 | 53 | t.Run(test.name, func(t *testing.T) { 54 | t.Parallel() 55 | 56 | id := messageField{} 57 | err := id.UnmarshalJSON(test.input) 58 | 59 | if test.expectErr { 60 | tests.Expect(t, err != nil, "expected error") 61 | } else { 62 | tests.Equal(t, err, nil, "unexpected error") 63 | } 64 | 65 | tests.Equal(t, id, test.output, "unexpected unmarshal result") 66 | }) 67 | } 68 | } 69 | 70 | func TestMessageField_UnmarshalText(t *testing.T) { 71 | t.Parallel() 72 | 73 | var id messageField 74 | err := id.UnmarshalText([]byte("")) 75 | 76 | tests.Equal(t, id, mustMessageField(t, ""), "unexpected unmarshal result") 77 | tests.Equal(t, err, nil, "unexpected error") 78 | 79 | err = id.UnmarshalText([]byte("in\nvalid")) 80 | 81 | tests.Expect(t, err != nil, "expected error") 82 | tests.Expect(t, !id.IsSet() && id.String() == "", "ID is not unset after invalid unmarshal") 83 | } 84 | 85 | func TestMessageField_MarshalJSON(t *testing.T) { 86 | t.Parallel() 87 | 88 | var id messageField 89 | v, err := id.MarshalJSON() 90 | 91 | tests.Equal(t, err, nil, "unexpected error") 92 | tests.Equal(t, string(v), "null", "invalid JSON result") 93 | 94 | id = mustMessageField(t, "") 95 | v, err = id.MarshalJSON() 96 | 97 | tests.Equal(t, err, nil, "unexpected error") 98 | tests.Equal(t, string(v), "\"\"", "invalid JSON result") 99 | } 100 | 101 | func TestMessageField_MarshalText(t *testing.T) { 102 | t.Parallel() 103 | 104 | var id messageField 105 | v, err := id.MarshalText() 106 | 107 | tests.Expect(t, err != nil, "expected error") 108 | tests.DeepEqual(t, v, nil, "invalid result") 109 | 110 | id = mustMessageField(t, "") 111 | v, err = id.MarshalText() 112 | 113 | tests.Equal(t, err, nil, "unexpected error") 114 | tests.DeepEqual(t, v, []byte{}, "unexpected result") 115 | } 116 | 117 | func TestMessageField_Scan(t *testing.T) { 118 | t.Parallel() 119 | 120 | var id messageField 121 | 122 | err := id.Scan(nil) 123 | tests.Equal(t, err, nil, "unexpected error") 124 | tests.Equal(t, id, messageField{}, "unexpected result") 125 | 126 | err = id.Scan("") 127 | tests.Equal(t, err, nil, "unexpected error") 128 | tests.Equal(t, id, mustMessageField(t, ""), "unexpected result") 129 | 130 | err = id.Scan([]byte("")) 131 | tests.Equal(t, err, nil, "unexpected error") 132 | tests.Equal(t, id, mustMessageField(t, ""), "unexpected result") 133 | 134 | err = id.Scan(5) 135 | tests.Expect(t, err != nil, "expected error") 136 | tests.Equal(t, id, messageField{}, "invalid result") 137 | } 138 | 139 | func TestMessageField_Value(t *testing.T) { 140 | t.Parallel() 141 | 142 | var id messageField 143 | v, err := id.Value() 144 | tests.Equal(t, err, nil, "unexpected error") 145 | tests.Equal(t, v, nil, "unexpected value") 146 | 147 | id = mustMessageField(t, "") 148 | v, err = id.Value() 149 | tests.Equal(t, err, nil, "unexpected error") 150 | tests.Equal(t, v, "", "unexpected value") 151 | } 152 | 153 | func TestFieldConstructors(t *testing.T) { 154 | t.Parallel() 155 | 156 | _, err := NewID("a\nb") 157 | tests.Equal(t, err.Error(), "invalid event ID: input is multiline", "unexpected error message") 158 | _, err = NewType("a\nb") 159 | tests.Equal(t, err.Error(), "invalid event type: input is multiline", "unexpected error message") 160 | 161 | tests.Panics(t, func() { ID("a\nb") }, "id creation should panic") 162 | tests.Panics(t, func() { Type("a\nb") }, "id creation should panic") 163 | } 164 | -------------------------------------------------------------------------------- /message_test.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/binary" 7 | "encoding/hex" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "os" 12 | "strings" 13 | "testing" 14 | "time" 15 | 16 | "github.com/tmaxmax/go-sse/internal/parser" 17 | "github.com/tmaxmax/go-sse/internal/tests" 18 | ) 19 | 20 | func TestNew(t *testing.T) { 21 | t.Parallel() 22 | 23 | e := Message{Type: Type("x"), ID: ID("lol"), Retry: time.Second} 24 | e.AppendData("whatever", "input", "will\nbe\nchunked", "amazing") 25 | 26 | expected := Message{ 27 | chunks: []chunk{ 28 | {content: "whatever"}, 29 | {content: "input"}, 30 | {content: "will"}, 31 | {content: "be"}, 32 | {content: "chunked"}, 33 | {content: "amazing"}, 34 | }, 35 | Retry: time.Second, 36 | Type: Type("x"), 37 | ID: ID("lol"), 38 | } 39 | 40 | tests.DeepEqual(t, e, expected, "invalid event") 41 | } 42 | 43 | func TestEvent_WriteTo(t *testing.T) { 44 | t.Parallel() 45 | 46 | t.Run("Empty", func(t *testing.T) { 47 | e := &Message{} 48 | w := &strings.Builder{} 49 | n, _ := e.WriteTo(w) 50 | tests.Equal(t, n, 0, "bytes were written") 51 | tests.Equal(t, w.String(), "", "message should produce no output") 52 | }) 53 | 54 | t.Run("Valid", func(t *testing.T) { 55 | e := &Message{Type: Type("test_event"), ID: ID("example_id"), Retry: time.Second * 5} 56 | e.AppendData("This is an example\nOf an event", "", "a string here") 57 | e.AppendComment("This test should pass") 58 | e.AppendData("Important data\nImportant again\r\rVery important\r\n") 59 | 60 | output := "id: example_id\nevent: test_event\nretry: 5000\ndata: This is an example\ndata: Of an event\ndata: a string here\n: This test should pass\ndata: Important data\ndata: Important again\ndata: \ndata: Very important\n\n" 61 | expectedWritten := int64(len(output)) 62 | 63 | w := &strings.Builder{} 64 | 65 | written, _ := e.WriteTo(w) 66 | 67 | tests.Equal(t, w.String(), output, "event written incorrectly") 68 | tests.Equal(t, written, expectedWritten, "written byte count wrong") 69 | }) 70 | 71 | type retryTest struct { 72 | expected string 73 | value time.Duration 74 | } 75 | 76 | retryTests := []retryTest{ 77 | {value: -1}, 78 | {value: 0}, 79 | {value: time.Microsecond}, 80 | {value: time.Millisecond, expected: "retry: 1\n\n"}, 81 | } 82 | for _, v := range retryTests { 83 | t.Run(fmt.Sprintf("Retry/%s", v.value), func(t *testing.T) { 84 | e := &Message{Retry: v.value} 85 | tests.Equal(t, e.String(), v.expected, "incorrect output") 86 | }) 87 | } 88 | } 89 | 90 | func TestEvent_UnmarshalText(t *testing.T) { 91 | t.Parallel() 92 | 93 | type test struct { 94 | name string 95 | input string 96 | expectedErr error 97 | expected Message 98 | } 99 | 100 | nilEvent := Message{} 101 | nilEvent.reset() 102 | 103 | tt := []test{ 104 | { 105 | name: "No input", 106 | expected: nilEvent, 107 | expectedErr: &UnmarshalError{Reason: ErrUnexpectedEOF}, 108 | }, 109 | { 110 | name: "Invalid retry field", 111 | input: "retry: sigma male\n", 112 | expected: nilEvent, 113 | expectedErr: &UnmarshalError{ 114 | FieldName: string(parser.FieldNameRetry), 115 | FieldValue: "sigma male", 116 | Reason: fmt.Errorf("contains character %q, which is not an ASCII digit", 's'), 117 | }, 118 | }, 119 | { 120 | name: "Valid input, no final newline", 121 | input: "data: first\ndata:second\ndata:third", 122 | expected: nilEvent, 123 | expectedErr: &UnmarshalError{Reason: ErrUnexpectedEOF}, 124 | }, 125 | { 126 | name: "Valid input", 127 | input: "data: raw bytes here\nretry: 500\nretry: 1000\nid: 1000\nid: 2000\nid: \x001\n: with comments\ndata: again raw bytes\ndata: from multiple lines\nevent: overwritten name\nevent: my name here\n\ndata: I should be ignored", 128 | expected: Message{ 129 | chunks: []chunk{ 130 | {content: "raw bytes here"}, 131 | {content: "with comments", isComment: true}, 132 | {content: "again raw bytes"}, 133 | {content: "from multiple lines"}, 134 | }, 135 | Retry: time.Second, 136 | Type: Type("my name here"), 137 | ID: ID("2000"), 138 | }, 139 | }, 140 | } 141 | 142 | for _, test := range tt { 143 | t.Run(test.name, func(t *testing.T) { 144 | e := Message{} 145 | 146 | if err := e.UnmarshalText([]byte(test.input)); (test.expectedErr != nil && err.Error() != test.expectedErr.Error()) || (test.expectedErr == nil && err != nil) { 147 | t.Fatalf("Invalid unmarshal error: got %q, want %q", err, test.expectedErr) 148 | } 149 | tests.DeepEqual(t, e, test.expected, "invalid unmarshal") 150 | }) 151 | } 152 | } 153 | 154 | //nolint:all 155 | func Example_messageWriter() { 156 | e := Message{ 157 | Type: Type("test"), 158 | ID: ID("1"), 159 | } 160 | w := &strings.Builder{} 161 | 162 | bw := base64.NewEncoder(base64.StdEncoding, w) 163 | binary.Write(bw, binary.BigEndian, []byte{6, 9, 4, 2, 0}) 164 | binary.Write(bw, binary.BigEndian, []byte("data from sensor")) 165 | bw.Close() 166 | w.WriteByte('\n') // Ensures that the data written above will be a distinct `data` field. 167 | 168 | enc := json.NewEncoder(w) 169 | enc.SetIndent("", " ") 170 | enc.Encode(map[string]string{"hello": "world"}) 171 | // Not necessary to add a newline here – json.Encoder.Encode adds a newline at the end. 172 | 173 | // io.CopyN(hex.NewEncoder(w), rand.Reader, 8) 174 | io.Copy(hex.NewEncoder(w), bytes.NewReader([]byte{5, 1, 6, 34, 234, 12, 143, 91})) 175 | 176 | mw := io.MultiWriter(os.Stdout, w) 177 | // The first newline adds the data written above as a `data field`. 178 | io.WriteString(mw, "\nYou'll see me both in console and in event\n\n") 179 | 180 | // Add the data to the event. It will be split into fields here, 181 | // according to the newlines present in the input. 182 | e.AppendData(w.String()) 183 | e.WriteTo(os.Stdout) 184 | // Output: 185 | // You'll see me both in console and in event 186 | // 187 | // id: 1 188 | // event: test 189 | // data: BgkEAgBkYXRhIGZyb20gc2Vuc29y 190 | // data: { 191 | // data: "hello": "world" 192 | // data: } 193 | // data: 05010622ea0c8f5b 194 | // data: You'll see me both in console and in event 195 | // data: 196 | } 197 | 198 | func newBenchmarkEvent() *Message { 199 | e := Message{Type: Type("This is the event's name"), ID: ID("example_id"), Retry: time.Minute} 200 | e.AppendData("Example data\nWith multiple rows\r\nThis is interesting") 201 | e.AppendComment("An useless comment here that spans\non\n\nmultiple\nlines") 202 | return &e 203 | } 204 | 205 | var benchmarkEvent = newBenchmarkEvent() 206 | 207 | func BenchmarkEvent_WriteTo(b *testing.B) { 208 | b.ReportAllocs() 209 | 210 | for n := 0; n < b.N; n++ { 211 | _, _ = benchmarkEvent.WriteTo(io.Discard) 212 | } 213 | } 214 | 215 | var benchmarkText = []string{ 216 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", 217 | "Pellentesque at dui non quam faucibus ultricies.", 218 | "Quisque non sem gravida, sodales lorem eget, lobortis est.", 219 | "Quisque porttitor nunc eu mollis congue.", 220 | "Vivamus sollicitudin tellus ut mi malesuada lacinia.", 221 | "Aenean aliquet tortor non urna sodales dignissim.", 222 | "Sed quis diam sed dui feugiat aliquam.", 223 | "Etiam sit amet neque cursus, semper nibh non, ornare nunc.", 224 | "Phasellus dignissim lacus vitae felis interdum, eget pharetra augue bibendum.", 225 | "Sed euismod enim sed ante laoreet, non ullamcorper enim dapibus.", 226 | "Ut accumsan arcu venenatis, egestas nisi consectetur, dignissim felis.", 227 | "Praesent lacinia elit ut tristique molestie.", 228 | "Mauris ut nibh id ante ultricies egestas.", 229 | "Mauris porttitor augue quis maximus efficitur.", 230 | "Fusce auctor enim viverra elit imperdiet, non dignissim dolor condimentum.", 231 | "Fusce scelerisque quam vel erat tempor elementum.", 232 | "Nullam ac velit in nisl hendrerit rhoncus sed ut dui.", 233 | "Pellentesque laoreet arcu vitae commodo gravida.", 234 | "Pellentesque sagittis enim quis sapien mollis tempor.", 235 | "Phasellus fermentum leo vitae odio efficitur, eu lacinia enim elementum.", 236 | "Morbi faucibus nisi a velit dictum eleifend.", 237 | } 238 | 239 | func BenchmarkEvent_WriteTo_text(b *testing.B) { 240 | ev := Message{} 241 | ev.AppendData(benchmarkText...) 242 | 243 | for n := 0; n < b.N; n++ { 244 | _, _ = ev.WriteTo(io.Discard) 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /replay.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | "time" 7 | ) 8 | 9 | // NewFiniteReplayer creates a finite replay provider with the given max 10 | // count and auto ID behaviour. 11 | // 12 | // Count is the maximum number of events FiniteReplayer should hold as 13 | // valid. It must be greater than zero. 14 | // 15 | // AutoIDs configures FiniteReplayer to automatically set the IDs of 16 | // events. 17 | func NewFiniteReplayer( 18 | count int, autoIDs bool, 19 | ) (*FiniteReplayer, error) { 20 | if count < 2 { 21 | return nil, errors.New("count must be at least 2") 22 | } 23 | 24 | r := &FiniteReplayer{} 25 | r.buf.buf = make([]messageWithTopics, count) 26 | if autoIDs { 27 | r.currentID = new(uint64) 28 | } 29 | 30 | return r, nil 31 | } 32 | 33 | // FiniteReplayer is a replayer that replays at maximum a certain number of events. 34 | // The events must have an ID unless the replayer is configured to set IDs automatically. 35 | type FiniteReplayer struct { 36 | currentID *uint64 37 | buf queue[messageWithTopics] 38 | } 39 | 40 | // Put puts a message into the replayer's buffer. If there are more messages than the maximum 41 | // number, the oldest message is removed. 42 | func (f *FiniteReplayer) Put(message *Message, topics []string) (*Message, error) { 43 | if len(topics) == 0 { 44 | return nil, ErrNoTopic 45 | } 46 | 47 | message, err := ensureID(message, f.currentID) 48 | if err != nil { 49 | return nil, err 50 | } 51 | 52 | f.buf.enqueue(messageWithTopics{message: message, topics: topics}) 53 | 54 | return message, nil 55 | } 56 | 57 | // Replay replays the stored messages to the listener. 58 | func (f *FiniteReplayer) Replay(subscription Subscription) error { 59 | i := findIDInQueue(&f.buf, subscription.LastEventID, f.currentID != nil) 60 | if i < 0 { 61 | return nil 62 | } 63 | 64 | var err error 65 | f.buf.each(i)(func(_ int, m messageWithTopics) bool { 66 | if topicsIntersect(subscription.Topics, m.topics) { 67 | if err = subscription.Client.Send(m.message); err != nil { 68 | return false 69 | } 70 | } 71 | return true 72 | }) 73 | if err != nil { 74 | return err 75 | } 76 | 77 | return subscription.Client.Flush() 78 | } 79 | 80 | // ValidReplayer is a Replayer that replays all the buffered non-expired events. 81 | // 82 | // The replayer removes any expired events when a new event is put and after at least 83 | // a GCInterval period passed. 84 | // 85 | // The events must have an ID unless the replayer is configured to set IDs automatically. 86 | type ValidReplayer struct { 87 | lastGC time.Time 88 | 89 | // The function used to retrieve the current time. Defaults to time.Now. 90 | // Useful when testing. 91 | Now func() time.Time 92 | 93 | currentID *uint64 94 | messages queue[messageWithTopicsAndExpiry] 95 | 96 | ttl time.Duration 97 | // After how long the replayer should attempt to clean up expired events. 98 | // By default cleanup is done after a fourth of the TTL has passed; this means 99 | // that messages may be stored for a duration equal to 5/4*TTL. If this is not 100 | // desired, set the GC interval to a value sensible for your use case or set 101 | // it to 0 – this disables automatic cleanup, enabling you to do it manually 102 | // using the GC method. 103 | GCInterval time.Duration 104 | } 105 | 106 | // NewValidReplayer creates a ValidReplayer with the given message 107 | // lifetime duration (time-to-live) and auto ID behavior. 108 | // 109 | // The TTL must be a positive duration. It is technically possible to use a very 110 | // big duration in order to store and replay every message put for the lifetime 111 | // of the program; this is not recommended, as memory usage becomes effectively 112 | // unbounded which might lead to a crash. 113 | func NewValidReplayer(ttl time.Duration, autoIDs bool) (*ValidReplayer, error) { 114 | if ttl <= 0 { 115 | return nil, errors.New("event TTL must be greater than zero") 116 | } 117 | 118 | r := &ValidReplayer{ 119 | Now: time.Now, 120 | GCInterval: ttl / 4, 121 | ttl: ttl, 122 | } 123 | 124 | if autoIDs { 125 | r.currentID = new(uint64) 126 | } 127 | 128 | return r, nil 129 | } 130 | 131 | // Put puts the message into the replayer's buffer. 132 | func (v *ValidReplayer) Put(message *Message, topics []string) (*Message, error) { 133 | if len(topics) == 0 { 134 | return nil, ErrNoTopic 135 | } 136 | 137 | now := v.Now() 138 | if v.lastGC.IsZero() { 139 | v.lastGC = now 140 | } 141 | 142 | if v.shouldGC(now) { 143 | v.doGC(now) 144 | v.lastGC = now 145 | } 146 | 147 | message, err := ensureID(message, v.currentID) 148 | if err != nil { 149 | return nil, err 150 | } 151 | 152 | if v.messages.count == len(v.messages.buf) { 153 | newCap := len(v.messages.buf) * 2 154 | if minCap := 4; newCap < minCap { 155 | newCap = minCap 156 | } 157 | v.messages.resize(newCap) 158 | } 159 | 160 | v.messages.enqueue(messageWithTopicsAndExpiry{messageWithTopics: messageWithTopics{message: message, topics: topics}, exp: now.Add(v.ttl)}) 161 | 162 | return message, nil 163 | } 164 | 165 | func (v *ValidReplayer) shouldGC(now time.Time) bool { 166 | return v.GCInterval > 0 && now.Sub(v.lastGC) >= v.GCInterval 167 | } 168 | 169 | // GC removes all the expired messages from the replayer's buffer. 170 | func (v *ValidReplayer) GC() { 171 | v.doGC(v.Now()) 172 | } 173 | 174 | func (v *ValidReplayer) doGC(now time.Time) { 175 | for v.messages.count > 0 { 176 | e := v.messages.buf[v.messages.head] 177 | if e.exp.After(now) { 178 | break 179 | } 180 | 181 | v.messages.dequeue() 182 | } 183 | 184 | if v.messages.count <= len(v.messages.buf)/4 { 185 | newCap := len(v.messages.buf) / 2 186 | if minCap := 4; newCap < minCap { 187 | newCap = minCap 188 | } 189 | v.messages.resize(newCap) 190 | } 191 | } 192 | 193 | // Replay replays all the valid messages to the listener. 194 | func (v *ValidReplayer) Replay(subscription Subscription) error { 195 | i := findIDInQueue(&v.messages, subscription.LastEventID, v.currentID != nil) 196 | if i < 0 { 197 | return nil 198 | } 199 | 200 | now := v.Now() 201 | 202 | var err error 203 | v.messages.each(i)(func(_ int, m messageWithTopicsAndExpiry) bool { 204 | if m.exp.After(now) && topicsIntersect(subscription.Topics, m.topics) { 205 | if err = subscription.Client.Send(m.message); err != nil { 206 | return false 207 | } 208 | } 209 | return true 210 | }) 211 | if err != nil { 212 | return err 213 | } 214 | 215 | return subscription.Client.Flush() 216 | } 217 | 218 | // topicsIntersect returns true if the given topic slices have at least one topic in common. 219 | func topicsIntersect(a, b []string) bool { 220 | for _, at := range a { 221 | for _, bt := range b { 222 | if at == bt { 223 | return true 224 | } 225 | } 226 | } 227 | 228 | return false 229 | } 230 | 231 | func ensureID(m *Message, currentID *uint64) (*Message, error) { 232 | if currentID == nil { 233 | if !m.ID.IsSet() { 234 | return nil, errors.New("message has no ID") 235 | } 236 | 237 | return m, nil 238 | } 239 | 240 | if m.ID.IsSet() { 241 | return nil, errors.New("message already has an ID, can't use generated ID") 242 | } 243 | 244 | m = m.Clone() 245 | m.ID = ID(strconv.FormatUint(*currentID, 10)) 246 | 247 | (*currentID)++ 248 | 249 | return m, nil 250 | } 251 | 252 | type queue[T any] struct { 253 | buf []T 254 | head, tail, count int 255 | } 256 | 257 | func (q *queue[T]) each(startAt int) func(func(int, T) bool) { 258 | return func(yield func(int, T) bool) { 259 | if startAt < q.tail { 260 | for i := startAt; i < q.tail; i++ { 261 | if !yield(i, q.buf[i]) { 262 | return 263 | } 264 | } 265 | } else { 266 | for i := startAt; i < len(q.buf); i++ { 267 | if !yield(i, q.buf[i]) { 268 | return 269 | } 270 | } 271 | for i := 0; i < q.tail; i++ { 272 | if !yield(i, q.buf[i]) { 273 | return 274 | } 275 | } 276 | } 277 | } 278 | } 279 | 280 | func (q *queue[T]) enqueue(v T) { 281 | q.buf[q.tail] = v 282 | 283 | q.tail++ 284 | 285 | overwritten := false 286 | if q.tail > q.head && q.count == len(q.buf) { 287 | q.head = q.tail 288 | overwritten = true 289 | } else { 290 | q.count++ 291 | } 292 | 293 | if q.tail == len(q.buf) { 294 | q.tail = 0 295 | if overwritten { 296 | q.head = 0 297 | } 298 | } 299 | } 300 | 301 | func (q *queue[T]) dequeue() { 302 | q.buf[q.head] = *new(T) 303 | 304 | q.head++ 305 | if q.head == len(q.buf) { 306 | q.head = 0 307 | } 308 | 309 | q.count-- 310 | } 311 | 312 | func (q *queue[T]) resize(newSize int) { 313 | buf := make([]T, newSize) 314 | if q.head < q.tail { 315 | copy(buf, q.buf[q.head:q.tail]) 316 | } else { 317 | n := copy(buf, q.buf[q.head:]) 318 | copy(buf[n:], q.buf[:q.tail]) 319 | } 320 | 321 | q.head = 0 322 | q.tail = q.count 323 | q.buf = buf 324 | } 325 | 326 | func findIDInQueue[M interface{ ID() EventID }](q *queue[M], id EventID, autoID bool) int { 327 | if q.count == 0 { 328 | return -1 329 | } 330 | 331 | if autoID { 332 | id, err := strconv.ParseUint(id.String(), 10, 64) 333 | if err != nil { 334 | return -1 335 | } 336 | 337 | firstID, _ := strconv.ParseUint(q.buf[q.head].ID().String(), 10, 64) 338 | 339 | pos := -1 340 | if delta := id - firstID; id >= firstID { 341 | if delta >= uint64(q.count) { //nolint:gosec // int always positive 342 | return -1 343 | } 344 | pos = int(delta) //nolint:gosec // delta < q.count, which is an int 345 | } 346 | 347 | i := pos + q.head + 1 348 | if i >= len(q.buf) { 349 | i -= len(q.buf) 350 | } 351 | 352 | return i 353 | } 354 | 355 | i := -1 356 | q.each(q.head)(func(j int, m M) bool { 357 | if m.ID() == id { 358 | i = j 359 | return false 360 | } 361 | return true 362 | }) 363 | 364 | if i != -1 { 365 | i++ 366 | if i == len(q.buf) { 367 | i = 0 368 | } else if i == q.tail { 369 | i = -1 370 | } 371 | } 372 | 373 | return i 374 | } 375 | 376 | func (m messageWithTopics) ID() EventID { return m.message.ID } 377 | 378 | type messageWithTopicsAndExpiry struct { 379 | exp time.Time 380 | messageWithTopics 381 | } 382 | 383 | // noopReplayer is the default replay provider used if none is given. It does nothing. 384 | // It is used to avoid nil checks for the provider each time it is used. 385 | type noopReplayer struct{} 386 | 387 | func (n noopReplayer) Put(m *Message, _ []string) (*Message, error) { return m, nil } 388 | func (n noopReplayer) Replay(_ Subscription) error { return nil } 389 | -------------------------------------------------------------------------------- /replay_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "testing" 7 | "time" 8 | 9 | "github.com/tmaxmax/go-sse" 10 | "github.com/tmaxmax/go-sse/internal/tests" 11 | ) 12 | 13 | func replay(tb testing.TB, p sse.Replayer, lastEventID sse.EventID, topics ...string) []*sse.Message { 14 | tb.Helper() 15 | 16 | if len(topics) == 0 { 17 | topics = []string{sse.DefaultTopic} 18 | } 19 | 20 | var replayed []*sse.Message 21 | cb := mockClient(func(m *sse.Message) error { 22 | if m != nil { 23 | replayed = append(replayed, m) 24 | } 25 | return nil 26 | }) 27 | 28 | sub := sse.Subscription{ 29 | Client: cb, 30 | LastEventID: lastEventID, 31 | Topics: topics, 32 | } 33 | 34 | _ = p.Replay(sub) 35 | 36 | sub.LastEventID = sse.EventID{} 37 | _ = p.Replay(sub) 38 | 39 | sub.LastEventID = sse.ID("mama") 40 | _ = p.Replay(sub) 41 | 42 | sub.LastEventID = sse.ID("10") 43 | _ = p.Replay(sub) 44 | 45 | return replayed 46 | } 47 | 48 | func put(tb testing.TB, p sse.Replayer, msg *sse.Message, topics ...string) *sse.Message { 49 | tb.Helper() 50 | 51 | if len(topics) == 0 { 52 | topics = []string{sse.DefaultTopic} 53 | } 54 | 55 | msg, err := p.Put(msg, topics) 56 | tests.Equal(tb, err, nil, "invalid message") 57 | 58 | return msg 59 | } 60 | 61 | func testReplayError(tb testing.TB, p sse.Replayer, tm *tests.Time) { 62 | tb.Helper() 63 | 64 | tm.Reset() 65 | tm.Add(time.Hour) 66 | 67 | put(tb, p, msg(tb, "a", "1")) 68 | put(tb, p, msg(tb, "b", "2")) 69 | 70 | cb := mockClient(func(_ *sse.Message) error { return nil }) 71 | 72 | tm.Rewind() 73 | 74 | err := p.Replay(sse.Subscription{ 75 | Client: cb, 76 | LastEventID: sse.ID("1"), 77 | Topics: []string{sse.DefaultTopic}, 78 | }) 79 | 80 | tests.Equal(tb, err, nil, "received invalid error") 81 | } 82 | 83 | func TestValidReplayProvider(t *testing.T) { 84 | t.Parallel() 85 | 86 | tm := &tests.Time{} 87 | ttl := time.Millisecond * 5 88 | 89 | _, err := sse.NewValidReplayer(0, false) 90 | tests.Expect(t, err != nil, "replay provider cannot be created with zero or negative TTL") 91 | 92 | p, _ := sse.NewValidReplayer(ttl, true) 93 | p.GCInterval = 0 94 | p.Now = tm.Now 95 | 96 | tests.Equal(t, p.Replay(sse.Subscription{}), nil, "replay failed on provider without messages") 97 | 98 | now := time.Now() 99 | tm.Set(now) 100 | 101 | put(t, p, msg(t, "hi", "")) 102 | put(t, p, msg(t, "there", ""), "t") 103 | tm.Add(ttl) 104 | put(t, p, msg(t, "world", "")) 105 | put(t, p, msg(t, "again", ""), "t") 106 | tm.Add(ttl * 3) 107 | put(t, p, msg(t, "world", "")) 108 | put(t, p, msg(t, "x", ""), "t") 109 | tm.Add(ttl * 5) 110 | put(t, p, msg(t, "again", ""), "t") 111 | 112 | tm.Set(now.Add(ttl)) 113 | 114 | p.GC() 115 | 116 | replayed := replay(t, p, sse.ID("3"), sse.DefaultTopic, "topic with no messages")[0] 117 | tests.Equal(t, replayed.String(), "id: 4\ndata: world\n\n", "invalid message received") 118 | 119 | p.GCInterval = ttl / 5 120 | // Should trigger automatic GC which should clean up most of the messages. 121 | tm.Set(now.Add(ttl * 5)) 122 | put(t, p, msg(t, "not again", ""), "t") 123 | 124 | allReplayed := replay(t, p, sse.ID("3"), "t", "topic with no messages") 125 | tests.Equal(t, len(allReplayed), 2, "there should be two messages in topic 't'") 126 | tests.Equal(t, allReplayed[0].String(), "id: 6\ndata: again\n\n", "invalid message received") 127 | 128 | tr, err := sse.NewValidReplayer(time.Second, false) 129 | tests.Equal(t, err, nil, "replay provider should be created") 130 | 131 | testReplayError(t, tr, tm) 132 | } 133 | 134 | func TestFiniteReplayProvider(t *testing.T) { 135 | t.Parallel() 136 | 137 | _, err := sse.NewFiniteReplayer(1, false) 138 | tests.Expect(t, err != nil, "should not create FiniteReplayProvider with count less than 2") 139 | 140 | p, err := sse.NewFiniteReplayer(3, false) 141 | tests.Equal(t, err, nil, "should create new FiniteReplayProvider") 142 | 143 | tests.Equal(t, p.Replay(sse.Subscription{}), nil, "replay failed on provider without messages") 144 | 145 | _, err = p.Put(msg(t, "panic", ""), []string{sse.DefaultTopic}) 146 | tests.Expect(t, err != nil, "message without IDs cannot be put in a replay provider") 147 | 148 | _, err = p.Put(msg(t, "panic", "5"), nil) 149 | tests.ErrorIs(t, err, sse.ErrNoTopic, "incorrect error returned when no topic is provided") 150 | 151 | put(t, p, msg(t, "", "1")) 152 | put(t, p, msg(t, "hello", "2")) 153 | put(t, p, msg(t, "there", "3"), "t") 154 | put(t, p, msg(t, "world", "4")) 155 | 156 | replayed := replay(t, p, sse.ID("2"))[0] 157 | tests.Equal(t, replayed.String(), "id: 4\ndata: world\n\n", "invalid replayed message") 158 | 159 | put(t, p, msg(t, "", "5"), "t") 160 | put(t, p, msg(t, "again", "6")) 161 | 162 | replayed = replay(t, p, sse.ID("4"), sse.DefaultTopic, "topic with no messages")[0] 163 | tests.Equal(t, replayed.String(), "id: 6\ndata: again\n\n", "invalid replayed message") 164 | 165 | idp, err := sse.NewFiniteReplayer(10, true) 166 | tests.Equal(t, err, nil, "should create new FiniteReplayProvider") 167 | 168 | _, err = idp.Put(msg(t, "should error", "should not have ID"), []string{sse.DefaultTopic}) 169 | tests.Expect(t, err != nil, "messages with IDs cannot be put in an autoID replay provider") 170 | 171 | tr, err := sse.NewFiniteReplayer(10, false) 172 | tests.Equal(t, err, nil, "should create new FiniteReplayProvider") 173 | 174 | testReplayError(t, tr, nil) 175 | } 176 | 177 | func TestFiniteReplayProvider_allocations(t *testing.T) { 178 | p, err := sse.NewFiniteReplayer(3, false) 179 | tests.Equal(t, err, nil, "should create new FiniteReplayProvider") 180 | 181 | const runs = 100 182 | 183 | topics := []string{sse.DefaultTopic} 184 | // Add one to the number of runs to take the warmup run of 185 | // AllocsPerRun() into account. 186 | queue := make([]*sse.Message, runs+1) 187 | lastID := runs 188 | 189 | for i := 0; i < len(queue); i++ { 190 | queue[i] = msg(t, 191 | fmt.Sprintf("message %d", i), 192 | strconv.Itoa(i), 193 | ) 194 | } 195 | 196 | var run int 197 | 198 | avgAllocs := testing.AllocsPerRun(runs, func() { 199 | put(t, p, queue[run], topics...) 200 | 201 | run++ 202 | }) 203 | 204 | tests.Equal(t, avgAllocs, 0, "no allocations should be made on Put()") 205 | 206 | var replayCount int 207 | 208 | cb := mockClient(func(m *sse.Message) error { 209 | if m != nil { 210 | replayCount++ 211 | } 212 | 213 | return nil 214 | }) 215 | 216 | sub := sse.Subscription{ 217 | Client: cb, 218 | Topics: topics, 219 | } 220 | 221 | sub.LastEventID = sse.ID(strconv.Itoa(lastID - 3)) 222 | 223 | err = p.Replay(sub) 224 | tests.Equal(t, err, nil, "replay from fourth last should succeed") 225 | 226 | tests.Equal(t, replayCount, 0, "replay from fourth last should not yield messages") 227 | 228 | sub.LastEventID = sse.ID(strconv.Itoa(lastID - 2)) 229 | 230 | err = p.Replay(sub) 231 | tests.Equal(t, err, nil, "replay from third last should succeed") 232 | 233 | tests.Equal(t, replayCount, 2, "replay from third last should yield 2 messages") 234 | } 235 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package sse provides utilities for creating and consuming fully spec-compliant HTML5 server-sent events streams. 3 | 4 | The central piece of a server's implementation is the Provider interface. A Provider describes a publish-subscribe 5 | system that can be used to implement messaging for the SSE protocol. This package already has an 6 | implementation, called Joe, that is the default provider for any server. Abstracting the messaging 7 | system implementation away allows servers to use any arbitrary provider under the same interface. 8 | The default provider will work for simple use-cases, but where scalability is required, one will 9 | look at a more suitable solution. Adapters that satisfy the Provider interface can easily be created, 10 | and then plugged into the server instance. 11 | Events themselves are represented using the Message type. 12 | 13 | On the client-side, we use the Client struct to create connections to event streams. Using an `http.Request` 14 | we instantiate a Connection. Then we subscribe to incoming events using callback functions, and then 15 | we establish the connection by calling the Connection's Connect method. 16 | */ 17 | package sse 18 | 19 | import ( 20 | "context" 21 | "errors" 22 | "log/slog" 23 | "net/http" 24 | "sync" 25 | ) 26 | 27 | // The Subscription struct is used to subscribe to a given provider. 28 | type Subscription struct { 29 | // The client to which messages are sent. The implementation of the interface does not have to be 30 | // thread-safe – providers will not call methods on it concurrently. 31 | Client MessageWriter 32 | // An optional last event ID indicating the event to resume the stream from. 33 | // The events will replay starting from the first valid event sent after the one with the given ID. 34 | // If the ID is invalid replaying events will be omitted and new events will be sent as normal. 35 | LastEventID EventID 36 | // The topics to receive message from. Must be a non-empty list. 37 | // Topics are orthogonal to event types. They are used to filter what the server sends to each client. 38 | Topics []string 39 | } 40 | 41 | // A Provider is a publish-subscribe system that can be used to implement a HTML5 server-sent events 42 | // protocol. A standard interface is required so HTTP request handlers are agnostic to the provider's implementation. 43 | // 44 | // Providers are required to be thread-safe. 45 | // 46 | // After Shutdown is called, trying to call any method of the provider must return ErrProviderClosed. The providers 47 | // may return other implementation-specific errors too, but the close error is guaranteed to be the same across 48 | // providers. 49 | type Provider interface { 50 | // Subscribe to the provider. The context is used to remove the subscriber automatically 51 | // when it is done. Errors returned by the subscription's callback function must be returned 52 | // by Subscribe. 53 | // 54 | // Providers can assume that the topics list for a subscription has at least one topic. 55 | Subscribe(ctx context.Context, subscription Subscription) error 56 | // Publish a message to all the subscribers that are subscribed to the given topics. 57 | // The topics slice must be non-empty, or ErrNoTopic will be raised. 58 | Publish(message *Message, topics []string) error 59 | // Shutdown stops the provider. Calling Shutdown will clean up all the provider's resources 60 | // and make Subscribe and Publish fail with an error. All the listener channels will be 61 | // closed and any ongoing publishes will be aborted. 62 | // 63 | // If the given context times out before the provider is shut down – shutting it down takes 64 | // longer, the context error is returned. 65 | // 66 | // Calling Shutdown multiple times after it successfully returned the first time 67 | // does nothing but return ErrProviderClosed. 68 | Shutdown(ctx context.Context) error 69 | } 70 | 71 | // ErrProviderClosed is a sentinel error returned by providers when any operation is attempted after the provider is closed. 72 | // A closed provider might also be a result of an unexpected panic inside the provider. 73 | var ErrProviderClosed = errors.New("go-sse.server: provider is closed") 74 | 75 | // ErrNoTopic is a sentinel error returned when a Message is published without any topics. 76 | // It is not an issue to call Server.Publish without topics, because the Server will add the DefaultTopic; 77 | // it is an error to call Provider.Publish or Replayer.Put without any topics, though. 78 | var ErrNoTopic = errors.New("go-sse.server: no topics specified") 79 | 80 | // DefaultTopic is the identifier for the topic that is implied when no topics are specified for a Subscription 81 | // or a Message. 82 | const DefaultTopic = "" 83 | 84 | // A Server is mostly a convenience wrapper around a Provider. 85 | // It implements the http.Handler interface and has some methods 86 | // for calling the underlying provider's methods. 87 | // 88 | // When creating a server, if no provider is specified using the WithProvider 89 | // option, the Joe provider found in this package with no replay provider is used. 90 | type Server struct { 91 | // The provider used to publish and subscribe clients to events. 92 | // Defaults to Joe. 93 | Provider Provider 94 | // A callback that's called when an SSE session is started. 95 | // You can use this to authorize the session, set the topics 96 | // the client should be subscribed to and so on. Using the 97 | // Res field of the Session you can write an error response 98 | // to the client. 99 | // 100 | // The boolean returned indicates whether the given request 101 | // should be accepted or not. If it is true, the Provider will receive 102 | // a new subscription for the connection and events will be sent 103 | // to this client, otherwise the request will be ended. 104 | // 105 | // Note that OnSession can write the HTTP response code itself, if something other 106 | // than the implicit 200 OK is desired. This is especially helpful when refusing sessions – 107 | // if OnSession does not write a response code, clients will receive a confusing 200 OK. 108 | // 109 | // If this is not set, the client will be subscribed to the provider 110 | // using the DefaultTopic. 111 | OnSession func(w http.ResponseWriter, r *http.Request) (topics []string, allowed bool) 112 | // If the Logger function is set and returns a non-nil Logger instance, 113 | // the Server will log various information about the request lifecycle. 114 | Logger func(r *http.Request) *slog.Logger 115 | 116 | provider Provider 117 | initDone sync.Once 118 | } 119 | 120 | // ServeHTTP implements a default HTTP handler for a server. 121 | // 122 | // This handler upgrades the request, subscribes it to the server's provider and 123 | // starts sending incoming events to the client, while logging any errors. 124 | // It also sends the Last-Event-ID header's value, if present. 125 | // 126 | // If the request isn't upgradeable, it writes a message to the client along with 127 | // an 500 Internal Server ConnectionError response code. If on subscribe the provider returns 128 | // an error, it writes the error message to the client and a 500 Internal Server ConnectionError 129 | // response code. 130 | // 131 | // To customize behavior, use the OnSession callback or create your custom handler. 132 | func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { 133 | s.init() 134 | // Make sure to keep the ServeHTTP implementation line number in sync with the number in the README! 135 | 136 | var l *slog.Logger 137 | if s.Logger != nil { 138 | l = s.Logger(r) 139 | } 140 | 141 | if l != nil { 142 | l.Info("sse: starting new session") 143 | } 144 | 145 | sess, err := Upgrade(w, r) 146 | if err != nil { 147 | if l != nil { 148 | l.Error("sse: unsupported", "error", err) 149 | } 150 | 151 | http.Error(w, "Server-sent events unsupported", http.StatusInternalServerError) 152 | return 153 | } 154 | 155 | sub, ok := s.getSubscription(sess) 156 | if !ok { 157 | if l != nil { 158 | l.Warn("sse: invalid subscription") 159 | } 160 | 161 | return 162 | } 163 | 164 | if l != nil { 165 | l.Info("sse: subscribing session", "topics", sub.Topics, "lastEventID", sub.LastEventID) 166 | } 167 | 168 | if err = s.provider.Subscribe(r.Context(), sub); err != nil { 169 | if l != nil { 170 | l.Error("sse: subscribe error", "error", err) 171 | } 172 | 173 | http.Error(w, err.Error(), http.StatusInternalServerError) 174 | return 175 | } 176 | 177 | if l != nil { 178 | l.Info("sse: session ended") 179 | } 180 | } 181 | 182 | // Publish sends the event to all subscribes that are subscribed to the topic the event is published to. 183 | // The topics are optional - if none are specified, the event is published to the DefaultTopic. 184 | func (s *Server) Publish(e *Message, topics ...string) error { 185 | s.init() 186 | return s.provider.Publish(e, getTopics(topics)) 187 | } 188 | 189 | // Shutdown closes all the connections and stops the server. Publish operations will fail 190 | // with the error sent by the underlying provider. NewServer requests will be ignored. 191 | // 192 | // Call this method when shutting down the HTTP server using http.Server's RegisterOnShutdown 193 | // method. Not doing this will result in the server never shutting down or connections being 194 | // abruptly stopped. 195 | // 196 | // See the Provider.Shutdown documentation for information on context usage and errors. 197 | func (s *Server) Shutdown(ctx context.Context) error { 198 | s.init() 199 | return s.provider.Shutdown(ctx) 200 | } 201 | 202 | func (s *Server) init() { 203 | s.initDone.Do(func() { 204 | s.provider = s.Provider 205 | if s.provider == nil { 206 | s.provider = &Joe{} 207 | } 208 | }) 209 | } 210 | 211 | func (s *Server) getSubscription(sess *Session) (Subscription, bool) { 212 | sub := Subscription{Client: sess, LastEventID: sess.LastEventID, Topics: defaultTopicSlice} 213 | if s.OnSession != nil { 214 | topics, ok := s.OnSession(sess.Res, sess.Req) 215 | if ok && len(topics) > 0 { 216 | sub.Topics = topics 217 | } 218 | 219 | return sub, ok 220 | } 221 | 222 | return sub, true 223 | } 224 | 225 | var defaultTopicSlice = []string{DefaultTopic} 226 | 227 | func getTopics(initial []string) []string { 228 | if len(initial) == 0 { 229 | return defaultTopicSlice 230 | } 231 | 232 | return initial 233 | } 234 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "math/rand" 10 | "net/http" 11 | "net/http/httptest" 12 | "strconv" 13 | "strings" 14 | "testing" 15 | "time" 16 | 17 | "github.com/tmaxmax/go-sse" 18 | "github.com/tmaxmax/go-sse/internal/tests" 19 | ) 20 | 21 | type mockProvider struct { 22 | SubError error 23 | Closed chan struct{} 24 | Pub *sse.Message 25 | PubTopics []string 26 | Sub sse.Subscription 27 | Subscribed bool 28 | Stopped bool 29 | Published bool 30 | } 31 | 32 | func (m *mockProvider) Subscribe(ctx context.Context, sub sse.Subscription) error { 33 | m.Subscribed = true 34 | if m.SubError != nil { 35 | return m.SubError 36 | } 37 | 38 | defer close(m.Closed) 39 | m.Sub = sub 40 | 41 | e := &sse.Message{} 42 | e.AppendData("hello") 43 | 44 | if err := sub.Client.Send(e); err != nil { 45 | return fmt.Errorf("send failed: %w", err) 46 | } 47 | 48 | if err := sub.Client.Flush(); err != nil { 49 | return fmt.Errorf("flush failed: %w", err) 50 | } 51 | 52 | <-ctx.Done() 53 | 54 | return nil 55 | } 56 | 57 | func (m *mockProvider) Publish(msg *sse.Message, topics []string) error { 58 | m.Pub = msg 59 | m.PubTopics = topics 60 | m.Published = true 61 | return nil 62 | } 63 | 64 | func (m *mockProvider) Shutdown(_ context.Context) error { 65 | m.Stopped = true 66 | return nil 67 | } 68 | 69 | var _ sse.Provider = (*mockProvider)(nil) 70 | 71 | func newMockProvider(tb testing.TB, subErr error) *mockProvider { 72 | tb.Helper() 73 | 74 | return &mockProvider{Closed: make(chan struct{}), SubError: subErr} 75 | } 76 | 77 | type mockHandler struct { 78 | slog.Handler 79 | } 80 | 81 | func (h mockHandler) Handle(ctx context.Context, r slog.Record) error { 82 | var zero time.Time 83 | r.Time = zero 84 | return h.Handler.Handle(ctx, r) 85 | } 86 | 87 | func mockLogFunc(w io.Writer) func(*http.Request) *slog.Logger { 88 | h := slog.NewTextHandler(w, nil) 89 | mockH := mockHandler{h} 90 | return func(*http.Request) *slog.Logger { 91 | return slog.New(mockH) 92 | } 93 | } 94 | 95 | func TestServer_ShutdownPublish(t *testing.T) { 96 | t.Parallel() 97 | 98 | p := &mockProvider{} 99 | s := &sse.Server{Provider: p} 100 | 101 | _ = s.Publish(&sse.Message{}) 102 | tests.Expect(t, p.Published, "Publish wasn't called") 103 | tests.DeepEqual(t, []any{*p.Pub, p.PubTopics}, []any{sse.Message{}, []string{sse.DefaultTopic}}, "incorrect message") 104 | 105 | p.Published = false 106 | _ = s.Publish(&sse.Message{}, "topic") 107 | tests.Expect(t, p.Published, "Publish wasn't called") 108 | tests.DeepEqual(t, []any{*p.Pub, p.PubTopics}, []any{sse.Message{}, []string{"topic"}}, "incorrect message") 109 | 110 | _ = s.Shutdown(context.Background()) 111 | tests.Expect(t, p.Stopped, "Stop wasn't called") 112 | } 113 | 114 | func request(tb testing.TB, method, address string, body io.Reader) (*http.Request, context.CancelFunc) { //nolint 115 | tb.Helper() 116 | 117 | r := httptest.NewRequest(method, address, body) 118 | ctx, cancel := context.WithCancel(r.Context()) 119 | return r.WithContext(ctx), cancel 120 | } 121 | 122 | func TestServer_ServeHTTP(t *testing.T) { 123 | t.Parallel() 124 | rec := httptest.NewRecorder() 125 | req, cancel := request(t, "", "http://localhost", nil) 126 | defer cancel() 127 | p := newMockProvider(t, nil) 128 | req.Header.Set("Last-Event-ID", "5") 129 | 130 | go cancel() 131 | sb := &strings.Builder{} 132 | (&sse.Server{Provider: p, Logger: mockLogFunc(sb)}).ServeHTTP(rec, req) 133 | 134 | tests.Expect(t, p.Subscribed, "Subscribe wasn't called") 135 | tests.Equal(t, p.Sub.LastEventID, sse.ID("5"), "Invalid last event ID received") 136 | tests.Equal(t, rec.Body.String(), "data: hello\n\n", "Invalid response body") 137 | tests.Equal(t, rec.Code, http.StatusOK, "invalid response code") 138 | tests.Equal(t, sb.String(), "level=INFO msg=\"sse: starting new session\"\nlevel=INFO msg=\"sse: subscribing session\" topics=[] lastEventID=5\nlevel=INFO msg=\"sse: session ended\"\n", "invalid log output") 139 | } 140 | 141 | type noFlusher struct { 142 | http.ResponseWriter 143 | } 144 | 145 | func TestServer_ServeHTTP_unsupportedRespWriter(t *testing.T) { 146 | t.Parallel() 147 | 148 | rec := httptest.NewRecorder() 149 | req, cancel := request(t, "", "http://localhost", nil) 150 | defer cancel() 151 | p := newMockProvider(t, nil) 152 | sb := &strings.Builder{} 153 | 154 | (&sse.Server{Provider: p, Logger: mockLogFunc(sb)}).ServeHTTP(noFlusher{rec}, req) 155 | 156 | tests.Equal(t, rec.Code, http.StatusInternalServerError, "invalid response code") 157 | tests.Equal(t, rec.Body.String(), "Server-sent events unsupported\n", "invalid response body") 158 | tests.Equal(t, sb.String(), "level=INFO msg=\"sse: starting new session\"\nlevel=ERROR msg=\"sse: unsupported\" error=\"go-sse.server: upgrade unsupported\"\n", "invalid log output") 159 | } 160 | 161 | func TestServer_ServeHTTP_subscribeError(t *testing.T) { 162 | t.Parallel() 163 | 164 | rec := httptest.NewRecorder() 165 | req, _ := http.NewRequest("", "http://localhost", http.NoBody) 166 | p := newMockProvider(t, errors.New("can't subscribe")) 167 | sb := &strings.Builder{} 168 | 169 | (&sse.Server{Provider: p, Logger: mockLogFunc(sb)}).ServeHTTP(rec, req) 170 | 171 | tests.Equal(t, rec.Body.String(), p.SubError.Error()+"\n", "invalid response body") 172 | tests.Equal(t, rec.Code, http.StatusInternalServerError, "invalid response code") 173 | tests.Equal(t, sb.String(), "level=INFO msg=\"sse: starting new session\"\nlevel=INFO msg=\"sse: subscribing session\" topics=[] lastEventID=\"\"\nlevel=ERROR msg=\"sse: subscribe error\" error=\"can't subscribe\"\n", "invalid log output") 174 | } 175 | 176 | func TestServer_OnSession(t *testing.T) { 177 | t.Parallel() 178 | 179 | t.Run("Invalid", func(t *testing.T) { 180 | rec := httptest.NewRecorder() 181 | req := httptest.NewRequest("", "/", http.NoBody) 182 | p := newMockProvider(t, nil) 183 | sb := &strings.Builder{} 184 | 185 | (&sse.Server{ 186 | Provider: p, 187 | Logger: mockLogFunc(sb), 188 | OnSession: func(w http.ResponseWriter, _ *http.Request) ([]string, bool) { 189 | http.Error(w, "this is invalid", http.StatusBadRequest) 190 | return nil, false 191 | }, 192 | }).ServeHTTP(rec, req) 193 | 194 | tests.Equal(t, rec.Body.String(), "this is invalid\n", "invalid response body") 195 | tests.Equal(t, rec.Code, http.StatusBadRequest, "invalid response code") 196 | tests.Equal(t, sb.String(), "level=INFO msg=\"sse: starting new session\"\nlevel=WARN msg=\"sse: invalid subscription\"\n", "invalid log output") 197 | }) 198 | } 199 | 200 | type flushResponseWriter interface { 201 | http.Flusher 202 | http.ResponseWriter 203 | } 204 | 205 | type responseWriterErr struct { 206 | flushResponseWriter 207 | } 208 | 209 | func (r *responseWriterErr) Write(p []byte) (int, error) { 210 | n, _ := r.flushResponseWriter.Write(p) 211 | return n, errors.New("") 212 | } 213 | 214 | func TestServer_ServeHTTP_connectionError(t *testing.T) { 215 | t.Parallel() 216 | 217 | rec := httptest.NewRecorder() 218 | req, _ := http.NewRequest("", "http://localhost", http.NoBody) 219 | p := newMockProvider(t, nil) 220 | 221 | (&sse.Server{Provider: p}).ServeHTTP(&responseWriterErr{rec}, req) 222 | _, ok := <-p.Closed 223 | tests.Expect(t, !ok, "request error should not block server") 224 | } 225 | 226 | func getMessage(tb testing.TB) *sse.Message { 227 | tb.Helper() 228 | 229 | m := &sse.Message{ 230 | ID: sse.ID(strconv.Itoa(rand.Int())), 231 | Type: sse.Type("test"), 232 | } 233 | m.AppendData("Hello world!", "Nice to see you all.") 234 | 235 | return m 236 | } 237 | 238 | type discardResponseWriter struct { 239 | w io.Writer 240 | h http.Header 241 | c int 242 | } 243 | 244 | func (d *discardResponseWriter) Header() http.Header { return d.h } 245 | func (d *discardResponseWriter) Write(b []byte) (int, error) { return d.w.Write(b) } 246 | func (d *discardResponseWriter) WriteHeader(code int) { d.c = code } 247 | func (d *discardResponseWriter) Flush() {} 248 | 249 | func getRequest(tb testing.TB) (w *discardResponseWriter, r *http.Request) { 250 | tb.Helper() 251 | 252 | w = &discardResponseWriter{w: io.Discard, h: make(http.Header)} 253 | r = httptest.NewRequest("", "http://localhost", http.NoBody) 254 | 255 | return 256 | } 257 | 258 | func benchmarkServer(b *testing.B, conns int) { 259 | b.Helper() 260 | 261 | s := &sse.Server{} 262 | b.Cleanup(func() { _ = s.Shutdown(context.Background()) }) 263 | 264 | m := getMessage(b) 265 | 266 | for i := 0; i < conns; i++ { 267 | w, r := getRequest(b) 268 | go s.ServeHTTP(w, r) 269 | } 270 | 271 | b.ResetTimer() 272 | b.ReportAllocs() 273 | 274 | for n := 0; n < b.N; n++ { 275 | _ = s.Publish(m) 276 | } 277 | } 278 | 279 | func BenchmarkServer(b *testing.B) { 280 | conns := [...]int{10, 100, 1000, 10000, 20000, 50000, 100000} 281 | 282 | for _, c := range conns { 283 | b.Run(strconv.Itoa(c), func(b *testing.B) { 284 | benchmarkServer(b, c) 285 | }) 286 | } 287 | } 288 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package sse 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | ) 7 | 8 | // ResponseWriter is a http.ResponseWriter augmented with a Flush method. 9 | type ResponseWriter interface { 10 | http.ResponseWriter 11 | Flush() error 12 | } 13 | 14 | // MessageWriter is a special kind of response writer used by providers to 15 | // send Messages to clients. 16 | type MessageWriter interface { 17 | // Send sends the message to the client. 18 | // To make sure it is sent, call Flush. 19 | Send(m *Message) error 20 | // Flush sends any buffered messages to the client. 21 | Flush() error 22 | } 23 | 24 | // A Session is an HTTP request from an SSE client. 25 | // Create one using the Upgrade function. 26 | // 27 | // Using a Session you can also access the initial HTTP request, 28 | // get the last event ID, or write data to the client. 29 | type Session struct { 30 | // The response writer for the request. Can be used to write an error response 31 | // back to the client. Must not be used after the Session was subscribed! 32 | Res ResponseWriter 33 | // The initial HTTP request. Can be used to retrieve authentication data, 34 | // topics, or data from context – a logger, for example. 35 | Req *http.Request 36 | // Last event ID of the client. It is unset if no ID was provided in the Last-Event-Id 37 | // request header. 38 | LastEventID EventID 39 | 40 | didUpgrade bool 41 | } 42 | 43 | // Send sends the given event to the client. It returns any errors that occurred while writing the event. 44 | func (s *Session) Send(e *Message) error { 45 | if err := s.doUpgrade(); err != nil { 46 | return err 47 | } 48 | if _, err := e.WriteTo(s.Res); err != nil { 49 | return err 50 | } 51 | return nil 52 | } 53 | 54 | // Flush sends any buffered messages to the client. 55 | func (s *Session) Flush() error { 56 | prevDidUpgrade := s.didUpgrade 57 | if err := s.doUpgrade(); err != nil { 58 | return err 59 | } 60 | if prevDidUpgrade == s.didUpgrade { 61 | return s.Res.Flush() 62 | } 63 | return nil 64 | } 65 | 66 | func (s *Session) doUpgrade() error { 67 | if !s.didUpgrade { 68 | s.Res.Header()[headerContentType] = headerContentTypeValue 69 | if err := s.Res.Flush(); err != nil { 70 | return err 71 | } 72 | s.didUpgrade = true 73 | } 74 | return nil 75 | } 76 | 77 | // Upgrade upgrades an HTTP request to support server-sent events. 78 | // It returns a Session that's used to send events to the client, or an 79 | // error if the upgrade failed. 80 | // 81 | // The headers required by the SSE protocol are only sent when calling 82 | // the Send method for the first time. If other operations are done before 83 | // sending messages, other headers and status codes can safely be set. 84 | func Upgrade(w http.ResponseWriter, r *http.Request) (*Session, error) { 85 | rw := getResponseWriter(w) 86 | if rw == nil { 87 | return nil, ErrUpgradeUnsupported 88 | } 89 | 90 | id := EventID{} 91 | // Clients must not send empty Last-Event-Id headers: 92 | // https://html.spec.whatwg.org/multipage/server-sent-events.html#sse-processing-model 93 | if h := r.Header[headerLastEventID]; len(h) != 0 && h[0] != "" { 94 | // We ignore the validity flag because if the given ID is invalid then an unset ID will be returned, 95 | // which providers are required to ignore. 96 | id, _ = NewID(h[0]) 97 | } 98 | 99 | return &Session{Req: r, Res: rw, LastEventID: id}, nil 100 | } 101 | 102 | // ErrUpgradeUnsupported is returned when a request can't be upgraded to support server-sent events. 103 | var ErrUpgradeUnsupported = errors.New("go-sse.server: upgrade unsupported") 104 | 105 | // Canonicalized header keys. 106 | const ( 107 | headerLastEventID = "Last-Event-Id" 108 | headerContentType = "Content-Type" 109 | ) 110 | 111 | // Pre-allocated header value. 112 | var headerContentTypeValue = []string{"text/event-stream"} 113 | 114 | // Logic below is similar to Go 1.20's ResponseController. 115 | // We can't use that because we need to check if the request supports 116 | // flushing messages before we subscribe it to the event stream. 117 | 118 | type writeFlusher interface { 119 | http.ResponseWriter 120 | http.Flusher 121 | } 122 | 123 | type writeFlusherError interface { 124 | http.ResponseWriter 125 | FlushError() error 126 | } 127 | 128 | type rwUnwrapper interface { 129 | Unwrap() http.ResponseWriter 130 | } 131 | 132 | func getResponseWriter(w http.ResponseWriter) ResponseWriter { 133 | for { 134 | switch v := w.(type) { 135 | case writeFlusherError: 136 | return flusherErrorWrapper{v} 137 | case writeFlusher: 138 | return flusherWrapper{v} 139 | case rwUnwrapper: 140 | w = v.Unwrap() 141 | default: 142 | return nil 143 | } 144 | } 145 | } 146 | 147 | type flusherWrapper struct { 148 | writeFlusher 149 | } 150 | 151 | func (f flusherWrapper) Flush() error { 152 | f.writeFlusher.Flush() 153 | return nil 154 | } 155 | 156 | type flusherErrorWrapper struct { 157 | writeFlusherError 158 | } 159 | 160 | func (f flusherErrorWrapper) Flush() error { return f.FlushError() } 161 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package sse_test 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/tmaxmax/go-sse" 11 | "github.com/tmaxmax/go-sse/internal/tests" 12 | ) 13 | 14 | func TestUpgrade(t *testing.T) { 15 | t.Parallel() 16 | 17 | req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) 18 | req.Header.Set("Last-Event-Id", "hello") 19 | rec := httptest.NewRecorder() 20 | 21 | sess, err := sse.Upgrade(rec, req) 22 | tests.Equal(t, err, nil, "unexpected error") 23 | tests.Expect(t, !rec.Flushed, "response writer was flushed") 24 | tests.Equal(t, sess.Send(&sse.Message{ID: sess.LastEventID}), nil, "unexpected Send error") 25 | tests.Equal(t, sess.Flush(), nil, "unexpected Flush error") 26 | 27 | r := rec.Result() 28 | t.Cleanup(func() { _ = r.Body.Close() }) 29 | 30 | expectedHeaders := http.Header{ 31 | "Content-Type": []string{"text/event-stream"}, 32 | } 33 | expectedBody := "id: hello\n\n" 34 | 35 | body, err := io.ReadAll(r.Body) 36 | tests.Equal(t, err, nil, "failed to read response body") 37 | 38 | tests.DeepEqual(t, r.Header, expectedHeaders, "invalid response headers") 39 | tests.Equal(t, expectedBody, string(body), "invalid response body (and Last-Event-Id)") 40 | 41 | _, err = sse.Upgrade(nil, nil) 42 | tests.ErrorIs(t, err, sse.ErrUpgradeUnsupported, "invalid Upgrade error") 43 | } 44 | 45 | var errWriteFailed = errors.New("err") 46 | 47 | type errorWriter struct { 48 | Flushed bool 49 | } 50 | 51 | func (e *errorWriter) WriteHeader(_ int) {} 52 | func (e *errorWriter) Header() http.Header { return http.Header{} } 53 | func (e *errorWriter) Write(_ []byte) (int, error) { return 0, errWriteFailed } 54 | func (e *errorWriter) Flush() { e.Flushed = true } 55 | 56 | func TestUpgradedRequest_Send(t *testing.T) { 57 | t.Parallel() 58 | 59 | rec := httptest.NewRecorder() 60 | 61 | conn, err := sse.Upgrade(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) 62 | tests.Equal(t, err, nil, "unexpected NewConnection error") 63 | 64 | rec.Flushed = false 65 | 66 | ev := sse.Message{} 67 | ev.AppendData("sarmale") 68 | expected, _ := ev.MarshalText() 69 | 70 | tests.Equal(t, conn.Send(&ev), nil, "unexpected Send error") 71 | tests.Expect(t, rec.Flushed, "writer wasn't flushed") 72 | tests.DeepEqual(t, rec.Body.Bytes(), expected, "body not written correctly") 73 | } 74 | 75 | func TestUpgradedRequest_Send_error(t *testing.T) { 76 | t.Parallel() 77 | 78 | rec := &errorWriter{} 79 | 80 | conn, err := sse.Upgrade(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) 81 | tests.Equal(t, err, nil, "unexpected NewConnection error") 82 | 83 | rec.Flushed = false 84 | 85 | tests.ErrorIs(t, conn.Send(&sse.Message{ID: sse.ID("")}), errWriteFailed, "invalid Send error") 86 | tests.Expect(t, rec.Flushed, "writer wasn't flushed") 87 | } 88 | --------------------------------------------------------------------------------