├── .gitattributes
├── .github
└── workflows
│ ├── codeql-analysis.yml
│ └── main.yml
├── .gitignore
├── .golangci.yml
├── .run
├── go test signalr.run.xml
└── go test signalr_test.run.xml
├── Invokeresult.go
├── LICENSE
├── Makefile
├── README.md
├── chatsample
├── main.go
├── middleware
│ ├── log_requests.go
│ └── response_writer_wrapper.go
└── public
│ ├── fs.go
│ ├── index.html
│ └── js
│ ├── signalr.js
│ ├── signalr.js.map
│ └── signalr.min.js
├── client.go
├── client_test.go
├── clientoptions.go
├── clientoptions_test.go
├── clientproxy.go
├── clientsseconnection.go
├── codecov.yml
├── connection.go
├── connection_test.go
├── connectionbase.go
├── ctxpipe.go
├── doc.go
├── go.mod
├── go.sum
├── groupmanager.go
├── httpconnection.go
├── httpmux.go
├── httpserver_test.go
├── hub.go
├── hubclients.go
├── hubconnection.go
├── hubcontext.go
├── hubcontext_test.go
├── hublifetimemanager.go
├── hubprotocol.go
├── hubprotocol_test.go
├── invocation_test.go
├── invokeclient.go
├── jsonhubprotocol.go
├── logger_test.go
├── loop.go
├── messagepackhubprotocol.go
├── messagepackhubprotocol_test.go
├── negotiateresponse.go
├── netconnection.go
├── options.go
├── party.go
├── receiver.go
├── router
├── Makefile
├── chirouter.go
├── doc.go
├── go.mod
├── go.sum
├── gorillarouter.go
├── httprouter.go
└── router_test
│ └── router_test.go
├── server.go
├── server_test.go
├── serveroptions.go
├── serveroptions_test.go
├── serversseconnection.go
├── signalr_suite_test.go
├── signalr_test
├── logger_test.go
├── netconnection_test.go
├── package-lock.json
├── package.json
├── server_test.go
├── setupJest.ts
├── spec
│ └── server.spec.ts
├── tsconfig.json
└── tsconfig.spec.json
├── streamclient.go
├── streamclient_test.go
├── streamer.go
├── streaminvocation_test.go
├── testLogConf.json
├── testingconnection_test.go
└── websocketconnection.go
/.gitattributes:
--------------------------------------------------------------------------------
1 | chatsample/public/js/signalr.js linguist-detectable=false
2 | chatsample/public/js/signalr.js.map linguist-detectable=false
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ master ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ master ]
20 | schedule:
21 | - cron: '45 8 * * 3'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'go' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
37 | # Learn more:
38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
39 |
40 | steps:
41 | - name: Checkout repository
42 | uses: actions/checkout@v2
43 |
44 | # Initializes the CodeQL tools for scanning.
45 | - name: Initialize CodeQL
46 | uses: github/codeql-action/init@v1
47 | with:
48 | languages: ${{ matrix.language }}
49 | # If you wish to specify custom queries, you can do so here or in a config file.
50 | # By default, queries listed here will override any specified in a config file.
51 | # Prefix the list here with "+" to use these queries and those in the config file.
52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
53 |
54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
55 | # If this step fails, then you should remove it and run the build manually (see below)
56 | - name: Autobuild
57 | uses: github/codeql-action/autobuild@v1
58 |
59 | # ℹ️ Command-line programs to run using the OS shell.
60 | # 📚 https://git.io/JvXDl
61 |
62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
63 | # and modify them (or add more) to build your code if your project
64 | # uses a compiled language
65 |
66 | #- run: |
67 | # make bootstrap
68 | # make release
69 |
70 | - name: Perform CodeQL Analysis
71 | uses: github/codeql-action/analyze@v1
72 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Build and Test
2 |
3 | # This workflow will run on master branch and on any pull requests targeting master
4 | on:
5 | push:
6 | branches:
7 | - master
8 | pull_request:
9 |
10 | jobs:
11 |
12 | golangci:
13 | name: golangci
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: golangci-lint
18 | uses: golangci/golangci-lint-action@v3
19 | with:
20 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
21 | version: v1.51
22 |
23 | test:
24 | name: Test and Coverage
25 | runs-on: ubuntu-latest
26 | steps:
27 | - name: Set up Go
28 | uses: actions/setup-go@v1
29 | with:
30 | go-version: 1.16
31 |
32 | - name: Check out code
33 | uses: actions/checkout@v1
34 |
35 | - name: Run Unit tests.
36 | run: make test-coverage
37 |
38 | - name: Upload Coverage report to CodeCov
39 | uses: codecov/codecov-action@v1.0.0
40 | with:
41 | token: ${{secrets.CODECOV_TOKEN}}
42 | file: ./coverage.txt
43 |
44 | # test-macos:
45 | # name: Test MacOS
46 | # runs-on: macos-11
47 | # steps:
48 | # - name: Set up Go
49 | # uses: actions/setup-go@v1
50 | # with:
51 | # go-version: 1.16
52 | #
53 | # - name: Check out code
54 | # uses: actions/checkout@v1
55 | #
56 | # - name: Run Unit tests.
57 | # run: make test
58 | #
59 | # test-windows:
60 | # name: Test Windows
61 | # runs-on: windows-2019
62 | # steps:
63 | # - name: Set up Go
64 | # uses: actions/setup-go@v1
65 | # with:
66 | # go-version: 1.16
67 | #
68 | # - name: Check out code
69 | # uses: actions/checkout@v1
70 | #
71 | # - name: Run Unit tests.
72 | # run: make test
73 |
74 | build:
75 | name: Build
76 | runs-on: ubuntu-latest
77 | needs: [golangci, test]
78 | steps:
79 | - name: Set up Go
80 | uses: actions/setup-go@v1
81 | with:
82 | go-version: 1.16
83 |
84 | - name: Check out code
85 | uses: actions/checkout@v1
86 |
87 | - name: Build
88 | run: make build
89 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 | lerna-debug.log*
8 |
9 | # Diagnostic reports (https://nodejs.org/api/report.html)
10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
11 |
12 | # Runtime data
13 | pids
14 | *.pid
15 | *.seed
16 | *.pid.lock
17 |
18 | # Directory for instrumented libs generated by jscoverage/JSCover
19 | lib-cov
20 |
21 | # Coverage directory used by tools like istanbul
22 | coverage
23 | *.lcov
24 |
25 | # nyc test coverage
26 | .nyc_output
27 |
28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
29 | .grunt
30 |
31 | # Bower dependency directory (https://bower.io/)
32 | bower_components
33 |
34 | # node-waf configuration
35 | .lock-wscript
36 |
37 | # Compiled binary addons (https://nodejs.org/api/addons.html)
38 | build/Release
39 |
40 | # Dependency directories
41 | node_modules/
42 | jspm_packages/
43 |
44 | # TypeScript v1 declaration files
45 | typings/
46 |
47 | # TypeScript cache
48 | *.tsbuildinfo
49 |
50 | # Optional npm cache directory
51 | .npm
52 |
53 | # Optional eslint cache
54 | .eslintcache
55 |
56 | # Microbundle cache
57 | .rpt2_cache/
58 | .rts2_cache_cjs/
59 | .rts2_cache_es/
60 | .rts2_cache_umd/
61 |
62 | # Optional REPL history
63 | .node_repl_history
64 |
65 | # Output of 'npm pack'
66 | *.tgz
67 |
68 | # Yarn Integrity file
69 | .yarn-integrity
70 |
71 | # dotenv environment variables file
72 | .env
73 | .env.test
74 |
75 | # parcel-bundler cache (https://parceljs.org/)
76 | .cache
77 |
78 | # Next.js build output
79 | .next
80 |
81 | # Nuxt.js build / generate output
82 | .nuxt
83 | dist
84 |
85 | # Gatsby files
86 | .cache/
87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js
88 | # https://nextjs.org/blog/next-9-1#public-directory-support
89 | # public
90 |
91 | # vuepress build output
92 | .vuepress/dist
93 |
94 | # Serverless directories
95 | .serverless/
96 |
97 | # FuseBox cache
98 | .fusebox/
99 |
100 | # DynamoDB Local files
101 | .dynamodb/
102 |
103 | # TernJS port file
104 | .tern-port
105 |
106 | # Jetbrains IDE files
107 | .idea/
108 | /.vscode
109 | /cover.out
110 | /coverage.txt
111 |
112 | signalr_test/package-lock.json
113 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | skip-dirs:
3 | - chatsample
4 | - signalr_test
5 | - router/router_test
6 | skip-files:
7 | - ".+_test\\.go$"
--------------------------------------------------------------------------------
/.run/go test signalr.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.run/go test signalr_test.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/Invokeresult.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | )
6 |
7 | // InvokeResult is the combined value/error result for async invocations. Used as channel type.
8 | type InvokeResult struct {
9 | Value interface{}
10 | Error error
11 | }
12 |
13 | // newInvokeResultChan combines a value result and an error result channel into one InvokeResult channel
14 | // The InvokeResult channel is automatically closed when both input channels are closed.
15 | func newInvokeResultChan(ctx context.Context, resultChan <-chan interface{}, errChan <-chan error) <-chan InvokeResult {
16 | ch := make(chan InvokeResult, 1)
17 | go func(ctx context.Context, ch chan InvokeResult, resultChan <-chan interface{}, errChan <-chan error) {
18 | var resultChanClosed, errChanClosed bool
19 | loop:
20 | for !resultChanClosed || !errChanClosed {
21 | select {
22 | case <-ctx.Done():
23 | break loop
24 | case value, ok := <-resultChan:
25 | if !ok {
26 | resultChanClosed = true
27 | } else {
28 | ch <- InvokeResult{
29 | Value: value,
30 | }
31 | }
32 | case err, ok := <-errChan:
33 | if !ok {
34 | errChanClosed = true
35 | } else {
36 | ch <- InvokeResult{
37 | Error: err,
38 | }
39 | }
40 | }
41 | }
42 | close(ch)
43 | }(ctx, ch, resultChan, errChan)
44 | return ch
45 | }
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 David Fowler, Philipp Seith
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PROJECT_NAME := "signalr"
2 | PKG := "github.com/philippseith/$(PROJECT_NAME)"
3 | PKG_LIST := $(shell go list ${PKG}/... | grep -v /vendor/)
4 | GO_FILES := $(shell find . -name '*.go' | grep -v /vendor/ | grep -v _test.go)
5 |
6 | .PHONY: all dep lint vet test test-coverage build clean
7 |
8 | all: build
9 |
10 | dep: ## Get the dependencies
11 | @go mod download
12 |
13 | lint: ## Lint Golang files
14 | @golangci-lint run
15 |
16 | vet: ## Run go vet
17 | @go vet ${PKG_LIST}
18 |
19 | test: ## Run unittests
20 | @go test -race -short -count=1 ${PKG_LIST}
21 |
22 | test-coverage: ## Run tests with coverage
23 | @go test -race -short -count=1 -coverpkg=. -coverprofile cover.out -covermode=atomic ${PKG_LIST}
24 | @cat cover.out >> coverage.txt
25 |
26 | build: dep ## Build the binary file
27 | @go build -i -o build/main $(PKG)
28 |
29 | clean: ## Remove previous build
30 | @rm -f $(PROJECT_NAME)/build
31 |
32 | run-chatsample: ## run the local ./chatsample server
33 | @go run ./chatsample/*.go
34 |
35 | help: ## Display this help screen
36 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
37 |
--------------------------------------------------------------------------------
/chatsample/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | _ "embed"
6 | "fmt"
7 | "log"
8 | "net/http"
9 | "os"
10 | "strings"
11 | "time"
12 |
13 | kitlog "github.com/go-kit/log"
14 |
15 | "github.com/philippseith/signalr"
16 | "github.com/philippseith/signalr/chatsample/middleware"
17 | "github.com/philippseith/signalr/chatsample/public"
18 | )
19 |
20 | type chat struct {
21 | signalr.Hub
22 | }
23 |
24 | func (c *chat) OnConnected(connectionID string) {
25 | fmt.Printf("%s connected\n", connectionID)
26 | c.Groups().AddToGroup("group", connectionID)
27 | }
28 |
29 | func (c *chat) OnDisconnected(connectionID string) {
30 | fmt.Printf("%s disconnected\n", connectionID)
31 | c.Groups().RemoveFromGroup("group", connectionID)
32 | }
33 |
34 | func (c *chat) Broadcast(message string) {
35 | // Broadcast to all clients
36 | c.Clients().Group("group").Send("receive", message)
37 | }
38 |
39 | func (c *chat) Echo(message string) {
40 | c.Clients().Caller().Send("receive", message)
41 | }
42 |
43 | func (c *chat) Panic() {
44 | panic("Don't panic!")
45 | }
46 |
47 | func (c *chat) RequestAsync(message string) <-chan map[string]string {
48 | r := make(chan map[string]string)
49 | go func() {
50 | defer close(r)
51 | time.Sleep(4 * time.Second)
52 | m := make(map[string]string)
53 | m["ToUpper"] = strings.ToUpper(message)
54 | m["ToLower"] = strings.ToLower(message)
55 | m["len"] = fmt.Sprint(len(message))
56 | r <- m
57 | }()
58 | return r
59 | }
60 |
61 | func (c *chat) RequestTuple(message string) (string, string, int) {
62 | return strings.ToUpper(message), strings.ToLower(message), len(message)
63 | }
64 |
65 | func (c *chat) DateStream() <-chan string {
66 | r := make(chan string)
67 | go func() {
68 | defer close(r)
69 | for i := 0; i < 50; i++ {
70 | r <- fmt.Sprint(time.Now().Clock())
71 | time.Sleep(time.Second)
72 | }
73 | }()
74 | return r
75 | }
76 |
77 | func (c *chat) UploadStream(upload1 <-chan int, factor float64, upload2 <-chan float64) {
78 | ok1 := true
79 | ok2 := true
80 | u1 := 0
81 | u2 := 0.0
82 | c.Echo(fmt.Sprintf("f: %v", factor))
83 | for {
84 | select {
85 | case u1, ok1 = <-upload1:
86 | if ok1 {
87 | c.Echo(fmt.Sprintf("u1: %v", u1))
88 | } else if !ok2 {
89 | c.Echo("Finished")
90 | return
91 | }
92 | case u2, ok2 = <-upload2:
93 | if ok2 {
94 | c.Echo(fmt.Sprintf("u2: %v", u2))
95 | } else if !ok1 {
96 | c.Echo("Finished")
97 | return
98 | }
99 | }
100 | }
101 | }
102 |
103 | func (c *chat) Abort() {
104 | fmt.Println("Abort")
105 | c.Hub.Abort()
106 | }
107 |
108 | //func runTCPServer(address string, hub signalr.HubInterface) {
109 | // listener, err := net.Listen("tcp", address)
110 | //
111 | // if err != nil {
112 | // fmt.Println(err)
113 | // return
114 | // }
115 | //
116 | // fmt.Printf("Listening for TCP connection on %s\n", listener.Addr())
117 | //
118 | // server, _ := signalr.NewServer(context.TODO(), signalr.UseHub(hub))
119 | //
120 | // for {
121 | // conn, err := listener.Accept()
122 | //
123 | // if err != nil {
124 | // fmt.Println(err)
125 | // break
126 | // }
127 | //
128 | // go server.Serve(context.TODO(), newNetConnection(conn))
129 | // }
130 | //}
131 |
132 | func runHTTPServer(address string, hub signalr.HubInterface) {
133 | server, _ := signalr.NewServer(context.TODO(), signalr.SimpleHubFactory(hub),
134 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stdout), false),
135 | signalr.KeepAliveInterval(2*time.Second))
136 | router := http.NewServeMux()
137 | server.MapHTTP(signalr.WithHTTPServeMux(router), "/chat")
138 |
139 | fmt.Printf("Serving public content from the embedded filesystem\n")
140 | router.Handle("/", http.FileServer(http.FS(public.FS)))
141 | fmt.Printf("Listening for websocket connections on http://%s\n", address)
142 | if err := http.ListenAndServe(address, middleware.LogRequests(router)); err != nil {
143 | log.Fatal("ListenAndServe:", err)
144 | }
145 | }
146 |
147 | func runHTTPClient(address string, receiver interface{}) error {
148 | c, err := signalr.NewClient(context.Background(), nil,
149 | signalr.WithReceiver(receiver),
150 | signalr.WithConnector(func() (signalr.Connection, error) {
151 | creationCtx, _ := context.WithTimeout(context.Background(), 2*time.Second)
152 | return signalr.NewHTTPConnection(creationCtx, address)
153 | }),
154 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stdout), false))
155 | if err != nil {
156 | return err
157 | }
158 | c.Start()
159 | fmt.Println("Client started")
160 | return nil
161 | }
162 |
163 | type receiver struct {
164 | signalr.Receiver
165 | }
166 |
167 | func (r *receiver) Receive(msg string) {
168 | fmt.Println(msg)
169 | // The silly client urges the server to end his connection after 10 seconds
170 | r.Server().Send("abort")
171 | }
172 |
173 | func main() {
174 | hub := &chat{}
175 |
176 | //go runTCPServer("127.0.0.1:8007", hub)
177 | go runHTTPServer("localhost:8086", hub)
178 | <-time.After(time.Millisecond * 2)
179 | go func() {
180 | fmt.Println(runHTTPClient("http://localhost:8086/chat", &receiver{}))
181 | }()
182 | ch := make(chan struct{})
183 | <-ch
184 | }
185 |
--------------------------------------------------------------------------------
/chatsample/middleware/log_requests.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "time"
7 | )
8 |
9 | // LogRequests writes simple request logs to STDOUT so that we can see what requests the server is handling
10 | func LogRequests(h http.Handler) http.Handler {
11 | // type our middleware as an http.HandlerFunc so that it is seen as an http.Handler
12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13 | // wrap the original response writer so we can capture response details
14 | wrappedWriter := wrapResponseWriter(w)
15 | start := time.Now() // request start time
16 |
17 | // serve the inner request
18 | h.ServeHTTP(wrappedWriter, r)
19 |
20 | // extract request/response details
21 | status := wrappedWriter.status
22 | uri := r.URL.String()
23 | method := r.Method
24 | duration := time.Since(start)
25 |
26 | // write to console
27 | fmt.Printf("%03d %s %s %v\n", status, method, uri, duration)
28 | })
29 | }
30 |
--------------------------------------------------------------------------------
/chatsample/middleware/response_writer_wrapper.go:
--------------------------------------------------------------------------------
1 | package middleware
2 |
3 | import (
4 | "bufio"
5 | "errors"
6 | "net"
7 | "net/http"
8 | )
9 |
10 | func wrapResponseWriter(w http.ResponseWriter) *responseWriterWrapper {
11 | return &responseWriterWrapper{ResponseWriter: w}
12 | }
13 |
14 | // responseWriterWrapper is a minimal wrapper for http.ResponseWriter that allows the
15 | // written HTTP status code to be captured for logging.
16 | // adapted from: https://github.com/elithrar/admission-control/blob/df0c4bf37a96d159d9181a71cee6e5485d5a50a9/request_logger.go#L11-L13
17 | type responseWriterWrapper struct {
18 | http.ResponseWriter
19 | status int
20 | wroteHeader bool
21 | }
22 |
23 | // Status provides access to the wrapped http.ResponseWriter's status
24 | func (rw *responseWriterWrapper) Status() int {
25 | return rw.status
26 | }
27 |
28 | // Header provides access to the wrapped http.ResponseWriter's header
29 | // allowing handlers to set HTTP headers on the wrapped response
30 | func (rw *responseWriterWrapper) Header() http.Header {
31 | return rw.ResponseWriter.Header()
32 | }
33 |
34 | // WriteHeader intercepts the written status code and caches it
35 | // so that we can access it later
36 | func (rw *responseWriterWrapper) WriteHeader(code int) {
37 | if rw.wroteHeader {
38 | return
39 | }
40 |
41 | rw.status = code
42 | rw.ResponseWriter.WriteHeader(code)
43 | rw.wroteHeader = true
44 |
45 | return
46 | }
47 |
48 | // Flush implements http.Flusher
49 | func (rw *responseWriterWrapper) Flush() {
50 | if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
51 | flusher.Flush()
52 | }
53 | }
54 |
55 | func (rw *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
56 | if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
57 | return hijacker.Hijack()
58 | }
59 | return nil, nil, errors.New("http.Hijacker not implemented")
60 | }
61 |
--------------------------------------------------------------------------------
/chatsample/public/fs.go:
--------------------------------------------------------------------------------
1 | package public
2 |
3 | import "embed"
4 |
5 | //go:embed *
6 | var FS embed.FS
7 |
--------------------------------------------------------------------------------
/chatsample/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
17 |
18 |
19 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/clientoptions.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/cenkalti/backoff/v4"
7 | )
8 |
9 | // WithConnection sets the Connection of the Client
10 | func WithConnection(connection Connection) func(party Party) error {
11 | return func(party Party) error {
12 | if client, ok := party.(*client); ok {
13 | if client.connectionFactory != nil {
14 | return errors.New("options WithConnection and WithConnector can not be used together")
15 | }
16 | client.conn = connection
17 | return nil
18 | }
19 | return errors.New("option WithConnection is client only")
20 | }
21 | }
22 |
23 | // WithConnector allows the Client to establish a connection
24 | // using the Connection build by the connectionFactory.
25 | // It is also used for auto reconnect if the connection is lost.
26 | func WithConnector(connectionFactory func() (Connection, error)) func(Party) error {
27 | return func(party Party) error {
28 | if client, ok := party.(*client); ok {
29 | if client.conn != nil {
30 | return errors.New("options WithConnection and WithConnector can not be used together")
31 | }
32 | client.connectionFactory = connectionFactory
33 | return nil
34 | }
35 | return errors.New("option WithConnector is client only")
36 | }
37 | }
38 |
39 | // WithReceiver sets the object which will receive server side calls to client methods (e.g. callbacks)
40 | func WithReceiver(receiver interface{}) func(Party) error {
41 | return func(party Party) error {
42 | if client, ok := party.(*client); ok {
43 | client.receiver = receiver
44 | if receiver, ok := receiver.(ReceiverInterface); ok {
45 | receiver.Init(client)
46 | }
47 | return nil
48 | }
49 | return errors.New("option WithReceiver is client only")
50 | }
51 | }
52 |
53 | // WithBackoff sets the backoff.BackOff used for repeated connection attempts in the client.
54 | // See https://pkg.go.dev/github.com/cenkalti/backoff for configuration options.
55 | // If the option is not set, backoff.NewExponentialBackOff() without any further configuration will be used.
56 | func WithBackoff(backoffFactory func() backoff.BackOff) func(party Party) error {
57 | return func(party Party) error {
58 | if client, ok := party.(*client); ok {
59 | client.backoffFactory = backoffFactory
60 | return nil
61 | }
62 | return errors.New("option WithBackoff is client only")
63 | }
64 | }
65 |
66 | // TransferFormat sets the transfer format used on the transport. Allowed values are "Text" and "Binary"
67 | func TransferFormat(format string) func(Party) error {
68 | return func(p Party) error {
69 | if c, ok := p.(*client); ok {
70 | switch format {
71 | case "Text":
72 | c.format = "json"
73 | case "Binary":
74 | c.format = "messagepack"
75 | default:
76 | return fmt.Errorf("invalid transferformat %v", format)
77 | }
78 | return nil
79 | }
80 | return errors.New("option TransferFormat is client only")
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/clientoptions_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "github.com/cenkalti/backoff/v4"
6 |
7 | . "github.com/onsi/ginkgo"
8 | . "github.com/onsi/gomega"
9 | )
10 |
11 | var _ = Describe("Client options", func() {
12 |
13 | Describe("WithConnection and WithConnector option", func() {
14 | Context("none of them is given", func() {
15 | It("NewClient should fail", func() {
16 | _, err := NewClient(context.TODO())
17 | Expect(err).To(HaveOccurred())
18 | }, 3.0)
19 | })
20 | Context("both are given", func() {
21 | It("NewClient should fail", func() {
22 | conn := NewNetConnection(context.TODO(), nil)
23 | _, err := NewClient(context.TODO(), WithConnection(conn), WithConnector(func() (Connection, error) {
24 | return conn, nil
25 | }))
26 | Expect(err).To(HaveOccurred())
27 | }, 3.0)
28 | })
29 | Context("only WithConnection is given", func() {
30 | It("NewClient should not fail", func() {
31 | conn := NewNetConnection(context.TODO(), nil)
32 | _, err := NewClient(context.TODO(), WithConnection(conn))
33 | Expect(err).NotTo(HaveOccurred())
34 | }, 3.0)
35 | })
36 | Context("only WithConnector is given", func() {
37 | It("NewClient should not fail", func() {
38 | conn := NewNetConnection(context.TODO(), nil)
39 | _, err := NewClient(context.TODO(), WithConnector(func() (Connection, error) {
40 | return conn, nil
41 | }))
42 | Expect(err).NotTo(HaveOccurred())
43 | }, 3.0)
44 | })
45 | Context("only WithBackoff is given", func() {
46 | It("NewClient should not fail", func() {
47 | conn := NewNetConnection(context.TODO(), nil)
48 | _, err := NewClient(context.TODO(), WithConnection(conn), WithBackoff(func() backoff.BackOff {
49 | return backoff.NewExponentialBackOff()
50 | }))
51 | Expect(err).NotTo(HaveOccurred())
52 | }, 3.0)
53 | })
54 | })
55 |
56 | })
57 |
--------------------------------------------------------------------------------
/clientproxy.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | //ClientProxy allows the hub to send messages to one or more of its clients
4 | type ClientProxy interface {
5 | Send(target string, args ...interface{})
6 | }
7 |
8 | type allClientProxy struct {
9 | lifetimeManager HubLifetimeManager
10 | }
11 |
12 | func (a *allClientProxy) Send(target string, args ...interface{}) {
13 | a.lifetimeManager.InvokeAll(target, args)
14 | }
15 |
16 | type singleClientProxy struct {
17 | connectionID string
18 | lifetimeManager HubLifetimeManager
19 | }
20 |
21 | func (a *singleClientProxy) Send(target string, args ...interface{}) {
22 | a.lifetimeManager.InvokeClient(a.connectionID, target, args)
23 | }
24 |
25 | type groupClientProxy struct {
26 | groupName string
27 | lifetimeManager HubLifetimeManager
28 | }
29 |
30 | func (g *groupClientProxy) Send(target string, args ...interface{}) {
31 | g.lifetimeManager.InvokeGroup(g.groupName, target, args)
32 | }
33 |
--------------------------------------------------------------------------------
/clientsseconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "io"
8 | "net/http"
9 | "net/url"
10 | "strings"
11 | )
12 |
13 | type clientSSEConnection struct {
14 | ConnectionBase
15 | reqURL string
16 | sseReader io.Reader
17 | sseWriter io.Writer
18 | }
19 |
20 | func newClientSSEConnection(address string, connectionID string, body io.ReadCloser) (*clientSSEConnection, error) {
21 | // Setup request
22 | reqURL, err := url.Parse(address)
23 | if err != nil {
24 | return nil, err
25 | }
26 | q := reqURL.Query()
27 | q.Set("id", connectionID)
28 | reqURL.RawQuery = q.Encode()
29 | c := clientSSEConnection{
30 | ConnectionBase: ConnectionBase{
31 | ctx: context.Background(),
32 | connectionID: connectionID,
33 | },
34 | reqURL: reqURL.String(),
35 | }
36 | c.sseReader, c.sseWriter = io.Pipe()
37 | go func() {
38 | defer func() { closeResponseBody(body) }()
39 | p := make([]byte, 1<<15)
40 | loop:
41 | for {
42 | n, err := body.Read(p)
43 | if err != nil {
44 | break loop
45 | }
46 | lines := strings.Split(string(p[:n]), "\n")
47 | for _, line := range lines {
48 | line = strings.Trim(line, "\r\t ")
49 | // Ignore everything but data
50 | if strings.Index(line, "data:") != 0 {
51 | continue
52 | }
53 | json := strings.Replace(strings.Trim(line, "\r"), "data:", "", 1)
54 | // Spec says: If it starts with Space, remove it
55 | if len(json) > 0 && json[0] == ' ' {
56 | json = json[1:]
57 | }
58 | _, err = c.sseWriter.Write([]byte(json))
59 | if err != nil {
60 | break loop
61 | }
62 | }
63 | }
64 | }()
65 | return &c, nil
66 | }
67 |
68 | func (c *clientSSEConnection) Read(p []byte) (n int, err error) {
69 | return c.sseReader.Read(p)
70 | }
71 |
72 | func (c *clientSSEConnection) Write(p []byte) (n int, err error) {
73 | req, err := http.NewRequest("POST", c.reqURL, bytes.NewReader(p))
74 | if err != nil {
75 | return 0, err
76 | }
77 | client := &http.Client{}
78 | resp, err := client.Do(req)
79 | if err != nil {
80 | return 0, err
81 | }
82 | if resp.StatusCode != 200 {
83 | err = fmt.Errorf("POST %v -> %v", c.reqURL, resp.Status)
84 | }
85 | closeResponseBody(resp.Body)
86 | return len(p), err
87 | }
88 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: true
3 |
4 | coverage:
5 | precision: 2
6 | round: down
7 | range: "60...100"
8 |
9 | parsers:
10 | gcov:
11 | branch_detection:
12 | conditional: yes
13 | loop: yes
14 | method: no
15 | macro: no
16 |
17 | comment:
18 | layout: "reach,diff,flags,files,footer"
19 | behavior: default
20 | require_changes: false
21 |
22 | ignore:
23 | - chatsample/**/*
24 | - signalr_test/**/*
25 |
--------------------------------------------------------------------------------
/connection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "io"
6 | )
7 |
8 | // Connection describes a connection between signalR client and server
9 | type Connection interface {
10 | io.Reader
11 | io.Writer
12 | Context() context.Context
13 | ConnectionID() string
14 | SetConnectionID(id string)
15 | }
16 |
17 | // TransferMode is either TextTransferMode or BinaryTransferMode
18 | type TransferMode int
19 |
20 | // MessageType constants.
21 | const (
22 | // TextTransferMode is for UTF-8 encoded text messages like JSON.
23 | TextTransferMode TransferMode = iota + 1
24 | // BinaryTransferMode is for binary messages like MessagePack.
25 | BinaryTransferMode
26 | )
27 |
28 | // ConnectionWithTransferMode is a Connection with TransferMode (e.g. Websocket)
29 | type ConnectionWithTransferMode interface {
30 | TransferMode() TransferMode
31 | SetTransferMode(transferMode TransferMode)
32 | }
33 |
--------------------------------------------------------------------------------
/connection_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "time"
7 |
8 | . "github.com/onsi/ginkgo"
9 | . "github.com/onsi/gomega"
10 | )
11 |
12 | var _ = Describe("Connection", func() {
13 |
14 | Describe("Connection closed", func() {
15 | var server Server
16 | var conn *testingConnection
17 | BeforeEach(func(done Done) {
18 | server, conn = connect(&Hub{})
19 | close(done)
20 | })
21 | AfterEach(func(done Done) {
22 | server.cancel()
23 | close(done)
24 | })
25 | Context("When the connection is closed", func() {
26 | It("should close the connection and not answer an invocation", func(done Done) {
27 | conn.ClientSend(`{"type":7}`)
28 | conn.ClientSend(`{"type":1,"invocationId": "123","target":"unknownFunc"}`)
29 | // When the connection is closed, the server should either send a closeMessage or nothing at all
30 | select {
31 | case message := <-conn.received:
32 | Expect(message.(closeMessage)).NotTo(BeNil())
33 | case <-time.After(100 * time.Millisecond):
34 | }
35 | close(done)
36 | })
37 | })
38 | Context("When the connection is closed with an invalid close message", func() {
39 | It("should close the connection and should not answer an invocation", func(done Done) {
40 | conn.ClientSend(`{"type":7,"error":1}`)
41 | conn.ClientSend(`{"type":1,"invocationId": "123","target":"unknownFunc"}`)
42 | // When the connection is closed, the server should either send a closeMessage or nothing at all
43 | select {
44 | case message := <-conn.received:
45 | Expect(message.(closeMessage)).NotTo(BeNil())
46 | case <-time.After(100 * time.Millisecond):
47 | }
48 | close(done)
49 | })
50 | })
51 | })
52 | })
53 |
54 | var _ = Describe("Protocol", func() {
55 | var server Server
56 | var conn *testingConnection
57 | BeforeEach(func(done Done) {
58 | server, conn = connect(&Hub{})
59 | close(done)
60 | })
61 | AfterEach(func(done Done) {
62 | server.cancel()
63 | close(done)
64 | })
65 | Describe("Invalid messages", func() {
66 | Context("When a message with invalid id is sent", func() {
67 | It("should close the connection with an error", func(done Done) {
68 | conn.ClientSend(`{"type":99}`)
69 | select {
70 | case message := <-conn.received:
71 | Expect(message).To(BeAssignableToTypeOf(closeMessage{}))
72 | Expect(message.(closeMessage).Error).NotTo(BeNil())
73 | case <-time.After(100 * time.Millisecond):
74 | Fail("timed out")
75 | }
76 | close(done)
77 | })
78 | })
79 | })
80 |
81 | Describe("Ping", func() {
82 | Context("When a ping is received", func() {
83 | It("should ignore it", func(done Done) {
84 | conn.ClientSend(`{"type":6}`)
85 | select {
86 | case <-conn.received:
87 | Fail("ping not ignored")
88 | case <-time.After(100 * time.Millisecond):
89 | }
90 | close(done)
91 | })
92 | })
93 | })
94 | })
95 |
96 | type handshakeHub struct {
97 | Hub
98 | }
99 |
100 | func (h *handshakeHub) Shake() {
101 | shakeQueue <- "Shake()"
102 | }
103 |
104 | var shakeQueue = make(chan string, 10)
105 |
106 | func getTestBedHandshake() (*testingConnection, context.CancelFunc) {
107 | ctx, cancel := context.WithCancel(context.Background())
108 | server, _ := NewServer(ctx, SimpleHubFactory(&handshakeHub{}), testLoggerOption())
109 | conn := newTestingConnection()
110 | go func() { _ = server.Serve(conn) }()
111 | return conn, cancel
112 | }
113 |
114 | var _ = Describe("Handshake", func() {
115 | Context("When the handshake is sent as one message to the server", func() {
116 | It("should be connected", func(done Done) {
117 | conn, cancel := getTestBedHandshake()
118 | conn.ClientSend(`{"protocol": "json","version": 1}`)
119 | conn.ClientSend(`{"type":1,"invocationId": "123A","target":"shake"}`)
120 | Expect(<-shakeQueue).To(Equal("Shake()"))
121 | cancel()
122 | close(done)
123 | })
124 | })
125 | Context("When the handshake is sent as partial message to the server", func() {
126 | It("should be connected", func(done Done) {
127 | conn, cancel := getTestBedHandshake()
128 | _, _ = conn.cliWriter.Write([]byte(`{"protocol"`))
129 | conn.ClientSend(`: "json","version": 1}`)
130 | conn.ClientSend(`{"type":1,"invocationId": "123B","target":"shake"}`)
131 | Expect(<-shakeQueue).To(Equal("Shake()"))
132 | cancel()
133 | close(done)
134 | })
135 | })
136 | Context("When an invalid handshake is sent as partial message to the server", func() {
137 | It("should not be connected", func(done Done) {
138 | conn, cancel := getTestBedHandshake()
139 | _, _ = conn.cliWriter.Write([]byte(`{"protocol"`))
140 | // Opening curly brace is invalid
141 | conn.ClientSend(`{: "json","version": 1}`)
142 | conn.ClientSend(`{"type":1,"invocationId": "123C","target":"shake"}`)
143 | select {
144 | case <-shakeQueue:
145 | Fail("server connected with invalid handshake")
146 | case <-time.After(100 * time.Millisecond):
147 | }
148 | cancel()
149 | close(done)
150 | })
151 | })
152 | Context("When a handshake is sent with an unsupported protocol", func() {
153 | It("should return an error handshake response and be not connected", func(done Done) {
154 | conn, cancel := getTestBedHandshake()
155 | conn.ClientSend(`{"protocol": "bson","version": 1}`)
156 | response, err := conn.ClientReceive()
157 | Expect(err).To(BeNil())
158 | Expect(response).NotTo(BeNil())
159 | jsonMap := make(map[string]interface{})
160 | err = json.Unmarshal([]byte(response), &jsonMap)
161 | Expect(err).To(BeNil())
162 | Expect(jsonMap["error"]).NotTo(BeNil())
163 | conn.ClientSend(`{"type":1,"invocationId": "123D","target":"shake"}`)
164 | select {
165 | case <-shakeQueue:
166 | Fail("server connected with invalid handshake")
167 | case <-time.After(100 * time.Millisecond):
168 | }
169 | cancel()
170 | close(done)
171 | })
172 | })
173 | Context("When the connection fails before the server can receive handshake request", func() {
174 | It("should not be connected", func(done Done) {
175 | conn, cancel := getTestBedHandshake()
176 | conn.SetFailRead("failed read in handshake")
177 | conn.ClientSend(`{"protocol": "json","version": 1}`)
178 | conn.ClientSend(`{"type":1,"invocationId": "123E","target":"shake"}`)
179 | select {
180 | case <-shakeQueue:
181 | Fail("server connected with fail before handshake")
182 | case <-time.After(100 * time.Millisecond):
183 | }
184 | cancel()
185 | close(done)
186 | })
187 | })
188 | Context("When the handshake is received by the server but the connection fails when the response should be sent ", func() {
189 | It("should not be connected", func(done Done) {
190 | conn, cancel := getTestBedHandshake()
191 | conn.SetFailWrite("failed write in handshake")
192 | conn.ClientSend(`{"protocol": "json","version": 1}`)
193 | conn.ClientSend(`{"type":1,"invocationId": "123F","target":"shake"}`)
194 | select {
195 | case <-shakeQueue:
196 | Fail("server connected with fail before handshake")
197 | case <-time.After(100 * time.Millisecond):
198 | }
199 | cancel()
200 | close(done)
201 | })
202 | })
203 | Context("When the handshake with an unsupported protocol is received by the server but the connection fails when the response should be sent ", func() {
204 | It("should not be connected", func(done Done) {
205 | conn, cancel := getTestBedHandshake()
206 | conn.SetFailWrite("failed write in handshake")
207 | conn.ClientSend(`{"protocol": "bson","version": 1}`)
208 | conn.ClientSend(`{"type":1,"invocationId": "123G","target":"shake"}`)
209 | select {
210 | case <-shakeQueue:
211 | Fail("server connected with fail before handshake")
212 | case <-time.After(100 * time.Millisecond):
213 | }
214 | cancel()
215 | close(done)
216 | })
217 | })
218 | Context("When the handshake connection is initiated, but the client does not send a handshake request within the handshake timeout ", func() {
219 | It("should not be connected", func(done Done) {
220 | server, _ := NewServer(context.TODO(), SimpleHubFactory(&handshakeHub{}), HandshakeTimeout(time.Millisecond*100), testLoggerOption())
221 | conn := newTestingConnection()
222 | go func() { _ = server.Serve(conn) }()
223 | time.Sleep(time.Millisecond * 200)
224 | conn.ClientSend(`{"protocol": "json","version": 1}`)
225 | conn.ClientSend(`{"type":1,"invocationId": "123H","target":"shake"}`)
226 | select {
227 | case <-shakeQueue:
228 | Fail("server connected with fail before handshake")
229 | case <-time.After(100 * time.Millisecond):
230 | }
231 | server.cancel()
232 | close(done)
233 | }, 2.0)
234 | })
235 | })
236 |
--------------------------------------------------------------------------------
/connectionbase.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "sync"
6 | )
7 |
8 | // ConnectionBase is a baseclass for implementers of the Connection interface.
9 | type ConnectionBase struct {
10 | mx sync.RWMutex
11 | ctx context.Context
12 | connectionID string
13 | }
14 |
15 | // NewConnectionBase creates a new ConnectionBase
16 | func NewConnectionBase(ctx context.Context, connectionID string) *ConnectionBase {
17 | cb := &ConnectionBase{
18 | ctx: ctx,
19 | connectionID: connectionID,
20 | }
21 | return cb
22 | }
23 |
24 | // Context can be used to wait for cancellation of the Connection
25 | func (cb *ConnectionBase) Context() context.Context {
26 | cb.mx.RLock()
27 | defer cb.mx.RUnlock()
28 | return cb.ctx
29 | }
30 |
31 | // ConnectionID is the ID of the connection.
32 | func (cb *ConnectionBase) ConnectionID() string {
33 | cb.mx.RLock()
34 | defer cb.mx.RUnlock()
35 | return cb.connectionID
36 | }
37 |
38 | // SetConnectionID sets the ConnectionID
39 | func (cb *ConnectionBase) SetConnectionID(id string) {
40 | cb.mx.Lock()
41 | defer cb.mx.Unlock()
42 | cb.connectionID = id
43 | }
44 |
45 | // ReadWriteWithContext is a wrapper to make blocking io.Writer / io.Reader cancelable.
46 | // It can be used to implement cancellation of connections.
47 | // ReadWriteWithContext will return when either the Read/Write operation has ended or ctx has been canceled.
48 | // doRW func() (int, error)
49 | // doRW should contain the Read/Write operation.
50 | // unblockRW func()
51 | // unblockRW should contain the operation to unblock the Read/Write operation.
52 | // If there is no way to unblock the operation, one goroutine will leak when ctx is canceled.
53 | // As the standard use case when ReadWriteWithContext is canceled is the cancellation of a connection this leak
54 | // will be problematic on heavily used servers with uncommon connection types. Luckily, the standard connection types
55 | // for ServerSentEvents, Websockets and common net.Conn connections can be unblocked.
56 | func ReadWriteWithContext(ctx context.Context, doRW func() (int, error), unblockRW func()) (int, error) {
57 | if ctx.Err() != nil {
58 | return 0, ctx.Err()
59 | }
60 | resultChan := make(chan RWJobResult, 1)
61 | go func() {
62 | n, err := doRW()
63 | resultChan <- RWJobResult{n: n, err: err}
64 | close(resultChan)
65 | }()
66 | select {
67 | case <-ctx.Done():
68 | unblockRW()
69 | return 0, ctx.Err()
70 | case r := <-resultChan:
71 | return r.n, r.err
72 | }
73 | }
74 |
75 | // RWJobResult can be used to send the result of an io.Writer / io.Reader operation over a channel.
76 | // Use it for special connection types, where ReadWriteWithContext does not fit all needs.
77 | type RWJobResult struct {
78 | n int
79 | err error
80 | }
81 |
--------------------------------------------------------------------------------
/ctxpipe.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "io"
7 | "sync"
8 | )
9 |
10 | // onceError is an object that will only store an error once.
11 | type onceError struct {
12 | sync.Mutex // guards following
13 | err error
14 | }
15 |
16 | func (a *onceError) Store(err error) {
17 | a.Lock()
18 | defer a.Unlock()
19 | if a.err != nil {
20 | return
21 | }
22 | a.err = err
23 | }
24 | func (a *onceError) Load() error {
25 | a.Lock()
26 | defer a.Unlock()
27 | return a.err
28 | }
29 |
30 | // ErrClosedPipe is the error used for read or write operations on a closed pipe.
31 | var ErrClosedPipe = errors.New("io: read/write on closed pipe")
32 |
33 | // A pipe is the shared pipe structure underlying PipeReader and PipeWriter.
34 | type pipe struct {
35 | wrMu sync.Mutex // Serializes Write operations
36 | wrCh chan []byte
37 | rdCh chan int
38 |
39 | once sync.Once // Protects closing done
40 | ctx context.Context
41 | cancel context.CancelFunc
42 | rErr onceError
43 | wErr onceError
44 | }
45 |
46 | func (p *pipe) Read(b []byte) (n int, err error) {
47 | select {
48 | case <-p.ctx.Done():
49 | return 0, p.readCloseError()
50 | default:
51 | }
52 |
53 | select {
54 | case bw := <-p.wrCh:
55 | nr := copy(b, bw)
56 | p.rdCh <- nr
57 | return nr, nil
58 | case <-p.ctx.Done():
59 | return 0, p.readCloseError()
60 | }
61 | }
62 |
63 | func (p *pipe) readCloseError() error {
64 | rErr := p.rErr.Load()
65 | if wErr := p.wErr.Load(); rErr == nil && wErr != nil {
66 | return wErr
67 | }
68 | return ErrClosedPipe
69 | }
70 |
71 | func (p *pipe) CloseRead(err error) error {
72 | if err == nil {
73 | err = ErrClosedPipe
74 | }
75 | p.rErr.Store(err)
76 | p.once.Do(func() { p.cancel() })
77 | return nil
78 | }
79 |
80 | func (p *pipe) Write(b []byte) (n int, err error) {
81 | select {
82 | case <-p.ctx.Done():
83 | return 0, p.writeCloseError()
84 | default:
85 | p.wrMu.Lock()
86 | defer p.wrMu.Unlock()
87 | }
88 |
89 | for once := true; once || len(b) > 0; once = false {
90 | select {
91 | case p.wrCh <- b:
92 | nw := <-p.rdCh
93 | b = b[nw:]
94 | n += nw
95 | case <-p.ctx.Done():
96 | return n, p.writeCloseError()
97 | }
98 | }
99 | return n, nil
100 | }
101 |
102 | func (p *pipe) writeCloseError() error {
103 | wErr := p.wErr.Load()
104 | if rErr := p.rErr.Load(); wErr == nil && rErr != nil {
105 | return rErr
106 | }
107 | return ErrClosedPipe
108 | }
109 |
110 | func (p *pipe) CloseWrite(err error) error {
111 | if err == nil {
112 | err = io.EOF
113 | }
114 | p.wErr.Store(err)
115 | p.once.Do(func() { p.cancel() })
116 | return nil
117 | }
118 |
119 | // A PipeReader is the read half of a pipe.
120 | type PipeReader struct {
121 | p *pipe
122 | }
123 |
124 | // Read implements the standard Read interface:
125 | // it reads data from the pipe, blocking until a writer
126 | // arrives or the write end is closed.
127 | // If the write end is closed with an error, that error is
128 | // returned as err; otherwise err is EOF.
129 | func (r *PipeReader) Read(data []byte) (n int, err error) {
130 | return r.p.Read(data)
131 | }
132 |
133 | // Close closes the reader; subsequent writes to the
134 | // write half of the pipe will return the error ErrClosedPipe.
135 | func (r *PipeReader) Close() error {
136 | return r.CloseWithError(nil)
137 | }
138 |
139 | // CloseWithError closes the reader; subsequent writes
140 | // to the write half of the pipe will return the error err.
141 | //
142 | // CloseWithError never overwrites the previous error if it exists
143 | // and always returns nil.
144 | func (r *PipeReader) CloseWithError(err error) error {
145 | return r.p.CloseRead(err)
146 | }
147 |
148 | // A PipeWriter is the write half of a pipe.
149 | type PipeWriter struct {
150 | p *pipe
151 | }
152 |
153 | // Write implements the standard Write interface:
154 | // it writes data to the pipe, blocking until one or more readers
155 | // have consumed all the data or the read end is closed.
156 | // If the read end is closed with an error, that err is
157 | // returned as err; otherwise err is ErrClosedPipe.
158 | func (w *PipeWriter) Write(data []byte) (n int, err error) {
159 | return w.p.Write(data)
160 | }
161 |
162 | // Close closes the writer; subsequent reads from the
163 | // read half of the pipe will return no bytes and EOF.
164 | func (w *PipeWriter) Close() error {
165 | return w.CloseWithError(nil)
166 | }
167 |
168 | // CloseWithError closes the writer; subsequent reads from the
169 | // read half of the pipe will return no bytes and the error err,
170 | // or EOF if err is nil.
171 | //
172 | // CloseWithError never overwrites the previous error if it exists
173 | // and always returns nil.
174 | func (w *PipeWriter) CloseWithError(err error) error {
175 | return w.p.CloseWrite(err)
176 | }
177 |
178 | // CtxPipe creates a synchronous in-memory pipe.
179 | // It can be used to connect code expecting an io.Reader
180 | // with code expecting an io.Writer.
181 | //
182 | // By canceling the context, Read and Write can be canceled
183 | //
184 | // Reads and Writes on the pipe are matched one to one
185 | // except when multiple Reads are needed to consume a single Write.
186 | // That is, each Write to the PipeWriter blocks until it has satisfied
187 | // one or more Reads from the PipeReader that fully consume
188 | // the written data.
189 | // The data is copied directly from the Write to the corresponding
190 | // Read (or Reads); there is no internal buffering.
191 | //
192 | // It is safe to call Read and Write in parallel with each other or with Close.
193 | // Parallel calls to Read and parallel calls to Write are also safe:
194 | // the individual calls will be gated sequentially.
195 | func CtxPipe(ctx context.Context) (*PipeReader, *PipeWriter) {
196 | p := &pipe{
197 | wrCh: make(chan []byte),
198 | rdCh: make(chan int),
199 | }
200 | p.ctx, p.cancel = context.WithCancel(ctx)
201 | return &PipeReader{p}, &PipeWriter{p}
202 | }
203 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package signalr contains a SignalR client and a SignalR server.
3 | Both support the transport types Websockets and Server-Sent Events
4 | and the transfer formats Text (JSON) and Binary (MessagePack).
5 |
6 | # Basics
7 |
8 | The SignalR Protocol is a protocol for two-way RPC over any stream- or message-based transport.
9 | Either party in the connection may invoke procedures on the other party,
10 | and procedures can return zero or more results or an error.
11 | Typically, SignalR connections are HTTP-based, but it is dead simple to implement a signalr.Connection on any transport
12 | that supports io.Reader and io.Writer.
13 |
14 | # Client
15 |
16 | A Client can be used in client side code to access server methods. From an existing connection, it can be created with NewClient().
17 |
18 | // NewClient with raw TCP connection and MessagePack encoding
19 | conn, err := net.Dial("tcp", "example.com:6502")
20 | client := NewClient(ctx,
21 | WithConnection(NewNetConnection(ctx, conn)),
22 | TransferFormat("Binary),
23 | WithReceiver(receiver))
24 |
25 | client.Start()
26 |
27 | A special case is NewHTTPClient(), which creates a Client from a server address and negotiates with the server
28 | which kind of connection (Websockets, Server-Sent Events) will be used.
29 |
30 | // Configurable HTTP connection
31 | conn, err := NewHTTPConnection(ctx, "http://example.com/hub", WithHTTPHeaders(..))
32 | // Client with JSON encoding
33 | client, err := NewClient(ctx,
34 | WithConnection(conn),
35 | TransferFormat("Text"),
36 | WithReceiver(receiver))
37 |
38 | client.Start()
39 |
40 | The object which will receive server callbacks is passed to NewClient() by using the WithReceiver option.
41 | After calling client.Start(), the client is ready to call server methods or to receive callbacks.
42 |
43 | # Server
44 |
45 | A Server provides the public methods of a server side class over signalr to the client.
46 | Such a server side class is called a hub and must implement HubInterface.
47 | It is reasonable to derive your hubs from the Hub struct type, which already implements HubInterface.
48 | Servers for arbitrary connection types can be created with NewServer().
49 |
50 | // Typical server with log level debug to Stderr
51 | server, err := NewServer(ctx, SimpleHubFactory(hub), Logger(log.NewLogfmtLogger(os.Stderr), true))
52 |
53 | To serve a connection, call server.Serve(connection) in a goroutine. Serve ends when the connection is closed or the
54 | servers context is canceled.
55 |
56 | // Serving over TCP, accepting client who use MessagePack or JSON
57 | addr, _ := net.ResolveTCPAddr("tcp", "localhost:6502")
58 | listener, _ := net.ListenTCP("tcp", addr)
59 | tcpConn, _ := listener.Accept()
60 | go server.Serve(NewNetConnection(conn))
61 |
62 | To serve a HTTP connection, use server.MapHTTP(), which connects the server with a path in an http.ServeMux.
63 | The server then automatically negotiates which kind of connection (Websockets, Server-Sent Events) will be used.
64 |
65 | // build a signalr.Server using your hub
66 | // and any server options you may need
67 | server, _ := signalr.NewServer(ctx,
68 | signalr.SimpleHubFactory(&AppHub{})
69 | signalr.KeepAliveInterval(2*time.Second),
70 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stderr), true))
71 | )
72 |
73 | // create a new http.ServerMux to handle your app's http requests
74 | router := http.NewServeMux()
75 |
76 | // ask the signalr server to map it's server
77 | // api routes to your custom baseurl
78 | server.MapHTTP(signalr.WithHTTPServeMux(router), "/hub")
79 |
80 | // in addition to mapping the signalr routes
81 | // your mux will need to serve the static files
82 | // which make up your client-side app, including
83 | // the signalr javascript files. here is an example
84 | // of doing that using a local `public` package
85 | // which was created with the go:embed directive
86 | //
87 | // fmt.Printf("Serving static content from the embedded filesystem\n")
88 | // router.Handle("/", http.FileServer(http.FS(public.FS)))
89 |
90 | // bind your mux to a given address and start handling requests
91 | fmt.Printf("Listening for websocket connections on http://%s\n", address)
92 | if err := http.ListenAndServe(address, router); err != nil {
93 | log.Fatal("ListenAndServe:", err)
94 | }
95 |
96 | # Supported method signatures
97 |
98 | The SignalR protocol constrains the signature of hub or receiver methods that can be used over SignalR.
99 | All methods with serializable types as parameters and return types are supported.
100 | Methods with multiple return values are not generally supported, but returning one or no value and an optional error is supported.
101 |
102 | // Simple signatures for hub/receiver methods
103 | func (mh *MathHub) Divide(a, b float64) (float64, error) // error on division by zero
104 | func (ah *AlgoHub) Sort(values []string) []string
105 | func (ah *AlgoHub) FindKey(value []string, dict map[int][]string) (int, error) // error on not found
106 | func (receiver *View) DisplayServerValue(value interface{}) // will work for every serializable value
107 |
108 | Methods which return a single sending channel (<-chan), and optionally an error, are used to initiate callee side streaming.
109 | The caller will receive the contents of the channel as stream.
110 | When the returned channel is closed, the stream will be completed.
111 |
112 | // Streaming methods
113 | func (n *Netflix) Stream(show string, season, episode int) (<-chan []byte, error) // error on password shared
114 |
115 | Methods with one or multiple receiving channels (chan<-) as parameters are used as receivers for caller side streaming.
116 | The caller invokes this method and pushes one or multiple streams to the callee. The method should end when all channels
117 | are closed. A channel is closed by the server when the assigned stream is completed.
118 | The methods which return a channel are not supported.
119 |
120 | // Caller side streaming
121 | func (mh *MathHub) MultiplyAndSum(a, b chan<- float64) float64
122 |
123 | In most cases, the caller will be the client and the callee the server. But the vice versa case is also possible.
124 | */
125 | package signalr
126 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippseith/signalr
2 |
3 | go 1.16
4 |
5 | require (
6 | github.com/cenkalti/backoff/v4 v4.2.1
7 | github.com/dave/jennifer v1.6.1
8 | github.com/go-kit/log v0.2.1
9 | github.com/google/uuid v1.3.0
10 | github.com/onsi/ginkgo v1.12.1
11 | github.com/onsi/gomega v1.11.0
12 | github.com/stretchr/testify v1.8.4
13 | github.com/teivah/onecontext v1.3.0
14 | github.com/vmihailenco/msgpack/v5 v5.3.5
15 | nhooyr.io/websocket v1.8.7
16 | )
17 |
18 | require (
19 | github.com/fsnotify/fsnotify v1.6.0 // indirect
20 | github.com/klauspost/compress v1.16.6 // indirect
21 | golang.org/x/text v0.10.0 // indirect
22 | )
23 |
--------------------------------------------------------------------------------
/groupmanager.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | // GroupManager manages the client groups of the hub
4 | type GroupManager interface {
5 | AddToGroup(groupName string, connectionID string)
6 | RemoveFromGroup(groupName string, connectionID string)
7 | }
8 |
9 | type defaultGroupManager struct {
10 | lifetimeManager HubLifetimeManager
11 | }
12 |
13 | func (d *defaultGroupManager) AddToGroup(groupName string, connectionID string) {
14 | d.lifetimeManager.AddToGroup(groupName, connectionID)
15 | }
16 |
17 | func (d *defaultGroupManager) RemoveFromGroup(groupName string, connectionID string) {
18 | d.lifetimeManager.RemoveFromGroup(groupName, connectionID)
19 | }
20 |
--------------------------------------------------------------------------------
/httpconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "net/http"
9 | "net/url"
10 | "path"
11 |
12 | "nhooyr.io/websocket"
13 | )
14 |
15 | // Doer is the *http.Client interface
16 | type Doer interface {
17 | Do(req *http.Request) (*http.Response, error)
18 | }
19 |
20 | type httpConnection struct {
21 | client Doer
22 | headers func() http.Header
23 | }
24 |
25 | // WithHTTPClient sets the http client used to connect to the signalR server.
26 | // The client is only used for http requests. It is not used for the websocket connection.
27 | func WithHTTPClient(client Doer) func(*httpConnection) error {
28 | return func(c *httpConnection) error {
29 | c.client = client
30 | return nil
31 | }
32 | }
33 |
34 | // WithHTTPHeaders sets the function for providing request headers for HTTP and websocket requests
35 | func WithHTTPHeaders(headers func() http.Header) func(*httpConnection) error {
36 | return func(c *httpConnection) error {
37 | c.headers = headers
38 | return nil
39 | }
40 | }
41 |
42 | // NewHTTPConnection creates a signalR HTTP Connection for usage with a Client.
43 | // ctx can be used to cancel the SignalR negotiation during the creation of the Connection
44 | // but not the Connection itself.
45 | func NewHTTPConnection(ctx context.Context, address string, options ...func(*httpConnection) error) (Connection, error) {
46 | httpConn := &httpConnection{}
47 |
48 | for _, option := range options {
49 | if option != nil {
50 | if err := option(httpConn); err != nil {
51 | return nil, err
52 | }
53 | }
54 | }
55 |
56 | if httpConn.client == nil {
57 | httpConn.client = http.DefaultClient
58 | }
59 |
60 | reqURL, err := url.Parse(address)
61 | if err != nil {
62 | return nil, err
63 | }
64 |
65 | negotiateURL := *reqURL
66 | negotiateURL.Path = path.Join(negotiateURL.Path, "negotiate")
67 | req, err := http.NewRequestWithContext(ctx, "POST", negotiateURL.String(), nil)
68 | if err != nil {
69 | return nil, err
70 | }
71 |
72 | if httpConn.headers != nil {
73 | req.Header = httpConn.headers()
74 | }
75 |
76 | resp, err := httpConn.client.Do(req)
77 | if err != nil {
78 | return nil, err
79 | }
80 | defer func() { closeResponseBody(resp.Body) }()
81 |
82 | if resp.StatusCode != 200 {
83 | return nil, fmt.Errorf("%v %v -> %v", req.Method, req.URL.String(), resp.Status)
84 | }
85 |
86 | body, err := io.ReadAll(resp.Body)
87 | if err != nil {
88 | return nil, err
89 | }
90 |
91 | nr := negotiateResponse{}
92 | if err := json.Unmarshal(body, &nr); err != nil {
93 | return nil, err
94 | }
95 |
96 | q := reqURL.Query()
97 | q.Set("id", nr.ConnectionID)
98 | reqURL.RawQuery = q.Encode()
99 |
100 | // Select the best connection
101 | var conn Connection
102 | switch {
103 | case nr.getTransferFormats("WebTransports") != nil:
104 | // TODO
105 |
106 | case nr.getTransferFormats("WebSockets") != nil:
107 | wsURL := reqURL
108 |
109 | // switch to wss for secure connection
110 | if reqURL.Scheme == "https" {
111 | wsURL.Scheme = "wss"
112 | } else {
113 | wsURL.Scheme = "ws"
114 | }
115 |
116 | opts := &websocket.DialOptions{}
117 |
118 | if httpConn.headers != nil {
119 | opts.HTTPHeader = httpConn.headers()
120 | } else {
121 | opts.HTTPHeader = http.Header{}
122 | }
123 |
124 | for _, cookie := range resp.Cookies() {
125 | opts.HTTPHeader.Add("Cookie", cookie.String())
126 | }
127 |
128 | ws, _, err := websocket.Dial(ctx, wsURL.String(), opts)
129 | if err != nil {
130 | return nil, err
131 | }
132 |
133 | // TODO think about if the API should give the possibility to cancel this connection
134 | conn = newWebSocketConnection(context.Background(), nr.ConnectionID, ws)
135 |
136 | case nr.getTransferFormats("ServerSentEvents") != nil:
137 | req, err := http.NewRequest("GET", reqURL.String(), nil)
138 | if err != nil {
139 | return nil, err
140 | }
141 |
142 | if httpConn.headers != nil {
143 | req.Header = httpConn.headers()
144 | }
145 | req.Header.Set("Accept", "text/event-stream")
146 |
147 | resp, err := httpConn.client.Do(req)
148 | if err != nil {
149 | return nil, err
150 | }
151 |
152 | conn, err = newClientSSEConnection(address, nr.ConnectionID, resp.Body)
153 | if err != nil {
154 | return nil, err
155 | }
156 | }
157 |
158 | return conn, nil
159 | }
160 |
161 | // closeResponseBody reads a http response body to the end and closes it
162 | // See https://blog.cubieserver.de/2022/http-connection-reuse-in-go-clients/
163 | // The body needs to be fully read and closed, otherwise the connection will not be reused
164 | func closeResponseBody(body io.ReadCloser) {
165 | _, _ = io.Copy(io.Discard, body)
166 | _ = body.Close()
167 | }
168 |
--------------------------------------------------------------------------------
/httpmux.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "crypto/rand"
5 | "encoding/base64"
6 | "encoding/json"
7 | "fmt"
8 | "net/http"
9 | "strconv"
10 | "strings"
11 | "sync"
12 | "time"
13 |
14 | "github.com/teivah/onecontext"
15 | "nhooyr.io/websocket"
16 | )
17 |
18 | type httpMux struct {
19 | mx sync.RWMutex
20 | connectionMap map[string]Connection
21 | server Server
22 | }
23 |
24 | func newHTTPMux(server Server) *httpMux {
25 | return &httpMux{
26 | connectionMap: make(map[string]Connection),
27 | server: server,
28 | }
29 | }
30 |
31 | func (h *httpMux) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
32 | switch request.Method {
33 | case "POST":
34 | h.handlePost(writer, request)
35 | case "GET":
36 | h.handleGet(writer, request)
37 | default:
38 | writer.WriteHeader(http.StatusBadRequest)
39 | }
40 | }
41 |
42 | func (h *httpMux) handlePost(writer http.ResponseWriter, request *http.Request) {
43 | connectionID := request.URL.Query().Get("id")
44 | if connectionID == "" {
45 | writer.WriteHeader(http.StatusBadRequest)
46 | return
47 | }
48 | info, _ := h.server.prefixLoggers("")
49 | for {
50 | h.mx.RLock()
51 | c, ok := h.connectionMap[connectionID]
52 | h.mx.RUnlock()
53 | if ok {
54 | // Connection is initiated
55 | switch conn := c.(type) {
56 | case *serverSSEConnection:
57 | writer.WriteHeader(conn.consumeRequest(request))
58 | return
59 | case *negotiateConnection:
60 | // connection start initiated but not completed
61 | default:
62 | // ConnectionID already used for WebSocket(?)
63 | writer.WriteHeader(http.StatusConflict)
64 | return
65 | }
66 | } else {
67 | writer.WriteHeader(http.StatusNotFound)
68 | return
69 | }
70 | <-time.After(10 * time.Millisecond)
71 | _ = info.Log("event", "handlePost for SSE connection repeated")
72 | }
73 | }
74 |
75 | func (h *httpMux) handleGet(writer http.ResponseWriter, request *http.Request) {
76 | upgrade := false
77 | for _, connHead := range strings.Split(request.Header.Get("Connection"), ",") {
78 | if strings.ToLower(strings.TrimSpace(connHead)) == "upgrade" {
79 | upgrade = true
80 | break
81 | }
82 | }
83 | if upgrade &&
84 | strings.ToLower(request.Header.Get("Upgrade")) == "websocket" {
85 | h.handleWebsocket(writer, request)
86 | } else if strings.ToLower(request.Header.Get("Accept")) == "text/event-stream" {
87 | h.handleServerSentEvent(writer, request)
88 | } else {
89 | writer.WriteHeader(http.StatusBadRequest)
90 | }
91 | }
92 |
93 | func (h *httpMux) handleServerSentEvent(writer http.ResponseWriter, request *http.Request) {
94 | connectionID := request.URL.Query().Get("id")
95 | if connectionID == "" {
96 | writer.WriteHeader(http.StatusBadRequest)
97 | return
98 | }
99 | h.mx.RLock()
100 | c, ok := h.connectionMap[connectionID]
101 | h.mx.RUnlock()
102 | if ok {
103 | if _, ok := c.(*negotiateConnection); ok {
104 | ctx, _ := onecontext.Merge(h.server.context(), request.Context())
105 | sseConn, jobChan, jobResultChan, err := newServerSSEConnection(ctx, c.ConnectionID())
106 | if err != nil {
107 | writer.WriteHeader(http.StatusInternalServerError)
108 | return
109 | }
110 | flusher, ok := writer.(http.Flusher)
111 | if !ok {
112 | writer.WriteHeader(http.StatusInternalServerError)
113 | return
114 | }
115 | // Connection is negotiated but not initiated
116 | // We compose http and send it over sse
117 | writer.Header().Set("Content-Type", "text/event-stream")
118 | writer.Header().Set("Connection", "keep-alive")
119 | writer.Header().Set("Cache-Control", "no-cache")
120 | writer.WriteHeader(http.StatusOK)
121 | // End this Server Sent Event (yes, your response now is one and the client will wait for this initial event to end)
122 | _, _ = fmt.Fprint(writer, ":\r\n\r\n")
123 | writer.(http.Flusher).Flush()
124 | go func() {
125 | // We can't WriteHeader 500 if we get an error as we already wrote the header, so ignore it.
126 | _ = h.serveConnection(sseConn)
127 | }()
128 | // Loop for write jobs from the sseServerConnection
129 | for buf := range jobChan {
130 | n, err := writer.Write(buf)
131 | if err == nil {
132 | flusher.Flush()
133 | }
134 | jobResultChan <- RWJobResult{n: n, err: err}
135 | }
136 | close(jobResultChan)
137 | } else {
138 | // connectionID in use
139 | writer.WriteHeader(http.StatusConflict)
140 | }
141 | } else {
142 | writer.WriteHeader(http.StatusNotFound)
143 | }
144 | }
145 |
146 | func (h *httpMux) handleWebsocket(writer http.ResponseWriter, request *http.Request) {
147 | accOptions := &websocket.AcceptOptions{
148 | CompressionMode: websocket.CompressionContextTakeover,
149 | InsecureSkipVerify: h.server.insecureSkipVerify(),
150 | OriginPatterns: h.server.originPatterns(),
151 | }
152 | websocketConn, err := websocket.Accept(writer, request, accOptions)
153 | if err != nil {
154 | _, debug := h.server.loggers()
155 | _ = debug.Log(evt, "handleWebsocket", msg, "error accepting websockets", "error", err)
156 | // don't need to write an error header here as websocket.Accept has already used http.Error
157 | return
158 | }
159 | websocketConn.SetReadLimit(int64(h.server.maximumReceiveMessageSize()))
160 | connectionMapKey := request.URL.Query().Get("id")
161 | if connectionMapKey == "" {
162 | // Support websocket connection without negotiate
163 | connectionMapKey = newConnectionID()
164 | h.mx.Lock()
165 | h.connectionMap[connectionMapKey] = &negotiateConnection{
166 | ConnectionBase{connectionID: connectionMapKey},
167 | }
168 | h.mx.Unlock()
169 | }
170 | h.mx.RLock()
171 | c, ok := h.connectionMap[connectionMapKey]
172 | h.mx.RUnlock()
173 | if ok {
174 | if _, ok := c.(*negotiateConnection); ok {
175 | // Connection is negotiated but not initiated
176 | ctx, _ := onecontext.Merge(h.server.context(), request.Context())
177 | err = h.serveConnection(newWebSocketConnection(ctx, c.ConnectionID(), websocketConn))
178 | if err != nil {
179 | _ = websocketConn.Close(1005, err.Error())
180 | }
181 | } else {
182 | // Already initiated
183 | _ = websocketConn.Close(1002, "Bad request")
184 | }
185 | } else {
186 | // Not negotiated
187 | _ = websocketConn.Close(1002, "Not found")
188 | }
189 | }
190 |
191 | func (h *httpMux) negotiate(w http.ResponseWriter, req *http.Request) {
192 | if req.Method != "POST" {
193 | w.WriteHeader(http.StatusBadRequest)
194 | } else {
195 | connectionID := newConnectionID()
196 | connectionMapKey := connectionID
197 | negotiateVersion, err := strconv.Atoi(req.Header.Get("negotiateVersion"))
198 | if err != nil {
199 | negotiateVersion = 0
200 | }
201 | connectionToken := ""
202 | if negotiateVersion == 1 {
203 | connectionToken = newConnectionID()
204 | connectionMapKey = connectionToken
205 | }
206 | h.mx.Lock()
207 | h.connectionMap[connectionMapKey] = &negotiateConnection{
208 | ConnectionBase{connectionID: connectionID},
209 | }
210 | h.mx.Unlock()
211 | var availableTransports []availableTransport
212 | for _, transport := range h.server.availableTransports() {
213 | switch transport {
214 | case "ServerSentEvents":
215 | availableTransports = append(availableTransports,
216 | availableTransport{
217 | Transport: "ServerSentEvents",
218 | TransferFormats: []string{"Text"},
219 | })
220 | case "WebSockets":
221 | availableTransports = append(availableTransports,
222 | availableTransport{
223 | Transport: "WebSockets",
224 | TransferFormats: []string{"Text", "Binary"},
225 | })
226 | }
227 | }
228 | response := negotiateResponse{
229 | ConnectionToken: connectionToken,
230 | ConnectionID: connectionID,
231 | NegotiateVersion: negotiateVersion,
232 | AvailableTransports: availableTransports,
233 | }
234 |
235 | w.WriteHeader(http.StatusOK)
236 | _ = json.NewEncoder(w).Encode(response) // Can't imagine an error when encoding
237 | }
238 | }
239 |
240 | func (h *httpMux) serveConnection(c Connection) error {
241 | h.mx.Lock()
242 | h.connectionMap[c.ConnectionID()] = c
243 | h.mx.Unlock()
244 | return h.server.Serve(c)
245 | }
246 |
247 | func newConnectionID() string {
248 | bytes := make([]byte, 16)
249 | // rand.Read only fails when the systems random number generator fails. Rare case, ignore
250 | _, _ = rand.Read(bytes)
251 | // Important: Use URLEncoding. StdEncoding contains "/" which will be randomly part of the connectionID and cause parsing problems
252 | return base64.URLEncoding.EncodeToString(bytes)
253 | }
254 |
255 | type negotiateConnection struct {
256 | ConnectionBase
257 | }
258 |
259 | func (n *negotiateConnection) Read([]byte) (int, error) {
260 | return 0, nil
261 | }
262 |
263 | func (n *negotiateConnection) Write([]byte) (int, error) {
264 | return 0, nil
265 | }
266 |
--------------------------------------------------------------------------------
/httpserver_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "net"
10 | "net/http"
11 | "net/http/httptest"
12 | "net/url"
13 | "strconv"
14 |
15 | "strings"
16 | "time"
17 |
18 | "github.com/go-kit/log/level"
19 | . "github.com/onsi/ginkgo"
20 | . "github.com/onsi/gomega"
21 | "nhooyr.io/websocket"
22 | )
23 |
24 | type addHub struct {
25 | Hub
26 | }
27 |
28 | func (w *addHub) Add2(i int) int {
29 | return i + 2
30 | }
31 |
32 | func (w *addHub) Echo(s string) string {
33 | return s
34 | }
35 |
36 | var _ = Describe("HTTP server", func() {
37 | for _, transport := range [][]string{
38 | {"WebSockets", "Text"},
39 | {"WebSockets", "Binary"},
40 | {"ServerSentEvents", "Text"},
41 | } {
42 | transport := transport
43 | Context(fmt.Sprintf("%v %v", transport[0], transport[1]), func() {
44 | Context("A correct negotiation request is sent", func() {
45 | It(fmt.Sprintf("should send a correct negotiation response with support for %v with text protocol", transport), func(done Done) {
46 | // Start server
47 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]), testLoggerOption())
48 | Expect(err).NotTo(HaveOccurred())
49 | router := http.NewServeMux()
50 | server.MapHTTP(WithHTTPServeMux(router), "/hub")
51 | testServer := httptest.NewServer(router)
52 | url, _ := url.Parse(testServer.URL)
53 | port, _ := strconv.Atoi(url.Port())
54 | // Negotiate
55 | negResp := negotiateWebSocketTestServer(port)
56 | Expect(negResp["connectionId"]).NotTo(BeNil())
57 | Expect(negResp["availableTransports"]).To(BeAssignableToTypeOf([]interface{}{}))
58 | avt := negResp["availableTransports"].([]interface{})
59 | Expect(len(avt)).To(BeNumerically(">", 0))
60 | Expect(avt[0]).To(BeAssignableToTypeOf(map[string]interface{}{}))
61 | avtVal := avt[0].(map[string]interface{})
62 | Expect(avtVal["transport"]).To(Equal(transport[0]))
63 | Expect(avtVal["transferFormats"]).To(BeAssignableToTypeOf([]interface{}{}))
64 | tf := avtVal["transferFormats"].([]interface{})
65 | Expect(tf).To(ContainElement("Text"))
66 | if transport[0] == "WebSockets" {
67 | Expect(tf).To(ContainElement("Binary"))
68 | }
69 | testServer.Close()
70 | close(done)
71 | }, 2.0)
72 | })
73 |
74 | Context("A invalid negotiation request is sent", func() {
75 | It(fmt.Sprintf("should send a correct negotiation response with support for %v with text protocol", transport), func(done Done) {
76 | // Start server
77 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]), testLoggerOption())
78 | Expect(err).NotTo(HaveOccurred())
79 | router := http.NewServeMux()
80 | server.MapHTTP(WithHTTPServeMux(router), "/hub")
81 | testServer := httptest.NewServer(router)
82 | url, _ := url.Parse(testServer.URL)
83 | port, _ := strconv.Atoi(url.Port())
84 | waitForPort(port)
85 | // Negotiate the wrong way
86 | resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port))
87 | Expect(err).To(BeNil())
88 | Expect(resp).NotTo(BeNil())
89 | Expect(resp.StatusCode).ToNot(Equal(200))
90 | testServer.Close()
91 | close(done)
92 | }, 2.0)
93 | })
94 |
95 | Context("Connection with client", func() {
96 | It("should successfully handle an Invoke call", func(done Done) {
97 | logger := &nonProtocolLogger{testLogger()}
98 | // Start server
99 | ctx, cancel := context.WithCancel(context.Background())
100 | server, err := NewServer(ctx,
101 | SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]),
102 | MaximumReceiveMessageSize(50000),
103 | Logger(logger, true))
104 | Expect(err).NotTo(HaveOccurred())
105 | router := http.NewServeMux()
106 | server.MapHTTP(WithHTTPServeMux(router), "/hub")
107 | testServer := httptest.NewServer(router)
108 | url, _ := url.Parse(testServer.URL)
109 | port, _ := strconv.Atoi(url.Port())
110 | waitForPort(port)
111 | // Try first connection
112 | conn, err := NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port))
113 | Expect(err).NotTo(HaveOccurred())
114 | client, err := NewClient(ctx,
115 | WithConnection(conn),
116 | MaximumReceiveMessageSize(60000),
117 | Logger(logger, true),
118 | TransferFormat(transport[1]))
119 | Expect(err).NotTo(HaveOccurred())
120 | Expect(client).NotTo(BeNil())
121 | client.Start()
122 | Expect(<-client.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
123 | result := <-client.Invoke("Add2", 1)
124 | Expect(result.Error).NotTo(HaveOccurred())
125 | Expect(result.Value).To(BeEquivalentTo(3))
126 |
127 | // Try second connection
128 | conn2, err := NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port))
129 | Expect(err).NotTo(HaveOccurred())
130 | client2, err := NewClient(ctx,
131 | WithConnection(conn2),
132 | Logger(logger, true),
133 | TransferFormat(transport[1]))
134 | Expect(err).NotTo(HaveOccurred())
135 | Expect(client2).NotTo(BeNil())
136 | client2.Start()
137 | Expect(<-client2.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
138 | result = <-client2.Invoke("Add2", 2)
139 | Expect(result.Error).NotTo(HaveOccurred())
140 | Expect(result.Value).To(BeEquivalentTo(4))
141 | // Huge message
142 | hugo := strings.Repeat("#", 2500)
143 | result = <-client.Invoke("Echo", hugo)
144 | Expect(result.Error).NotTo(HaveOccurred())
145 | s := result.Value.(string)
146 | Expect(s).To(Equal(hugo))
147 | cancel()
148 | go testServer.Close()
149 | close(done)
150 | }, 2.0)
151 | })
152 | })
153 | }
154 | Context("When no negotiation is send", func() {
155 | It("should serve websocket requests", func(done Done) {
156 | // Start server
157 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports("WebSockets"), testLoggerOption())
158 | Expect(err).NotTo(HaveOccurred())
159 | router := http.NewServeMux()
160 | server.MapHTTP(WithHTTPServeMux(router), "/hub")
161 | testServer := httptest.NewServer(router)
162 | url, _ := url.Parse(testServer.URL)
163 | port, _ := strconv.Atoi(url.Port())
164 | waitForPort(port)
165 | handShakeAndCallWebSocketTestServer(port, "")
166 | testServer.Close()
167 | close(done)
168 | }, 5.0)
169 | })
170 | })
171 |
172 | type nonProtocolLogger struct {
173 | logger StructuredLogger
174 | }
175 |
176 | func (n *nonProtocolLogger) Log(keyVals ...interface{}) error {
177 | for _, kv := range keyVals {
178 | if kv == "protocol" {
179 | return nil
180 | }
181 | }
182 | return n.logger.Log(keyVals...)
183 | }
184 |
185 | func negotiateWebSocketTestServer(port int) map[string]interface{} {
186 | waitForPort(port)
187 | buf := bytes.Buffer{}
188 | resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port), "text/plain;charset=UTF-8", &buf)
189 | Expect(err).To(BeNil())
190 | Expect(resp).ToNot(BeNil())
191 | defer func() {
192 | _ = resp.Body.Close()
193 | }()
194 | var body []byte
195 | body, err = io.ReadAll(resp.Body)
196 | Expect(err).To(BeNil())
197 | response := make(map[string]interface{})
198 | err = json.Unmarshal(body, &response)
199 | Expect(err).To(BeNil())
200 | return response
201 | }
202 |
203 | func handShakeAndCallWebSocketTestServer(port int, connectionID string) {
204 | waitForPort(port)
205 | logger := testLogger()
206 | protocol := jsonHubProtocol{}
207 | protocol.setDebugLogger(level.Debug(logger))
208 | var urlParam string
209 | if connectionID != "" {
210 | urlParam = fmt.Sprintf("?id=%v", connectionID)
211 | }
212 | ws, _, err := websocket.Dial(context.Background(), fmt.Sprintf("ws://127.0.0.1:%v/hub%v", port, urlParam), nil)
213 | Expect(err).To(BeNil())
214 | defer func() {
215 | _ = ws.Close(websocket.StatusNormalClosure, "")
216 | }()
217 | wsConn := newWebSocketConnection(context.TODO(), connectionID, ws)
218 | cliConn := newHubConnection(wsConn, &protocol, 1<<15, testLogger())
219 | _, _ = wsConn.Write(append([]byte(`{"protocol": "json","version": 1}`), 30))
220 | _, _ = wsConn.Write(append([]byte(`{"type":1,"invocationId":"666","target":"add2","arguments":[1]}`), 30))
221 | result := make(chan interface{})
222 | go func() {
223 | for recvResult := range cliConn.Receive() {
224 | if completionMessage, ok := recvResult.message.(completionMessage); ok {
225 | result <- completionMessage.Result
226 | return
227 | }
228 | }
229 | }()
230 | select {
231 | case r := <-result:
232 | var f float64
233 | Expect(protocol.UnmarshalArgument(r, &f)).NotTo(HaveOccurred())
234 | Expect(f).To(Equal(3.0))
235 | case <-time.After(1000 * time.Millisecond):
236 | Fail("timed out")
237 | }
238 | }
239 |
240 | func waitForPort(port int) {
241 | for {
242 | if _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", port)); err == nil {
243 | return
244 | }
245 | time.Sleep(100 * time.Millisecond)
246 | }
247 | }
248 |
--------------------------------------------------------------------------------
/hub.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "sync"
6 | )
7 |
8 | // HubInterface is a hubs interface
9 | type HubInterface interface {
10 | Initialize(hubContext HubContext)
11 | OnConnected(connectionID string)
12 | OnDisconnected(connectionID string)
13 | }
14 |
15 | // Hub is a base class for hubs
16 | type Hub struct {
17 | context HubContext
18 | cm sync.RWMutex
19 | }
20 |
21 | // Initialize initializes a hub with a HubContext
22 | func (h *Hub) Initialize(ctx HubContext) {
23 | h.cm.Lock()
24 | defer h.cm.Unlock()
25 | h.context = ctx
26 | }
27 |
28 | // Clients returns the clients of this hub
29 | func (h *Hub) Clients() HubClients {
30 | h.cm.RLock()
31 | defer h.cm.RUnlock()
32 | return h.context.Clients()
33 | }
34 |
35 | // Groups returns the client groups of this hub
36 | func (h *Hub) Groups() GroupManager {
37 | h.cm.RLock()
38 | defer h.cm.RUnlock()
39 | return h.context.Groups()
40 | }
41 |
42 | // Items returns the items for this connection
43 | func (h *Hub) Items() *sync.Map {
44 | h.cm.RLock()
45 | defer h.cm.RUnlock()
46 | return h.context.Items()
47 | }
48 |
49 | // ConnectionID gets the ID of the current connection
50 | func (h *Hub) ConnectionID() string {
51 | h.cm.RLock()
52 | defer h.cm.RUnlock()
53 | return h.context.ConnectionID()
54 | }
55 |
56 | // Context is the context.Context of the current connection
57 | func (h *Hub) Context() context.Context {
58 | h.cm.RLock()
59 | defer h.cm.RUnlock()
60 | return h.context.Context()
61 | }
62 |
63 | // Abort aborts the current connection
64 | func (h *Hub) Abort() {
65 | h.cm.RLock()
66 | defer h.cm.RUnlock()
67 | h.context.Abort()
68 | }
69 |
70 | // Logger returns the loggers used in this server. By this, derived hubs can use the same loggers as the server.
71 | func (h *Hub) Logger() (info StructuredLogger, dbg StructuredLogger) {
72 | h.cm.RLock()
73 | defer h.cm.RUnlock()
74 | return h.context.Logger()
75 | }
76 |
77 | // OnConnected is called when the hub is connected
78 | func (h *Hub) OnConnected(string) {}
79 |
80 | // OnDisconnected is called when the hub is disconnected
81 | func (h *Hub) OnDisconnected(string) {}
82 |
--------------------------------------------------------------------------------
/hubclients.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | // HubClients gives the hub access to various client groups
4 | // All() gets a ClientProxy that can be used to invoke methods on all clients connected to the hub
5 | // Caller() gets a ClientProxy that can be used to invoke methods of the current calling client
6 | // Client() gets a ClientProxy that can be used to invoke methods on the specified client connection
7 | // Group() gets a ClientProxy that can be used to invoke methods on all connections in the specified group
8 | type HubClients interface {
9 | All() ClientProxy
10 | Caller() ClientProxy
11 | Client(connectionID string) ClientProxy
12 | Group(groupName string) ClientProxy
13 | }
14 |
15 | type defaultHubClients struct {
16 | lifetimeManager HubLifetimeManager
17 | allCache allClientProxy
18 | }
19 |
20 | func (c *defaultHubClients) All() ClientProxy {
21 | return &c.allCache
22 | }
23 |
24 | func (c *defaultHubClients) Client(connectionID string) ClientProxy {
25 | return &singleClientProxy{connectionID: connectionID, lifetimeManager: c.lifetimeManager}
26 | }
27 |
28 | func (c *defaultHubClients) Group(groupName string) ClientProxy {
29 | return &groupClientProxy{groupName: groupName, lifetimeManager: c.lifetimeManager}
30 | }
31 |
32 | // Caller is only implemented to fulfill the HubClients interface, so the servers defaultHubClients interface can be
33 | // used for implementing Server.HubClients.
34 | func (c *defaultHubClients) Caller() ClientProxy {
35 | return nil
36 | }
37 |
38 | type callerHubClients struct {
39 | defaultHubClients *defaultHubClients
40 | connectionID string
41 | }
42 |
43 | func (c *callerHubClients) All() ClientProxy {
44 | return c.defaultHubClients.All()
45 | }
46 |
47 | func (c *callerHubClients) Caller() ClientProxy {
48 | return c.defaultHubClients.Client(c.connectionID)
49 | }
50 |
51 | func (c *callerHubClients) Client(connectionID string) ClientProxy {
52 | return c.defaultHubClients.Client(connectionID)
53 | }
54 |
55 | func (c *callerHubClients) Group(groupName string) ClientProxy {
56 | return c.defaultHubClients.Group(groupName)
57 | }
58 |
--------------------------------------------------------------------------------
/hubconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "io"
8 | "sync"
9 | "time"
10 | )
11 |
12 | // hubConnection is used by HubContext, Server and Client to realize the external API.
13 | // hubConnection uses a transport connection (of type Connection) and a hubProtocol to send and receive SignalR messages.
14 | type hubConnection interface {
15 | ConnectionID() string
16 | Receive() <-chan receiveResult
17 | SendInvocation(id string, target string, args []interface{}) error
18 | SendStreamInvocation(id string, target string, args []interface{}) error
19 | SendInvocationWithStreamIds(id string, target string, args []interface{}, streamIds []string) error
20 | StreamItem(id string, item interface{}) error
21 | Completion(id string, result interface{}, error string) error
22 | Close(error string, allowReconnect bool) error
23 | Ping() error
24 | LastWriteStamp() time.Time
25 | Items() *sync.Map
26 | Context() context.Context
27 | Abort()
28 | }
29 |
30 | type receiveResult struct {
31 | message interface{}
32 | err error
33 | }
34 |
35 | func newHubConnection(connection Connection, protocol hubProtocol, maximumReceiveMessageSize uint, info StructuredLogger) hubConnection {
36 | ctx, cancelFunc := context.WithCancel(connection.Context())
37 | c := &defaultHubConnection{
38 | ctx: ctx,
39 | cancelFunc: cancelFunc,
40 | protocol: protocol,
41 | mx: sync.Mutex{},
42 | connection: connection,
43 | maximumReceiveMessageSize: maximumReceiveMessageSize,
44 | items: &sync.Map{},
45 | info: info,
46 | }
47 | if connectionWithTransferMode, ok := connection.(ConnectionWithTransferMode); ok {
48 | connectionWithTransferMode.SetTransferMode(protocol.transferMode())
49 | }
50 | return c
51 | }
52 |
53 | type defaultHubConnection struct {
54 | ctx context.Context
55 | cancelFunc context.CancelFunc
56 | protocol hubProtocol
57 | mx sync.Mutex
58 | connection Connection
59 | maximumReceiveMessageSize uint
60 | items *sync.Map
61 | lastWriteStamp time.Time
62 | info StructuredLogger
63 | }
64 |
65 | func (c *defaultHubConnection) Items() *sync.Map {
66 | return c.items
67 | }
68 |
69 | func (c *defaultHubConnection) Close(errorText string, allowReconnect bool) error {
70 | var closeMessage = closeMessage{
71 | Type: 7,
72 | Error: errorText,
73 | AllowReconnect: allowReconnect,
74 | }
75 | return c.protocol.WriteMessage(closeMessage, c.connection)
76 | }
77 |
78 | func (c *defaultHubConnection) ConnectionID() string {
79 | return c.connection.ConnectionID()
80 | }
81 |
82 | func (c *defaultHubConnection) Context() context.Context {
83 | return c.ctx
84 | }
85 |
86 | func (c *defaultHubConnection) Abort() {
87 | c.cancelFunc()
88 | }
89 |
90 | func (c *defaultHubConnection) Receive() <-chan receiveResult {
91 | recvChan := make(chan receiveResult, 20)
92 | // Prepare cleanup
93 | writerDone := make(chan struct{}, 1)
94 | // the pipe connects the goroutine which reads from the connection and the goroutine which parses the read data
95 | reader, writer := CtxPipe(c.ctx)
96 | p := make([]byte, c.maximumReceiveMessageSize)
97 | go func(ctx context.Context, connection io.Reader, writer io.Writer, recvChan chan<- receiveResult, writerDone chan<- struct{}) {
98 | loop:
99 | for {
100 | select {
101 | case <-ctx.Done():
102 | break loop
103 | default:
104 | n, err := connection.Read(p)
105 | if err != nil {
106 | select {
107 | case recvChan <- receiveResult{err: err}:
108 | case <-ctx.Done():
109 | break loop
110 | }
111 | }
112 | if n > 0 {
113 | _, err = writer.Write(p[:n])
114 | if err != nil {
115 | select {
116 | case recvChan <- receiveResult{err: err}:
117 | case <-ctx.Done():
118 | break loop
119 | }
120 | }
121 | }
122 | }
123 | }
124 | // The pipe writer is done
125 | close(writerDone)
126 | }(c.ctx, c.connection, writer, recvChan, writerDone)
127 | // parse
128 | go func(ctx context.Context, reader io.Reader, recvChan chan<- receiveResult, writerDone <-chan struct{}) {
129 | remainBuf := bytes.Buffer{}
130 | loop:
131 | for {
132 | select {
133 | case <-ctx.Done():
134 | break loop
135 | case <-writerDone:
136 | break loop
137 | default:
138 | messages, err := c.protocol.ParseMessages(reader, &remainBuf)
139 | if err != nil {
140 | select {
141 | case recvChan <- receiveResult{err: err}:
142 | case <-ctx.Done():
143 | break loop
144 | case <-writerDone:
145 | break loop
146 | }
147 | } else {
148 | for _, message := range messages {
149 | select {
150 | case recvChan <- receiveResult{message: message}:
151 | case <-ctx.Done():
152 | break loop
153 | case <-writerDone:
154 | break loop
155 | }
156 | }
157 | }
158 | }
159 | }
160 | }(c.ctx, reader, recvChan, writerDone)
161 | return recvChan
162 | }
163 |
164 | func (c *defaultHubConnection) SendInvocation(id string, target string, args []interface{}) error {
165 | if args == nil {
166 | args = make([]interface{}, 0)
167 | }
168 | var invocationMessage = invocationMessage{
169 | Type: 1,
170 | InvocationID: id,
171 | Target: target,
172 | Arguments: args,
173 | }
174 | return c.writeMessage(invocationMessage)
175 | }
176 |
177 | func (c *defaultHubConnection) SendStreamInvocation(id string, target string, args []interface{}) error {
178 | if args == nil {
179 | args = make([]interface{}, 0)
180 | }
181 | var invocationMessage = invocationMessage{
182 | Type: 4,
183 | InvocationID: id,
184 | Target: target,
185 | Arguments: args,
186 | }
187 | return c.writeMessage(invocationMessage)
188 | }
189 |
190 | func (c *defaultHubConnection) SendInvocationWithStreamIds(id string, target string, args []interface{}, streamIds []string) error {
191 | var invocationMessage = invocationMessage{
192 | Type: 1,
193 | InvocationID: id,
194 | Target: target,
195 | Arguments: args,
196 | StreamIds: streamIds,
197 | }
198 | return c.writeMessage(invocationMessage)
199 | }
200 |
201 | func (c *defaultHubConnection) StreamItem(id string, item interface{}) error {
202 | var streamItemMessage = streamItemMessage{
203 | Type: 2,
204 | InvocationID: id,
205 | Item: item,
206 | }
207 | return c.writeMessage(streamItemMessage)
208 | }
209 |
210 | func (c *defaultHubConnection) Completion(id string, result interface{}, error string) error {
211 | var completionMessage = completionMessage{
212 | Type: 3,
213 | InvocationID: id,
214 | Result: result,
215 | Error: error,
216 | }
217 | return c.writeMessage(completionMessage)
218 | }
219 |
220 | func (c *defaultHubConnection) Ping() error {
221 | var pingMessage = hubMessage{
222 | Type: 6,
223 | }
224 | return c.writeMessage(pingMessage)
225 | }
226 |
227 | func (c *defaultHubConnection) LastWriteStamp() time.Time {
228 | defer c.mx.Unlock()
229 | c.mx.Lock()
230 | return c.lastWriteStamp
231 | }
232 |
233 | func (c *defaultHubConnection) writeMessage(message interface{}) error {
234 | c.mx.Lock()
235 | c.lastWriteStamp = time.Now()
236 | c.mx.Unlock()
237 | err := func() error {
238 | if c.ctx.Err() != nil {
239 | return fmt.Errorf("hubConnection canceled: %w", c.ctx.Err())
240 | }
241 | e := make(chan error, 1)
242 | go func() { e <- c.protocol.WriteMessage(message, c.connection) }()
243 | select {
244 | case <-c.ctx.Done():
245 | return fmt.Errorf("hubConnection canceled: %w", c.ctx.Err())
246 | case err := <-e:
247 | if err != nil {
248 | c.Abort()
249 | }
250 | return err
251 | }
252 | }()
253 | if err != nil {
254 | _ = c.info.Log(evt, msgSend, "message", fmtMsg(message), "error", err)
255 | }
256 | return err
257 | }
258 |
--------------------------------------------------------------------------------
/hubcontext.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "sync"
6 | )
7 |
8 | // HubContext is a context abstraction for a hub
9 | // Clients gets a HubClients that can be used to invoke methods on clients connected to the hub
10 | // Groups gets a GroupManager that can be used to add and remove connections to named groups
11 | // Items holds key/value pairs scoped to the hubs connection
12 | // ConnectionID gets the ID of the current connection
13 | // Abort aborts the current connection
14 | // Logger returns the logger used in this server
15 | type HubContext interface {
16 | Clients() HubClients
17 | Groups() GroupManager
18 | Items() *sync.Map
19 | ConnectionID() string
20 | Context() context.Context
21 | Abort()
22 | Logger() (info StructuredLogger, dbg StructuredLogger)
23 | }
24 |
25 | type connectionHubContext struct {
26 | abort context.CancelFunc
27 | connection hubConnection
28 | clients HubClients
29 | groups GroupManager
30 | info StructuredLogger
31 | dbg StructuredLogger
32 | }
33 |
34 | func (c *connectionHubContext) Clients() HubClients {
35 | return c.clients
36 | }
37 |
38 | func (c *connectionHubContext) Groups() GroupManager {
39 | return c.groups
40 | }
41 |
42 | func (c *connectionHubContext) Items() *sync.Map {
43 | return c.connection.Items()
44 | }
45 |
46 | func (c *connectionHubContext) ConnectionID() string {
47 | return c.connection.ConnectionID()
48 | }
49 |
50 | func (c *connectionHubContext) Context() context.Context {
51 | return c.connection.Context()
52 | }
53 |
54 | func (c *connectionHubContext) Abort() {
55 | c.abort()
56 | }
57 |
58 | func (c *connectionHubContext) Logger() (info StructuredLogger, dbg StructuredLogger) {
59 | return c.info, c.dbg
60 | }
61 |
--------------------------------------------------------------------------------
/hubcontext_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "sync"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/assert"
12 |
13 | . "github.com/onsi/ginkgo"
14 | . "github.com/onsi/gomega"
15 | )
16 |
17 | type contextHub struct {
18 | Hub
19 | }
20 |
21 | func (c *contextHub) OnConnected(string) {
22 | }
23 |
24 | func (c *contextHub) CallAll() {
25 | c.Clients().All().Send("clientFunc")
26 | }
27 |
28 | func (c *contextHub) CallCaller() {
29 | c.Clients().Caller().Send("clientFunc")
30 | }
31 |
32 | func (c *contextHub) CallClient(connectionID string) {
33 | c.Clients().Client(connectionID).Send("clientFunc")
34 | }
35 |
36 | func (c *contextHub) BuildGroup(connectionID1 string, connectionID2 string) {
37 | c.Groups().AddToGroup("local", connectionID1)
38 | c.Groups().AddToGroup("local", connectionID2)
39 | }
40 |
41 | func (c *contextHub) RemoveFromGroup(connectionID string) {
42 | c.Groups().RemoveFromGroup("local", connectionID)
43 | }
44 |
45 | func (c *contextHub) CallGroup() {
46 | c.Clients().Group("local").Send("clientFunc")
47 | }
48 |
49 | func (c *contextHub) AddItem(key string, value interface{}) {
50 | c.Items().Store(key, value)
51 | }
52 |
53 | func (c *contextHub) GetItem(key string) interface{} {
54 | if item, ok := c.Items().Load(key); ok {
55 | return item
56 | }
57 | return nil
58 | }
59 |
60 | func (c *contextHub) TestConnectionID() {
61 | }
62 |
63 | func (c *contextHub) Abort() {
64 | c.Hub.Abort()
65 | }
66 |
67 | type SimpleReceiver struct {
68 | ch chan struct{}
69 | }
70 |
71 | func (sr *SimpleReceiver) ClientFunc() {
72 | close(sr.ch)
73 | }
74 |
75 | func makeTCPServerAndClients(ctx context.Context, clientCount int) (Server, []Client, []*SimpleReceiver, []Connection, []Connection, error) {
76 | server, err := NewServer(ctx, SimpleHubFactory(&contextHub{}), testLoggerOption())
77 | if err != nil {
78 | return nil, nil, nil, nil, nil, err
79 | }
80 | cliConn := make([]Connection, clientCount)
81 | srvConn := make([]Connection, clientCount)
82 | receiver := make([]*SimpleReceiver, clientCount)
83 | client := make([]Client, clientCount)
84 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
85 | if err != nil {
86 | return nil, nil, nil, nil, nil, err
87 | }
88 | for i := 0; i < clientCount; i++ {
89 | listener, err := net.ListenTCP("tcp", addr)
90 | go func(i int) {
91 | for {
92 | tcpConn, _ := listener.Accept()
93 | conn := NewNetConnection(ctx, tcpConn)
94 | conn.SetConnectionID(fmt.Sprint(i))
95 | srvConn[i] = conn
96 | go func() { _ = server.Serve(conn) }()
97 | break
98 | }
99 | }(i)
100 | tcpConn, err := net.Dial("tcp",
101 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port))
102 | if err != nil {
103 | return nil, nil, nil, nil, nil, err
104 | }
105 | cliConn[i] = NewNetConnection(ctx, tcpConn)
106 | receiver[i] = &SimpleReceiver{ch: make(chan struct{})}
107 | client[i], err = NewClient(ctx, WithConnection(cliConn[i]), WithReceiver(receiver[i]), TransferFormat("Text"), testLoggerOption())
108 | if err != nil {
109 | return nil, nil, nil, nil, nil, err
110 | }
111 | client[i].Start()
112 | select {
113 | case err := <-client[i].WaitForState(ctx, ClientConnected):
114 | if err != nil {
115 | return nil, nil, nil, nil, nil, err
116 | }
117 | case <-ctx.Done():
118 | return nil, nil, nil, nil, nil, ctx.Err()
119 | }
120 | }
121 | return server, client, receiver, srvConn, cliConn, nil
122 | }
123 |
124 | func makePipeClientsAndReceivers() ([]Client, []*SimpleReceiver, context.CancelFunc) {
125 | cliConn := make([]*pipeConnection, 3)
126 | srvConn := make([]*pipeConnection, 3)
127 | for i := 0; i < 3; i++ {
128 | cliConn[i], srvConn[i] = newClientServerConnections()
129 | cliConn[i].SetConnectionID(fmt.Sprint(i))
130 | srvConn[i].SetConnectionID(fmt.Sprint(i))
131 | }
132 | ctx, cancel := context.WithCancel(context.Background())
133 | server, _ := NewServer(ctx, SimpleHubFactory(&contextHub{}), testLoggerOption())
134 | var wg sync.WaitGroup
135 | wg.Add(3)
136 | for i := 0; i < 3; i++ {
137 | go func(i int) {
138 | wg.Done()
139 | _ = server.Serve(srvConn[i])
140 | }(i)
141 | }
142 | wg.Wait()
143 | client := make([]Client, 3)
144 | receiver := make([]*SimpleReceiver, 3)
145 | for i := 0; i < 3; i++ {
146 | receiver[i] = &SimpleReceiver{ch: make(chan struct{})}
147 | client[i], _ = NewClient(ctx, WithConnection(cliConn[i]), WithReceiver(receiver[i]), testLoggerOption())
148 | client[i].Start()
149 | <-client[i].WaitForState(ctx, ClientConnected)
150 | }
151 | return client, receiver, cancel
152 | }
153 |
154 | var _ = Describe("HubContext", func() {
155 | for i := 0; i < 10; i++ {
156 | Context("Clients().All()", func() {
157 | It("should invoke all clients", func() {
158 | client, receiver, cancel := makePipeClientsAndReceivers()
159 | r := <-client[0].Invoke("CallAll")
160 | Expect(r.Error).NotTo(HaveOccurred())
161 | result := 0
162 | for result < 3 {
163 | select {
164 | case <-receiver[0].ch:
165 | result++
166 | case <-receiver[1].ch:
167 | result++
168 | case <-receiver[2].ch:
169 | result++
170 | case <-time.After(2 * time.Second):
171 | Fail("timeout waiting for clients getting results")
172 | }
173 | }
174 | cancel()
175 | })
176 | })
177 | Context("Clients().Caller()", func() {
178 | It("should invoke only the caller", func() {
179 | client, receiver, cancel := makePipeClientsAndReceivers()
180 | r := <-client[0].Invoke("CallCaller")
181 | Expect(r.Error).NotTo(HaveOccurred())
182 | select {
183 | case <-receiver[0].ch:
184 | case <-receiver[1].ch:
185 | Fail("Wrong client received message")
186 | case <-receiver[2].ch:
187 | Fail("Wrong client received message")
188 | case <-time.After(2 * time.Second):
189 | Fail("timeout waiting for clients getting results")
190 | }
191 | cancel()
192 | })
193 | })
194 | Context("Clients().Client()", func() {
195 | It("should invoke only the client which was addressed", func() {
196 | client, receiver, cancel := makePipeClientsAndReceivers()
197 | r := <-client[0].Invoke("CallClient", "1")
198 | Expect(r.Error).NotTo(HaveOccurred())
199 | select {
200 | case <-receiver[0].ch:
201 | Fail("Wrong client received message")
202 | case <-receiver[1].ch:
203 | case <-receiver[2].ch:
204 | Fail("Wrong client received message")
205 | case <-time.After(2 * time.Second):
206 | Fail("timeout waiting for clients getting results")
207 | }
208 | cancel()
209 | })
210 | })
211 | }
212 | })
213 |
214 | func TestGroupShouldInvokeOnlyTheClientsInTheGroup(t *testing.T) {
215 | ctx, cancel := context.WithCancel(context.Background())
216 | defer cancel()
217 | _, client, receiver, srvConn, _, err := makeTCPServerAndClients(ctx, 3)
218 | assert.NoError(t, err)
219 | select {
220 | case ir := <-client[0].Invoke("buildgroup", srvConn[1].ConnectionID(), srvConn[2].ConnectionID()):
221 | assert.NoError(t, ir.Error)
222 | case <-time.After(100 * time.Millisecond):
223 | assert.Fail(t, "timeout in invoke")
224 | }
225 | select {
226 | case ir := <-client[0].Invoke("callgroup"):
227 | assert.NoError(t, ir.Error)
228 | case <-time.After(100 * time.Millisecond):
229 | assert.Fail(t, "timeout in invoke")
230 | }
231 | gotCalled := 0
232 | select {
233 | case <-receiver[0].ch:
234 | assert.Fail(t, "client 1 received message for client 2, 3")
235 | case <-receiver[1].ch:
236 | gotCalled++
237 | case <-receiver[2].ch:
238 | gotCalled++
239 | case <-time.After(100 * time.Millisecond):
240 | if gotCalled < 2 {
241 | assert.Fail(t, "timeout without client 2 and 3 got called")
242 | }
243 | }
244 | }
245 |
246 | func TestRemoveClientsShouldRemoveClientsFromTheGroup(t *testing.T) {
247 | ctx, cancel := context.WithCancel(context.Background())
248 | defer cancel()
249 | _, client, receiver, srvConn, _, err := makeTCPServerAndClients(ctx, 3)
250 | assert.NoError(t, err)
251 | select {
252 | case ir := <-client[0].Invoke("buildgroup", srvConn[1].ConnectionID(), srvConn[2].ConnectionID()):
253 | assert.NoError(t, ir.Error)
254 | case <-time.After(100 * time.Millisecond):
255 | assert.Fail(t, "timeout in invoke")
256 | }
257 | select {
258 | case ir := <-client[0].Invoke("removefromgroup", srvConn[2].ConnectionID()):
259 | assert.NoError(t, ir.Error)
260 | case <-time.After(100 * time.Millisecond):
261 | assert.Fail(t, "timeout in invoke")
262 | }
263 | select {
264 | case ir := <-client[0].Invoke("callgroup"):
265 | assert.NoError(t, ir.Error)
266 | case <-time.After(100 * time.Millisecond):
267 | assert.Fail(t, "timeout in invoke")
268 | }
269 | gotCalled := false
270 | select {
271 | case <-receiver[0].ch:
272 | assert.Fail(t, "client 1 received message for client 2")
273 | case <-receiver[1].ch:
274 | gotCalled = true
275 | case <-receiver[2].ch:
276 | assert.Fail(t, "client 3 received message for client 2")
277 | case <-time.After(100 * time.Millisecond):
278 | if !gotCalled {
279 | assert.Fail(t, "timeout without client 3 got called")
280 | }
281 | }
282 | }
283 |
284 | func TestItemsShouldHoldItemsConnectionWise(t *testing.T) {
285 | ctx, cancel := context.WithCancel(context.Background())
286 | defer cancel()
287 | _, client, _, _, _, err := makeTCPServerAndClients(ctx, 2)
288 | assert.NoError(t, err)
289 | select {
290 | case ir := <-client[0].Invoke("additem", "first", 1):
291 | assert.NoError(t, ir.Error)
292 | case <-time.After(100 * time.Millisecond):
293 | assert.Fail(t, "timeout in invoke")
294 | }
295 | select {
296 | case ir := <-client[0].Invoke("getitem", "first"):
297 | assert.NoError(t, ir.Error)
298 | assert.Equal(t, ir.Value, 1.0)
299 | case <-time.After(100 * time.Millisecond):
300 | assert.Fail(t, "timeout in invoke")
301 | }
302 | select {
303 | case ir := <-client[1].Invoke("getitem", "first"):
304 | assert.NoError(t, ir.Error)
305 | assert.Equal(t, ir.Value, nil)
306 | case <-time.After(100 * time.Millisecond):
307 | assert.Fail(t, "timeout in invoke")
308 | }
309 | }
310 |
311 | func TestAbortShouldAbortTheConnectionOfTheCurrentCaller(t *testing.T) {
312 | ctx, cancel := context.WithCancel(context.Background())
313 | defer cancel()
314 | _, client, _, _, _, err := makeTCPServerAndClients(ctx, 2)
315 | assert.NoError(t, err)
316 | select {
317 | case ir := <-client[0].Invoke("abort"):
318 | assert.Error(t, ir.Error)
319 | select {
320 | case err := <-client[0].WaitForState(ctx, ClientClosed):
321 | assert.NoError(t, err)
322 | case <-time.After(500 * time.Millisecond):
323 | assert.Fail(t, "timeout waiting for client close")
324 | }
325 | case <-time.After(100 * time.Millisecond):
326 | assert.Fail(t, "timeout in invoke")
327 | }
328 | select {
329 | case ir := <-client[1].Invoke("additem", "first", 2):
330 | assert.NoError(t, ir.Error)
331 | case <-time.After(100 * time.Millisecond):
332 | assert.Fail(t, "timeout in invoke")
333 | }
334 | select {
335 | case ir := <-client[1].Invoke("getitem", "first"):
336 | assert.NoError(t, ir.Error)
337 | assert.Equal(t, ir.Value, 2.0)
338 | case <-time.After(100 * time.Millisecond):
339 | assert.Fail(t, "timeout in invoke")
340 | }
341 | }
342 |
--------------------------------------------------------------------------------
/hublifetimemanager.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "sync"
5 |
6 | "github.com/go-kit/log"
7 | )
8 |
9 | // HubLifetimeManager is a lifetime manager abstraction for hub instances
10 | // OnConnected() is called when a connection is started
11 | // OnDisconnected() is called when a connection is finished
12 | // InvokeAll() sends an invocation message to all hub connections
13 | // InvokeClient() sends an invocation message to a specified hub connection
14 | // InvokeGroup() sends an invocation message to a specified group of hub connections
15 | // AddToGroup() adds a connection to the specified group
16 | // RemoveFromGroup() removes a connection from the specified group
17 | type HubLifetimeManager interface {
18 | OnConnected(conn hubConnection)
19 | OnDisconnected(conn hubConnection)
20 | InvokeAll(target string, args []interface{})
21 | InvokeClient(connectionID string, target string, args []interface{})
22 | InvokeGroup(groupName string, target string, args []interface{})
23 | AddToGroup(groupName, connectionID string)
24 | RemoveFromGroup(groupName, connectionID string)
25 | }
26 |
27 | func newLifeTimeManager(info StructuredLogger) defaultHubLifetimeManager {
28 | return defaultHubLifetimeManager{
29 | info: log.WithPrefix(info, "ts", log.DefaultTimestampUTC,
30 | "class", "lifeTimeManager"),
31 | }
32 | }
33 |
34 | type defaultHubLifetimeManager struct {
35 | clients sync.Map
36 | groups sync.Map
37 | info StructuredLogger
38 | }
39 |
40 | func (d *defaultHubLifetimeManager) OnConnected(conn hubConnection) {
41 | d.clients.Store(conn.ConnectionID(), conn)
42 | }
43 |
44 | func (d *defaultHubLifetimeManager) OnDisconnected(conn hubConnection) {
45 | d.clients.Delete(conn.ConnectionID())
46 | }
47 |
48 | func (d *defaultHubLifetimeManager) InvokeAll(target string, args []interface{}) {
49 | d.clients.Range(func(key, value interface{}) bool {
50 | go func() {
51 | _ = value.(hubConnection).SendInvocation("", target, args)
52 | }()
53 | return true
54 | })
55 | }
56 |
57 | func (d *defaultHubLifetimeManager) InvokeClient(connectionID string, target string, args []interface{}) {
58 | if client, ok := d.clients.Load(connectionID); ok {
59 | go func() {
60 | _ = client.(hubConnection).SendInvocation("", target, args)
61 | }()
62 | }
63 | }
64 |
65 | func (d *defaultHubLifetimeManager) InvokeGroup(groupName string, target string, args []interface{}) {
66 | if groups, ok := d.groups.Load(groupName); ok {
67 | for _, v := range groups.(map[string]hubConnection) {
68 | conn := v
69 | go func() {
70 | _ = conn.SendInvocation("", target, args)
71 | }()
72 | }
73 | }
74 | }
75 |
76 | func (d *defaultHubLifetimeManager) AddToGroup(groupName string, connectionID string) {
77 | if client, ok := d.clients.Load(connectionID); ok {
78 | groups, _ := d.groups.LoadOrStore(groupName, make(map[string]hubConnection))
79 | groups.(map[string]hubConnection)[connectionID] = client.(hubConnection)
80 | }
81 | }
82 |
83 | func (d *defaultHubLifetimeManager) RemoveFromGroup(groupName string, connectionID string) {
84 | if groups, ok := d.groups.Load(groupName); ok {
85 | delete(groups.(map[string]hubConnection), connectionID)
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/hubprotocol.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | )
7 |
8 | // hubProtocol interface
9 | // ParseMessages() parses messages from an io.Reader and stores unparsed bytes in remainBuf.
10 | // If buf does not contain the whole message, it returns a nil message and complete false
11 | // WriteMessage writes a message to the specified writer
12 | // UnmarshalArgument() unmarshals a raw message depending of the specified value type into a destination value
13 | type hubProtocol interface {
14 | ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) ([]interface{}, error)
15 | WriteMessage(message interface{}, writer io.Writer) error
16 | UnmarshalArgument(src interface{}, dst interface{}) error
17 | setDebugLogger(dbg StructuredLogger)
18 | transferMode() TransferMode
19 | }
20 |
21 | //easyjson:json
22 | type hubMessage struct {
23 | Type int `json:"type"`
24 | }
25 |
26 | // easyjson:json
27 | type invocationMessage struct {
28 | Type int `json:"type"`
29 | Target string `json:"target"`
30 | InvocationID string `json:"invocationId,omitempty"`
31 | Arguments []interface{} `json:"arguments"`
32 | StreamIds []string `json:"streamIds,omitempty"`
33 | }
34 |
35 | //easyjson:json
36 | type completionMessage struct {
37 | Type int `json:"type"`
38 | InvocationID string `json:"invocationId"`
39 | Result interface{} `json:"result,omitempty"`
40 | Error string `json:"error,omitempty"`
41 | }
42 |
43 | //easyjson:json
44 | type streamItemMessage struct {
45 | Type int `json:"type"`
46 | InvocationID string `json:"invocationId"`
47 | Item interface{} `json:"item"`
48 | }
49 |
50 | //easyjson:json
51 | type cancelInvocationMessage struct {
52 | Type int `json:"type"`
53 | InvocationID string `json:"invocationId"`
54 | }
55 |
56 | //easyjson:json
57 | type closeMessage struct {
58 | Type int `json:"type"`
59 | Error string `json:"error"`
60 | AllowReconnect bool `json:"allowReconnect"`
61 | }
62 |
63 | //easyjson:json
64 | type handshakeRequest struct {
65 | Protocol string `json:"protocol"`
66 | Version int `json:"version"`
67 | }
68 |
69 | //easyjson:json
70 | type handshakeResponse struct {
71 | Error string `json:"error,omitempty"`
72 | }
73 |
--------------------------------------------------------------------------------
/invokeclient.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "sync"
7 | "time"
8 | )
9 |
10 | type invokeClient struct {
11 | mx sync.Mutex
12 | resultChans map[string]invocationResultChans
13 | protocol hubProtocol
14 | chanReceiveTimeout time.Duration
15 | }
16 |
17 | func newInvokeClient(protocol hubProtocol, chanReceiveTimeout time.Duration) *invokeClient {
18 | return &invokeClient{
19 | mx: sync.Mutex{},
20 | resultChans: make(map[string]invocationResultChans),
21 | protocol: protocol,
22 | chanReceiveTimeout: chanReceiveTimeout,
23 | }
24 | }
25 |
26 | type invocationResultChans struct {
27 | resultChan chan interface{}
28 | errChan chan error
29 | }
30 |
31 | func (i *invokeClient) newInvocation(id string) (chan interface{}, chan error) {
32 | i.mx.Lock()
33 | r := invocationResultChans{
34 | resultChan: make(chan interface{}, 1),
35 | errChan: make(chan error, 1),
36 | }
37 | i.resultChans[id] = r
38 | i.mx.Unlock()
39 | return r.resultChan, r.errChan
40 | }
41 |
42 | func (i *invokeClient) deleteInvocation(id string) {
43 | i.mx.Lock()
44 | if r, ok := i.resultChans[id]; ok {
45 | delete(i.resultChans, id)
46 | close(r.resultChan)
47 | close(r.errChan)
48 | }
49 | i.mx.Unlock()
50 | }
51 |
52 | func (i *invokeClient) cancelAllInvokes() {
53 | i.mx.Lock()
54 | for _, r := range i.resultChans {
55 | close(r.resultChan)
56 | go func(errChan chan error) {
57 | errChan <- errors.New("message loop ended")
58 | close(errChan)
59 | }(r.errChan)
60 | }
61 | // Clear map
62 | i.resultChans = make(map[string]invocationResultChans)
63 | i.mx.Unlock()
64 | }
65 |
66 | func (i *invokeClient) handlesInvocationID(invocationID string) bool {
67 | i.mx.Lock()
68 | _, ok := i.resultChans[invocationID]
69 | i.mx.Unlock()
70 | return ok
71 | }
72 |
73 | func (i *invokeClient) receiveCompletionItem(completion completionMessage) error {
74 | defer i.deleteInvocation(completion.InvocationID)
75 | i.mx.Lock()
76 | ir, ok := i.resultChans[completion.InvocationID]
77 | i.mx.Unlock()
78 | if ok {
79 | if completion.Error != "" {
80 | done := make(chan struct{})
81 | go func() {
82 | ir.errChan <- errors.New(completion.Error)
83 | done <- struct{}{}
84 | }()
85 | select {
86 | case <-done:
87 | return nil
88 | case <-time.After(i.chanReceiveTimeout):
89 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client sent error", i.chanReceiveTimeout)}
90 | }
91 | }
92 | if completion.Result != nil {
93 | var result interface{}
94 | if err := i.protocol.UnmarshalArgument(completion.Result, &result); err != nil {
95 | return err
96 | }
97 | done := make(chan struct{})
98 | go func() {
99 | ir.resultChan <- result
100 | if completion.Error != "" {
101 | ir.errChan <- errors.New(completion.Error)
102 | } else {
103 | ir.errChan <- nil
104 | }
105 | close(done)
106 | }()
107 | select {
108 | case <-done:
109 | return nil
110 | case <-time.After(i.chanReceiveTimeout):
111 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client sent value", i.chanReceiveTimeout)}
112 | }
113 | }
114 | return nil
115 | }
116 | return fmt.Errorf(`unknown completion id "%v"`, completion.InvocationID)
117 | }
118 |
--------------------------------------------------------------------------------
/jsonhubprotocol.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "reflect"
10 |
11 | "github.com/go-kit/log"
12 | )
13 |
14 | // jsonHubProtocol is the JSON based SignalR protocol
15 | type jsonHubProtocol struct {
16 | dbg log.Logger
17 | }
18 |
19 | // Protocol specific messages for correct unmarshaling of arguments or results.
20 | // jsonInvocationMessage is only used in ParseMessages, not in WriteMessage
21 | type jsonInvocationMessage struct {
22 | Type int `json:"type"`
23 | Target string `json:"target"`
24 | InvocationID string `json:"invocationId"`
25 | Arguments []json.RawMessage `json:"arguments"`
26 | StreamIds []string `json:"streamIds,omitempty"`
27 | }
28 |
29 | type jsonStreamItemMessage struct {
30 | Type int `json:"type"`
31 | InvocationID string `json:"invocationId"`
32 | Item json.RawMessage `json:"item"`
33 | }
34 |
35 | type jsonCompletionMessage struct {
36 | Type int `json:"type"`
37 | InvocationID string `json:"invocationId"`
38 | Result json.RawMessage `json:"result,omitempty"`
39 | Error string `json:"error,omitempty"`
40 | }
41 |
42 | type jsonError struct {
43 | raw string
44 | err error
45 | }
46 |
47 | func (j *jsonError) Error() string {
48 | return fmt.Sprintf("%v (source: %v)", j.err, j.raw)
49 | }
50 |
51 | // UnmarshalArgument unmarshals a json.RawMessage depending on the specified value type into value
52 | func (j *jsonHubProtocol) UnmarshalArgument(src interface{}, dst interface{}) error {
53 | rawSrc, ok := src.(json.RawMessage)
54 | if !ok {
55 | return fmt.Errorf("invalid source %#v for UnmarshalArgument", src)
56 | }
57 | if err := json.Unmarshal(rawSrc, dst); err != nil {
58 | return &jsonError{string(rawSrc), err}
59 | }
60 | _ = j.dbg.Log(evt, "UnmarshalArgument",
61 | "argument", string(rawSrc),
62 | "value", fmt.Sprintf("%v", reflect.ValueOf(dst).Elem()))
63 | return nil
64 | }
65 |
66 | // ParseMessages reads all messages from the reader and puts the remaining bytes into remainBuf
67 | func (j *jsonHubProtocol) ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) (messages []interface{}, err error) {
68 | frames, err := readJSONFrames(reader, remainBuf)
69 | if err != nil {
70 | return nil, err
71 | }
72 | message := hubMessage{}
73 | messages = make([]interface{}, 0)
74 | for _, frame := range frames {
75 | err = json.Unmarshal(frame, &message)
76 | _ = j.dbg.Log(evt, "read", msg, string(frame))
77 | if err != nil {
78 | return nil, &jsonError{string(frame), err}
79 | }
80 | typedMessage, err := j.parseMessage(message.Type, frame)
81 | if err != nil {
82 | return nil, err
83 | }
84 | // No specific type (aka Ping), use hubMessage
85 | if typedMessage == nil {
86 | typedMessage = message
87 | }
88 | messages = append(messages, typedMessage)
89 | }
90 | return messages, nil
91 | }
92 |
93 | func (j *jsonHubProtocol) parseMessage(messageType int, text []byte) (message interface{}, err error) {
94 | switch messageType {
95 | case 1, 4:
96 | jsonInvocation := jsonInvocationMessage{}
97 | if err = json.Unmarshal(text, &jsonInvocation); err != nil {
98 | err = &jsonError{string(text), err}
99 | }
100 | arguments := make([]interface{}, len(jsonInvocation.Arguments))
101 | for i, a := range jsonInvocation.Arguments {
102 | arguments[i] = a
103 | }
104 | return invocationMessage{
105 | Type: jsonInvocation.Type,
106 | Target: jsonInvocation.Target,
107 | InvocationID: jsonInvocation.InvocationID,
108 | Arguments: arguments,
109 | StreamIds: jsonInvocation.StreamIds,
110 | }, err
111 | case 2:
112 | jsonStreamItem := jsonStreamItemMessage{}
113 | if err = json.Unmarshal(text, &jsonStreamItem); err != nil {
114 | err = &jsonError{string(text), err}
115 | }
116 | return streamItemMessage{
117 | Type: jsonStreamItem.Type,
118 | InvocationID: jsonStreamItem.InvocationID,
119 | Item: jsonStreamItem.Item,
120 | }, err
121 | case 3:
122 | jsonCompletion := jsonCompletionMessage{}
123 | if err = json.Unmarshal(text, &jsonCompletion); err != nil {
124 | err = &jsonError{string(text), err}
125 | }
126 | completion := completionMessage{
127 | Type: jsonCompletion.Type,
128 | InvocationID: jsonCompletion.InvocationID,
129 | Error: jsonCompletion.Error,
130 | }
131 | // Only assign Result when non nil. setting interface{} Result to (json.RawMessage)(nil)
132 | // will produce a value which can not compared to nil even if it is pointing towards nil!
133 | // See https://www.calhoun.io/when-nil-isnt-equal-to-nil/ for explanation
134 | if jsonCompletion.Result != nil {
135 | completion.Result = jsonCompletion.Result
136 | }
137 | return completion, err
138 | case 5:
139 | invocation := cancelInvocationMessage{}
140 | if err = json.Unmarshal(text, &invocation); err != nil {
141 | err = &jsonError{string(text), err}
142 | }
143 | return invocation, err
144 | case 7:
145 | cm := closeMessage{}
146 | if err = json.Unmarshal(text, &cm); err != nil {
147 | err = &jsonError{string(text), err}
148 | }
149 | return cm, err
150 | default:
151 | return nil, nil
152 | }
153 | }
154 |
155 | // readJSONFrames reads all complete frames (delimited by 0x1e) from the reader and puts the remaining bytes into remainBuf
156 | func readJSONFrames(reader io.Reader, remainBuf *bytes.Buffer) ([][]byte, error) {
157 | p := make([]byte, 1<<15)
158 | buf := &bytes.Buffer{}
159 | _, _ = buf.ReadFrom(remainBuf)
160 | // Try getting data until at least one frame is available
161 | for {
162 | n, err := reader.Read(p)
163 | // Some reader implementations return io.EOF additionally to n=0 if no data could be read
164 | if err != nil && !errors.Is(err, io.EOF) {
165 | return nil, err
166 | }
167 | if n > 0 {
168 | _, _ = buf.Write(p[:n])
169 | frames, err := parseJSONFrames(buf)
170 | if err != nil {
171 | return nil, err
172 | }
173 | if len(frames) > 0 {
174 | _, _ = remainBuf.ReadFrom(buf)
175 | return frames, nil
176 | }
177 | }
178 | }
179 | }
180 |
181 | func parseJSONFrames(buf *bytes.Buffer) ([][]byte, error) {
182 | frames := make([][]byte, 0)
183 | for {
184 | frame, err := buf.ReadBytes(0x1e)
185 | if errors.Is(err, io.EOF) {
186 | // Restore incomplete frame in buffer
187 | _, _ = buf.Write(frame)
188 | break
189 | }
190 | if err != nil {
191 | return nil, err
192 | }
193 | frames = append(frames, frame[:len(frame)-1])
194 | }
195 | return frames, nil
196 | }
197 |
198 | // WriteMessage writes a message as JSON to the specified writer
199 | func (j *jsonHubProtocol) WriteMessage(message interface{}, writer io.Writer) error {
200 | var b []byte
201 | var err error
202 | if marshaler, ok := message.(json.Marshaler); ok {
203 | b, err = marshaler.MarshalJSON()
204 | } else {
205 | b, err = json.Marshal(message)
206 | }
207 | if err != nil {
208 | return err
209 | }
210 | b = append(b, 0x1e)
211 | _ = j.dbg.Log(evt, "write", msg, string(b))
212 | _, err = writer.Write(b)
213 | return err
214 | }
215 |
216 | func (j *jsonHubProtocol) transferMode() TransferMode {
217 | return TextTransferMode
218 | }
219 |
220 | func (j *jsonHubProtocol) setDebugLogger(dbg StructuredLogger) {
221 | j.dbg = log.WithPrefix(dbg, "ts", log.DefaultTimestampUTC, "protocol", "JSON")
222 | }
223 |
--------------------------------------------------------------------------------
/logger_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "io"
7 | "os"
8 | "sync"
9 | "testing"
10 | "time"
11 |
12 | "github.com/go-kit/log"
13 | )
14 |
15 | type loggerConfig struct {
16 | Enabled bool
17 | Debug bool
18 | }
19 |
20 | var lConf loggerConfig
21 |
22 | var tLog StructuredLogger
23 |
24 | func testLoggerOption() func(Party) error {
25 | testLogger()
26 | return Logger(tLog, lConf.Debug)
27 | }
28 |
29 | func testLogger() StructuredLogger {
30 | if tLog == nil {
31 | lConf = loggerConfig{Enabled: false, Debug: false}
32 | b, err := os.ReadFile("testLogConf.json")
33 | if err == nil {
34 | err = json.Unmarshal(b, &lConf)
35 | if err != nil {
36 | lConf = loggerConfig{Enabled: false, Debug: false}
37 | }
38 | }
39 | writer := io.Discard
40 | if lConf.Enabled {
41 | writer = os.Stderr
42 | }
43 | tLog = log.NewLogfmtLogger(writer)
44 | }
45 | return tLog
46 | }
47 |
48 | type panicLogger struct {
49 | log log.Logger
50 | }
51 |
52 | func (p *panicLogger) Log(keyVals ...interface{}) error {
53 | _ = p.log.Log(keyVals...)
54 | panic("panic as expected")
55 | }
56 |
57 | type testLogWriter struct {
58 | mx sync.Mutex
59 | p []byte
60 | t *testing.T
61 | }
62 |
63 | func (t *testLogWriter) Write(p []byte) (n int, err error) {
64 | t.mx.Lock()
65 | defer t.mx.Unlock()
66 | t.p = append(t.p, p...)
67 | if len(p) > 0 && p[len(p)-1] == 10 { // Will not work on Windows, but doesn't matter. This is only to check if the logger output still looks as expected
68 | t.t.Log(string(t.p))
69 | t.p = nil
70 | }
71 | return len(p), nil
72 | }
73 |
74 | func Test_PanicLogger(t *testing.T) {
75 | defer func() {
76 | if err := recover(); err != nil {
77 | t.Errorf("panic in logger: '%v'", err)
78 | }
79 | }()
80 | ctx, cancel := context.WithCancel(context.Background())
81 | server, _ := NewServer(ctx, SimpleHubFactory(&simpleHub{}),
82 | Logger(&panicLogger{log: log.NewLogfmtLogger(&testLogWriter{t: t})}, true),
83 | ChanReceiveTimeout(200*time.Millisecond),
84 | StreamBufferCapacity(5))
85 | // Create both ends of the connection
86 | cliConn, srvConn := newClientServerConnections()
87 | // Start the server
88 | go func() { _ = server.Serve(srvConn) }()
89 | // Create the Client
90 | client, _ := NewClient(ctx, WithConnection(cliConn), Logger(&panicLogger{log: log.NewLogfmtLogger(&testLogWriter{t: t})}, true))
91 | // Start it
92 | client.Start()
93 | // Do something
94 | <-client.Send("InvokeMe")
95 | cancel()
96 | }
97 |
--------------------------------------------------------------------------------
/messagepackhubprotocol_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "reflect"
6 |
7 | . "github.com/onsi/ginkgo"
8 | . "github.com/onsi/gomega"
9 | )
10 |
11 | var _ = Describe("MessagePackHubProtocol", func() {
12 | protocol := messagePackHubProtocol{}
13 | protocol.setDebugLogger(testLogger())
14 | Context("ParseMessages", func() {
15 | It("should encode/decode an InvocationMessage", func() {
16 | message := invocationMessage{
17 | Type: 4,
18 | Target: "target",
19 | InvocationID: "1",
20 | // because DecodeSlice below decodes ints to the smallest type and arrays always to []interface{}, we need to be very specific
21 | Arguments: []interface{}{"1", int8(1), []interface{}{int8(7), int8(3)}},
22 | StreamIds: []string{"0"},
23 | }
24 | buf := bytes.Buffer{}
25 | err := protocol.WriteMessage(message, &buf)
26 | Expect(err).NotTo(HaveOccurred())
27 | remainBuf := bytes.Buffer{}
28 | got, err := protocol.ParseMessages(&buf, &remainBuf)
29 | Expect(err).NotTo(HaveOccurred())
30 | Expect(remainBuf.Len()).To(Equal(0))
31 | Expect(len(got)).To(Equal(1))
32 | Expect(got[0]).To(BeAssignableToTypeOf(invocationMessage{}))
33 | gotMsg := got[0].(invocationMessage)
34 | Expect(gotMsg.Type).To(Equal(message.Type))
35 | Expect(gotMsg.Target).To(Equal(message.Target))
36 | Expect(gotMsg.InvocationID).To(Equal(message.InvocationID))
37 | Expect(gotMsg.StreamIds).To(Equal(message.StreamIds))
38 | for i, gotArg := range gotMsg.Arguments {
39 | // We can not directly compare gotArg and want.Arguments[i]
40 | // because msgpack serializes numbers to the shortest possible type
41 | t := reflect.TypeOf(message.Arguments[i])
42 | value := reflect.New(t)
43 | Expect(protocol.UnmarshalArgument(gotArg, value.Interface())).NotTo(HaveOccurred())
44 | Expect(reflect.Indirect(value).Interface()).To(Equal(message.Arguments[i]))
45 | }
46 | })
47 | })
48 | })
49 |
--------------------------------------------------------------------------------
/negotiateresponse.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | type availableTransport struct {
4 | Transport string `json:"transport"`
5 | TransferFormats []string `json:"transferFormats"`
6 | }
7 |
8 | type negotiateResponse struct {
9 | ConnectionToken string `json:"connectionToken,omitempty"`
10 | ConnectionID string `json:"connectionId"`
11 | NegotiateVersion int `json:"negotiateVersion,omitempty"`
12 | AvailableTransports []availableTransport `json:"availableTransports"`
13 | }
14 |
15 | func (nr *negotiateResponse) getTransferFormats(transportType string) []string {
16 | for _, transport := range nr.AvailableTransports {
17 | if transport.Transport == transportType {
18 | return transport.TransferFormats
19 | }
20 | }
21 | return nil
22 | }
23 |
--------------------------------------------------------------------------------
/netconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "crypto/rand"
6 | "encoding/base64"
7 | "fmt"
8 | "net"
9 | "time"
10 | )
11 |
12 | type netConnection struct {
13 | ConnectionBase
14 | conn net.Conn
15 | }
16 |
17 | // NewNetConnection wraps net.Conn into a Connection
18 | func NewNetConnection(ctx context.Context, conn net.Conn) Connection {
19 | netConn := &netConnection{
20 | ConnectionBase: *NewConnectionBase(ctx, getConnectionID()),
21 | conn: conn,
22 | }
23 | go func() {
24 | <-ctx.Done()
25 | _ = conn.Close()
26 | }()
27 | return netConn
28 | }
29 |
30 | func (nc *netConnection) Write(p []byte) (n int, err error) {
31 | n, err = ReadWriteWithContext(nc.Context(),
32 | func() (int, error) { return nc.conn.Write(p) },
33 | func() { _ = nc.conn.SetWriteDeadline(time.Now()) })
34 | if err != nil {
35 | err = fmt.Errorf("%T: %w", nc, err)
36 | }
37 | return n, err
38 | }
39 |
40 | func (nc *netConnection) Read(p []byte) (n int, err error) {
41 | n, err = ReadWriteWithContext(nc.Context(),
42 | func() (int, error) { return nc.conn.Read(p) },
43 | func() { _ = nc.conn.SetReadDeadline(time.Now()) })
44 | if err != nil {
45 | err = fmt.Errorf("%T: %w", nc, err)
46 | }
47 | return n, err
48 | }
49 |
50 | func getConnectionID() string {
51 | bytes := make([]byte, 16)
52 | _, _ = rand.Read(bytes)
53 | return base64.StdEncoding.EncodeToString(bytes)
54 | }
55 |
--------------------------------------------------------------------------------
/options.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/go-kit/log"
9 | "github.com/go-kit/log/level"
10 | )
11 |
12 | // TimeoutInterval is the interval one Party will consider the other Party disconnected
13 | // if it hasn't received a message (including keep-alive) in it.
14 | // The recommended value is double the KeepAliveInterval value.
15 | // Default is 30 seconds.
16 | func TimeoutInterval(timeout time.Duration) func(Party) error {
17 | return func(p Party) error {
18 | p.setTimeout(timeout)
19 | return nil
20 | }
21 | }
22 |
23 | // HandshakeTimeout is the interval if the other Party doesn't send an initial handshake message within,
24 | // the connection is closed. This is an advanced setting that should only be modified
25 | // if handshake timeout errors are occurring due to severe network latency.
26 | // For more detail on the handshake process,
27 | // see https://github.com/dotnet/aspnetcore/blob/master/src/SignalR/docs/specs/HubProtocol.md
28 | func HandshakeTimeout(timeout time.Duration) func(Party) error {
29 | return func(p Party) error {
30 | p.setHandshakeTimeout(timeout)
31 | return nil
32 | }
33 | }
34 |
35 | // KeepAliveInterval is the interval if the Party hasn't sent a message within,
36 | // a ping message is sent automatically to keep the connection open.
37 | // When changing KeepAliveInterval, change the Timeout setting on the other Party.
38 | // The recommended Timeout value is double the KeepAliveInterval value.
39 | // Default is 15 seconds.
40 | func KeepAliveInterval(interval time.Duration) func(Party) error {
41 | return func(p Party) error {
42 | p.setKeepAliveInterval(interval)
43 | return nil
44 | }
45 | }
46 |
47 | // StreamBufferCapacity is the maximum number of items that can be buffered for client upload streams.
48 | // If this limit is reached, the processing of invocations is blocked until the server processes stream items.
49 | // Default is 10.
50 | func StreamBufferCapacity(capacity uint) func(Party) error {
51 | return func(p Party) error {
52 | if capacity == 0 {
53 | return errors.New("unsupported StreamBufferCapacity 0")
54 | }
55 | p.setStreamBufferCapacity(capacity)
56 | return nil
57 | }
58 | }
59 |
60 | // MaximumReceiveMessageSize is the maximum size in bytes of a single incoming hub message.
61 | // Default is 32768 bytes (32KB)
62 | func MaximumReceiveMessageSize(sizeInBytes uint) func(Party) error {
63 | return func(p Party) error {
64 | if sizeInBytes == 0 {
65 | return errors.New("unsupported maximumReceiveMessageSize 0")
66 | }
67 | p.setMaximumReceiveMessageSize(sizeInBytes)
68 | return nil
69 | }
70 | }
71 |
72 | // ChanReceiveTimeout is the timeout for processing stream items from the client, after StreamBufferCapacity was reached
73 | // If the hub method is not able to process a stream item during the timeout duration,
74 | // the server will send a completion with error.
75 | // Default is 5 seconds.
76 | func ChanReceiveTimeout(timeout time.Duration) func(Party) error {
77 | return func(p Party) error {
78 | p.setChanReceiveTimeout(timeout)
79 | return nil
80 | }
81 | }
82 |
83 | // EnableDetailedErrors If true, detailed exception messages are returned to the other
84 | // Party when an exception is thrown in a Hub method.
85 | // The default is false, as these exception messages can contain sensitive information.
86 | func EnableDetailedErrors(enable bool) func(Party) error {
87 | return func(p Party) error {
88 | p.setEnableDetailedErrors(enable)
89 | return nil
90 | }
91 | }
92 |
93 | // StructuredLogger is the simplest logging interface for structured logging.
94 | // See github.com/go-kit/log
95 | type StructuredLogger interface {
96 | Log(keyVals ...interface{}) error
97 | }
98 |
99 | // Logger sets the logger used by the Party to log info events.
100 | // If debug is true, debug log event are generated, too
101 | func Logger(logger StructuredLogger, debug bool) func(Party) error {
102 | return func(p Party) error {
103 | i, d := buildInfoDebugLogger(logger, debug)
104 | p.setLoggers(i, d)
105 | return nil
106 | }
107 | }
108 |
109 | type recoverLogger struct {
110 | logger log.Logger
111 | }
112 |
113 | func (r *recoverLogger) Log(keyVals ...interface{}) error {
114 | defer func() {
115 | if err := recover(); err != nil {
116 | fmt.Printf("recovering from panic in logger: %v\n", err)
117 | }
118 | }()
119 | return r.logger.Log(keyVals...)
120 | }
121 |
122 | func buildInfoDebugLogger(logger log.Logger, debug bool) (log.Logger, log.Logger) {
123 | if debug {
124 | logger = level.NewFilter(logger, level.AllowDebug())
125 | } else {
126 | logger = level.NewFilter(logger, level.AllowInfo())
127 | }
128 | infoLogger := &recoverLogger{level.Info(logger)}
129 | debugLogger := log.With(&recoverLogger{level.Debug(logger)}, "caller", log.DefaultCaller)
130 | return infoLogger, debugLogger
131 | }
132 |
--------------------------------------------------------------------------------
/party.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/go-kit/log"
8 | )
9 |
10 | // Party is the common base of Server and Client. The Party methods are only used internally,
11 | // but the interface is public to allow using Options on Party as parameters for external functions
12 | type Party interface {
13 | context() context.Context
14 | cancel()
15 |
16 | onConnected(hc hubConnection)
17 | onDisconnected(hc hubConnection)
18 |
19 | invocationTarget(hc hubConnection) interface{}
20 |
21 | timeout() time.Duration
22 | setTimeout(timeout time.Duration)
23 |
24 | setHandshakeTimeout(timeout time.Duration)
25 |
26 | keepAliveInterval() time.Duration
27 | setKeepAliveInterval(interval time.Duration)
28 |
29 | insecureSkipVerify() bool
30 | setInsecureSkipVerify(skip bool)
31 |
32 | originPatterns() [] string
33 | setOriginPatterns(orgs []string)
34 |
35 | chanReceiveTimeout() time.Duration
36 | setChanReceiveTimeout(interval time.Duration)
37 |
38 | streamBufferCapacity() uint
39 | setStreamBufferCapacity(capacity uint)
40 |
41 | allowReconnect() bool
42 |
43 | enableDetailedErrors() bool
44 | setEnableDetailedErrors(enable bool)
45 |
46 | loggers() (info StructuredLogger, dbg StructuredLogger)
47 | setLoggers(info StructuredLogger, dbg StructuredLogger)
48 |
49 | prefixLoggers(connectionID string) (info StructuredLogger, dbg StructuredLogger)
50 |
51 | maximumReceiveMessageSize() uint
52 | setMaximumReceiveMessageSize(size uint)
53 | }
54 |
55 | func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase {
56 | ctx, cancelFunc := context.WithCancel(parentContext)
57 | return partyBase{
58 | ctx: ctx,
59 | cancelFunc: cancelFunc,
60 | _timeout: time.Second * 30,
61 | _handshakeTimeout: time.Second * 15,
62 | _keepAliveInterval: time.Second * 5,
63 | _chanReceiveTimeout: time.Second * 5,
64 | _streamBufferCapacity: 10,
65 | _maximumReceiveMessageSize: 1 << 15, // 32KB
66 | _enableDetailedErrors: false,
67 | _insecureSkipVerify: false,
68 | _originPatterns: nil,
69 | info: info,
70 | dbg: dbg,
71 | }
72 | }
73 |
74 | type partyBase struct {
75 | ctx context.Context
76 | cancelFunc context.CancelFunc
77 | _timeout time.Duration
78 | _handshakeTimeout time.Duration
79 | _keepAliveInterval time.Duration
80 | _chanReceiveTimeout time.Duration
81 | _streamBufferCapacity uint
82 | _maximumReceiveMessageSize uint
83 | _enableDetailedErrors bool
84 | _insecureSkipVerify bool
85 | _originPatterns []string
86 | info StructuredLogger
87 | dbg StructuredLogger
88 | }
89 |
90 | func (p *partyBase) context() context.Context {
91 | return p.ctx
92 | }
93 |
94 | func (p *partyBase) cancel() {
95 | p.cancelFunc()
96 | }
97 |
98 | func (p *partyBase) timeout() time.Duration {
99 | return p._timeout
100 | }
101 |
102 | func (p *partyBase) setTimeout(timeout time.Duration) {
103 | p._timeout = timeout
104 | }
105 |
106 | func (p *partyBase) HandshakeTimeout() time.Duration {
107 | return p._handshakeTimeout
108 | }
109 |
110 | func (p *partyBase) setHandshakeTimeout(timeout time.Duration) {
111 | p._handshakeTimeout = timeout
112 | }
113 |
114 | func (p *partyBase) keepAliveInterval() time.Duration {
115 | return p._keepAliveInterval
116 | }
117 |
118 | func (p *partyBase) setKeepAliveInterval(interval time.Duration) {
119 | p._keepAliveInterval = interval
120 | }
121 |
122 | func (p *partyBase) insecureSkipVerify() bool {
123 | return p._insecureSkipVerify
124 | }
125 | func (p *partyBase) setInsecureSkipVerify(skip bool) {
126 | p._insecureSkipVerify = skip
127 | }
128 |
129 | func (p *partyBase) originPatterns() []string {
130 | return p._originPatterns
131 | }
132 | func (p *partyBase) setOriginPatterns(origins []string) {
133 | p._originPatterns = origins
134 | }
135 |
136 | func (p *partyBase) chanReceiveTimeout() time.Duration {
137 | return p._chanReceiveTimeout
138 | }
139 |
140 | func (p *partyBase) setChanReceiveTimeout(interval time.Duration) {
141 | p._chanReceiveTimeout = interval
142 | }
143 |
144 | func (p *partyBase) streamBufferCapacity() uint {
145 | return p._streamBufferCapacity
146 | }
147 |
148 | func (p *partyBase) setStreamBufferCapacity(capacity uint) {
149 | p._streamBufferCapacity = capacity
150 | }
151 |
152 | func (p *partyBase) maximumReceiveMessageSize() uint {
153 | return p._maximumReceiveMessageSize
154 | }
155 |
156 | func (p *partyBase) setMaximumReceiveMessageSize(size uint) {
157 | p._maximumReceiveMessageSize = size
158 | }
159 |
160 | func (p *partyBase) enableDetailedErrors() bool {
161 | return p._enableDetailedErrors
162 | }
163 |
164 | func (p *partyBase) setEnableDetailedErrors(enable bool) {
165 | p._enableDetailedErrors = enable
166 | }
167 |
168 | func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) {
169 | p.info = info
170 | p.dbg = dbg
171 | }
172 |
173 | func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) {
174 | return p.info, p.dbg
175 | }
176 |
--------------------------------------------------------------------------------
/receiver.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | // ReceiverInterface allows receivers to interact with the server directly from the receiver methods
4 | // Init(Client)
5 | // Init is used by the Client to connect the receiver to the server.
6 | // Server() Client
7 | // Server can be used inside receiver methods to call Client methods,
8 | // e.g. Client.Send, Client.Invoke, Client.PullStream and Client.PushStreams
9 | type ReceiverInterface interface {
10 | Init(Client)
11 | Server() Client
12 | }
13 |
14 | // Receiver is a base class for receivers in the client.
15 | // It implements ReceiverInterface
16 | type Receiver struct {
17 | client Client
18 | }
19 |
20 | // Init is used by the Client to connect the receiver to the server.
21 | func (ch *Receiver) Init(client Client) {
22 | ch.client = client
23 | }
24 |
25 | // Server can be used inside receiver methods to call Client methods,
26 | func (ch *Receiver) Server() Client {
27 | return ch.client
28 | }
29 |
--------------------------------------------------------------------------------
/router/Makefile:
--------------------------------------------------------------------------------
1 | PROJECT_NAME := "signalr/router"
2 | PKG := "github.com/philippseith/$(PROJECT_NAME)"
3 | PKG_LIST := $(shell go list ${PKG}/... | grep -v /vendor/)
4 | GO_FILES := $(shell find . -name '*.go' | grep -v /vendor/ | grep -v _test.go)
5 |
6 | .PHONY: all dep lint vet test test-coverage build clean
7 |
8 | all: build
9 |
10 | dep: ## Get the dependencies
11 | @go mod download
12 |
13 | lint: ## Lint Golang files
14 | @golint -set_exit_status ${PKG_LIST}
15 |
16 | vet: ## Run go vet
17 | @go vet ${PKG_LIST}
18 |
19 | test: ## Run unittests
20 | @go test -short ${PKG_LIST}
21 |
22 | test-coverage: ## Run tests with coverage
23 | @go test -short -coverpkg=. -coverprofile cover.out -covermode=atomic ${PKG_LIST}
24 | @cat cover.out >> coverage.txt
25 |
26 | build: dep ## Build the binary file
27 | @go build -i -o build/main $(PKG)
28 |
29 | clean: ## Remove previous build
30 | @rm -f $(PROJECT_NAME)/build
31 |
32 | help: ## Display this help screen
33 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
34 |
--------------------------------------------------------------------------------
/router/chirouter.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/go-chi/chi/v5"
7 | "github.com/philippseith/signalr"
8 | )
9 |
10 | // WithChiRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP
11 | // which converts a chi.Router to a signalr.MappableRouter.
12 | func WithChiRouter(r chi.Router) func() signalr.MappableRouter {
13 | return func() signalr.MappableRouter {
14 | return &chiRouter{r: r}
15 | }
16 | }
17 |
18 | type chiRouter struct {
19 | r chi.Router
20 | }
21 |
22 | func (j *chiRouter) HandleFunc(path string, handler func(w http.ResponseWriter, r *http.Request)) {
23 | j.r.HandleFunc(path, handler)
24 | }
25 |
26 | func (j *chiRouter) Handle(pattern string, handler http.Handler) {
27 | j.r.Handle(pattern, handler)
28 | }
29 |
--------------------------------------------------------------------------------
/router/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package router contains signalr.MappableRouter factories for some popular http routers.
3 | (Listed in https://www.alexedwards.net/blog/which-go-router-should-i-use)
4 | The factories can be used to integrate the signalR server from github.com/philippseith/signalr with the router.
5 | If you don't like to reference router modules you aren't using, just copy the source code
6 | for the factory for your router.
7 | */
8 | package router
9 |
--------------------------------------------------------------------------------
/router/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippseith/signalr/router
2 |
3 | go 1.17
4 |
5 | require (
6 | github.com/go-chi/chi/v5 v5.0.5
7 | github.com/go-kit/log v0.2.0
8 | github.com/gorilla/mux v1.8.0
9 | github.com/julienschmidt/httprouter v1.3.0
10 | github.com/onsi/ginkgo v1.12.1
11 | github.com/onsi/gomega v1.11.0
12 | github.com/philippseith/signalr v0.4.2-0.20211029170321-9b04bbc12782
13 | )
14 |
15 | require (
16 | github.com/fsnotify/fsnotify v1.4.9 // indirect
17 | github.com/go-logfmt/logfmt v0.5.1 // indirect
18 | github.com/klauspost/compress v1.10.3 // indirect
19 | github.com/nxadm/tail v1.4.8 // indirect
20 | github.com/teivah/onecontext v1.3.0 // indirect
21 | github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect
22 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
23 | golang.org/x/net v0.7.0 // indirect
24 | golang.org/x/sys v0.5.0 // indirect
25 | golang.org/x/text v0.7.0 // indirect
26 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
27 | gopkg.in/yaml.v2 v2.4.0 // indirect
28 | nhooyr.io/websocket v1.8.7 // indirect
29 | )
30 |
--------------------------------------------------------------------------------
/router/gorillarouter.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/gorilla/mux"
7 | "github.com/philippseith/signalr"
8 | )
9 |
10 | // WithGorillaRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP
11 | // which converts a mux.Router to a signalr.MappableRouter.
12 | func WithGorillaRouter(r *mux.Router) func() signalr.MappableRouter {
13 | return func() signalr.MappableRouter {
14 | return &gorillaMappableRouter{r: r}
15 | }
16 | }
17 |
18 | type gorillaMappableRouter struct {
19 | r *mux.Router
20 | }
21 |
22 | func (g *gorillaMappableRouter) Handle(path string, handler http.Handler) {
23 | g.r.Handle(path, handler)
24 | }
25 |
26 | func (g *gorillaMappableRouter) HandleFunc(path string, handleFunc func(w http.ResponseWriter, r *http.Request)) {
27 | g.r.HandleFunc(path, handleFunc)
28 | }
29 |
--------------------------------------------------------------------------------
/router/httprouter.go:
--------------------------------------------------------------------------------
1 | package router
2 |
3 | import (
4 | "net/http"
5 |
6 | "github.com/julienschmidt/httprouter"
7 | "github.com/philippseith/signalr"
8 | )
9 |
10 | // WithHttpRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP
11 | // which converts a httprouter.Router to a signalr.MappableRouter.
12 | func WithHttpRouter(r *httprouter.Router) func() signalr.MappableRouter {
13 | return func() signalr.MappableRouter {
14 | return &julienRouter{r: r}
15 | }
16 | }
17 |
18 | type julienRouter struct {
19 | r *httprouter.Router
20 | }
21 |
22 | func (j *julienRouter) HandleFunc(path string, handler func(w http.ResponseWriter, r *http.Request)) {
23 | j.r.HandlerFunc("POST", path, handler)
24 | }
25 |
26 | func (j *julienRouter) Handle(pattern string, handler http.Handler) {
27 | j.r.Handler("POST", pattern, handler)
28 | j.r.Handler("GET", pattern, handler)
29 | }
30 |
--------------------------------------------------------------------------------
/router/router_test/router_test.go:
--------------------------------------------------------------------------------
1 | package router_test
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io/ioutil"
10 | "net"
11 | "net/http"
12 | "testing"
13 | "time"
14 |
15 | "github.com/go-kit/log"
16 |
17 | "github.com/go-chi/chi/v5"
18 |
19 | "github.com/julienschmidt/httprouter"
20 |
21 | "github.com/gorilla/mux"
22 |
23 | "github.com/philippseith/signalr"
24 | "github.com/philippseith/signalr/router"
25 |
26 | . "github.com/onsi/ginkgo"
27 | . "github.com/onsi/gomega"
28 | )
29 |
30 | func TestRouter(t *testing.T) {
31 | RegisterFailHandler(Fail)
32 | RunSpecs(t, "signalr/mux Suite")
33 | }
34 |
35 | func initGorillaRouter(server signalr.Server, port int) {
36 | r := mux.NewRouter()
37 | server.MapHTTP(router.WithGorillaRouter(r), "/hub")
38 | go func() {
39 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r)
40 | }()
41 | }
42 |
43 | func initHttpRouter(server signalr.Server, port int) {
44 | r := httprouter.New()
45 | server.MapHTTP(router.WithHttpRouter(r), "/hub")
46 | go func() {
47 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r)
48 | }()
49 | }
50 |
51 | func initChiRouter(server signalr.Server, port int) {
52 | r := chi.NewRouter()
53 | server.MapHTTP(router.WithChiRouter(r), "/hub")
54 | go func() {
55 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r)
56 | }()
57 | }
58 |
59 | var _ = Describe("Router", func() {
60 | for i, initFunc := range []func(server signalr.Server, port int){
61 | initGorillaRouter,
62 | initHttpRouter,
63 | initChiRouter,
64 | } {
65 | routerNames := []string{
66 | "gorilla/mux.Router",
67 | "julienschmidt/httprouter",
68 | "chi/Router",
69 | }
70 | Context(fmt.Sprintf("With %v", routerNames[i]), func() {
71 | Context("A correct negotiation request is sent", func() {
72 | It("should send a correct negotiation response", func(done Done) {
73 | // Start server
74 | ctx, serverCancel := context.WithCancel(context.Background())
75 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&addHub{}), signalr.HTTPTransports("WebSockets"))
76 | Expect(err).NotTo(HaveOccurred())
77 | port := freePort()
78 | initFunc(server, port)
79 | // Negotiate
80 | negResp := negotiateWebSocketTestServer(port)
81 | Expect(negResp["connectionId"]).NotTo(BeNil())
82 | Expect(negResp["availableTransports"]).To(BeAssignableToTypeOf([]interface{}{}))
83 | avt := negResp["availableTransports"].([]interface{})
84 | Expect(len(avt)).To(BeNumerically(">", 0))
85 | Expect(avt[0]).To(BeAssignableToTypeOf(map[string]interface{}{}))
86 | avtVal := avt[0].(map[string]interface{})
87 | Expect(avtVal["transferFormats"]).To(BeAssignableToTypeOf([]interface{}{}))
88 | tf := avtVal["transferFormats"].([]interface{})
89 | Expect(tf).To(ContainElement("Text"))
90 | Expect(tf).To(ContainElement("Binary"))
91 | serverCancel()
92 | close(done)
93 | })
94 | })
95 | Context("Connection with client", func() {
96 | It("should successfully handle an Invoke call", func(done Done) {
97 | // Start server
98 | server, err := signalr.NewServer(context.Background(),
99 | signalr.SimpleHubFactory(&addHub{}),
100 | signalr.Logger(log.NewNopLogger(), false),
101 | signalr.HTTPTransports("WebSockets"))
102 | Expect(err).NotTo(HaveOccurred())
103 | port := freePort()
104 | initFunc(server, port)
105 | waitForPort(port)
106 | // Start client
107 | conn, err := signalr.NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port))
108 | Expect(err).NotTo(HaveOccurred())
109 | client, err := signalr.NewClient(context.Background(),
110 | signalr.WithConnection(conn),
111 | signalr.Logger(log.NewNopLogger(), false),
112 | signalr.TransferFormat("Text"))
113 | Expect(err).NotTo(HaveOccurred())
114 | Expect(client).NotTo(BeNil())
115 | errCh := client.Start()
116 | Expect(client.WaitConnected(context.Background())).NotTo(HaveOccurred())
117 | go func() {
118 | defer GinkgoRecover()
119 | Expect(errors.Is(<-errCh, context.Canceled)).To(BeTrue())
120 | }()
121 | Expect(err).NotTo(HaveOccurred())
122 | result := <-client.Invoke("Add2", 1)
123 | Expect(result.Error).NotTo(HaveOccurred())
124 | Expect(result.Value).To(BeEquivalentTo(3))
125 | close(done)
126 | }, 10.0)
127 | })
128 | })
129 | }
130 | })
131 |
132 | type addHub struct {
133 | signalr.Hub
134 | }
135 |
136 | func (w *addHub) Add2(i int) int {
137 | return i + 2
138 | }
139 |
140 | func negotiateWebSocketTestServer(port int) map[string]interface{} {
141 | waitForPort(port)
142 | buf := bytes.Buffer{}
143 | resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port), "text/plain;charset=UTF-8", &buf)
144 | Expect(err).To(BeNil())
145 | Expect(resp).ToNot(BeNil())
146 | defer func() {
147 | _ = resp.Body.Close()
148 | }()
149 | var body []byte
150 | body, err = ioutil.ReadAll(resp.Body)
151 | Expect(err).To(BeNil())
152 | response := make(map[string]interface{})
153 | err = json.Unmarshal(body, &response)
154 | Expect(err).To(BeNil())
155 | return response
156 | }
157 |
158 | func freePort() int {
159 | if addr, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
160 | if listener, err := net.ListenTCP("tcp", addr); err == nil {
161 | defer func() {
162 | _ = listener.Close()
163 | }()
164 | return listener.Addr().(*net.TCPAddr).Port
165 | }
166 | }
167 | return 0
168 | }
169 |
170 | func waitForPort(port int) {
171 | for {
172 | if _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", port)); err == nil {
173 | return
174 | }
175 | time.Sleep(100 * time.Millisecond)
176 | }
177 | }
178 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "net/http"
10 | "os"
11 | "reflect"
12 | "runtime/debug"
13 |
14 | "github.com/go-kit/log"
15 | )
16 |
17 | // Server is a SignalR server for one type of hub.
18 | //
19 | // MapHTTP(mux *http.ServeMux, path string)
20 | // maps the servers' hub to a path on a http.ServeMux.
21 | //
22 | // Serve(conn Connection)
23 | // serves the hub of the server on one connection.
24 | // The same server might serve different connections in parallel. Serve does not return until the connection is closed
25 | // or the servers' context is canceled.
26 | //
27 | // HubClients()
28 | // allows to call all HubClients of the server from server-side, non-hub code.
29 | // Note that HubClients.Caller() returns nil, because there is no real caller which can be reached over a HubConnection.
30 | type Server interface {
31 | Party
32 | MapHTTP(routerFactory func() MappableRouter, path string)
33 | Serve(conn Connection) error
34 | HubClients() HubClients
35 | availableTransports() []string
36 | }
37 |
38 | type server struct {
39 | partyBase
40 | newHub func() HubInterface
41 | lifetimeManager HubLifetimeManager
42 | defaultHubClients *defaultHubClients
43 | groupManager GroupManager
44 | reconnectAllowed bool
45 | transports []string
46 | }
47 |
48 | // NewServer creates a new server for one type of hub. The hub type is set by one of the
49 | // options UseHub, HubFactory or SimpleHubFactory
50 | func NewServer(ctx context.Context, options ...func(Party) error) (Server, error) {
51 | info, dbg := buildInfoDebugLogger(log.NewLogfmtLogger(os.Stderr), false)
52 | lifetimeManager := newLifeTimeManager(info)
53 | server := &server{
54 | lifetimeManager: &lifetimeManager,
55 | defaultHubClients: &defaultHubClients{
56 | lifetimeManager: &lifetimeManager,
57 | allCache: allClientProxy{lifetimeManager: &lifetimeManager},
58 | },
59 | groupManager: &defaultGroupManager{
60 | lifetimeManager: &lifetimeManager,
61 | },
62 | partyBase: newPartyBase(ctx, info, dbg),
63 | reconnectAllowed: true,
64 | }
65 | for _, option := range options {
66 | if option != nil {
67 | if err := option(server); err != nil {
68 | return nil, err
69 | }
70 | }
71 | }
72 | if server.transports == nil {
73 | server.transports = []string{"WebSockets", "ServerSentEvents"}
74 | }
75 | if server.newHub == nil {
76 | return server, errors.New("cannot determine hub type. Neither UseHub, HubFactory or SimpleHubFactory given as option")
77 | }
78 | return server, nil
79 | }
80 |
81 | // MappableRouter encapsulates the methods used by server.MapHTTP to configure the
82 | // handlers required by the signalr protocol. this abstraction removes the explicit
83 | // binding to http.ServerMux and allows use of any mux which implements those basic
84 | // Handle and HandleFunc methods.
85 | type MappableRouter interface {
86 | HandleFunc(string, func(w http.ResponseWriter, r *http.Request))
87 | Handle(string, http.Handler)
88 | }
89 |
90 | // WithHTTPServeMux is a MappableRouter factory for MapHTTP which converts a
91 | // http.ServeMux to a MappableRouter.
92 | // For factories for other routers, see github.com/philippseith/signalr/router
93 | func WithHTTPServeMux(serveMux *http.ServeMux) func() MappableRouter {
94 | return func() MappableRouter {
95 | return serveMux
96 | }
97 | }
98 |
99 | // MapHTTP maps the servers' hub to a path in a MappableRouter
100 | func (s *server) MapHTTP(routerFactory func() MappableRouter, path string) {
101 | httpMux := newHTTPMux(s)
102 | router := routerFactory()
103 | router.HandleFunc(fmt.Sprintf("%s/negotiate", path), httpMux.negotiate)
104 | router.Handle(path, httpMux)
105 | }
106 |
107 | // Serve serves the hub of the server on one connection.
108 | // The same server might serve different connections in parallel. Serve does not return until the connection is closed
109 | // or the servers' context is canceled.
110 | func (s *server) Serve(conn Connection) error {
111 |
112 | protocol, err := s.processHandshake(conn)
113 | if err != nil {
114 | info, _ := s.prefixLoggers("")
115 | _ = info.Log(evt, "processHandshake", "connectionId", conn.ConnectionID(), "error", err, react, "do not connect")
116 | return err
117 | }
118 |
119 | return newLoop(s, conn, protocol).Run(make(chan struct{}, 1))
120 | }
121 |
122 | func (s *server) HubClients() HubClients {
123 | return s.defaultHubClients
124 | }
125 |
126 | func (s *server) availableTransports() []string {
127 | return s.transports
128 | }
129 |
130 | func (s *server) onConnected(hc hubConnection) {
131 | s.lifetimeManager.OnConnected(hc)
132 | go func() {
133 | defer s.recoverHubLifeCyclePanic()
134 | s.invocationTarget(hc).(HubInterface).OnConnected(hc.ConnectionID())
135 | }()
136 | }
137 |
138 | func (s *server) onDisconnected(hc hubConnection) {
139 | go func() {
140 | defer s.recoverHubLifeCyclePanic()
141 | s.invocationTarget(hc).(HubInterface).OnDisconnected(hc.ConnectionID())
142 | }()
143 | s.lifetimeManager.OnDisconnected(hc)
144 |
145 | }
146 |
147 | func (s *server) invocationTarget(conn hubConnection) interface{} {
148 | hub := s.newHub()
149 | hub.Initialize(s.newConnectionHubContext(conn))
150 | return hub
151 | }
152 |
153 | func (s *server) allowReconnect() bool {
154 | return s.reconnectAllowed
155 | }
156 |
157 | func (s *server) recoverHubLifeCyclePanic() {
158 | if err := recover(); err != nil {
159 | s.reconnectAllowed = false
160 | info, dbg := s.prefixLoggers("")
161 | _ = info.Log(evt, "panic in hub lifecycle", "error", err, react, "close connection, allow no reconnect")
162 | _ = dbg.Log(evt, "panic in hub lifecycle", "error", err, react, "close connection, allow no reconnect", "stack", string(debug.Stack()))
163 | s.cancel()
164 | }
165 | }
166 |
167 | func (s *server) prefixLoggers(connectionID string) (info StructuredLogger, dbg StructuredLogger) {
168 | return log.WithPrefix(s.info, "ts", log.DefaultTimestampUTC,
169 | "class", "Server",
170 | "connection", connectionID,
171 | "hub", reflect.ValueOf(s.newHub()).Elem().Type()),
172 | log.WithPrefix(s.dbg, "ts", log.DefaultTimestampUTC,
173 | "class", "Server",
174 | "connection", connectionID,
175 | "hub", reflect.ValueOf(s.newHub()).Elem().Type())
176 | }
177 |
178 | func (s *server) newConnectionHubContext(hubConn hubConnection) HubContext {
179 | return &connectionHubContext{
180 | abort: hubConn.Abort,
181 | clients: &callerHubClients{
182 | defaultHubClients: s.defaultHubClients,
183 | connectionID: hubConn.ConnectionID(),
184 | },
185 | groups: s.groupManager,
186 | connection: hubConn,
187 | info: s.info,
188 | dbg: s.dbg,
189 | }
190 | }
191 |
192 | func (s *server) processHandshake(conn Connection) (hubProtocol, error) {
193 | if request, err := s.receiveHandshakeRequest(conn); err != nil {
194 | return nil, err
195 | } else {
196 | return s.sendHandshakeResponse(conn, request)
197 | }
198 | }
199 |
200 | func (s *server) receiveHandshakeRequest(conn Connection) (handshakeRequest, error) {
201 | _, dbg := s.prefixLoggers(conn.ConnectionID())
202 | ctx, cancelRead := context.WithTimeout(s.context(), s.HandshakeTimeout())
203 | defer cancelRead()
204 | readJSONFramesChan := make(chan []interface{}, 1)
205 | go func() {
206 | var remainBuf bytes.Buffer
207 | rawHandshake, err := readJSONFrames(conn, &remainBuf)
208 | readJSONFramesChan <- []interface{}{rawHandshake, err}
209 | }()
210 | request := handshakeRequest{}
211 | select {
212 | case result := <-readJSONFramesChan:
213 | if result[1] != nil {
214 | return request, result[1].(error)
215 | }
216 | rawHandshake := result[0].([][]byte)
217 | _ = dbg.Log(evt, "handshake received", "msg", string(rawHandshake[0]))
218 | return request, json.Unmarshal(rawHandshake[0], &request)
219 | case <-ctx.Done():
220 | return request, ctx.Err()
221 | }
222 | }
223 |
224 | func (s *server) sendHandshakeResponse(conn Connection, request handshakeRequest) (protocol hubProtocol, err error) {
225 | info, dbg := s.prefixLoggers(conn.ConnectionID())
226 | ctx, cancelWrite := context.WithTimeout(s.context(), s.HandshakeTimeout())
227 | defer cancelWrite()
228 | var ok bool
229 | if protocol, ok = protocolMap[request.Protocol]; ok {
230 | // Send the handshake response
231 | const handshakeResponse = "{}\u001e"
232 | if _, err = ReadWriteWithContext(ctx,
233 | func() (int, error) {
234 | return conn.Write([]byte(handshakeResponse))
235 | }, func() {}); err != nil {
236 | _ = dbg.Log(evt, "handshake sent", "error", err)
237 | } else {
238 | _ = dbg.Log(evt, "handshake sent", "msg", handshakeResponse)
239 | }
240 | } else {
241 | err = fmt.Errorf("protocol %v not supported", request.Protocol)
242 | _ = info.Log(evt, "protocol requested", "error", err)
243 | if _, respErr := ReadWriteWithContext(ctx,
244 | func() (int, error) {
245 | const errorHandshakeResponse = "{\"error\":\"%s\"}\u001e"
246 | return conn.Write([]byte(fmt.Sprintf(errorHandshakeResponse, err)))
247 | }, func() {}); respErr != nil {
248 | _ = dbg.Log(evt, "handshake sent", "error", respErr)
249 | err = respErr
250 | }
251 | }
252 | return protocol, err
253 | }
254 |
255 | var protocolMap = map[string]hubProtocol{
256 | "json": &jsonHubProtocol{},
257 | "messagepack": &messagePackHubProtocol{},
258 | }
259 |
260 | // const for logging
261 | const evt string = "event"
262 | const msgRecv string = "message received"
263 | const msgSend string = "message send"
264 | const msg string = "message"
265 | const react string = "reaction"
266 |
--------------------------------------------------------------------------------
/server_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | . "github.com/onsi/ginkgo"
9 | . "github.com/onsi/gomega"
10 | )
11 |
12 | var _ = Describe("Server.HubClients", func() {
13 | Context("All().Send()", func() {
14 | j := 1
15 | It(fmt.Sprintf("should send clients %v", j), func(done Done) {
16 | // Create a simple server
17 | server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
18 | testLoggerOption(),
19 | ChanReceiveTimeout(200*time.Millisecond),
20 | StreamBufferCapacity(5))
21 | Expect(err).NotTo(HaveOccurred())
22 | Expect(server).NotTo(BeNil())
23 | // Create both ends of the connection
24 | cliConn, srvConn := newClientServerConnections()
25 | // Start the server
26 | go func() { _ = server.Serve(srvConn) }()
27 | // Give the server some time. In contrast to the client, we have not connected state to query
28 | <-time.After(100 * time.Millisecond)
29 | // Create the Client
30 | receiver := &simpleReceiver{ch: make(chan string, 1)}
31 | ctx, cancelClient := context.WithCancel(context.Background())
32 | client, _ := NewClient(ctx,
33 | WithConnection(cliConn),
34 | WithReceiver(receiver),
35 | testLoggerOption(),
36 | TransferFormat("Text"))
37 | Expect(client).NotTo(BeNil())
38 | // Start it
39 | client.Start()
40 | // Wait for client running
41 | Expect(<-client.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
42 | // Send from the server to "all" clients
43 | <-time.After(100 * time.Millisecond)
44 | server.HubClients().All().Send("OnCallback", fmt.Sprintf("All%v", j))
45 | // Did the receiver get what we did send?
46 | Expect(<-receiver.ch).To(Equal(fmt.Sprintf("All%v", j)))
47 | cancelClient()
48 | server.cancel()
49 | close(done)
50 | }, 1.0)
51 | })
52 |
53 | Context("Caller()", func() {
54 | It("should return nil", func() {
55 | server, _ := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
56 | testLoggerOption(),
57 | ChanReceiveTimeout(200*time.Millisecond),
58 | StreamBufferCapacity(5))
59 | Expect(server.HubClients().Caller()).To(BeNil())
60 | })
61 | })
62 | })
63 |
--------------------------------------------------------------------------------
/serveroptions.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "reflect"
7 | )
8 |
9 | // UseHub sets the hub instance used by the server
10 | func UseHub(hub HubInterface) func(Party) error {
11 | return func(p Party) error {
12 | if s, ok := p.(*server); ok {
13 | s.newHub = func() HubInterface { return hub }
14 | return nil
15 | }
16 | return errors.New("option UseHub is server only")
17 | }
18 | }
19 |
20 | // HubFactory sets the function which returns the hub instance for every hub method invocation
21 | // The function might create a new hub instance on every invocation.
22 | // If hub instances should be created and initialized by a DI framework,
23 | // the frameworks' factory method can be called here.
24 | func HubFactory(factory func() HubInterface) func(Party) error {
25 | return func(p Party) error {
26 | if s, ok := p.(*server); ok {
27 | s.newHub = factory
28 | return nil
29 | }
30 | return errors.New("option HubFactory is server only")
31 | }
32 | }
33 |
34 | // SimpleHubFactory sets a HubFactory which creates a new hub with the underlying type
35 | // of hubProto on each hub method invocation.
36 | func SimpleHubFactory(hubProto HubInterface) func(Party) error {
37 | return HubFactory(
38 | func() HubInterface {
39 | return reflect.New(reflect.ValueOf(hubProto).Elem().Type()).Interface().(HubInterface)
40 | })
41 | }
42 |
43 | // HTTPTransports sets the list of available transports for http connections. Allowed transports are
44 | // "WebSockets", "ServerSentEvents". Default is both transports are available.
45 | func HTTPTransports(transports ...string) func(Party) error {
46 | return func(p Party) error {
47 | if s, ok := p.(*server); ok {
48 | for _, transport := range transports {
49 | switch transport {
50 | case "WebSockets", "ServerSentEvents":
51 | s.transports = append(s.transports, transport)
52 | default:
53 | return fmt.Errorf("unsupported transport: %v", transport)
54 | }
55 | }
56 | return nil
57 | }
58 | return errors.New("option Transports is server only")
59 | }
60 | }
61 |
62 | // InsecureSkipVerify disables Accepts origin verification behaviour which is used to avoid same origin strategy.
63 | // See https://pkg.go.dev/nhooyr.io/websocket#AcceptOptions
64 | func InsecureSkipVerify(skip bool) func(Party) error {
65 | return func(p Party) error {
66 | p.setInsecureSkipVerify(skip)
67 | return nil
68 | }
69 | }
70 |
71 | // AllowOriginPatterns lists the host patterns for authorized origins which is used for avoid same origin strategy.
72 | // See https://pkg.go.dev/nhooyr.io/websocket#AcceptOptions
73 | func AllowOriginPatterns(origins []string) func(Party) error {
74 | return func(p Party) error {
75 | p.setOriginPatterns(origins)
76 | return nil
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/serversseconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "strings"
9 | "sync"
10 | "time"
11 | )
12 |
13 | type serverSSEConnection struct {
14 | ConnectionBase
15 | mx sync.Mutex
16 | postWriting bool
17 | postWriter io.Writer
18 | postReader io.Reader
19 | jobChan chan []byte
20 | jobResultChan chan RWJobResult
21 | }
22 |
23 | func newServerSSEConnection(ctx context.Context, connectionID string) (*serverSSEConnection, <-chan []byte, chan RWJobResult, error) {
24 | s := serverSSEConnection{
25 | ConnectionBase: *NewConnectionBase(ctx, connectionID),
26 | jobChan: make(chan []byte, 1),
27 | jobResultChan: make(chan RWJobResult, 1),
28 | }
29 | s.postReader, s.postWriter = io.Pipe()
30 | go func() {
31 | <-s.Context().Done()
32 | s.mx.Lock()
33 | close(s.jobChan)
34 | s.mx.Unlock()
35 | }()
36 | return &s, s.jobChan, s.jobResultChan, nil
37 | }
38 |
39 | func (s *serverSSEConnection) consumeRequest(request *http.Request) int {
40 | if err := s.Context().Err(); err != nil {
41 | return http.StatusGone // 410
42 | }
43 | s.mx.Lock()
44 | if s.postWriting {
45 | s.mx.Unlock()
46 | return http.StatusConflict // 409
47 | }
48 | s.postWriting = true
49 | s.mx.Unlock()
50 | defer func() {
51 | _ = request.Body.Close()
52 | }()
53 | body, err := io.ReadAll(request.Body)
54 | if err != nil {
55 | return http.StatusBadRequest // 400
56 | } else if _, err := s.postWriter.Write(body); err != nil {
57 | return http.StatusInternalServerError // 500
58 | }
59 | s.mx.Lock()
60 | s.postWriting = false
61 | s.mx.Unlock()
62 | <-time.After(50 * time.Millisecond)
63 | return http.StatusOK // 200
64 | }
65 |
66 | func (s *serverSSEConnection) Read(p []byte) (n int, err error) {
67 | n, err = ReadWriteWithContext(s.Context(),
68 | func() (int, error) { return s.postReader.Read(p) },
69 | func() { _, _ = s.postWriter.Write([]byte("\n")) })
70 | if err != nil {
71 | err = fmt.Errorf("%T: %w", s, err)
72 | }
73 | return n, err
74 | }
75 |
76 | func (s *serverSSEConnection) Write(p []byte) (n int, err error) {
77 | if err := s.Context().Err(); err != nil {
78 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
79 | }
80 | payload := ""
81 | for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
82 | payload = payload + "data: " + line + "\n"
83 | }
84 | // prevent race with goroutine closing the jobChan
85 | s.mx.Lock()
86 | if s.Context().Err() == nil {
87 | s.jobChan <- []byte(payload + "\n")
88 | } else {
89 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
90 | }
91 | s.mx.Unlock()
92 | select {
93 | case <-s.Context().Done():
94 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
95 | case r := <-s.jobResultChan:
96 | return r.n, r.err
97 | }
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/signalr_suite_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | . "github.com/onsi/ginkgo"
9 | . "github.com/onsi/gomega"
10 | )
11 |
12 | func TestSignalR(t *testing.T) {
13 | RegisterFailHandler(Fail)
14 | RunSpecs(t, "SignalR Suite")
15 | }
16 |
17 | func connect(hubProto HubInterface) (Server, *testingConnection) {
18 | server, err := NewServer(context.TODO(), SimpleHubFactory(hubProto),
19 | testLoggerOption(),
20 | ChanReceiveTimeout(200*time.Millisecond),
21 | StreamBufferCapacity(5))
22 | if err != nil {
23 | Fail(err.Error())
24 | return nil, nil
25 | }
26 | conn := newTestingConnectionForServer()
27 | go func() { _ = server.Serve(conn) }()
28 | return server, conn
29 | }
30 |
--------------------------------------------------------------------------------
/signalr_test/logger_test.go:
--------------------------------------------------------------------------------
1 | package signalr_test
2 |
3 | import (
4 | "encoding/json"
5 | "io/ioutil"
6 | "os"
7 |
8 | "github.com/go-kit/log"
9 | "github.com/philippseith/signalr"
10 | )
11 |
12 | type loggerConfig struct {
13 | Enabled bool
14 | Debug bool
15 | }
16 |
17 | var lConf loggerConfig
18 |
19 | var tLog signalr.StructuredLogger
20 |
21 | func testLoggerOption() func(signalr.Party) error {
22 | testLogger()
23 | return signalr.Logger(tLog, lConf.Debug)
24 | }
25 |
26 | func testLogger() signalr.StructuredLogger {
27 | if tLog == nil {
28 | lConf = loggerConfig{Enabled: false, Debug: false}
29 | b, err := ioutil.ReadFile("../testLogConf.json")
30 | if err == nil {
31 | err = json.Unmarshal(b, &lConf)
32 | if err != nil {
33 | lConf = loggerConfig{Enabled: false, Debug: false}
34 | }
35 | }
36 | writer := ioutil.Discard
37 | if lConf.Enabled {
38 | writer = os.Stderr
39 | }
40 | tLog = log.NewLogfmtLogger(writer)
41 | }
42 | return tLog
43 | }
44 |
--------------------------------------------------------------------------------
/signalr_test/netconnection_test.go:
--------------------------------------------------------------------------------
1 | package signalr_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "testing"
8 | "time"
9 |
10 | . "github.com/onsi/ginkgo"
11 | . "github.com/onsi/gomega"
12 | "github.com/philippseith/signalr"
13 | )
14 |
15 | func TestSignalR(t *testing.T) {
16 | RegisterFailHandler(Fail)
17 | RunSpecs(t, "SignalR external Suite")
18 | }
19 |
20 | type NetHub struct {
21 | signalr.Hub
22 | }
23 |
24 | func (n *NetHub) Smoke() string {
25 | return "no smoke!"
26 | }
27 |
28 | func (n *NetHub) ContinuousSmoke() chan string {
29 | ch := make(chan string, 1)
30 | go func() {
31 | loop:
32 | for i := 0; i < 5; i++ {
33 | select {
34 | case ch <- "smoke...":
35 | case <-n.Context().Done():
36 | break loop
37 | }
38 | <-time.After(100 * time.Millisecond)
39 | }
40 | close(ch)
41 | }()
42 | return ch
43 | }
44 |
45 | var _ = Describe("NetConnection", func() {
46 | Context("Smoke", func() {
47 | It("should transport a simple invocation over raw rcp", func(done Done) {
48 | ctx, cancel := context.WithCancel(context.Background())
49 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&NetHub{}), testLoggerOption())
50 | Expect(err).NotTo(HaveOccurred())
51 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
52 | Expect(err).NotTo(HaveOccurred())
53 | listener, err := net.ListenTCP("tcp", addr)
54 | Expect(err).NotTo(HaveOccurred())
55 | go func() {
56 | for {
57 | tcpConn, err := listener.Accept()
58 | Expect(err).NotTo(HaveOccurred())
59 | go func() { _ = server.Serve(signalr.NewNetConnection(ctx, tcpConn)) }()
60 | break
61 | }
62 | }()
63 | var client signalr.Client
64 | for {
65 | if clientConn, err := net.Dial("tcp",
66 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port)); err == nil {
67 | client, err = signalr.NewClient(ctx, signalr.WithConnection(signalr.NewNetConnection(ctx, clientConn)), testLoggerOption())
68 | Expect(err).NotTo(HaveOccurred())
69 | break
70 | }
71 | time.Sleep(100 * time.Millisecond)
72 | }
73 | client.Start()
74 | result := <-client.Invoke("smoke")
75 | Expect(result.Value).To(Equal("no smoke!"))
76 | cancel()
77 | close(done)
78 | })
79 | })
80 | Context("Stream and Timeout", func() {
81 | It("Client and Server should timeout when no messages are exchanged, but message exchange should prevent timeout", func(done Done) {
82 | ctx, cancel := context.WithCancel(context.Background())
83 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&NetHub{}), testLoggerOption(),
84 | // Set KeepAlive and Timeout so KeepAlive can't keep it alive
85 | signalr.TimeoutInterval(500*time.Millisecond), signalr.KeepAliveInterval(2*time.Second))
86 | Expect(err).NotTo(HaveOccurred())
87 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
88 | Expect(err).NotTo(HaveOccurred())
89 | listener, err := net.ListenTCP("tcp", addr)
90 | Expect(err).NotTo(HaveOccurred())
91 | serverDone := make(chan struct{}, 1)
92 | go func() {
93 | tcpConn, err := listener.Accept()
94 | Expect(err).NotTo(HaveOccurred())
95 | go func() {
96 | _ = server.Serve(signalr.NewNetConnection(ctx, tcpConn))
97 | serverDone <- struct{}{}
98 | }()
99 | }()
100 | var client signalr.Client
101 | for {
102 | if clientConn, err := net.Dial("tcp",
103 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port)); err == nil {
104 | client, err = signalr.NewClient(ctx, signalr.WithConnection(signalr.NewNetConnection(ctx, clientConn)), testLoggerOption(),
105 | // Set KeepAlive and Timeout so KeepAlive can't keep it alive
106 | signalr.TimeoutInterval(500*time.Millisecond), signalr.KeepAliveInterval(2*time.Second))
107 | Expect(err).NotTo(HaveOccurred())
108 | break
109 | }
110 | time.Sleep(100 * time.Millisecond)
111 | }
112 | client.Start()
113 | // The Server will send values each 100ms, so this should keep the connection alive
114 | i := 0
115 | for range client.PullStream("continuoussmoke") {
116 | i++
117 | }
118 | // some smoke messages and one error message when timed out
119 | Expect(i).To(BeNumerically(">", 1))
120 | // Wait for client and server to timeout
121 | <-time.After(time.Second)
122 | select {
123 | case <-serverDone:
124 | Expect(client.State() == signalr.ClientClosed)
125 | case <-time.After(10 * time.Millisecond):
126 | Fail("server not closed")
127 | }
128 | cancel()
129 | close(done)
130 | }, 5.0)
131 | })
132 | Context("SetConnectionID", func() {
133 | It("should change the ConnectionID", func() {
134 | clientConn, _ := net.Pipe()
135 | conn := signalr.NewNetConnection(context.Background(), clientConn)
136 | id := conn.ConnectionID()
137 | conn.SetConnectionID("Other" + id)
138 | Expect(conn.ConnectionID()).To(Equal("Other" + id))
139 | })
140 | })
141 | Context("Cancel", func() {
142 | It("should cancel the connection", func(done Done) {
143 | clientConn, serverConn := net.Pipe()
144 | ctx, cancel := context.WithCancel(context.Background())
145 | conn := signalr.NewNetConnection(ctx, clientConn)
146 | // Server loop
147 | go func() {
148 | b := make([]byte, 1024)
149 | for {
150 | if _, err := serverConn.Read(b); err != nil {
151 | break
152 | }
153 | }
154 | }()
155 | go func() {
156 | time.Sleep(500 * time.Millisecond)
157 | // cancel the connection
158 | cancel()
159 | }()
160 | for {
161 | if _, err := conn.Write([]byte("foobar")); err != nil {
162 | // This will never happen if the connection is not canceled
163 | break
164 | }
165 | }
166 | close(done)
167 | }, 2.0)
168 | })
169 | })
170 |
--------------------------------------------------------------------------------
/signalr_test/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "signalr_test",
3 | "version": "0.0.0",
4 | "description": "e2e test for signalr server",
5 | "private": true,
6 | "license": "UNLICENSED",
7 | "scripts": {
8 | "test": "run-p test:wsgo test:jest --race",
9 | "test:wsgo": "go test -v -run TestServerWebSockets",
10 | "test:jest": "npx jest -t 'e2e test with microsoft/signalr client should work'"
11 | },
12 | "dependencies": {
13 | "@microsoft/signalr": "^7.0.7",
14 | "@microsoft/signalr-protocol-msgpack": "^5.0.9",
15 | "@types/jest": "^27.0.2",
16 | "@types/node": "^12.20.4",
17 | "jest": "^27.2.5",
18 | "jest-preset-typescript": "^1.2.0",
19 | "path-parse": ">=1.0.7",
20 | "rxjs": "^6.6.6",
21 | "rxjs-for-await": "^0.0.2",
22 | "set-value": ">=4.0.1",
23 | "ts-jest": "^27.0.7",
24 | "typescript": "^4.2.2"
25 | },
26 | "devDependencies": {
27 | "npm-run-all": "^4.1.5"
28 | },
29 | "jest": {
30 | "testEnvironment": "node",
31 | "testTimeout": 10000,
32 | "preset": "jest-preset-typescript",
33 | "setupFilesAfterEnv": [
34 | "/setupJest.ts"
35 | ],
36 | "transformIgnorePatterns": [
37 | "node_modules/(?!(@aspnet/signalr)/)"
38 | ]
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/signalr_test/server_test.go:
--------------------------------------------------------------------------------
1 | package signalr_test
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "io"
7 | "math/rand"
8 | "net/http"
9 | "os"
10 | "os/exec"
11 | "path/filepath"
12 | "strings"
13 | "testing"
14 | "time"
15 |
16 | "github.com/philippseith/signalr"
17 | )
18 |
19 | func TestMain(m *testing.M) {
20 | npmInstall := exec.Command("npm", "install")
21 | stdout, err := npmInstall.StdoutPipe()
22 | if err != nil {
23 | println(err.Error())
24 | os.Exit(120)
25 | }
26 | stderr, err := npmInstall.StderrPipe()
27 | if err != nil {
28 | println(err.Error())
29 | os.Exit(121)
30 | }
31 | if err := npmInstall.Start(); err != nil {
32 | println(err.Error())
33 | os.Exit(122)
34 | }
35 | outSlurp, _ := io.ReadAll(stdout)
36 | errSlurp, _ := io.ReadAll(stderr)
37 | err = npmInstall.Wait()
38 | if err != nil {
39 | println(err.Error())
40 | fmt.Println(string(outSlurp))
41 | fmt.Println(string(errSlurp))
42 | os.Exit(123)
43 | }
44 | os.Exit(m.Run())
45 | }
46 |
47 | func TestServerSmoke(t *testing.T) {
48 | testServer(t, "^smoke", signalr.HTTPTransports("WebSockets"))
49 | }
50 |
51 | func TestServerJsonWebSockets(t *testing.T) {
52 | testServer(t, "^JSON", signalr.HTTPTransports("WebSockets"))
53 | }
54 |
55 | func TestServerJsonSSE(t *testing.T) {
56 | testServer(t, "^JSON", signalr.HTTPTransports("ServerSentEvents"))
57 | }
58 |
59 | func TestServerMessagePack(t *testing.T) {
60 | testServer(t, "^MessagePack", signalr.HTTPTransports("WebSockets"))
61 | }
62 |
63 | func testServer(t *testing.T, testNamePattern string, transports func(signalr.Party) error) {
64 | serverIsUp := make(chan struct{}, 1)
65 | quitServer := make(chan struct{}, 1)
66 | serverIsDown := make(chan struct{}, 1)
67 | go func() {
68 | runServer(t, serverIsUp, quitServer, transports)
69 | serverIsDown <- struct{}{}
70 | }()
71 | <-serverIsUp
72 | runJest(t, testNamePattern, quitServer)
73 | <-serverIsDown
74 | }
75 |
76 | func runJest(t *testing.T, testNamePattern string, quitServer chan struct{}) {
77 | defer func() { quitServer <- struct{}{} }()
78 | var jest = exec.Command(filepath.FromSlash("node_modules/.bin/jest"), fmt.Sprintf("--testNamePattern=%v", testNamePattern))
79 | stdout, err := jest.StdoutPipe()
80 | if err != nil {
81 | t.Error(err)
82 | }
83 | stderr, err := jest.StderrPipe()
84 | if err != nil {
85 | t.Error(err)
86 | }
87 | if err := jest.Start(); err != nil {
88 | t.Error(err)
89 | }
90 | outSlurp, _ := io.ReadAll(stdout)
91 | errSlurp, _ := io.ReadAll(stderr)
92 | err = jest.Wait()
93 | if err != nil {
94 | t.Error(err, fmt.Sprintf("\n%s\n%s", outSlurp, errSlurp))
95 | } else {
96 | // Strange: Jest reports test results to stderr
97 | t.Log(fmt.Sprintf("\n%s", errSlurp))
98 | }
99 | }
100 |
101 | func runServer(t *testing.T, serverIsUp chan struct{}, quitServer chan struct{}, transports func(signalr.Party) error) {
102 | // Install a handler to cancel the server
103 | doneQuit := make(chan struct{}, 1)
104 | ctx, cancelSignalRServer := context.WithCancel(context.Background())
105 | sRServer, _ := signalr.NewServer(ctx, signalr.SimpleHubFactory(&hub{}),
106 | signalr.KeepAliveInterval(2*time.Second),
107 | transports,
108 | testLoggerOption())
109 | router := http.NewServeMux()
110 | sRServer.MapHTTP(signalr.WithHTTPServeMux(router), "/hub")
111 |
112 | server := &http.Server{
113 | Addr: "127.0.0.1:5001",
114 | Handler: router,
115 | ReadTimeout: 2 * time.Second,
116 | WriteTimeout: 5 * time.Second,
117 | IdleTimeout: 10 * time.Second,
118 | }
119 | // wait for someone triggering quitServer
120 | go func() {
121 | <-quitServer
122 | // Cancel the signalR server and all its connections
123 | cancelSignalRServer()
124 | // Now shutdown the http server
125 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
126 | // If it does not Shutdown during 10s, try to end it by canceling the context
127 | defer cancel()
128 | server.SetKeepAlivesEnabled(false)
129 | _ = server.Shutdown(ctx)
130 | doneQuit <- struct{}{}
131 | }()
132 | // alternate method for quiting
133 | router.HandleFunc("/quit", func(http.ResponseWriter, *http.Request) {
134 | quitServer <- struct{}{}
135 | })
136 | t.Logf("Server %v is up", server.Addr)
137 | serverIsUp <- struct{}{}
138 | // Run the server
139 | _ = server.ListenAndServe()
140 | <-doneQuit
141 | }
142 |
143 | type hub struct {
144 | signalr.Hub
145 | }
146 |
147 | func (h *hub) Ping() string {
148 | return "Pong!"
149 | }
150 |
151 | func (h *hub) Touch() {
152 | h.Clients().Caller().Send("touched")
153 | }
154 |
155 | func (h *hub) TriumphantTriple(club string) []string {
156 | if strings.Contains(club, "FC Bayern") {
157 | return []string{"German Championship", "DFB Cup", "Champions League"}
158 | }
159 | return []string{}
160 | }
161 |
162 | type AlcoholicContent struct {
163 | Drink string `json:"drink"`
164 | Strength float32 `json:"strength"`
165 | }
166 |
167 | func (h *hub) AlcoholicContents() []AlcoholicContent {
168 | return []AlcoholicContent{
169 | {
170 | Drink: "Brunello",
171 | Strength: 13.5,
172 | },
173 | {
174 | Drink: "Beer",
175 | Strength: 4.9,
176 | },
177 | {
178 | Drink: "Lagavulin Cask Strength",
179 | Strength: 56.2,
180 | },
181 | }
182 | }
183 |
184 | func (h *hub) AlcoholicContentMap() map[string]float64 {
185 | return map[string]float64{
186 | "Brunello": 13.5,
187 | "Beer": 4.9,
188 | "Lagavulin Cask Strength": 56.2,
189 | }
190 | }
191 |
192 | func (h *hub) LargeCompressableContent() string {
193 | return strings.Repeat("data_", 10000)
194 | }
195 |
196 | func (h *hub) LargeUncompressableContent() string {
197 | return randString(20000)
198 | }
199 |
200 | var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
201 |
202 | func randString(n int) string {
203 | b := make([]rune, n)
204 | for i := range b {
205 | b[i] = letters[rand.Intn(len(letters))]
206 | }
207 | return string(b)
208 | }
209 |
210 | func (h *hub) FiveDates() <-chan string {
211 | r := make(chan string)
212 | go func() {
213 | for i := 0; i < 5; i++ {
214 | r <- fmt.Sprint(time.Now().Nanosecond())
215 | time.Sleep(time.Millisecond * 100)
216 | }
217 | close(r)
218 | }()
219 | return r
220 | }
221 |
222 | func (h *hub) UploadStream(upload <-chan int) {
223 | var allUs []int
224 | for u := range upload {
225 | allUs = append(allUs, u)
226 | }
227 | h.Clients().Caller().Send("OnUploadComplete", allUs)
228 | }
229 |
--------------------------------------------------------------------------------
/signalr_test/setupJest.ts:
--------------------------------------------------------------------------------
1 | import 'jest-preset-typescript';
2 |
--------------------------------------------------------------------------------
/signalr_test/spec/server.spec.ts:
--------------------------------------------------------------------------------
1 | import * as signalR from '@microsoft/signalr';
2 | import { Subject, from } from 'rxjs'
3 | import { eachValueFrom } from 'rxjs-for-await';
4 | import {MessagePackHubProtocol} from "@microsoft/signalr-protocol-msgpack";
5 |
6 | // IMPORTANT: When a proxy (e.g. px) is in use, the server will get the request,
7 | // but the client will not get the response
8 | // So disable the proxy for this process.
9 | process.env.http_proxy = "";
10 |
11 | const builder: signalR.HubConnectionBuilder =
12 | new signalR.HubConnectionBuilder().configureLogging(signalR.LogLevel.Debug);
13 |
14 | const hubUrl = "http://127.0.0.1:5001/hub";
15 |
16 |
17 | describe("smoke test", () => {
18 | it("should connect on a clients request for connection and answer a simple request",
19 | async () => {
20 | const connection: signalR.HubConnection = builder
21 | .withUrl(hubUrl)
22 | .build();
23 | await connection.start();
24 | const pong = await connection.invoke("ping");
25 | expect(pong).toEqual("Pong!");
26 | await connection.stop();
27 | });
28 | });
29 |
30 | describe("MessagePack smoke test", () => {
31 | it("should connect on a clients request for connection and answer a simple request",
32 | async () => {
33 | const connection: signalR.HubConnection = builder
34 | .withUrl(hubUrl)
35 | .withHubProtocol(new MessagePackHubProtocol())
36 | .build();
37 | await connection.start();
38 | const pong = await connection.invoke("ping");
39 | expect(pong).toEqual("Pong!");
40 | await connection.stop();
41 | });
42 | });
43 |
44 |
45 | class AlcoholicContent {
46 | drink: string
47 | strength: number
48 | }
49 |
50 | function runE2E(protocol: signalR.IHubProtocol) {
51 | let connection: signalR.HubConnection;
52 | beforeEach(async() => {
53 | connection = builder
54 | .withUrl(hubUrl)
55 | .withHubProtocol(protocol)
56 | .build();
57 | await connection.start();
58 | })
59 | afterEach(async() => {
60 | await connection.stop();
61 | })
62 | it("should answer a simple request", async () => {
63 | const pong = await connection.invoke("ping");
64 | expect(pong).toEqual("Pong!");
65 | })
66 | it("should send correct ping messages", async () => {
67 | const pong = await connection.invoke("ping");
68 | expect(pong).toEqual("Pong!");
69 | // Wait for a ping
70 | await new Promise(r => setTimeout(r, 5000));
71 | })
72 | it("should answer a simple request with multiple results", async () => {
73 | const triple = await connection.invoke("triumphantTriple", "1.FC Bayern München");
74 | expect(triple).toEqual(["German Championship", "DFB Cup", "Champions League"]);
75 | })
76 | it("should answer a request with an resulting array of structs", async () => {
77 | const contents = await connection.invoke("alcoholicContents");
78 | expect(contents.length).toEqual(3);
79 | expect(contents[0].drink).toEqual('Brunello');
80 | expect(Math.abs(contents[2].strength- 56.2)).toBeLessThan(0.0001);
81 | })
82 | it("should answer a request with an resulting map", async () => {
83 | const contents = await connection.invoke("alcoholicContentMap");
84 | expect(contents["Beer"]).toEqual(4.9);
85 | expect(Math.abs(contents["Lagavulin Cask Strength"] - 56.2)).toBeLessThan(0.0001);
86 | })
87 | it("should answer a request with a large amount of compressable data", async () => {
88 | const data = await connection.invoke("largeCompressableContent");
89 | expect(data.length).toEqual(50000);
90 | })
91 |
92 | it("should answer a request with a large amount of uncompressable data", async () => {
93 | const data = await connection.invoke("largeUncompressableContent");
94 | expect(data.length).toEqual(20000);
95 | })
96 | it("should receive a stream", async () => {
97 | const fiveDates: Subject = new Subject();
98 | connection.stream("FiveDates").subscribe(fiveDates);
99 | let i = 0;
100 | let lastValue = '';
101 | for await (const value of eachValueFrom(fiveDates)) {
102 | expect(value).toBeDefined();
103 | expect(value).not.toEqual(lastValue);
104 | lastValue = value;
105 | i++;
106 | }
107 | expect(i).toEqual(5)
108 | })
109 | it("should upload a stream", async() =>{
110 | const receive = new Promise(resolve => {
111 | connection.on("onUploadComplete", (r: number[]) => {
112 | resolve(r);
113 | });
114 | });
115 | await connection.send("uploadStream", from([2, 0, 7]));
116 | expect(await receive).toEqual([2, 0, 7])
117 | })
118 | it("should receive subsequent sends without await", async() => {
119 | let or: (value?: unknown) => void;
120 | const p = new Promise((r, rj) => {
121 | or = r;
122 | });
123 | let tc = 0
124 | connection.on("touched", () => {
125 | tc++;
126 | if (tc == 2) {
127 | or();
128 | }
129 | })
130 | connection.send("touch");
131 | connection.send("touch");
132 | await p;
133 | })
134 | }
135 |
136 | describe("JSON e2e test with microsoft/signalr client", () => {
137 | runE2E(new signalR.JsonHubProtocol());
138 | })
139 |
140 | describe("MessagePack e2e test with microsoft/signalr client", () => {
141 | runE2E(new MessagePackHubProtocol());
142 | })
143 |
--------------------------------------------------------------------------------
/signalr_test/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "module": "commonjs",
4 | "moduleResolution": "node",
5 | "target": "ES5",
6 | "typeRoots": [
7 | "node_modules/@types"
8 | ],
9 | "sourceMap": true,
10 | "noImplicitAny": true,
11 | "allowJs": true
12 | },
13 | "exclude": [
14 | "node_modules"
15 | ]
16 | }
--------------------------------------------------------------------------------
/signalr_test/tsconfig.spec.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "./tsconfig.json",
3 | "compilerOptions": {
4 | "outDir": "./out-tsc/spec",
5 | "types": [
6 | "node",
7 | "webpack-env"
8 | ]
9 | },
10 | "include": [
11 | "**/*.spec.ts",
12 | "**/*.d.ts"
13 | ]
14 | }
15 |
--------------------------------------------------------------------------------
/streamclient.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "sync"
7 | "time"
8 | )
9 |
10 | func newStreamClient(protocol hubProtocol, chanReceiveTimeout time.Duration, streamBufferCapacity uint) *streamClient {
11 | return &streamClient{
12 | mx: sync.Mutex{},
13 | upstreamChannels: make(map[string]reflect.Value),
14 | runningStreams: make(map[string]bool),
15 | chanReceiveTimeout: chanReceiveTimeout,
16 | streamBufferCapacity: streamBufferCapacity,
17 | protocol: protocol,
18 | }
19 | }
20 |
21 | type streamClient struct {
22 | mx sync.Mutex
23 | upstreamChannels map[string]reflect.Value
24 | runningStreams map[string]bool
25 | chanReceiveTimeout time.Duration
26 | streamBufferCapacity uint
27 | protocol hubProtocol
28 | }
29 |
30 | func (c *streamClient) buildChannelArgument(invocation invocationMessage, argType reflect.Type, chanCount int) (arg reflect.Value, canClientStreaming bool, err error) {
31 | c.mx.Lock()
32 | defer c.mx.Unlock()
33 | if argType.Kind() != reflect.Chan || argType.ChanDir() == reflect.SendDir {
34 | return reflect.Value{}, false, nil
35 | } else if len(invocation.StreamIds) > chanCount {
36 | // MakeChan does only accept bidirectional channels, and we need to Send to this channel anyway
37 | arg = reflect.MakeChan(reflect.ChanOf(reflect.BothDir, argType.Elem()), int(c.streamBufferCapacity))
38 | c.upstreamChannels[invocation.StreamIds[chanCount]] = arg
39 | return arg, true, nil
40 | } else {
41 | // To many channel parameters arguments this method. The client will not send streamItems for these
42 | return reflect.Value{}, true, fmt.Errorf("method %s has more chan parameters than the client will stream", invocation.Target)
43 | }
44 | }
45 |
46 | func (c *streamClient) newUpstreamChannel(invocationID string) <-chan interface{} {
47 | c.mx.Lock()
48 | defer c.mx.Unlock()
49 | upChan := make(chan interface{}, c.streamBufferCapacity)
50 | c.upstreamChannels[invocationID] = reflect.ValueOf(upChan)
51 | return upChan
52 | }
53 |
54 | func (c *streamClient) deleteUpstreamChannel(invocationID string) {
55 | c.mx.Lock()
56 | if upChan, ok := c.upstreamChannels[invocationID]; ok {
57 | upChan.Close()
58 | delete(c.upstreamChannels, invocationID)
59 | }
60 | c.mx.Unlock()
61 | }
62 |
63 | func (c *streamClient) receiveStreamItem(streamItem streamItemMessage) error {
64 | c.mx.Lock()
65 | defer c.mx.Unlock()
66 | if upChan, ok := c.upstreamChannels[streamItem.InvocationID]; ok {
67 | // Mark the stream as running to detect illegal completion with result on this id
68 | c.runningStreams[streamItem.InvocationID] = true
69 | chanVal := reflect.New(upChan.Type().Elem())
70 | err := c.protocol.UnmarshalArgument(streamItem.Item, chanVal.Interface())
71 | if err != nil {
72 | return err
73 | }
74 | return c.sendChanValSave(upChan, chanVal.Elem())
75 | }
76 | return fmt.Errorf(`unknown stream id "%v"`, streamItem.InvocationID)
77 | }
78 |
79 | func (c *streamClient) sendChanValSave(upChan reflect.Value, chanVal reflect.Value) error {
80 | done := make(chan error)
81 | go func() {
82 | defer func() {
83 | if r := recover(); r != nil {
84 | done <- fmt.Errorf("%v", r)
85 | }
86 | }()
87 | upChan.Send(chanVal)
88 | done <- nil
89 | }()
90 | select {
91 | case err := <-done:
92 | return err
93 | case <-time.After(c.chanReceiveTimeout):
94 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client streamed value", c.chanReceiveTimeout)}
95 | }
96 | }
97 |
98 | type hubChanTimeoutError struct {
99 | msg string
100 | }
101 |
102 | func (h *hubChanTimeoutError) Error() string {
103 | return h.msg
104 | }
105 |
106 | func (c *streamClient) handlesInvocationID(invocationID string) bool {
107 | c.mx.Lock()
108 | defer c.mx.Unlock()
109 | _, ok := c.upstreamChannels[invocationID]
110 | return ok
111 | }
112 |
113 | func (c *streamClient) receiveCompletionItem(completion completionMessage, invokeClient *invokeClient) error {
114 | c.mx.Lock()
115 | channel, ok := c.upstreamChannels[completion.InvocationID]
116 | c.mx.Unlock()
117 | if ok {
118 | var err error
119 | if completion.Error != "" {
120 | // Push error to the error channel
121 | err = invokeClient.receiveCompletionItem(completion)
122 | } else {
123 | if completion.Result != nil {
124 | c.mx.Lock()
125 | running := c.runningStreams[completion.InvocationID]
126 | c.mx.Unlock()
127 | if running {
128 | err = fmt.Errorf("client side streaming: received completion with result %v", completion)
129 | } else {
130 | // handle result like a stream item
131 | err = c.receiveStreamItem(streamItemMessage{
132 | Type: 2,
133 | InvocationID: completion.InvocationID,
134 | Item: completion.Result,
135 | })
136 | }
137 | }
138 | }
139 | channel.Close()
140 | // Close error channel
141 | invokeClient.deleteInvocation(completion.InvocationID)
142 | c.mx.Lock()
143 | delete(c.upstreamChannels, completion.InvocationID)
144 | delete(c.runningStreams, completion.InvocationID)
145 | c.mx.Unlock()
146 | return err
147 | }
148 | return fmt.Errorf("received completion with unknown id %v", completion.InvocationID)
149 | }
150 |
--------------------------------------------------------------------------------
/streamer.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "reflect"
5 | "sync"
6 | )
7 |
8 | type streamer struct {
9 | cancels sync.Map
10 | conn hubConnection
11 | }
12 |
13 | func (s *streamer) Start(invocationID string, reflectedChannel reflect.Value) {
14 | go func() {
15 | loop:
16 | for {
17 | // Waits for channel, so might hang
18 | if chanResult, ok := reflectedChannel.Recv(); ok {
19 | if _, ok := s.cancels.Load(invocationID); ok {
20 | s.cancels.Delete(invocationID)
21 | _ = s.conn.Completion(invocationID, nil, "")
22 | break loop
23 | }
24 | if s.conn.Context().Err() != nil {
25 | break loop
26 | }
27 | _ = s.conn.StreamItem(invocationID, chanResult.Interface())
28 | } else {
29 | if s.conn.Context().Err() == nil {
30 | _ = s.conn.Completion(invocationID, nil, "")
31 | }
32 | break loop
33 | }
34 | }
35 | }()
36 | }
37 |
38 | func (s *streamer) Stop(invocationID string) {
39 | s.cancels.Store(invocationID, struct{}{})
40 | }
41 |
--------------------------------------------------------------------------------
/streaminvocation_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "time"
5 |
6 | . "github.com/onsi/ginkgo"
7 | . "github.com/onsi/gomega"
8 | )
9 |
10 | var streamInvocationQueue = make(chan string, 20)
11 |
12 | type streamHub struct {
13 | Hub
14 | }
15 |
16 | func (s *streamHub) SimpleStream() <-chan int {
17 | r := make(chan int)
18 | go func() {
19 | defer close(r)
20 | for i := 1; i < 4; i++ {
21 | r <- i
22 | }
23 | }()
24 | streamInvocationQueue <- "SimpleStream()"
25 | return r
26 | }
27 |
28 | func (s *streamHub) EndlessStream() <-chan int {
29 | r := make(chan int)
30 | go func() {
31 | defer close(r)
32 | for i := 1; ; i++ {
33 | r <- i
34 | }
35 | }()
36 | streamInvocationQueue <- "EndlessStream()"
37 | return r
38 | }
39 |
40 | func (s *streamHub) SliceStream() <-chan []int {
41 | r := make(chan []int)
42 | go func() {
43 | defer close(r)
44 | for i := 1; i < 4; i++ {
45 | s := make([]int, 2)
46 | s[0] = i
47 | s[1] = i * 2
48 | r <- s
49 | }
50 | }()
51 | streamInvocationQueue <- "SliceStream()"
52 | return r
53 | }
54 |
55 | func (s *streamHub) SimpleInt() int {
56 | streamInvocationQueue <- "SimpleInt()"
57 | return -1
58 | }
59 |
60 | var _ = Describe("StreamInvocation", func() {
61 |
62 | Describe("Simple stream invocation", func() {
63 | var server Server
64 | var conn *testingConnection
65 | BeforeEach(func(done Done) {
66 | server, conn = connect(&streamHub{})
67 | close(done)
68 | })
69 | AfterEach(func(done Done) {
70 | server.cancel()
71 | close(done)
72 | })
73 | Context("When invoked by the client", func() {
74 | It("should be invoked on the server, return stream items and a final completion without result", func(done Done) {
75 | p := &jsonHubProtocol{dbg: testLogger()}
76 | conn.ClientSend(`{"type":4,"invocationId": "zzz","target":"simplestream"}`)
77 | Expect(<-streamInvocationQueue).To(Equal("SimpleStream()"))
78 | for i := 1; i < 4; i++ {
79 | recv := (<-conn.received).(streamItemMessage)
80 | Expect(recv).NotTo(BeNil())
81 | Expect(recv.InvocationID).To(Equal("zzz"))
82 | var f float64
83 | Expect(p.UnmarshalArgument(recv.Item, &f)).NotTo(HaveOccurred())
84 | Expect(f).To(Equal(float64(i)))
85 | }
86 | recv := (<-conn.received).(completionMessage)
87 | Expect(recv).NotTo(BeNil())
88 | Expect(recv.InvocationID).To(Equal("zzz"))
89 | Expect(recv.Result).To(BeNil())
90 | Expect(recv.Error).To(Equal(""))
91 | close(done)
92 | })
93 | })
94 | })
95 |
96 | Describe("Slice stream invocation", func() {
97 | var server Server
98 | var conn *testingConnection
99 | BeforeEach(func(done Done) {
100 | server, conn = connect(&streamHub{})
101 | close(done)
102 | })
103 | AfterEach(func(done Done) {
104 | server.cancel()
105 | close(done)
106 | })
107 | Context("When invoked by the client", func() {
108 | It("should be invoked on the server, return stream items and a final completion without result", func(done Done) {
109 | protocol := jsonHubProtocol{dbg: testLogger()}
110 | conn.ClientSend(`{"type":4,"invocationId": "slice","target":"slicestream"}`)
111 | Expect(<-streamInvocationQueue).To(Equal("SliceStream()"))
112 | for i := 1; i < 4; i++ {
113 | recv := (<-conn.received).(streamItemMessage)
114 | Expect(recv).NotTo(BeNil())
115 | Expect(recv.InvocationID).To(Equal("slice"))
116 | exp := make([]int, 0, 2)
117 | exp = append(exp, i)
118 | exp = append(exp, i*2)
119 | var got []int
120 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred())
121 | Expect(got).To(Equal(exp))
122 | }
123 | recv := (<-conn.received).(completionMessage)
124 | Expect(recv).NotTo(BeNil())
125 | Expect(recv.InvocationID).To(Equal("slice"))
126 | Expect(recv.Result).To(BeNil())
127 | Expect(recv.Error).To(Equal(""))
128 | close(done)
129 | })
130 | })
131 | })
132 |
133 | Describe("Stop simple stream invocation", func() {
134 | var server Server
135 | var conn *testingConnection
136 | BeforeEach(func(done Done) {
137 | server, conn = connect(&streamHub{})
138 | close(done)
139 | })
140 | AfterEach(func(done Done) {
141 | server.cancel()
142 | close(done)
143 | })
144 | Context("When invoked by the client and stop after one result", func() {
145 | It("should be invoked on the server, return stream one item and a final completion without result", func(done Done) {
146 | protocol := jsonHubProtocol{dbg: testLogger()}
147 | conn.ClientSend(`{"type":4,"invocationId": "xxx","target":"endlessstream"}`)
148 | Expect(<-streamInvocationQueue).To(Equal("EndlessStream()"))
149 | recv := (<-conn.received).(streamItemMessage)
150 | Expect(recv).NotTo(BeNil())
151 | Expect(recv.InvocationID).To(Equal("xxx"))
152 | var got int
153 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred())
154 | Expect(got).To(Equal(1))
155 | // stop it
156 | conn.ClientSend(`{"type":5,"invocationId": "xxx"}`)
157 | loop:
158 | for {
159 | recv := <-conn.received
160 | Expect(recv).NotTo(BeNil())
161 | switch recv := recv.(type) {
162 | case streamItemMessage:
163 | Expect(recv.InvocationID).To(Equal("xxx"))
164 | case completionMessage:
165 | Expect(recv.InvocationID).To(Equal("xxx"))
166 | Expect(recv.Result).To(BeNil())
167 | Expect(recv.Error).To(Equal(""))
168 | break loop
169 | }
170 | }
171 | close(done)
172 | })
173 | })
174 | })
175 |
176 | Describe("Invalid CancelInvocation", func() {
177 | var server Server
178 | var conn *testingConnection
179 | BeforeEach(func(done Done) {
180 | server, conn = connect(&streamHub{})
181 | close(done)
182 | })
183 | AfterEach(func(done Done) {
184 | server.cancel()
185 | close(done)
186 | })
187 | Context("When invoked by the client and receiving an invalid CancelInvocation", func() {
188 | It("should close the connection with an error", func(done Done) {
189 | protocol := &jsonHubProtocol{dbg: testLogger()}
190 | conn.ClientSend(`{"type":4,"invocationId": "xyz","target":"endlessstream"}`)
191 | Expect(<-streamInvocationQueue).To(Equal("EndlessStream()"))
192 | recv := (<-conn.received).(streamItemMessage)
193 | Expect(recv).NotTo(BeNil())
194 | Expect(recv.InvocationID).To(Equal("xyz"))
195 | var got int
196 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred())
197 | Expect(got).To(Equal(1))
198 | // try to stop it, but do not get it right
199 | conn.ClientSend(`{"type":5,"invocationId":1}`)
200 | loop:
201 | for {
202 | message := <-conn.received
203 | switch message := message.(type) {
204 | case closeMessage:
205 | Expect(message.Error).NotTo(BeNil())
206 | break loop
207 | default:
208 | }
209 | }
210 | close(done)
211 | })
212 | })
213 | })
214 |
215 | Describe("Stream invocation of method with no stream result", func() {
216 | var server Server
217 | var conn *testingConnection
218 | BeforeEach(func(done Done) {
219 | server, conn = connect(&streamHub{})
220 | close(done)
221 | })
222 | AfterEach(func(done Done) {
223 | server.cancel()
224 | close(done)
225 | })
226 | Context("When invoked by the client", func() {
227 | It("should be invoked on the server, return one stream item with the \"no stream\" result and a final completion without result", func(done Done) {
228 | protocol := &jsonHubProtocol{dbg: testLogger()}
229 | conn.ClientSend(`{"type":4,"invocationId": "yyy","target":"simpleint"}`)
230 | Expect(<-streamInvocationQueue).To(Equal("SimpleInt()"))
231 | sRecv := (<-conn.received).(streamItemMessage)
232 | Expect(sRecv).NotTo(BeNil())
233 | Expect(sRecv.InvocationID).To(Equal("yyy"))
234 | var got int
235 | Expect(protocol.UnmarshalArgument(sRecv.Item, &got)).NotTo(HaveOccurred())
236 | Expect(got).To(Equal(-1))
237 | cRecv := (<-conn.received).(completionMessage)
238 | Expect(cRecv).NotTo(BeNil())
239 | Expect(cRecv.InvocationID).To(Equal("yyy"))
240 | Expect(cRecv.Result).To(BeNil())
241 | Expect(cRecv.Error).To(Equal(""))
242 | close(done)
243 | })
244 | })
245 | })
246 |
247 | Describe("invalid messages", func() {
248 | var server Server
249 | var conn *testingConnection
250 | BeforeEach(func(done Done) {
251 | server, conn = connect(&streamHub{})
252 | close(done)
253 | })
254 | AfterEach(func(done Done) {
255 | server.cancel()
256 | close(done)
257 | })
258 | Context("When an invalid stream invocation message is sent", func() {
259 | It("should return a completion with error", func(done Done) {
260 | conn.ClientSend(`{"type":4}`)
261 | select {
262 | case message := <-conn.received:
263 | completionMessage := message.(completionMessage)
264 | Expect(completionMessage).NotTo(BeNil())
265 | Expect(completionMessage.Error).NotTo(BeNil())
266 | case <-time.After(100 * time.Millisecond):
267 | }
268 | close(done)
269 | })
270 | })
271 | })
272 |
273 | })
274 |
--------------------------------------------------------------------------------
/testLogConf.json:
--------------------------------------------------------------------------------
1 | {
2 | "Enabled": false,
3 | "Debug": false
4 | }
--------------------------------------------------------------------------------
/testingconnection_test.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "sync"
11 | "time"
12 |
13 | "github.com/onsi/ginkgo"
14 | )
15 |
16 | type testingConnection struct {
17 | timeout time.Duration
18 | connectionID string
19 | srvWriter io.Writer
20 | srvReader io.Reader
21 | cliWriter io.Writer
22 | cliReader io.Reader
23 | received chan interface{}
24 | cnMutex sync.Mutex
25 | connected bool
26 | cliSendChan chan string
27 | srvSendChan chan []byte
28 | failRead string
29 | failWrite string
30 | failMx sync.Mutex
31 | }
32 |
33 | func (t *testingConnection) Context() context.Context {
34 | return context.TODO()
35 | }
36 |
37 | var connNum = 0
38 | var connNumMx sync.Mutex
39 |
40 | func (t *testingConnection) SetTimeout(timeout time.Duration) {
41 | t.timeout = timeout
42 | }
43 |
44 | func (t *testingConnection) Timeout() time.Duration {
45 | return t.timeout
46 | }
47 |
48 | func (t *testingConnection) ConnectionID() string {
49 | connNumMx.Lock()
50 | defer connNumMx.Unlock()
51 | if t.connectionID == "" {
52 | connNum++
53 | t.connectionID = fmt.Sprintf("test%v", connNum)
54 | }
55 | return t.connectionID
56 | }
57 |
58 | func (t *testingConnection) SetConnectionID(id string) {
59 | t.connectionID = id
60 | }
61 |
62 | func (t *testingConnection) Read(b []byte) (n int, err error) {
63 | if fr := t.FailRead(); fr != "" {
64 | defer func() { t.SetFailRead("") }()
65 | return 0, errors.New(fr)
66 | }
67 | timer := make(<-chan time.Time)
68 | if t.Timeout() > 0 {
69 | timer = time.After(t.Timeout())
70 | }
71 | nch := make(chan int)
72 | go func() {
73 | n, _ := t.srvReader.Read(b)
74 | nch <- n
75 | }()
76 | select {
77 | case n := <-nch:
78 | return n, nil
79 | case <-timer:
80 | return 0, fmt.Errorf("timeout %v", t.Timeout())
81 | }
82 | }
83 |
84 | func (t *testingConnection) Write(b []byte) (n int, err error) {
85 | if fw := t.FailWrite(); fw != "" {
86 | defer func() { t.SetFailWrite("") }()
87 | return 0, errors.New(fw)
88 | }
89 | t.srvSendChan <- b
90 | return len(b), nil
91 | }
92 |
93 | func (t *testingConnection) Connected() bool {
94 | t.cnMutex.Lock()
95 | defer t.cnMutex.Unlock()
96 | return t.connected
97 | }
98 |
99 | func (t *testingConnection) SetConnected(connected bool) {
100 | t.cnMutex.Lock()
101 | defer t.cnMutex.Unlock()
102 | t.connected = connected
103 | }
104 |
105 | func (t *testingConnection) FailRead() string {
106 | defer t.failMx.Unlock()
107 | t.failMx.Lock()
108 | return t.failRead
109 | }
110 |
111 | func (t *testingConnection) FailWrite() string {
112 | defer t.failMx.Unlock()
113 | t.failMx.Lock()
114 | return t.failWrite
115 | }
116 |
117 | func (t *testingConnection) SetFailRead(fail string) {
118 | defer t.failMx.Unlock()
119 | t.failMx.Lock()
120 | t.failRead = fail
121 | }
122 |
123 | func (t *testingConnection) SetFailWrite(fail string) {
124 | defer t.failMx.Unlock()
125 | t.failMx.Lock()
126 | t.failWrite = fail
127 | }
128 |
129 | // newTestingConnectionForServer builds a testingConnection with an sent (but not yet received) handshake for testing a server
130 | func newTestingConnectionForServer() *testingConnection {
131 | conn := newTestingConnection()
132 | // client receive loop
133 | go receiveLoop(conn)()
134 | // Send initial Handshake
135 | conn.ClientSend(`{"protocol": "json","version": 1}`)
136 | conn.SetConnected(true)
137 | return conn
138 | }
139 |
140 | func newTestingConnection() *testingConnection {
141 | cliReader, srvWriter := io.Pipe()
142 | srvReader, cliWriter := io.Pipe()
143 | conn := testingConnection{
144 | srvWriter: srvWriter,
145 | srvReader: srvReader,
146 | cliWriter: cliWriter,
147 | cliReader: cliReader,
148 | received: make(chan interface{}, 20),
149 | cliSendChan: make(chan string, 20),
150 | srvSendChan: make(chan []byte, 20),
151 | timeout: time.Second * 5,
152 | }
153 | // client send loop
154 | go func() {
155 | for {
156 | _, _ = conn.cliWriter.Write(append([]byte(<-conn.cliSendChan), 30))
157 | }
158 | }()
159 | // server send loop
160 | go func() {
161 | for {
162 | _, _ = conn.srvWriter.Write(<-conn.srvSendChan)
163 | }
164 | }()
165 | return &conn
166 | }
167 |
168 | func (t *testingConnection) ClientSend(message string) {
169 | t.cliSendChan <- message
170 | }
171 |
172 | func (t *testingConnection) ClientReceive() (string, error) {
173 | var buf bytes.Buffer
174 | var data = make([]byte, 1<<15) // 32K
175 | var nn int
176 | for {
177 | if message, err := buf.ReadString(30); err != nil {
178 | buf.Write(data[:nn])
179 | if n, err := t.cliReader.Read(data[nn:]); err == nil {
180 | buf.Write(data[nn : nn+n])
181 | nn = nn + n
182 | } else {
183 | return "", err
184 | }
185 | } else {
186 | return message[:len(message)-1], nil
187 | }
188 | }
189 | }
190 |
191 | func (t *testingConnection) ReceiveChan() chan interface{} {
192 | return t.received
193 | }
194 |
195 | type clientReceiver interface {
196 | ClientReceive() (string, error)
197 | ReceiveChan() chan interface{}
198 | SetConnected(bool)
199 | }
200 |
201 | func receiveLoop(conn clientReceiver) func() {
202 | return func() {
203 | defer ginkgo.GinkgoRecover()
204 | errorHandler := func(err error) { ginkgo.Fail(fmt.Sprintf("received invalid message from server %v", err.Error())) }
205 | for {
206 | if message, err := conn.ClientReceive(); err == nil {
207 | var hubMessage hubMessage
208 | if err = json.Unmarshal([]byte(message), &hubMessage); err == nil {
209 | switch hubMessage.Type {
210 | case 1, 4:
211 | var invocationMessage invocationMessage
212 | if err = json.Unmarshal([]byte(message), &invocationMessage); err == nil {
213 | conn.ReceiveChan() <- invocationMessage
214 | } else {
215 | errorHandler(err)
216 | }
217 | case 2:
218 | var jsonStreamItemMessage jsonStreamItemMessage
219 | if err = json.Unmarshal([]byte(message), &jsonStreamItemMessage); err == nil {
220 |
221 | conn.ReceiveChan() <- streamItemMessage{
222 | Type: jsonStreamItemMessage.Type,
223 | InvocationID: jsonStreamItemMessage.InvocationID,
224 | Item: jsonStreamItemMessage.Item,
225 | }
226 | } else {
227 | errorHandler(err)
228 | }
229 | case 3:
230 | var completionMessage completionMessage
231 | if err = json.Unmarshal([]byte(message), &completionMessage); err == nil {
232 | conn.ReceiveChan() <- completionMessage
233 | } else {
234 | errorHandler(err)
235 | }
236 | case 7:
237 | var closeMessage closeMessage
238 | if err = json.Unmarshal([]byte(message), &closeMessage); err == nil {
239 | conn.SetConnected(false)
240 | conn.ReceiveChan() <- closeMessage
241 | } else {
242 | errorHandler(err)
243 | }
244 | }
245 | } else {
246 | errorHandler(err)
247 | }
248 | }
249 | }
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/websocketconnection.go:
--------------------------------------------------------------------------------
1 | package signalr
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 |
8 | "nhooyr.io/websocket"
9 | )
10 |
11 | type webSocketConnection struct {
12 | ConnectionBase
13 | conn *websocket.Conn
14 | transferMode TransferMode
15 | }
16 |
17 | func newWebSocketConnection(ctx context.Context, connectionID string, conn *websocket.Conn) *webSocketConnection {
18 | w := &webSocketConnection{
19 | conn: conn,
20 | ConnectionBase: *NewConnectionBase(ctx, connectionID),
21 | }
22 | return w
23 | }
24 |
25 | func (w *webSocketConnection) Write(p []byte) (n int, err error) {
26 | messageType := websocket.MessageText
27 | if w.transferMode == BinaryTransferMode {
28 | messageType = websocket.MessageBinary
29 | }
30 | n, err = ReadWriteWithContext(w.Context(),
31 | func() (int, error) {
32 | err := w.conn.Write(w.Context(), messageType, p)
33 | if err != nil {
34 | return 0, err
35 | }
36 | return len(p), nil
37 | },
38 | func() {})
39 | if err != nil {
40 | err = fmt.Errorf("%T: %w", w, err)
41 | _ = w.conn.Close(1000, err.Error())
42 | }
43 | return n, err
44 | }
45 |
46 | func (w *webSocketConnection) Read(p []byte) (n int, err error) {
47 | n, err = ReadWriteWithContext(w.Context(),
48 | func() (int, error) {
49 | _, data, err := w.conn.Read(w.Context())
50 | if err != nil {
51 | return 0, err
52 | }
53 | return bytes.NewReader(data).Read(p)
54 | },
55 | func() {})
56 | if err != nil {
57 | err = fmt.Errorf("%T: %w", w, err)
58 | _ = w.conn.Close(1000, err.Error())
59 | }
60 | return n, err
61 | }
62 |
63 | func (w *webSocketConnection) TransferMode() TransferMode {
64 | return w.transferMode
65 | }
66 |
67 | func (w *webSocketConnection) SetTransferMode(transferMode TransferMode) {
68 | w.transferMode = transferMode
69 | }
70 |
--------------------------------------------------------------------------------