├── .gitattributes ├── .github └── workflows │ ├── codeql-analysis.yml │ └── main.yml ├── .gitignore ├── .golangci.yml ├── .run ├── go test signalr.run.xml └── go test signalr_test.run.xml ├── Invokeresult.go ├── LICENSE ├── Makefile ├── README.md ├── chatsample ├── main.go ├── middleware │ ├── log_requests.go │ └── response_writer_wrapper.go └── public │ ├── fs.go │ ├── index.html │ └── js │ ├── signalr.js │ ├── signalr.js.map │ └── signalr.min.js ├── client.go ├── client_test.go ├── clientoptions.go ├── clientoptions_test.go ├── clientproxy.go ├── clientsseconnection.go ├── codecov.yml ├── connection.go ├── connection_test.go ├── connectionbase.go ├── ctxpipe.go ├── doc.go ├── go.mod ├── go.sum ├── groupmanager.go ├── httpconnection.go ├── httpmux.go ├── httpserver_test.go ├── hub.go ├── hubclients.go ├── hubconnection.go ├── hubcontext.go ├── hubcontext_test.go ├── hublifetimemanager.go ├── hubprotocol.go ├── hubprotocol_test.go ├── invocation_test.go ├── invokeclient.go ├── jsonhubprotocol.go ├── logger_test.go ├── loop.go ├── messagepackhubprotocol.go ├── messagepackhubprotocol_test.go ├── negotiateresponse.go ├── netconnection.go ├── options.go ├── party.go ├── receiver.go ├── router ├── Makefile ├── chirouter.go ├── doc.go ├── go.mod ├── go.sum ├── gorillarouter.go ├── httprouter.go └── router_test │ └── router_test.go ├── server.go ├── server_test.go ├── serveroptions.go ├── serveroptions_test.go ├── serversseconnection.go ├── signalr_suite_test.go ├── signalr_test ├── logger_test.go ├── netconnection_test.go ├── package-lock.json ├── package.json ├── server_test.go ├── setupJest.ts ├── spec │ └── server.spec.ts ├── tsconfig.json └── tsconfig.spec.json ├── streamclient.go ├── streamclient_test.go ├── streamer.go ├── streaminvocation_test.go ├── testLogConf.json ├── testingconnection_test.go └── websocketconnection.go /.gitattributes: -------------------------------------------------------------------------------- 1 | chatsample/public/js/signalr.js linguist-detectable=false 2 | chatsample/public/js/signalr.js.map linguist-detectable=false -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '45 8 * * 3' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 37 | # Learn more: 38 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 39 | 40 | steps: 41 | - name: Checkout repository 42 | uses: actions/checkout@v2 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@v1 47 | with: 48 | languages: ${{ matrix.language }} 49 | # If you wish to specify custom queries, you can do so here or in a config file. 50 | # By default, queries listed here will override any specified in a config file. 51 | # Prefix the list here with "+" to use these queries and those in the config file. 52 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 53 | 54 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 55 | # If this step fails, then you should remove it and run the build manually (see below) 56 | - name: Autobuild 57 | uses: github/codeql-action/autobuild@v1 58 | 59 | # ℹ️ Command-line programs to run using the OS shell. 60 | # 📚 https://git.io/JvXDl 61 | 62 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 63 | # and modify them (or add more) to build your code if your project 64 | # uses a compiled language 65 | 66 | #- run: | 67 | # make bootstrap 68 | # make release 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v1 72 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | # This workflow will run on master branch and on any pull requests targeting master 4 | on: 5 | push: 6 | branches: 7 | - master 8 | pull_request: 9 | 10 | jobs: 11 | 12 | golangci: 13 | name: golangci 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: golangci-lint 18 | uses: golangci/golangci-lint-action@v3 19 | with: 20 | # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. 21 | version: v1.51 22 | 23 | test: 24 | name: Test and Coverage 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Set up Go 28 | uses: actions/setup-go@v1 29 | with: 30 | go-version: 1.16 31 | 32 | - name: Check out code 33 | uses: actions/checkout@v1 34 | 35 | - name: Run Unit tests. 36 | run: make test-coverage 37 | 38 | - name: Upload Coverage report to CodeCov 39 | uses: codecov/codecov-action@v1.0.0 40 | with: 41 | token: ${{secrets.CODECOV_TOKEN}} 42 | file: ./coverage.txt 43 | 44 | # test-macos: 45 | # name: Test MacOS 46 | # runs-on: macos-11 47 | # steps: 48 | # - name: Set up Go 49 | # uses: actions/setup-go@v1 50 | # with: 51 | # go-version: 1.16 52 | # 53 | # - name: Check out code 54 | # uses: actions/checkout@v1 55 | # 56 | # - name: Run Unit tests. 57 | # run: make test 58 | # 59 | # test-windows: 60 | # name: Test Windows 61 | # runs-on: windows-2019 62 | # steps: 63 | # - name: Set up Go 64 | # uses: actions/setup-go@v1 65 | # with: 66 | # go-version: 1.16 67 | # 68 | # - name: Check out code 69 | # uses: actions/checkout@v1 70 | # 71 | # - name: Run Unit tests. 72 | # run: make test 73 | 74 | build: 75 | name: Build 76 | runs-on: ubuntu-latest 77 | needs: [golangci, test] 78 | steps: 79 | - name: Set up Go 80 | uses: actions/setup-go@v1 81 | with: 82 | go-version: 1.16 83 | 84 | - name: Check out code 85 | uses: actions/checkout@v1 86 | 87 | - name: Build 88 | run: make build 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | 9 | # Diagnostic reports (https://nodejs.org/api/report.html) 10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 11 | 12 | # Runtime data 13 | pids 14 | *.pid 15 | *.seed 16 | *.pid.lock 17 | 18 | # Directory for instrumented libs generated by jscoverage/JSCover 19 | lib-cov 20 | 21 | # Coverage directory used by tools like istanbul 22 | coverage 23 | *.lcov 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # TypeScript cache 48 | *.tsbuildinfo 49 | 50 | # Optional npm cache directory 51 | .npm 52 | 53 | # Optional eslint cache 54 | .eslintcache 55 | 56 | # Microbundle cache 57 | .rpt2_cache/ 58 | .rts2_cache_cjs/ 59 | .rts2_cache_es/ 60 | .rts2_cache_umd/ 61 | 62 | # Optional REPL history 63 | .node_repl_history 64 | 65 | # Output of 'npm pack' 66 | *.tgz 67 | 68 | # Yarn Integrity file 69 | .yarn-integrity 70 | 71 | # dotenv environment variables file 72 | .env 73 | .env.test 74 | 75 | # parcel-bundler cache (https://parceljs.org/) 76 | .cache 77 | 78 | # Next.js build output 79 | .next 80 | 81 | # Nuxt.js build / generate output 82 | .nuxt 83 | dist 84 | 85 | # Gatsby files 86 | .cache/ 87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js 88 | # https://nextjs.org/blog/next-9-1#public-directory-support 89 | # public 90 | 91 | # vuepress build output 92 | .vuepress/dist 93 | 94 | # Serverless directories 95 | .serverless/ 96 | 97 | # FuseBox cache 98 | .fusebox/ 99 | 100 | # DynamoDB Local files 101 | .dynamodb/ 102 | 103 | # TernJS port file 104 | .tern-port 105 | 106 | # Jetbrains IDE files 107 | .idea/ 108 | /.vscode 109 | /cover.out 110 | /coverage.txt 111 | 112 | signalr_test/package-lock.json 113 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | skip-dirs: 3 | - chatsample 4 | - signalr_test 5 | - router/router_test 6 | skip-files: 7 | - ".+_test\\.go$" -------------------------------------------------------------------------------- /.run/go test signalr.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /.run/go test signalr_test.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /Invokeresult.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // InvokeResult is the combined value/error result for async invocations. Used as channel type. 8 | type InvokeResult struct { 9 | Value interface{} 10 | Error error 11 | } 12 | 13 | // newInvokeResultChan combines a value result and an error result channel into one InvokeResult channel 14 | // The InvokeResult channel is automatically closed when both input channels are closed. 15 | func newInvokeResultChan(ctx context.Context, resultChan <-chan interface{}, errChan <-chan error) <-chan InvokeResult { 16 | ch := make(chan InvokeResult, 1) 17 | go func(ctx context.Context, ch chan InvokeResult, resultChan <-chan interface{}, errChan <-chan error) { 18 | var resultChanClosed, errChanClosed bool 19 | loop: 20 | for !resultChanClosed || !errChanClosed { 21 | select { 22 | case <-ctx.Done(): 23 | break loop 24 | case value, ok := <-resultChan: 25 | if !ok { 26 | resultChanClosed = true 27 | } else { 28 | ch <- InvokeResult{ 29 | Value: value, 30 | } 31 | } 32 | case err, ok := <-errChan: 33 | if !ok { 34 | errChanClosed = true 35 | } else { 36 | ch <- InvokeResult{ 37 | Error: err, 38 | } 39 | } 40 | } 41 | } 42 | close(ch) 43 | }(ctx, ch, resultChan, errChan) 44 | return ch 45 | } 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 David Fowler, Philipp Seith 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT_NAME := "signalr" 2 | PKG := "github.com/philippseith/$(PROJECT_NAME)" 3 | PKG_LIST := $(shell go list ${PKG}/... | grep -v /vendor/) 4 | GO_FILES := $(shell find . -name '*.go' | grep -v /vendor/ | grep -v _test.go) 5 | 6 | .PHONY: all dep lint vet test test-coverage build clean 7 | 8 | all: build 9 | 10 | dep: ## Get the dependencies 11 | @go mod download 12 | 13 | lint: ## Lint Golang files 14 | @golangci-lint run 15 | 16 | vet: ## Run go vet 17 | @go vet ${PKG_LIST} 18 | 19 | test: ## Run unittests 20 | @go test -race -short -count=1 ${PKG_LIST} 21 | 22 | test-coverage: ## Run tests with coverage 23 | @go test -race -short -count=1 -coverpkg=. -coverprofile cover.out -covermode=atomic ${PKG_LIST} 24 | @cat cover.out >> coverage.txt 25 | 26 | build: dep ## Build the binary file 27 | @go build -i -o build/main $(PKG) 28 | 29 | clean: ## Remove previous build 30 | @rm -f $(PROJECT_NAME)/build 31 | 32 | run-chatsample: ## run the local ./chatsample server 33 | @go run ./chatsample/*.go 34 | 35 | help: ## Display this help screen 36 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 37 | -------------------------------------------------------------------------------- /chatsample/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | _ "embed" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "os" 10 | "strings" 11 | "time" 12 | 13 | kitlog "github.com/go-kit/log" 14 | 15 | "github.com/philippseith/signalr" 16 | "github.com/philippseith/signalr/chatsample/middleware" 17 | "github.com/philippseith/signalr/chatsample/public" 18 | ) 19 | 20 | type chat struct { 21 | signalr.Hub 22 | } 23 | 24 | func (c *chat) OnConnected(connectionID string) { 25 | fmt.Printf("%s connected\n", connectionID) 26 | c.Groups().AddToGroup("group", connectionID) 27 | } 28 | 29 | func (c *chat) OnDisconnected(connectionID string) { 30 | fmt.Printf("%s disconnected\n", connectionID) 31 | c.Groups().RemoveFromGroup("group", connectionID) 32 | } 33 | 34 | func (c *chat) Broadcast(message string) { 35 | // Broadcast to all clients 36 | c.Clients().Group("group").Send("receive", message) 37 | } 38 | 39 | func (c *chat) Echo(message string) { 40 | c.Clients().Caller().Send("receive", message) 41 | } 42 | 43 | func (c *chat) Panic() { 44 | panic("Don't panic!") 45 | } 46 | 47 | func (c *chat) RequestAsync(message string) <-chan map[string]string { 48 | r := make(chan map[string]string) 49 | go func() { 50 | defer close(r) 51 | time.Sleep(4 * time.Second) 52 | m := make(map[string]string) 53 | m["ToUpper"] = strings.ToUpper(message) 54 | m["ToLower"] = strings.ToLower(message) 55 | m["len"] = fmt.Sprint(len(message)) 56 | r <- m 57 | }() 58 | return r 59 | } 60 | 61 | func (c *chat) RequestTuple(message string) (string, string, int) { 62 | return strings.ToUpper(message), strings.ToLower(message), len(message) 63 | } 64 | 65 | func (c *chat) DateStream() <-chan string { 66 | r := make(chan string) 67 | go func() { 68 | defer close(r) 69 | for i := 0; i < 50; i++ { 70 | r <- fmt.Sprint(time.Now().Clock()) 71 | time.Sleep(time.Second) 72 | } 73 | }() 74 | return r 75 | } 76 | 77 | func (c *chat) UploadStream(upload1 <-chan int, factor float64, upload2 <-chan float64) { 78 | ok1 := true 79 | ok2 := true 80 | u1 := 0 81 | u2 := 0.0 82 | c.Echo(fmt.Sprintf("f: %v", factor)) 83 | for { 84 | select { 85 | case u1, ok1 = <-upload1: 86 | if ok1 { 87 | c.Echo(fmt.Sprintf("u1: %v", u1)) 88 | } else if !ok2 { 89 | c.Echo("Finished") 90 | return 91 | } 92 | case u2, ok2 = <-upload2: 93 | if ok2 { 94 | c.Echo(fmt.Sprintf("u2: %v", u2)) 95 | } else if !ok1 { 96 | c.Echo("Finished") 97 | return 98 | } 99 | } 100 | } 101 | } 102 | 103 | func (c *chat) Abort() { 104 | fmt.Println("Abort") 105 | c.Hub.Abort() 106 | } 107 | 108 | //func runTCPServer(address string, hub signalr.HubInterface) { 109 | // listener, err := net.Listen("tcp", address) 110 | // 111 | // if err != nil { 112 | // fmt.Println(err) 113 | // return 114 | // } 115 | // 116 | // fmt.Printf("Listening for TCP connection on %s\n", listener.Addr()) 117 | // 118 | // server, _ := signalr.NewServer(context.TODO(), signalr.UseHub(hub)) 119 | // 120 | // for { 121 | // conn, err := listener.Accept() 122 | // 123 | // if err != nil { 124 | // fmt.Println(err) 125 | // break 126 | // } 127 | // 128 | // go server.Serve(context.TODO(), newNetConnection(conn)) 129 | // } 130 | //} 131 | 132 | func runHTTPServer(address string, hub signalr.HubInterface) { 133 | server, _ := signalr.NewServer(context.TODO(), signalr.SimpleHubFactory(hub), 134 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stdout), false), 135 | signalr.KeepAliveInterval(2*time.Second)) 136 | router := http.NewServeMux() 137 | server.MapHTTP(signalr.WithHTTPServeMux(router), "/chat") 138 | 139 | fmt.Printf("Serving public content from the embedded filesystem\n") 140 | router.Handle("/", http.FileServer(http.FS(public.FS))) 141 | fmt.Printf("Listening for websocket connections on http://%s\n", address) 142 | if err := http.ListenAndServe(address, middleware.LogRequests(router)); err != nil { 143 | log.Fatal("ListenAndServe:", err) 144 | } 145 | } 146 | 147 | func runHTTPClient(address string, receiver interface{}) error { 148 | c, err := signalr.NewClient(context.Background(), nil, 149 | signalr.WithReceiver(receiver), 150 | signalr.WithConnector(func() (signalr.Connection, error) { 151 | creationCtx, _ := context.WithTimeout(context.Background(), 2*time.Second) 152 | return signalr.NewHTTPConnection(creationCtx, address) 153 | }), 154 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stdout), false)) 155 | if err != nil { 156 | return err 157 | } 158 | c.Start() 159 | fmt.Println("Client started") 160 | return nil 161 | } 162 | 163 | type receiver struct { 164 | signalr.Receiver 165 | } 166 | 167 | func (r *receiver) Receive(msg string) { 168 | fmt.Println(msg) 169 | // The silly client urges the server to end his connection after 10 seconds 170 | r.Server().Send("abort") 171 | } 172 | 173 | func main() { 174 | hub := &chat{} 175 | 176 | //go runTCPServer("127.0.0.1:8007", hub) 177 | go runHTTPServer("localhost:8086", hub) 178 | <-time.After(time.Millisecond * 2) 179 | go func() { 180 | fmt.Println(runHTTPClient("http://localhost:8086/chat", &receiver{})) 181 | }() 182 | ch := make(chan struct{}) 183 | <-ch 184 | } 185 | -------------------------------------------------------------------------------- /chatsample/middleware/log_requests.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | // LogRequests writes simple request logs to STDOUT so that we can see what requests the server is handling 10 | func LogRequests(h http.Handler) http.Handler { 11 | // type our middleware as an http.HandlerFunc so that it is seen as an http.Handler 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | // wrap the original response writer so we can capture response details 14 | wrappedWriter := wrapResponseWriter(w) 15 | start := time.Now() // request start time 16 | 17 | // serve the inner request 18 | h.ServeHTTP(wrappedWriter, r) 19 | 20 | // extract request/response details 21 | status := wrappedWriter.status 22 | uri := r.URL.String() 23 | method := r.Method 24 | duration := time.Since(start) 25 | 26 | // write to console 27 | fmt.Printf("%03d %s %s %v\n", status, method, uri, duration) 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /chatsample/middleware/response_writer_wrapper.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | func wrapResponseWriter(w http.ResponseWriter) *responseWriterWrapper { 11 | return &responseWriterWrapper{ResponseWriter: w} 12 | } 13 | 14 | // responseWriterWrapper is a minimal wrapper for http.ResponseWriter that allows the 15 | // written HTTP status code to be captured for logging. 16 | // adapted from: https://github.com/elithrar/admission-control/blob/df0c4bf37a96d159d9181a71cee6e5485d5a50a9/request_logger.go#L11-L13 17 | type responseWriterWrapper struct { 18 | http.ResponseWriter 19 | status int 20 | wroteHeader bool 21 | } 22 | 23 | // Status provides access to the wrapped http.ResponseWriter's status 24 | func (rw *responseWriterWrapper) Status() int { 25 | return rw.status 26 | } 27 | 28 | // Header provides access to the wrapped http.ResponseWriter's header 29 | // allowing handlers to set HTTP headers on the wrapped response 30 | func (rw *responseWriterWrapper) Header() http.Header { 31 | return rw.ResponseWriter.Header() 32 | } 33 | 34 | // WriteHeader intercepts the written status code and caches it 35 | // so that we can access it later 36 | func (rw *responseWriterWrapper) WriteHeader(code int) { 37 | if rw.wroteHeader { 38 | return 39 | } 40 | 41 | rw.status = code 42 | rw.ResponseWriter.WriteHeader(code) 43 | rw.wroteHeader = true 44 | 45 | return 46 | } 47 | 48 | // Flush implements http.Flusher 49 | func (rw *responseWriterWrapper) Flush() { 50 | if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { 51 | flusher.Flush() 52 | } 53 | } 54 | 55 | func (rw *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { 56 | if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { 57 | return hijacker.Hijack() 58 | } 59 | return nil, nil, errors.New("http.Hijacker not implemented") 60 | } 61 | -------------------------------------------------------------------------------- /chatsample/public/fs.go: -------------------------------------------------------------------------------- 1 | package public 2 | 3 | import "embed" 4 | 5 | //go:embed * 6 | var FS embed.FS 7 | -------------------------------------------------------------------------------- /chatsample/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 17 | 18 | 19 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /clientoptions.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/cenkalti/backoff/v4" 7 | ) 8 | 9 | // WithConnection sets the Connection of the Client 10 | func WithConnection(connection Connection) func(party Party) error { 11 | return func(party Party) error { 12 | if client, ok := party.(*client); ok { 13 | if client.connectionFactory != nil { 14 | return errors.New("options WithConnection and WithConnector can not be used together") 15 | } 16 | client.conn = connection 17 | return nil 18 | } 19 | return errors.New("option WithConnection is client only") 20 | } 21 | } 22 | 23 | // WithConnector allows the Client to establish a connection 24 | // using the Connection build by the connectionFactory. 25 | // It is also used for auto reconnect if the connection is lost. 26 | func WithConnector(connectionFactory func() (Connection, error)) func(Party) error { 27 | return func(party Party) error { 28 | if client, ok := party.(*client); ok { 29 | if client.conn != nil { 30 | return errors.New("options WithConnection and WithConnector can not be used together") 31 | } 32 | client.connectionFactory = connectionFactory 33 | return nil 34 | } 35 | return errors.New("option WithConnector is client only") 36 | } 37 | } 38 | 39 | // WithReceiver sets the object which will receive server side calls to client methods (e.g. callbacks) 40 | func WithReceiver(receiver interface{}) func(Party) error { 41 | return func(party Party) error { 42 | if client, ok := party.(*client); ok { 43 | client.receiver = receiver 44 | if receiver, ok := receiver.(ReceiverInterface); ok { 45 | receiver.Init(client) 46 | } 47 | return nil 48 | } 49 | return errors.New("option WithReceiver is client only") 50 | } 51 | } 52 | 53 | // WithBackoff sets the backoff.BackOff used for repeated connection attempts in the client. 54 | // See https://pkg.go.dev/github.com/cenkalti/backoff for configuration options. 55 | // If the option is not set, backoff.NewExponentialBackOff() without any further configuration will be used. 56 | func WithBackoff(backoffFactory func() backoff.BackOff) func(party Party) error { 57 | return func(party Party) error { 58 | if client, ok := party.(*client); ok { 59 | client.backoffFactory = backoffFactory 60 | return nil 61 | } 62 | return errors.New("option WithBackoff is client only") 63 | } 64 | } 65 | 66 | // TransferFormat sets the transfer format used on the transport. Allowed values are "Text" and "Binary" 67 | func TransferFormat(format string) func(Party) error { 68 | return func(p Party) error { 69 | if c, ok := p.(*client); ok { 70 | switch format { 71 | case "Text": 72 | c.format = "json" 73 | case "Binary": 74 | c.format = "messagepack" 75 | default: 76 | return fmt.Errorf("invalid transferformat %v", format) 77 | } 78 | return nil 79 | } 80 | return errors.New("option TransferFormat is client only") 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /clientoptions_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "github.com/cenkalti/backoff/v4" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("Client options", func() { 12 | 13 | Describe("WithConnection and WithConnector option", func() { 14 | Context("none of them is given", func() { 15 | It("NewClient should fail", func() { 16 | _, err := NewClient(context.TODO()) 17 | Expect(err).To(HaveOccurred()) 18 | }, 3.0) 19 | }) 20 | Context("both are given", func() { 21 | It("NewClient should fail", func() { 22 | conn := NewNetConnection(context.TODO(), nil) 23 | _, err := NewClient(context.TODO(), WithConnection(conn), WithConnector(func() (Connection, error) { 24 | return conn, nil 25 | })) 26 | Expect(err).To(HaveOccurred()) 27 | }, 3.0) 28 | }) 29 | Context("only WithConnection is given", func() { 30 | It("NewClient should not fail", func() { 31 | conn := NewNetConnection(context.TODO(), nil) 32 | _, err := NewClient(context.TODO(), WithConnection(conn)) 33 | Expect(err).NotTo(HaveOccurred()) 34 | }, 3.0) 35 | }) 36 | Context("only WithConnector is given", func() { 37 | It("NewClient should not fail", func() { 38 | conn := NewNetConnection(context.TODO(), nil) 39 | _, err := NewClient(context.TODO(), WithConnector(func() (Connection, error) { 40 | return conn, nil 41 | })) 42 | Expect(err).NotTo(HaveOccurred()) 43 | }, 3.0) 44 | }) 45 | Context("only WithBackoff is given", func() { 46 | It("NewClient should not fail", func() { 47 | conn := NewNetConnection(context.TODO(), nil) 48 | _, err := NewClient(context.TODO(), WithConnection(conn), WithBackoff(func() backoff.BackOff { 49 | return backoff.NewExponentialBackOff() 50 | })) 51 | Expect(err).NotTo(HaveOccurred()) 52 | }, 3.0) 53 | }) 54 | }) 55 | 56 | }) 57 | -------------------------------------------------------------------------------- /clientproxy.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | //ClientProxy allows the hub to send messages to one or more of its clients 4 | type ClientProxy interface { 5 | Send(target string, args ...interface{}) 6 | } 7 | 8 | type allClientProxy struct { 9 | lifetimeManager HubLifetimeManager 10 | } 11 | 12 | func (a *allClientProxy) Send(target string, args ...interface{}) { 13 | a.lifetimeManager.InvokeAll(target, args) 14 | } 15 | 16 | type singleClientProxy struct { 17 | connectionID string 18 | lifetimeManager HubLifetimeManager 19 | } 20 | 21 | func (a *singleClientProxy) Send(target string, args ...interface{}) { 22 | a.lifetimeManager.InvokeClient(a.connectionID, target, args) 23 | } 24 | 25 | type groupClientProxy struct { 26 | groupName string 27 | lifetimeManager HubLifetimeManager 28 | } 29 | 30 | func (g *groupClientProxy) Send(target string, args ...interface{}) { 31 | g.lifetimeManager.InvokeGroup(g.groupName, target, args) 32 | } 33 | -------------------------------------------------------------------------------- /clientsseconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | ) 12 | 13 | type clientSSEConnection struct { 14 | ConnectionBase 15 | reqURL string 16 | sseReader io.Reader 17 | sseWriter io.Writer 18 | } 19 | 20 | func newClientSSEConnection(address string, connectionID string, body io.ReadCloser) (*clientSSEConnection, error) { 21 | // Setup request 22 | reqURL, err := url.Parse(address) 23 | if err != nil { 24 | return nil, err 25 | } 26 | q := reqURL.Query() 27 | q.Set("id", connectionID) 28 | reqURL.RawQuery = q.Encode() 29 | c := clientSSEConnection{ 30 | ConnectionBase: ConnectionBase{ 31 | ctx: context.Background(), 32 | connectionID: connectionID, 33 | }, 34 | reqURL: reqURL.String(), 35 | } 36 | c.sseReader, c.sseWriter = io.Pipe() 37 | go func() { 38 | defer func() { closeResponseBody(body) }() 39 | p := make([]byte, 1<<15) 40 | loop: 41 | for { 42 | n, err := body.Read(p) 43 | if err != nil { 44 | break loop 45 | } 46 | lines := strings.Split(string(p[:n]), "\n") 47 | for _, line := range lines { 48 | line = strings.Trim(line, "\r\t ") 49 | // Ignore everything but data 50 | if strings.Index(line, "data:") != 0 { 51 | continue 52 | } 53 | json := strings.Replace(strings.Trim(line, "\r"), "data:", "", 1) 54 | // Spec says: If it starts with Space, remove it 55 | if len(json) > 0 && json[0] == ' ' { 56 | json = json[1:] 57 | } 58 | _, err = c.sseWriter.Write([]byte(json)) 59 | if err != nil { 60 | break loop 61 | } 62 | } 63 | } 64 | }() 65 | return &c, nil 66 | } 67 | 68 | func (c *clientSSEConnection) Read(p []byte) (n int, err error) { 69 | return c.sseReader.Read(p) 70 | } 71 | 72 | func (c *clientSSEConnection) Write(p []byte) (n int, err error) { 73 | req, err := http.NewRequest("POST", c.reqURL, bytes.NewReader(p)) 74 | if err != nil { 75 | return 0, err 76 | } 77 | client := &http.Client{} 78 | resp, err := client.Do(req) 79 | if err != nil { 80 | return 0, err 81 | } 82 | if resp.StatusCode != 200 { 83 | err = fmt.Errorf("POST %v -> %v", c.reqURL, resp.Status) 84 | } 85 | closeResponseBody(resp.Body) 86 | return len(p), err 87 | } 88 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: true 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "60...100" 8 | 9 | parsers: 10 | gcov: 11 | branch_detection: 12 | conditional: yes 13 | loop: yes 14 | method: no 15 | macro: no 16 | 17 | comment: 18 | layout: "reach,diff,flags,files,footer" 19 | behavior: default 20 | require_changes: false 21 | 22 | ignore: 23 | - chatsample/**/* 24 | - signalr_test/**/* 25 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "io" 6 | ) 7 | 8 | // Connection describes a connection between signalR client and server 9 | type Connection interface { 10 | io.Reader 11 | io.Writer 12 | Context() context.Context 13 | ConnectionID() string 14 | SetConnectionID(id string) 15 | } 16 | 17 | // TransferMode is either TextTransferMode or BinaryTransferMode 18 | type TransferMode int 19 | 20 | // MessageType constants. 21 | const ( 22 | // TextTransferMode is for UTF-8 encoded text messages like JSON. 23 | TextTransferMode TransferMode = iota + 1 24 | // BinaryTransferMode is for binary messages like MessagePack. 25 | BinaryTransferMode 26 | ) 27 | 28 | // ConnectionWithTransferMode is a Connection with TransferMode (e.g. Websocket) 29 | type ConnectionWithTransferMode interface { 30 | TransferMode() TransferMode 31 | SetTransferMode(transferMode TransferMode) 32 | } 33 | -------------------------------------------------------------------------------- /connection_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "time" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("Connection", func() { 13 | 14 | Describe("Connection closed", func() { 15 | var server Server 16 | var conn *testingConnection 17 | BeforeEach(func(done Done) { 18 | server, conn = connect(&Hub{}) 19 | close(done) 20 | }) 21 | AfterEach(func(done Done) { 22 | server.cancel() 23 | close(done) 24 | }) 25 | Context("When the connection is closed", func() { 26 | It("should close the connection and not answer an invocation", func(done Done) { 27 | conn.ClientSend(`{"type":7}`) 28 | conn.ClientSend(`{"type":1,"invocationId": "123","target":"unknownFunc"}`) 29 | // When the connection is closed, the server should either send a closeMessage or nothing at all 30 | select { 31 | case message := <-conn.received: 32 | Expect(message.(closeMessage)).NotTo(BeNil()) 33 | case <-time.After(100 * time.Millisecond): 34 | } 35 | close(done) 36 | }) 37 | }) 38 | Context("When the connection is closed with an invalid close message", func() { 39 | It("should close the connection and should not answer an invocation", func(done Done) { 40 | conn.ClientSend(`{"type":7,"error":1}`) 41 | conn.ClientSend(`{"type":1,"invocationId": "123","target":"unknownFunc"}`) 42 | // When the connection is closed, the server should either send a closeMessage or nothing at all 43 | select { 44 | case message := <-conn.received: 45 | Expect(message.(closeMessage)).NotTo(BeNil()) 46 | case <-time.After(100 * time.Millisecond): 47 | } 48 | close(done) 49 | }) 50 | }) 51 | }) 52 | }) 53 | 54 | var _ = Describe("Protocol", func() { 55 | var server Server 56 | var conn *testingConnection 57 | BeforeEach(func(done Done) { 58 | server, conn = connect(&Hub{}) 59 | close(done) 60 | }) 61 | AfterEach(func(done Done) { 62 | server.cancel() 63 | close(done) 64 | }) 65 | Describe("Invalid messages", func() { 66 | Context("When a message with invalid id is sent", func() { 67 | It("should close the connection with an error", func(done Done) { 68 | conn.ClientSend(`{"type":99}`) 69 | select { 70 | case message := <-conn.received: 71 | Expect(message).To(BeAssignableToTypeOf(closeMessage{})) 72 | Expect(message.(closeMessage).Error).NotTo(BeNil()) 73 | case <-time.After(100 * time.Millisecond): 74 | Fail("timed out") 75 | } 76 | close(done) 77 | }) 78 | }) 79 | }) 80 | 81 | Describe("Ping", func() { 82 | Context("When a ping is received", func() { 83 | It("should ignore it", func(done Done) { 84 | conn.ClientSend(`{"type":6}`) 85 | select { 86 | case <-conn.received: 87 | Fail("ping not ignored") 88 | case <-time.After(100 * time.Millisecond): 89 | } 90 | close(done) 91 | }) 92 | }) 93 | }) 94 | }) 95 | 96 | type handshakeHub struct { 97 | Hub 98 | } 99 | 100 | func (h *handshakeHub) Shake() { 101 | shakeQueue <- "Shake()" 102 | } 103 | 104 | var shakeQueue = make(chan string, 10) 105 | 106 | func getTestBedHandshake() (*testingConnection, context.CancelFunc) { 107 | ctx, cancel := context.WithCancel(context.Background()) 108 | server, _ := NewServer(ctx, SimpleHubFactory(&handshakeHub{}), testLoggerOption()) 109 | conn := newTestingConnection() 110 | go func() { _ = server.Serve(conn) }() 111 | return conn, cancel 112 | } 113 | 114 | var _ = Describe("Handshake", func() { 115 | Context("When the handshake is sent as one message to the server", func() { 116 | It("should be connected", func(done Done) { 117 | conn, cancel := getTestBedHandshake() 118 | conn.ClientSend(`{"protocol": "json","version": 1}`) 119 | conn.ClientSend(`{"type":1,"invocationId": "123A","target":"shake"}`) 120 | Expect(<-shakeQueue).To(Equal("Shake()")) 121 | cancel() 122 | close(done) 123 | }) 124 | }) 125 | Context("When the handshake is sent as partial message to the server", func() { 126 | It("should be connected", func(done Done) { 127 | conn, cancel := getTestBedHandshake() 128 | _, _ = conn.cliWriter.Write([]byte(`{"protocol"`)) 129 | conn.ClientSend(`: "json","version": 1}`) 130 | conn.ClientSend(`{"type":1,"invocationId": "123B","target":"shake"}`) 131 | Expect(<-shakeQueue).To(Equal("Shake()")) 132 | cancel() 133 | close(done) 134 | }) 135 | }) 136 | Context("When an invalid handshake is sent as partial message to the server", func() { 137 | It("should not be connected", func(done Done) { 138 | conn, cancel := getTestBedHandshake() 139 | _, _ = conn.cliWriter.Write([]byte(`{"protocol"`)) 140 | // Opening curly brace is invalid 141 | conn.ClientSend(`{: "json","version": 1}`) 142 | conn.ClientSend(`{"type":1,"invocationId": "123C","target":"shake"}`) 143 | select { 144 | case <-shakeQueue: 145 | Fail("server connected with invalid handshake") 146 | case <-time.After(100 * time.Millisecond): 147 | } 148 | cancel() 149 | close(done) 150 | }) 151 | }) 152 | Context("When a handshake is sent with an unsupported protocol", func() { 153 | It("should return an error handshake response and be not connected", func(done Done) { 154 | conn, cancel := getTestBedHandshake() 155 | conn.ClientSend(`{"protocol": "bson","version": 1}`) 156 | response, err := conn.ClientReceive() 157 | Expect(err).To(BeNil()) 158 | Expect(response).NotTo(BeNil()) 159 | jsonMap := make(map[string]interface{}) 160 | err = json.Unmarshal([]byte(response), &jsonMap) 161 | Expect(err).To(BeNil()) 162 | Expect(jsonMap["error"]).NotTo(BeNil()) 163 | conn.ClientSend(`{"type":1,"invocationId": "123D","target":"shake"}`) 164 | select { 165 | case <-shakeQueue: 166 | Fail("server connected with invalid handshake") 167 | case <-time.After(100 * time.Millisecond): 168 | } 169 | cancel() 170 | close(done) 171 | }) 172 | }) 173 | Context("When the connection fails before the server can receive handshake request", func() { 174 | It("should not be connected", func(done Done) { 175 | conn, cancel := getTestBedHandshake() 176 | conn.SetFailRead("failed read in handshake") 177 | conn.ClientSend(`{"protocol": "json","version": 1}`) 178 | conn.ClientSend(`{"type":1,"invocationId": "123E","target":"shake"}`) 179 | select { 180 | case <-shakeQueue: 181 | Fail("server connected with fail before handshake") 182 | case <-time.After(100 * time.Millisecond): 183 | } 184 | cancel() 185 | close(done) 186 | }) 187 | }) 188 | Context("When the handshake is received by the server but the connection fails when the response should be sent ", func() { 189 | It("should not be connected", func(done Done) { 190 | conn, cancel := getTestBedHandshake() 191 | conn.SetFailWrite("failed write in handshake") 192 | conn.ClientSend(`{"protocol": "json","version": 1}`) 193 | conn.ClientSend(`{"type":1,"invocationId": "123F","target":"shake"}`) 194 | select { 195 | case <-shakeQueue: 196 | Fail("server connected with fail before handshake") 197 | case <-time.After(100 * time.Millisecond): 198 | } 199 | cancel() 200 | close(done) 201 | }) 202 | }) 203 | Context("When the handshake with an unsupported protocol is received by the server but the connection fails when the response should be sent ", func() { 204 | It("should not be connected", func(done Done) { 205 | conn, cancel := getTestBedHandshake() 206 | conn.SetFailWrite("failed write in handshake") 207 | conn.ClientSend(`{"protocol": "bson","version": 1}`) 208 | conn.ClientSend(`{"type":1,"invocationId": "123G","target":"shake"}`) 209 | select { 210 | case <-shakeQueue: 211 | Fail("server connected with fail before handshake") 212 | case <-time.After(100 * time.Millisecond): 213 | } 214 | cancel() 215 | close(done) 216 | }) 217 | }) 218 | Context("When the handshake connection is initiated, but the client does not send a handshake request within the handshake timeout ", func() { 219 | It("should not be connected", func(done Done) { 220 | server, _ := NewServer(context.TODO(), SimpleHubFactory(&handshakeHub{}), HandshakeTimeout(time.Millisecond*100), testLoggerOption()) 221 | conn := newTestingConnection() 222 | go func() { _ = server.Serve(conn) }() 223 | time.Sleep(time.Millisecond * 200) 224 | conn.ClientSend(`{"protocol": "json","version": 1}`) 225 | conn.ClientSend(`{"type":1,"invocationId": "123H","target":"shake"}`) 226 | select { 227 | case <-shakeQueue: 228 | Fail("server connected with fail before handshake") 229 | case <-time.After(100 * time.Millisecond): 230 | } 231 | server.cancel() 232 | close(done) 233 | }, 2.0) 234 | }) 235 | }) 236 | -------------------------------------------------------------------------------- /connectionbase.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // ConnectionBase is a baseclass for implementers of the Connection interface. 9 | type ConnectionBase struct { 10 | mx sync.RWMutex 11 | ctx context.Context 12 | connectionID string 13 | } 14 | 15 | // NewConnectionBase creates a new ConnectionBase 16 | func NewConnectionBase(ctx context.Context, connectionID string) *ConnectionBase { 17 | cb := &ConnectionBase{ 18 | ctx: ctx, 19 | connectionID: connectionID, 20 | } 21 | return cb 22 | } 23 | 24 | // Context can be used to wait for cancellation of the Connection 25 | func (cb *ConnectionBase) Context() context.Context { 26 | cb.mx.RLock() 27 | defer cb.mx.RUnlock() 28 | return cb.ctx 29 | } 30 | 31 | // ConnectionID is the ID of the connection. 32 | func (cb *ConnectionBase) ConnectionID() string { 33 | cb.mx.RLock() 34 | defer cb.mx.RUnlock() 35 | return cb.connectionID 36 | } 37 | 38 | // SetConnectionID sets the ConnectionID 39 | func (cb *ConnectionBase) SetConnectionID(id string) { 40 | cb.mx.Lock() 41 | defer cb.mx.Unlock() 42 | cb.connectionID = id 43 | } 44 | 45 | // ReadWriteWithContext is a wrapper to make blocking io.Writer / io.Reader cancelable. 46 | // It can be used to implement cancellation of connections. 47 | // ReadWriteWithContext will return when either the Read/Write operation has ended or ctx has been canceled. 48 | // doRW func() (int, error) 49 | // doRW should contain the Read/Write operation. 50 | // unblockRW func() 51 | // unblockRW should contain the operation to unblock the Read/Write operation. 52 | // If there is no way to unblock the operation, one goroutine will leak when ctx is canceled. 53 | // As the standard use case when ReadWriteWithContext is canceled is the cancellation of a connection this leak 54 | // will be problematic on heavily used servers with uncommon connection types. Luckily, the standard connection types 55 | // for ServerSentEvents, Websockets and common net.Conn connections can be unblocked. 56 | func ReadWriteWithContext(ctx context.Context, doRW func() (int, error), unblockRW func()) (int, error) { 57 | if ctx.Err() != nil { 58 | return 0, ctx.Err() 59 | } 60 | resultChan := make(chan RWJobResult, 1) 61 | go func() { 62 | n, err := doRW() 63 | resultChan <- RWJobResult{n: n, err: err} 64 | close(resultChan) 65 | }() 66 | select { 67 | case <-ctx.Done(): 68 | unblockRW() 69 | return 0, ctx.Err() 70 | case r := <-resultChan: 71 | return r.n, r.err 72 | } 73 | } 74 | 75 | // RWJobResult can be used to send the result of an io.Writer / io.Reader operation over a channel. 76 | // Use it for special connection types, where ReadWriteWithContext does not fit all needs. 77 | type RWJobResult struct { 78 | n int 79 | err error 80 | } 81 | -------------------------------------------------------------------------------- /ctxpipe.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "sync" 8 | ) 9 | 10 | // onceError is an object that will only store an error once. 11 | type onceError struct { 12 | sync.Mutex // guards following 13 | err error 14 | } 15 | 16 | func (a *onceError) Store(err error) { 17 | a.Lock() 18 | defer a.Unlock() 19 | if a.err != nil { 20 | return 21 | } 22 | a.err = err 23 | } 24 | func (a *onceError) Load() error { 25 | a.Lock() 26 | defer a.Unlock() 27 | return a.err 28 | } 29 | 30 | // ErrClosedPipe is the error used for read or write operations on a closed pipe. 31 | var ErrClosedPipe = errors.New("io: read/write on closed pipe") 32 | 33 | // A pipe is the shared pipe structure underlying PipeReader and PipeWriter. 34 | type pipe struct { 35 | wrMu sync.Mutex // Serializes Write operations 36 | wrCh chan []byte 37 | rdCh chan int 38 | 39 | once sync.Once // Protects closing done 40 | ctx context.Context 41 | cancel context.CancelFunc 42 | rErr onceError 43 | wErr onceError 44 | } 45 | 46 | func (p *pipe) Read(b []byte) (n int, err error) { 47 | select { 48 | case <-p.ctx.Done(): 49 | return 0, p.readCloseError() 50 | default: 51 | } 52 | 53 | select { 54 | case bw := <-p.wrCh: 55 | nr := copy(b, bw) 56 | p.rdCh <- nr 57 | return nr, nil 58 | case <-p.ctx.Done(): 59 | return 0, p.readCloseError() 60 | } 61 | } 62 | 63 | func (p *pipe) readCloseError() error { 64 | rErr := p.rErr.Load() 65 | if wErr := p.wErr.Load(); rErr == nil && wErr != nil { 66 | return wErr 67 | } 68 | return ErrClosedPipe 69 | } 70 | 71 | func (p *pipe) CloseRead(err error) error { 72 | if err == nil { 73 | err = ErrClosedPipe 74 | } 75 | p.rErr.Store(err) 76 | p.once.Do(func() { p.cancel() }) 77 | return nil 78 | } 79 | 80 | func (p *pipe) Write(b []byte) (n int, err error) { 81 | select { 82 | case <-p.ctx.Done(): 83 | return 0, p.writeCloseError() 84 | default: 85 | p.wrMu.Lock() 86 | defer p.wrMu.Unlock() 87 | } 88 | 89 | for once := true; once || len(b) > 0; once = false { 90 | select { 91 | case p.wrCh <- b: 92 | nw := <-p.rdCh 93 | b = b[nw:] 94 | n += nw 95 | case <-p.ctx.Done(): 96 | return n, p.writeCloseError() 97 | } 98 | } 99 | return n, nil 100 | } 101 | 102 | func (p *pipe) writeCloseError() error { 103 | wErr := p.wErr.Load() 104 | if rErr := p.rErr.Load(); wErr == nil && rErr != nil { 105 | return rErr 106 | } 107 | return ErrClosedPipe 108 | } 109 | 110 | func (p *pipe) CloseWrite(err error) error { 111 | if err == nil { 112 | err = io.EOF 113 | } 114 | p.wErr.Store(err) 115 | p.once.Do(func() { p.cancel() }) 116 | return nil 117 | } 118 | 119 | // A PipeReader is the read half of a pipe. 120 | type PipeReader struct { 121 | p *pipe 122 | } 123 | 124 | // Read implements the standard Read interface: 125 | // it reads data from the pipe, blocking until a writer 126 | // arrives or the write end is closed. 127 | // If the write end is closed with an error, that error is 128 | // returned as err; otherwise err is EOF. 129 | func (r *PipeReader) Read(data []byte) (n int, err error) { 130 | return r.p.Read(data) 131 | } 132 | 133 | // Close closes the reader; subsequent writes to the 134 | // write half of the pipe will return the error ErrClosedPipe. 135 | func (r *PipeReader) Close() error { 136 | return r.CloseWithError(nil) 137 | } 138 | 139 | // CloseWithError closes the reader; subsequent writes 140 | // to the write half of the pipe will return the error err. 141 | // 142 | // CloseWithError never overwrites the previous error if it exists 143 | // and always returns nil. 144 | func (r *PipeReader) CloseWithError(err error) error { 145 | return r.p.CloseRead(err) 146 | } 147 | 148 | // A PipeWriter is the write half of a pipe. 149 | type PipeWriter struct { 150 | p *pipe 151 | } 152 | 153 | // Write implements the standard Write interface: 154 | // it writes data to the pipe, blocking until one or more readers 155 | // have consumed all the data or the read end is closed. 156 | // If the read end is closed with an error, that err is 157 | // returned as err; otherwise err is ErrClosedPipe. 158 | func (w *PipeWriter) Write(data []byte) (n int, err error) { 159 | return w.p.Write(data) 160 | } 161 | 162 | // Close closes the writer; subsequent reads from the 163 | // read half of the pipe will return no bytes and EOF. 164 | func (w *PipeWriter) Close() error { 165 | return w.CloseWithError(nil) 166 | } 167 | 168 | // CloseWithError closes the writer; subsequent reads from the 169 | // read half of the pipe will return no bytes and the error err, 170 | // or EOF if err is nil. 171 | // 172 | // CloseWithError never overwrites the previous error if it exists 173 | // and always returns nil. 174 | func (w *PipeWriter) CloseWithError(err error) error { 175 | return w.p.CloseWrite(err) 176 | } 177 | 178 | // CtxPipe creates a synchronous in-memory pipe. 179 | // It can be used to connect code expecting an io.Reader 180 | // with code expecting an io.Writer. 181 | // 182 | // By canceling the context, Read and Write can be canceled 183 | // 184 | // Reads and Writes on the pipe are matched one to one 185 | // except when multiple Reads are needed to consume a single Write. 186 | // That is, each Write to the PipeWriter blocks until it has satisfied 187 | // one or more Reads from the PipeReader that fully consume 188 | // the written data. 189 | // The data is copied directly from the Write to the corresponding 190 | // Read (or Reads); there is no internal buffering. 191 | // 192 | // It is safe to call Read and Write in parallel with each other or with Close. 193 | // Parallel calls to Read and parallel calls to Write are also safe: 194 | // the individual calls will be gated sequentially. 195 | func CtxPipe(ctx context.Context) (*PipeReader, *PipeWriter) { 196 | p := &pipe{ 197 | wrCh: make(chan []byte), 198 | rdCh: make(chan int), 199 | } 200 | p.ctx, p.cancel = context.WithCancel(ctx) 201 | return &PipeReader{p}, &PipeWriter{p} 202 | } 203 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package signalr contains a SignalR client and a SignalR server. 3 | Both support the transport types Websockets and Server-Sent Events 4 | and the transfer formats Text (JSON) and Binary (MessagePack). 5 | 6 | # Basics 7 | 8 | The SignalR Protocol is a protocol for two-way RPC over any stream- or message-based transport. 9 | Either party in the connection may invoke procedures on the other party, 10 | and procedures can return zero or more results or an error. 11 | Typically, SignalR connections are HTTP-based, but it is dead simple to implement a signalr.Connection on any transport 12 | that supports io.Reader and io.Writer. 13 | 14 | # Client 15 | 16 | A Client can be used in client side code to access server methods. From an existing connection, it can be created with NewClient(). 17 | 18 | // NewClient with raw TCP connection and MessagePack encoding 19 | conn, err := net.Dial("tcp", "example.com:6502") 20 | client := NewClient(ctx, 21 | WithConnection(NewNetConnection(ctx, conn)), 22 | TransferFormat("Binary), 23 | WithReceiver(receiver)) 24 | 25 | client.Start() 26 | 27 | A special case is NewHTTPClient(), which creates a Client from a server address and negotiates with the server 28 | which kind of connection (Websockets, Server-Sent Events) will be used. 29 | 30 | // Configurable HTTP connection 31 | conn, err := NewHTTPConnection(ctx, "http://example.com/hub", WithHTTPHeaders(..)) 32 | // Client with JSON encoding 33 | client, err := NewClient(ctx, 34 | WithConnection(conn), 35 | TransferFormat("Text"), 36 | WithReceiver(receiver)) 37 | 38 | client.Start() 39 | 40 | The object which will receive server callbacks is passed to NewClient() by using the WithReceiver option. 41 | After calling client.Start(), the client is ready to call server methods or to receive callbacks. 42 | 43 | # Server 44 | 45 | A Server provides the public methods of a server side class over signalr to the client. 46 | Such a server side class is called a hub and must implement HubInterface. 47 | It is reasonable to derive your hubs from the Hub struct type, which already implements HubInterface. 48 | Servers for arbitrary connection types can be created with NewServer(). 49 | 50 | // Typical server with log level debug to Stderr 51 | server, err := NewServer(ctx, SimpleHubFactory(hub), Logger(log.NewLogfmtLogger(os.Stderr), true)) 52 | 53 | To serve a connection, call server.Serve(connection) in a goroutine. Serve ends when the connection is closed or the 54 | servers context is canceled. 55 | 56 | // Serving over TCP, accepting client who use MessagePack or JSON 57 | addr, _ := net.ResolveTCPAddr("tcp", "localhost:6502") 58 | listener, _ := net.ListenTCP("tcp", addr) 59 | tcpConn, _ := listener.Accept() 60 | go server.Serve(NewNetConnection(conn)) 61 | 62 | To serve a HTTP connection, use server.MapHTTP(), which connects the server with a path in an http.ServeMux. 63 | The server then automatically negotiates which kind of connection (Websockets, Server-Sent Events) will be used. 64 | 65 | // build a signalr.Server using your hub 66 | // and any server options you may need 67 | server, _ := signalr.NewServer(ctx, 68 | signalr.SimpleHubFactory(&AppHub{}) 69 | signalr.KeepAliveInterval(2*time.Second), 70 | signalr.Logger(kitlog.NewLogfmtLogger(os.Stderr), true)) 71 | ) 72 | 73 | // create a new http.ServerMux to handle your app's http requests 74 | router := http.NewServeMux() 75 | 76 | // ask the signalr server to map it's server 77 | // api routes to your custom baseurl 78 | server.MapHTTP(signalr.WithHTTPServeMux(router), "/hub") 79 | 80 | // in addition to mapping the signalr routes 81 | // your mux will need to serve the static files 82 | // which make up your client-side app, including 83 | // the signalr javascript files. here is an example 84 | // of doing that using a local `public` package 85 | // which was created with the go:embed directive 86 | // 87 | // fmt.Printf("Serving static content from the embedded filesystem\n") 88 | // router.Handle("/", http.FileServer(http.FS(public.FS))) 89 | 90 | // bind your mux to a given address and start handling requests 91 | fmt.Printf("Listening for websocket connections on http://%s\n", address) 92 | if err := http.ListenAndServe(address, router); err != nil { 93 | log.Fatal("ListenAndServe:", err) 94 | } 95 | 96 | # Supported method signatures 97 | 98 | The SignalR protocol constrains the signature of hub or receiver methods that can be used over SignalR. 99 | All methods with serializable types as parameters and return types are supported. 100 | Methods with multiple return values are not generally supported, but returning one or no value and an optional error is supported. 101 | 102 | // Simple signatures for hub/receiver methods 103 | func (mh *MathHub) Divide(a, b float64) (float64, error) // error on division by zero 104 | func (ah *AlgoHub) Sort(values []string) []string 105 | func (ah *AlgoHub) FindKey(value []string, dict map[int][]string) (int, error) // error on not found 106 | func (receiver *View) DisplayServerValue(value interface{}) // will work for every serializable value 107 | 108 | Methods which return a single sending channel (<-chan), and optionally an error, are used to initiate callee side streaming. 109 | The caller will receive the contents of the channel as stream. 110 | When the returned channel is closed, the stream will be completed. 111 | 112 | // Streaming methods 113 | func (n *Netflix) Stream(show string, season, episode int) (<-chan []byte, error) // error on password shared 114 | 115 | Methods with one or multiple receiving channels (chan<-) as parameters are used as receivers for caller side streaming. 116 | The caller invokes this method and pushes one or multiple streams to the callee. The method should end when all channels 117 | are closed. A channel is closed by the server when the assigned stream is completed. 118 | The methods which return a channel are not supported. 119 | 120 | // Caller side streaming 121 | func (mh *MathHub) MultiplyAndSum(a, b chan<- float64) float64 122 | 123 | In most cases, the caller will be the client and the callee the server. But the vice versa case is also possible. 124 | */ 125 | package signalr 126 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippseith/signalr 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/cenkalti/backoff/v4 v4.2.1 7 | github.com/dave/jennifer v1.6.1 8 | github.com/go-kit/log v0.2.1 9 | github.com/google/uuid v1.3.0 10 | github.com/onsi/ginkgo v1.12.1 11 | github.com/onsi/gomega v1.11.0 12 | github.com/stretchr/testify v1.8.4 13 | github.com/teivah/onecontext v1.3.0 14 | github.com/vmihailenco/msgpack/v5 v5.3.5 15 | nhooyr.io/websocket v1.8.7 16 | ) 17 | 18 | require ( 19 | github.com/fsnotify/fsnotify v1.6.0 // indirect 20 | github.com/klauspost/compress v1.16.6 // indirect 21 | golang.org/x/text v0.10.0 // indirect 22 | ) 23 | -------------------------------------------------------------------------------- /groupmanager.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | // GroupManager manages the client groups of the hub 4 | type GroupManager interface { 5 | AddToGroup(groupName string, connectionID string) 6 | RemoveFromGroup(groupName string, connectionID string) 7 | } 8 | 9 | type defaultGroupManager struct { 10 | lifetimeManager HubLifetimeManager 11 | } 12 | 13 | func (d *defaultGroupManager) AddToGroup(groupName string, connectionID string) { 14 | d.lifetimeManager.AddToGroup(groupName, connectionID) 15 | } 16 | 17 | func (d *defaultGroupManager) RemoveFromGroup(groupName string, connectionID string) { 18 | d.lifetimeManager.RemoveFromGroup(groupName, connectionID) 19 | } 20 | -------------------------------------------------------------------------------- /httpconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "path" 11 | 12 | "nhooyr.io/websocket" 13 | ) 14 | 15 | // Doer is the *http.Client interface 16 | type Doer interface { 17 | Do(req *http.Request) (*http.Response, error) 18 | } 19 | 20 | type httpConnection struct { 21 | client Doer 22 | headers func() http.Header 23 | } 24 | 25 | // WithHTTPClient sets the http client used to connect to the signalR server. 26 | // The client is only used for http requests. It is not used for the websocket connection. 27 | func WithHTTPClient(client Doer) func(*httpConnection) error { 28 | return func(c *httpConnection) error { 29 | c.client = client 30 | return nil 31 | } 32 | } 33 | 34 | // WithHTTPHeaders sets the function for providing request headers for HTTP and websocket requests 35 | func WithHTTPHeaders(headers func() http.Header) func(*httpConnection) error { 36 | return func(c *httpConnection) error { 37 | c.headers = headers 38 | return nil 39 | } 40 | } 41 | 42 | // NewHTTPConnection creates a signalR HTTP Connection for usage with a Client. 43 | // ctx can be used to cancel the SignalR negotiation during the creation of the Connection 44 | // but not the Connection itself. 45 | func NewHTTPConnection(ctx context.Context, address string, options ...func(*httpConnection) error) (Connection, error) { 46 | httpConn := &httpConnection{} 47 | 48 | for _, option := range options { 49 | if option != nil { 50 | if err := option(httpConn); err != nil { 51 | return nil, err 52 | } 53 | } 54 | } 55 | 56 | if httpConn.client == nil { 57 | httpConn.client = http.DefaultClient 58 | } 59 | 60 | reqURL, err := url.Parse(address) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | negotiateURL := *reqURL 66 | negotiateURL.Path = path.Join(negotiateURL.Path, "negotiate") 67 | req, err := http.NewRequestWithContext(ctx, "POST", negotiateURL.String(), nil) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if httpConn.headers != nil { 73 | req.Header = httpConn.headers() 74 | } 75 | 76 | resp, err := httpConn.client.Do(req) 77 | if err != nil { 78 | return nil, err 79 | } 80 | defer func() { closeResponseBody(resp.Body) }() 81 | 82 | if resp.StatusCode != 200 { 83 | return nil, fmt.Errorf("%v %v -> %v", req.Method, req.URL.String(), resp.Status) 84 | } 85 | 86 | body, err := io.ReadAll(resp.Body) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | nr := negotiateResponse{} 92 | if err := json.Unmarshal(body, &nr); err != nil { 93 | return nil, err 94 | } 95 | 96 | q := reqURL.Query() 97 | q.Set("id", nr.ConnectionID) 98 | reqURL.RawQuery = q.Encode() 99 | 100 | // Select the best connection 101 | var conn Connection 102 | switch { 103 | case nr.getTransferFormats("WebTransports") != nil: 104 | // TODO 105 | 106 | case nr.getTransferFormats("WebSockets") != nil: 107 | wsURL := reqURL 108 | 109 | // switch to wss for secure connection 110 | if reqURL.Scheme == "https" { 111 | wsURL.Scheme = "wss" 112 | } else { 113 | wsURL.Scheme = "ws" 114 | } 115 | 116 | opts := &websocket.DialOptions{} 117 | 118 | if httpConn.headers != nil { 119 | opts.HTTPHeader = httpConn.headers() 120 | } else { 121 | opts.HTTPHeader = http.Header{} 122 | } 123 | 124 | for _, cookie := range resp.Cookies() { 125 | opts.HTTPHeader.Add("Cookie", cookie.String()) 126 | } 127 | 128 | ws, _, err := websocket.Dial(ctx, wsURL.String(), opts) 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | // TODO think about if the API should give the possibility to cancel this connection 134 | conn = newWebSocketConnection(context.Background(), nr.ConnectionID, ws) 135 | 136 | case nr.getTransferFormats("ServerSentEvents") != nil: 137 | req, err := http.NewRequest("GET", reqURL.String(), nil) 138 | if err != nil { 139 | return nil, err 140 | } 141 | 142 | if httpConn.headers != nil { 143 | req.Header = httpConn.headers() 144 | } 145 | req.Header.Set("Accept", "text/event-stream") 146 | 147 | resp, err := httpConn.client.Do(req) 148 | if err != nil { 149 | return nil, err 150 | } 151 | 152 | conn, err = newClientSSEConnection(address, nr.ConnectionID, resp.Body) 153 | if err != nil { 154 | return nil, err 155 | } 156 | } 157 | 158 | return conn, nil 159 | } 160 | 161 | // closeResponseBody reads a http response body to the end and closes it 162 | // See https://blog.cubieserver.de/2022/http-connection-reuse-in-go-clients/ 163 | // The body needs to be fully read and closed, otherwise the connection will not be reused 164 | func closeResponseBody(body io.ReadCloser) { 165 | _, _ = io.Copy(io.Discard, body) 166 | _ = body.Close() 167 | } 168 | -------------------------------------------------------------------------------- /httpmux.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | "time" 13 | 14 | "github.com/teivah/onecontext" 15 | "nhooyr.io/websocket" 16 | ) 17 | 18 | type httpMux struct { 19 | mx sync.RWMutex 20 | connectionMap map[string]Connection 21 | server Server 22 | } 23 | 24 | func newHTTPMux(server Server) *httpMux { 25 | return &httpMux{ 26 | connectionMap: make(map[string]Connection), 27 | server: server, 28 | } 29 | } 30 | 31 | func (h *httpMux) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 32 | switch request.Method { 33 | case "POST": 34 | h.handlePost(writer, request) 35 | case "GET": 36 | h.handleGet(writer, request) 37 | default: 38 | writer.WriteHeader(http.StatusBadRequest) 39 | } 40 | } 41 | 42 | func (h *httpMux) handlePost(writer http.ResponseWriter, request *http.Request) { 43 | connectionID := request.URL.Query().Get("id") 44 | if connectionID == "" { 45 | writer.WriteHeader(http.StatusBadRequest) 46 | return 47 | } 48 | info, _ := h.server.prefixLoggers("") 49 | for { 50 | h.mx.RLock() 51 | c, ok := h.connectionMap[connectionID] 52 | h.mx.RUnlock() 53 | if ok { 54 | // Connection is initiated 55 | switch conn := c.(type) { 56 | case *serverSSEConnection: 57 | writer.WriteHeader(conn.consumeRequest(request)) 58 | return 59 | case *negotiateConnection: 60 | // connection start initiated but not completed 61 | default: 62 | // ConnectionID already used for WebSocket(?) 63 | writer.WriteHeader(http.StatusConflict) 64 | return 65 | } 66 | } else { 67 | writer.WriteHeader(http.StatusNotFound) 68 | return 69 | } 70 | <-time.After(10 * time.Millisecond) 71 | _ = info.Log("event", "handlePost for SSE connection repeated") 72 | } 73 | } 74 | 75 | func (h *httpMux) handleGet(writer http.ResponseWriter, request *http.Request) { 76 | upgrade := false 77 | for _, connHead := range strings.Split(request.Header.Get("Connection"), ",") { 78 | if strings.ToLower(strings.TrimSpace(connHead)) == "upgrade" { 79 | upgrade = true 80 | break 81 | } 82 | } 83 | if upgrade && 84 | strings.ToLower(request.Header.Get("Upgrade")) == "websocket" { 85 | h.handleWebsocket(writer, request) 86 | } else if strings.ToLower(request.Header.Get("Accept")) == "text/event-stream" { 87 | h.handleServerSentEvent(writer, request) 88 | } else { 89 | writer.WriteHeader(http.StatusBadRequest) 90 | } 91 | } 92 | 93 | func (h *httpMux) handleServerSentEvent(writer http.ResponseWriter, request *http.Request) { 94 | connectionID := request.URL.Query().Get("id") 95 | if connectionID == "" { 96 | writer.WriteHeader(http.StatusBadRequest) 97 | return 98 | } 99 | h.mx.RLock() 100 | c, ok := h.connectionMap[connectionID] 101 | h.mx.RUnlock() 102 | if ok { 103 | if _, ok := c.(*negotiateConnection); ok { 104 | ctx, _ := onecontext.Merge(h.server.context(), request.Context()) 105 | sseConn, jobChan, jobResultChan, err := newServerSSEConnection(ctx, c.ConnectionID()) 106 | if err != nil { 107 | writer.WriteHeader(http.StatusInternalServerError) 108 | return 109 | } 110 | flusher, ok := writer.(http.Flusher) 111 | if !ok { 112 | writer.WriteHeader(http.StatusInternalServerError) 113 | return 114 | } 115 | // Connection is negotiated but not initiated 116 | // We compose http and send it over sse 117 | writer.Header().Set("Content-Type", "text/event-stream") 118 | writer.Header().Set("Connection", "keep-alive") 119 | writer.Header().Set("Cache-Control", "no-cache") 120 | writer.WriteHeader(http.StatusOK) 121 | // End this Server Sent Event (yes, your response now is one and the client will wait for this initial event to end) 122 | _, _ = fmt.Fprint(writer, ":\r\n\r\n") 123 | writer.(http.Flusher).Flush() 124 | go func() { 125 | // We can't WriteHeader 500 if we get an error as we already wrote the header, so ignore it. 126 | _ = h.serveConnection(sseConn) 127 | }() 128 | // Loop for write jobs from the sseServerConnection 129 | for buf := range jobChan { 130 | n, err := writer.Write(buf) 131 | if err == nil { 132 | flusher.Flush() 133 | } 134 | jobResultChan <- RWJobResult{n: n, err: err} 135 | } 136 | close(jobResultChan) 137 | } else { 138 | // connectionID in use 139 | writer.WriteHeader(http.StatusConflict) 140 | } 141 | } else { 142 | writer.WriteHeader(http.StatusNotFound) 143 | } 144 | } 145 | 146 | func (h *httpMux) handleWebsocket(writer http.ResponseWriter, request *http.Request) { 147 | accOptions := &websocket.AcceptOptions{ 148 | CompressionMode: websocket.CompressionContextTakeover, 149 | InsecureSkipVerify: h.server.insecureSkipVerify(), 150 | OriginPatterns: h.server.originPatterns(), 151 | } 152 | websocketConn, err := websocket.Accept(writer, request, accOptions) 153 | if err != nil { 154 | _, debug := h.server.loggers() 155 | _ = debug.Log(evt, "handleWebsocket", msg, "error accepting websockets", "error", err) 156 | // don't need to write an error header here as websocket.Accept has already used http.Error 157 | return 158 | } 159 | websocketConn.SetReadLimit(int64(h.server.maximumReceiveMessageSize())) 160 | connectionMapKey := request.URL.Query().Get("id") 161 | if connectionMapKey == "" { 162 | // Support websocket connection without negotiate 163 | connectionMapKey = newConnectionID() 164 | h.mx.Lock() 165 | h.connectionMap[connectionMapKey] = &negotiateConnection{ 166 | ConnectionBase{connectionID: connectionMapKey}, 167 | } 168 | h.mx.Unlock() 169 | } 170 | h.mx.RLock() 171 | c, ok := h.connectionMap[connectionMapKey] 172 | h.mx.RUnlock() 173 | if ok { 174 | if _, ok := c.(*negotiateConnection); ok { 175 | // Connection is negotiated but not initiated 176 | ctx, _ := onecontext.Merge(h.server.context(), request.Context()) 177 | err = h.serveConnection(newWebSocketConnection(ctx, c.ConnectionID(), websocketConn)) 178 | if err != nil { 179 | _ = websocketConn.Close(1005, err.Error()) 180 | } 181 | } else { 182 | // Already initiated 183 | _ = websocketConn.Close(1002, "Bad request") 184 | } 185 | } else { 186 | // Not negotiated 187 | _ = websocketConn.Close(1002, "Not found") 188 | } 189 | } 190 | 191 | func (h *httpMux) negotiate(w http.ResponseWriter, req *http.Request) { 192 | if req.Method != "POST" { 193 | w.WriteHeader(http.StatusBadRequest) 194 | } else { 195 | connectionID := newConnectionID() 196 | connectionMapKey := connectionID 197 | negotiateVersion, err := strconv.Atoi(req.Header.Get("negotiateVersion")) 198 | if err != nil { 199 | negotiateVersion = 0 200 | } 201 | connectionToken := "" 202 | if negotiateVersion == 1 { 203 | connectionToken = newConnectionID() 204 | connectionMapKey = connectionToken 205 | } 206 | h.mx.Lock() 207 | h.connectionMap[connectionMapKey] = &negotiateConnection{ 208 | ConnectionBase{connectionID: connectionID}, 209 | } 210 | h.mx.Unlock() 211 | var availableTransports []availableTransport 212 | for _, transport := range h.server.availableTransports() { 213 | switch transport { 214 | case "ServerSentEvents": 215 | availableTransports = append(availableTransports, 216 | availableTransport{ 217 | Transport: "ServerSentEvents", 218 | TransferFormats: []string{"Text"}, 219 | }) 220 | case "WebSockets": 221 | availableTransports = append(availableTransports, 222 | availableTransport{ 223 | Transport: "WebSockets", 224 | TransferFormats: []string{"Text", "Binary"}, 225 | }) 226 | } 227 | } 228 | response := negotiateResponse{ 229 | ConnectionToken: connectionToken, 230 | ConnectionID: connectionID, 231 | NegotiateVersion: negotiateVersion, 232 | AvailableTransports: availableTransports, 233 | } 234 | 235 | w.WriteHeader(http.StatusOK) 236 | _ = json.NewEncoder(w).Encode(response) // Can't imagine an error when encoding 237 | } 238 | } 239 | 240 | func (h *httpMux) serveConnection(c Connection) error { 241 | h.mx.Lock() 242 | h.connectionMap[c.ConnectionID()] = c 243 | h.mx.Unlock() 244 | return h.server.Serve(c) 245 | } 246 | 247 | func newConnectionID() string { 248 | bytes := make([]byte, 16) 249 | // rand.Read only fails when the systems random number generator fails. Rare case, ignore 250 | _, _ = rand.Read(bytes) 251 | // Important: Use URLEncoding. StdEncoding contains "/" which will be randomly part of the connectionID and cause parsing problems 252 | return base64.URLEncoding.EncodeToString(bytes) 253 | } 254 | 255 | type negotiateConnection struct { 256 | ConnectionBase 257 | } 258 | 259 | func (n *negotiateConnection) Read([]byte) (int, error) { 260 | return 0, nil 261 | } 262 | 263 | func (n *negotiateConnection) Write([]byte) (int, error) { 264 | return 0, nil 265 | } 266 | -------------------------------------------------------------------------------- /httpserver_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "strconv" 14 | 15 | "strings" 16 | "time" 17 | 18 | "github.com/go-kit/log/level" 19 | . "github.com/onsi/ginkgo" 20 | . "github.com/onsi/gomega" 21 | "nhooyr.io/websocket" 22 | ) 23 | 24 | type addHub struct { 25 | Hub 26 | } 27 | 28 | func (w *addHub) Add2(i int) int { 29 | return i + 2 30 | } 31 | 32 | func (w *addHub) Echo(s string) string { 33 | return s 34 | } 35 | 36 | var _ = Describe("HTTP server", func() { 37 | for _, transport := range [][]string{ 38 | {"WebSockets", "Text"}, 39 | {"WebSockets", "Binary"}, 40 | {"ServerSentEvents", "Text"}, 41 | } { 42 | transport := transport 43 | Context(fmt.Sprintf("%v %v", transport[0], transport[1]), func() { 44 | Context("A correct negotiation request is sent", func() { 45 | It(fmt.Sprintf("should send a correct negotiation response with support for %v with text protocol", transport), func(done Done) { 46 | // Start server 47 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]), testLoggerOption()) 48 | Expect(err).NotTo(HaveOccurred()) 49 | router := http.NewServeMux() 50 | server.MapHTTP(WithHTTPServeMux(router), "/hub") 51 | testServer := httptest.NewServer(router) 52 | url, _ := url.Parse(testServer.URL) 53 | port, _ := strconv.Atoi(url.Port()) 54 | // Negotiate 55 | negResp := negotiateWebSocketTestServer(port) 56 | Expect(negResp["connectionId"]).NotTo(BeNil()) 57 | Expect(negResp["availableTransports"]).To(BeAssignableToTypeOf([]interface{}{})) 58 | avt := negResp["availableTransports"].([]interface{}) 59 | Expect(len(avt)).To(BeNumerically(">", 0)) 60 | Expect(avt[0]).To(BeAssignableToTypeOf(map[string]interface{}{})) 61 | avtVal := avt[0].(map[string]interface{}) 62 | Expect(avtVal["transport"]).To(Equal(transport[0])) 63 | Expect(avtVal["transferFormats"]).To(BeAssignableToTypeOf([]interface{}{})) 64 | tf := avtVal["transferFormats"].([]interface{}) 65 | Expect(tf).To(ContainElement("Text")) 66 | if transport[0] == "WebSockets" { 67 | Expect(tf).To(ContainElement("Binary")) 68 | } 69 | testServer.Close() 70 | close(done) 71 | }, 2.0) 72 | }) 73 | 74 | Context("A invalid negotiation request is sent", func() { 75 | It(fmt.Sprintf("should send a correct negotiation response with support for %v with text protocol", transport), func(done Done) { 76 | // Start server 77 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]), testLoggerOption()) 78 | Expect(err).NotTo(HaveOccurred()) 79 | router := http.NewServeMux() 80 | server.MapHTTP(WithHTTPServeMux(router), "/hub") 81 | testServer := httptest.NewServer(router) 82 | url, _ := url.Parse(testServer.URL) 83 | port, _ := strconv.Atoi(url.Port()) 84 | waitForPort(port) 85 | // Negotiate the wrong way 86 | resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port)) 87 | Expect(err).To(BeNil()) 88 | Expect(resp).NotTo(BeNil()) 89 | Expect(resp.StatusCode).ToNot(Equal(200)) 90 | testServer.Close() 91 | close(done) 92 | }, 2.0) 93 | }) 94 | 95 | Context("Connection with client", func() { 96 | It("should successfully handle an Invoke call", func(done Done) { 97 | logger := &nonProtocolLogger{testLogger()} 98 | // Start server 99 | ctx, cancel := context.WithCancel(context.Background()) 100 | server, err := NewServer(ctx, 101 | SimpleHubFactory(&addHub{}), HTTPTransports(transport[0]), 102 | MaximumReceiveMessageSize(50000), 103 | Logger(logger, true)) 104 | Expect(err).NotTo(HaveOccurred()) 105 | router := http.NewServeMux() 106 | server.MapHTTP(WithHTTPServeMux(router), "/hub") 107 | testServer := httptest.NewServer(router) 108 | url, _ := url.Parse(testServer.URL) 109 | port, _ := strconv.Atoi(url.Port()) 110 | waitForPort(port) 111 | // Try first connection 112 | conn, err := NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port)) 113 | Expect(err).NotTo(HaveOccurred()) 114 | client, err := NewClient(ctx, 115 | WithConnection(conn), 116 | MaximumReceiveMessageSize(60000), 117 | Logger(logger, true), 118 | TransferFormat(transport[1])) 119 | Expect(err).NotTo(HaveOccurred()) 120 | Expect(client).NotTo(BeNil()) 121 | client.Start() 122 | Expect(<-client.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) 123 | result := <-client.Invoke("Add2", 1) 124 | Expect(result.Error).NotTo(HaveOccurred()) 125 | Expect(result.Value).To(BeEquivalentTo(3)) 126 | 127 | // Try second connection 128 | conn2, err := NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port)) 129 | Expect(err).NotTo(HaveOccurred()) 130 | client2, err := NewClient(ctx, 131 | WithConnection(conn2), 132 | Logger(logger, true), 133 | TransferFormat(transport[1])) 134 | Expect(err).NotTo(HaveOccurred()) 135 | Expect(client2).NotTo(BeNil()) 136 | client2.Start() 137 | Expect(<-client2.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) 138 | result = <-client2.Invoke("Add2", 2) 139 | Expect(result.Error).NotTo(HaveOccurred()) 140 | Expect(result.Value).To(BeEquivalentTo(4)) 141 | // Huge message 142 | hugo := strings.Repeat("#", 2500) 143 | result = <-client.Invoke("Echo", hugo) 144 | Expect(result.Error).NotTo(HaveOccurred()) 145 | s := result.Value.(string) 146 | Expect(s).To(Equal(hugo)) 147 | cancel() 148 | go testServer.Close() 149 | close(done) 150 | }, 2.0) 151 | }) 152 | }) 153 | } 154 | Context("When no negotiation is send", func() { 155 | It("should serve websocket requests", func(done Done) { 156 | // Start server 157 | server, err := NewServer(context.TODO(), SimpleHubFactory(&addHub{}), HTTPTransports("WebSockets"), testLoggerOption()) 158 | Expect(err).NotTo(HaveOccurred()) 159 | router := http.NewServeMux() 160 | server.MapHTTP(WithHTTPServeMux(router), "/hub") 161 | testServer := httptest.NewServer(router) 162 | url, _ := url.Parse(testServer.URL) 163 | port, _ := strconv.Atoi(url.Port()) 164 | waitForPort(port) 165 | handShakeAndCallWebSocketTestServer(port, "") 166 | testServer.Close() 167 | close(done) 168 | }, 5.0) 169 | }) 170 | }) 171 | 172 | type nonProtocolLogger struct { 173 | logger StructuredLogger 174 | } 175 | 176 | func (n *nonProtocolLogger) Log(keyVals ...interface{}) error { 177 | for _, kv := range keyVals { 178 | if kv == "protocol" { 179 | return nil 180 | } 181 | } 182 | return n.logger.Log(keyVals...) 183 | } 184 | 185 | func negotiateWebSocketTestServer(port int) map[string]interface{} { 186 | waitForPort(port) 187 | buf := bytes.Buffer{} 188 | resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port), "text/plain;charset=UTF-8", &buf) 189 | Expect(err).To(BeNil()) 190 | Expect(resp).ToNot(BeNil()) 191 | defer func() { 192 | _ = resp.Body.Close() 193 | }() 194 | var body []byte 195 | body, err = io.ReadAll(resp.Body) 196 | Expect(err).To(BeNil()) 197 | response := make(map[string]interface{}) 198 | err = json.Unmarshal(body, &response) 199 | Expect(err).To(BeNil()) 200 | return response 201 | } 202 | 203 | func handShakeAndCallWebSocketTestServer(port int, connectionID string) { 204 | waitForPort(port) 205 | logger := testLogger() 206 | protocol := jsonHubProtocol{} 207 | protocol.setDebugLogger(level.Debug(logger)) 208 | var urlParam string 209 | if connectionID != "" { 210 | urlParam = fmt.Sprintf("?id=%v", connectionID) 211 | } 212 | ws, _, err := websocket.Dial(context.Background(), fmt.Sprintf("ws://127.0.0.1:%v/hub%v", port, urlParam), nil) 213 | Expect(err).To(BeNil()) 214 | defer func() { 215 | _ = ws.Close(websocket.StatusNormalClosure, "") 216 | }() 217 | wsConn := newWebSocketConnection(context.TODO(), connectionID, ws) 218 | cliConn := newHubConnection(wsConn, &protocol, 1<<15, testLogger()) 219 | _, _ = wsConn.Write(append([]byte(`{"protocol": "json","version": 1}`), 30)) 220 | _, _ = wsConn.Write(append([]byte(`{"type":1,"invocationId":"666","target":"add2","arguments":[1]}`), 30)) 221 | result := make(chan interface{}) 222 | go func() { 223 | for recvResult := range cliConn.Receive() { 224 | if completionMessage, ok := recvResult.message.(completionMessage); ok { 225 | result <- completionMessage.Result 226 | return 227 | } 228 | } 229 | }() 230 | select { 231 | case r := <-result: 232 | var f float64 233 | Expect(protocol.UnmarshalArgument(r, &f)).NotTo(HaveOccurred()) 234 | Expect(f).To(Equal(3.0)) 235 | case <-time.After(1000 * time.Millisecond): 236 | Fail("timed out") 237 | } 238 | } 239 | 240 | func waitForPort(port int) { 241 | for { 242 | if _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", port)); err == nil { 243 | return 244 | } 245 | time.Sleep(100 * time.Millisecond) 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /hub.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // HubInterface is a hubs interface 9 | type HubInterface interface { 10 | Initialize(hubContext HubContext) 11 | OnConnected(connectionID string) 12 | OnDisconnected(connectionID string) 13 | } 14 | 15 | // Hub is a base class for hubs 16 | type Hub struct { 17 | context HubContext 18 | cm sync.RWMutex 19 | } 20 | 21 | // Initialize initializes a hub with a HubContext 22 | func (h *Hub) Initialize(ctx HubContext) { 23 | h.cm.Lock() 24 | defer h.cm.Unlock() 25 | h.context = ctx 26 | } 27 | 28 | // Clients returns the clients of this hub 29 | func (h *Hub) Clients() HubClients { 30 | h.cm.RLock() 31 | defer h.cm.RUnlock() 32 | return h.context.Clients() 33 | } 34 | 35 | // Groups returns the client groups of this hub 36 | func (h *Hub) Groups() GroupManager { 37 | h.cm.RLock() 38 | defer h.cm.RUnlock() 39 | return h.context.Groups() 40 | } 41 | 42 | // Items returns the items for this connection 43 | func (h *Hub) Items() *sync.Map { 44 | h.cm.RLock() 45 | defer h.cm.RUnlock() 46 | return h.context.Items() 47 | } 48 | 49 | // ConnectionID gets the ID of the current connection 50 | func (h *Hub) ConnectionID() string { 51 | h.cm.RLock() 52 | defer h.cm.RUnlock() 53 | return h.context.ConnectionID() 54 | } 55 | 56 | // Context is the context.Context of the current connection 57 | func (h *Hub) Context() context.Context { 58 | h.cm.RLock() 59 | defer h.cm.RUnlock() 60 | return h.context.Context() 61 | } 62 | 63 | // Abort aborts the current connection 64 | func (h *Hub) Abort() { 65 | h.cm.RLock() 66 | defer h.cm.RUnlock() 67 | h.context.Abort() 68 | } 69 | 70 | // Logger returns the loggers used in this server. By this, derived hubs can use the same loggers as the server. 71 | func (h *Hub) Logger() (info StructuredLogger, dbg StructuredLogger) { 72 | h.cm.RLock() 73 | defer h.cm.RUnlock() 74 | return h.context.Logger() 75 | } 76 | 77 | // OnConnected is called when the hub is connected 78 | func (h *Hub) OnConnected(string) {} 79 | 80 | // OnDisconnected is called when the hub is disconnected 81 | func (h *Hub) OnDisconnected(string) {} 82 | -------------------------------------------------------------------------------- /hubclients.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | // HubClients gives the hub access to various client groups 4 | // All() gets a ClientProxy that can be used to invoke methods on all clients connected to the hub 5 | // Caller() gets a ClientProxy that can be used to invoke methods of the current calling client 6 | // Client() gets a ClientProxy that can be used to invoke methods on the specified client connection 7 | // Group() gets a ClientProxy that can be used to invoke methods on all connections in the specified group 8 | type HubClients interface { 9 | All() ClientProxy 10 | Caller() ClientProxy 11 | Client(connectionID string) ClientProxy 12 | Group(groupName string) ClientProxy 13 | } 14 | 15 | type defaultHubClients struct { 16 | lifetimeManager HubLifetimeManager 17 | allCache allClientProxy 18 | } 19 | 20 | func (c *defaultHubClients) All() ClientProxy { 21 | return &c.allCache 22 | } 23 | 24 | func (c *defaultHubClients) Client(connectionID string) ClientProxy { 25 | return &singleClientProxy{connectionID: connectionID, lifetimeManager: c.lifetimeManager} 26 | } 27 | 28 | func (c *defaultHubClients) Group(groupName string) ClientProxy { 29 | return &groupClientProxy{groupName: groupName, lifetimeManager: c.lifetimeManager} 30 | } 31 | 32 | // Caller is only implemented to fulfill the HubClients interface, so the servers defaultHubClients interface can be 33 | // used for implementing Server.HubClients. 34 | func (c *defaultHubClients) Caller() ClientProxy { 35 | return nil 36 | } 37 | 38 | type callerHubClients struct { 39 | defaultHubClients *defaultHubClients 40 | connectionID string 41 | } 42 | 43 | func (c *callerHubClients) All() ClientProxy { 44 | return c.defaultHubClients.All() 45 | } 46 | 47 | func (c *callerHubClients) Caller() ClientProxy { 48 | return c.defaultHubClients.Client(c.connectionID) 49 | } 50 | 51 | func (c *callerHubClients) Client(connectionID string) ClientProxy { 52 | return c.defaultHubClients.Client(connectionID) 53 | } 54 | 55 | func (c *callerHubClients) Group(groupName string) ClientProxy { 56 | return c.defaultHubClients.Group(groupName) 57 | } 58 | -------------------------------------------------------------------------------- /hubconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // hubConnection is used by HubContext, Server and Client to realize the external API. 13 | // hubConnection uses a transport connection (of type Connection) and a hubProtocol to send and receive SignalR messages. 14 | type hubConnection interface { 15 | ConnectionID() string 16 | Receive() <-chan receiveResult 17 | SendInvocation(id string, target string, args []interface{}) error 18 | SendStreamInvocation(id string, target string, args []interface{}) error 19 | SendInvocationWithStreamIds(id string, target string, args []interface{}, streamIds []string) error 20 | StreamItem(id string, item interface{}) error 21 | Completion(id string, result interface{}, error string) error 22 | Close(error string, allowReconnect bool) error 23 | Ping() error 24 | LastWriteStamp() time.Time 25 | Items() *sync.Map 26 | Context() context.Context 27 | Abort() 28 | } 29 | 30 | type receiveResult struct { 31 | message interface{} 32 | err error 33 | } 34 | 35 | func newHubConnection(connection Connection, protocol hubProtocol, maximumReceiveMessageSize uint, info StructuredLogger) hubConnection { 36 | ctx, cancelFunc := context.WithCancel(connection.Context()) 37 | c := &defaultHubConnection{ 38 | ctx: ctx, 39 | cancelFunc: cancelFunc, 40 | protocol: protocol, 41 | mx: sync.Mutex{}, 42 | connection: connection, 43 | maximumReceiveMessageSize: maximumReceiveMessageSize, 44 | items: &sync.Map{}, 45 | info: info, 46 | } 47 | if connectionWithTransferMode, ok := connection.(ConnectionWithTransferMode); ok { 48 | connectionWithTransferMode.SetTransferMode(protocol.transferMode()) 49 | } 50 | return c 51 | } 52 | 53 | type defaultHubConnection struct { 54 | ctx context.Context 55 | cancelFunc context.CancelFunc 56 | protocol hubProtocol 57 | mx sync.Mutex 58 | connection Connection 59 | maximumReceiveMessageSize uint 60 | items *sync.Map 61 | lastWriteStamp time.Time 62 | info StructuredLogger 63 | } 64 | 65 | func (c *defaultHubConnection) Items() *sync.Map { 66 | return c.items 67 | } 68 | 69 | func (c *defaultHubConnection) Close(errorText string, allowReconnect bool) error { 70 | var closeMessage = closeMessage{ 71 | Type: 7, 72 | Error: errorText, 73 | AllowReconnect: allowReconnect, 74 | } 75 | return c.protocol.WriteMessage(closeMessage, c.connection) 76 | } 77 | 78 | func (c *defaultHubConnection) ConnectionID() string { 79 | return c.connection.ConnectionID() 80 | } 81 | 82 | func (c *defaultHubConnection) Context() context.Context { 83 | return c.ctx 84 | } 85 | 86 | func (c *defaultHubConnection) Abort() { 87 | c.cancelFunc() 88 | } 89 | 90 | func (c *defaultHubConnection) Receive() <-chan receiveResult { 91 | recvChan := make(chan receiveResult, 20) 92 | // Prepare cleanup 93 | writerDone := make(chan struct{}, 1) 94 | // the pipe connects the goroutine which reads from the connection and the goroutine which parses the read data 95 | reader, writer := CtxPipe(c.ctx) 96 | p := make([]byte, c.maximumReceiveMessageSize) 97 | go func(ctx context.Context, connection io.Reader, writer io.Writer, recvChan chan<- receiveResult, writerDone chan<- struct{}) { 98 | loop: 99 | for { 100 | select { 101 | case <-ctx.Done(): 102 | break loop 103 | default: 104 | n, err := connection.Read(p) 105 | if err != nil { 106 | select { 107 | case recvChan <- receiveResult{err: err}: 108 | case <-ctx.Done(): 109 | break loop 110 | } 111 | } 112 | if n > 0 { 113 | _, err = writer.Write(p[:n]) 114 | if err != nil { 115 | select { 116 | case recvChan <- receiveResult{err: err}: 117 | case <-ctx.Done(): 118 | break loop 119 | } 120 | } 121 | } 122 | } 123 | } 124 | // The pipe writer is done 125 | close(writerDone) 126 | }(c.ctx, c.connection, writer, recvChan, writerDone) 127 | // parse 128 | go func(ctx context.Context, reader io.Reader, recvChan chan<- receiveResult, writerDone <-chan struct{}) { 129 | remainBuf := bytes.Buffer{} 130 | loop: 131 | for { 132 | select { 133 | case <-ctx.Done(): 134 | break loop 135 | case <-writerDone: 136 | break loop 137 | default: 138 | messages, err := c.protocol.ParseMessages(reader, &remainBuf) 139 | if err != nil { 140 | select { 141 | case recvChan <- receiveResult{err: err}: 142 | case <-ctx.Done(): 143 | break loop 144 | case <-writerDone: 145 | break loop 146 | } 147 | } else { 148 | for _, message := range messages { 149 | select { 150 | case recvChan <- receiveResult{message: message}: 151 | case <-ctx.Done(): 152 | break loop 153 | case <-writerDone: 154 | break loop 155 | } 156 | } 157 | } 158 | } 159 | } 160 | }(c.ctx, reader, recvChan, writerDone) 161 | return recvChan 162 | } 163 | 164 | func (c *defaultHubConnection) SendInvocation(id string, target string, args []interface{}) error { 165 | if args == nil { 166 | args = make([]interface{}, 0) 167 | } 168 | var invocationMessage = invocationMessage{ 169 | Type: 1, 170 | InvocationID: id, 171 | Target: target, 172 | Arguments: args, 173 | } 174 | return c.writeMessage(invocationMessage) 175 | } 176 | 177 | func (c *defaultHubConnection) SendStreamInvocation(id string, target string, args []interface{}) error { 178 | if args == nil { 179 | args = make([]interface{}, 0) 180 | } 181 | var invocationMessage = invocationMessage{ 182 | Type: 4, 183 | InvocationID: id, 184 | Target: target, 185 | Arguments: args, 186 | } 187 | return c.writeMessage(invocationMessage) 188 | } 189 | 190 | func (c *defaultHubConnection) SendInvocationWithStreamIds(id string, target string, args []interface{}, streamIds []string) error { 191 | var invocationMessage = invocationMessage{ 192 | Type: 1, 193 | InvocationID: id, 194 | Target: target, 195 | Arguments: args, 196 | StreamIds: streamIds, 197 | } 198 | return c.writeMessage(invocationMessage) 199 | } 200 | 201 | func (c *defaultHubConnection) StreamItem(id string, item interface{}) error { 202 | var streamItemMessage = streamItemMessage{ 203 | Type: 2, 204 | InvocationID: id, 205 | Item: item, 206 | } 207 | return c.writeMessage(streamItemMessage) 208 | } 209 | 210 | func (c *defaultHubConnection) Completion(id string, result interface{}, error string) error { 211 | var completionMessage = completionMessage{ 212 | Type: 3, 213 | InvocationID: id, 214 | Result: result, 215 | Error: error, 216 | } 217 | return c.writeMessage(completionMessage) 218 | } 219 | 220 | func (c *defaultHubConnection) Ping() error { 221 | var pingMessage = hubMessage{ 222 | Type: 6, 223 | } 224 | return c.writeMessage(pingMessage) 225 | } 226 | 227 | func (c *defaultHubConnection) LastWriteStamp() time.Time { 228 | defer c.mx.Unlock() 229 | c.mx.Lock() 230 | return c.lastWriteStamp 231 | } 232 | 233 | func (c *defaultHubConnection) writeMessage(message interface{}) error { 234 | c.mx.Lock() 235 | c.lastWriteStamp = time.Now() 236 | c.mx.Unlock() 237 | err := func() error { 238 | if c.ctx.Err() != nil { 239 | return fmt.Errorf("hubConnection canceled: %w", c.ctx.Err()) 240 | } 241 | e := make(chan error, 1) 242 | go func() { e <- c.protocol.WriteMessage(message, c.connection) }() 243 | select { 244 | case <-c.ctx.Done(): 245 | return fmt.Errorf("hubConnection canceled: %w", c.ctx.Err()) 246 | case err := <-e: 247 | if err != nil { 248 | c.Abort() 249 | } 250 | return err 251 | } 252 | }() 253 | if err != nil { 254 | _ = c.info.Log(evt, msgSend, "message", fmtMsg(message), "error", err) 255 | } 256 | return err 257 | } 258 | -------------------------------------------------------------------------------- /hubcontext.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // HubContext is a context abstraction for a hub 9 | // Clients gets a HubClients that can be used to invoke methods on clients connected to the hub 10 | // Groups gets a GroupManager that can be used to add and remove connections to named groups 11 | // Items holds key/value pairs scoped to the hubs connection 12 | // ConnectionID gets the ID of the current connection 13 | // Abort aborts the current connection 14 | // Logger returns the logger used in this server 15 | type HubContext interface { 16 | Clients() HubClients 17 | Groups() GroupManager 18 | Items() *sync.Map 19 | ConnectionID() string 20 | Context() context.Context 21 | Abort() 22 | Logger() (info StructuredLogger, dbg StructuredLogger) 23 | } 24 | 25 | type connectionHubContext struct { 26 | abort context.CancelFunc 27 | connection hubConnection 28 | clients HubClients 29 | groups GroupManager 30 | info StructuredLogger 31 | dbg StructuredLogger 32 | } 33 | 34 | func (c *connectionHubContext) Clients() HubClients { 35 | return c.clients 36 | } 37 | 38 | func (c *connectionHubContext) Groups() GroupManager { 39 | return c.groups 40 | } 41 | 42 | func (c *connectionHubContext) Items() *sync.Map { 43 | return c.connection.Items() 44 | } 45 | 46 | func (c *connectionHubContext) ConnectionID() string { 47 | return c.connection.ConnectionID() 48 | } 49 | 50 | func (c *connectionHubContext) Context() context.Context { 51 | return c.connection.Context() 52 | } 53 | 54 | func (c *connectionHubContext) Abort() { 55 | c.abort() 56 | } 57 | 58 | func (c *connectionHubContext) Logger() (info StructuredLogger, dbg StructuredLogger) { 59 | return c.info, c.dbg 60 | } 61 | -------------------------------------------------------------------------------- /hubcontext_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | 13 | . "github.com/onsi/ginkgo" 14 | . "github.com/onsi/gomega" 15 | ) 16 | 17 | type contextHub struct { 18 | Hub 19 | } 20 | 21 | func (c *contextHub) OnConnected(string) { 22 | } 23 | 24 | func (c *contextHub) CallAll() { 25 | c.Clients().All().Send("clientFunc") 26 | } 27 | 28 | func (c *contextHub) CallCaller() { 29 | c.Clients().Caller().Send("clientFunc") 30 | } 31 | 32 | func (c *contextHub) CallClient(connectionID string) { 33 | c.Clients().Client(connectionID).Send("clientFunc") 34 | } 35 | 36 | func (c *contextHub) BuildGroup(connectionID1 string, connectionID2 string) { 37 | c.Groups().AddToGroup("local", connectionID1) 38 | c.Groups().AddToGroup("local", connectionID2) 39 | } 40 | 41 | func (c *contextHub) RemoveFromGroup(connectionID string) { 42 | c.Groups().RemoveFromGroup("local", connectionID) 43 | } 44 | 45 | func (c *contextHub) CallGroup() { 46 | c.Clients().Group("local").Send("clientFunc") 47 | } 48 | 49 | func (c *contextHub) AddItem(key string, value interface{}) { 50 | c.Items().Store(key, value) 51 | } 52 | 53 | func (c *contextHub) GetItem(key string) interface{} { 54 | if item, ok := c.Items().Load(key); ok { 55 | return item 56 | } 57 | return nil 58 | } 59 | 60 | func (c *contextHub) TestConnectionID() { 61 | } 62 | 63 | func (c *contextHub) Abort() { 64 | c.Hub.Abort() 65 | } 66 | 67 | type SimpleReceiver struct { 68 | ch chan struct{} 69 | } 70 | 71 | func (sr *SimpleReceiver) ClientFunc() { 72 | close(sr.ch) 73 | } 74 | 75 | func makeTCPServerAndClients(ctx context.Context, clientCount int) (Server, []Client, []*SimpleReceiver, []Connection, []Connection, error) { 76 | server, err := NewServer(ctx, SimpleHubFactory(&contextHub{}), testLoggerOption()) 77 | if err != nil { 78 | return nil, nil, nil, nil, nil, err 79 | } 80 | cliConn := make([]Connection, clientCount) 81 | srvConn := make([]Connection, clientCount) 82 | receiver := make([]*SimpleReceiver, clientCount) 83 | client := make([]Client, clientCount) 84 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 85 | if err != nil { 86 | return nil, nil, nil, nil, nil, err 87 | } 88 | for i := 0; i < clientCount; i++ { 89 | listener, err := net.ListenTCP("tcp", addr) 90 | go func(i int) { 91 | for { 92 | tcpConn, _ := listener.Accept() 93 | conn := NewNetConnection(ctx, tcpConn) 94 | conn.SetConnectionID(fmt.Sprint(i)) 95 | srvConn[i] = conn 96 | go func() { _ = server.Serve(conn) }() 97 | break 98 | } 99 | }(i) 100 | tcpConn, err := net.Dial("tcp", 101 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port)) 102 | if err != nil { 103 | return nil, nil, nil, nil, nil, err 104 | } 105 | cliConn[i] = NewNetConnection(ctx, tcpConn) 106 | receiver[i] = &SimpleReceiver{ch: make(chan struct{})} 107 | client[i], err = NewClient(ctx, WithConnection(cliConn[i]), WithReceiver(receiver[i]), TransferFormat("Text"), testLoggerOption()) 108 | if err != nil { 109 | return nil, nil, nil, nil, nil, err 110 | } 111 | client[i].Start() 112 | select { 113 | case err := <-client[i].WaitForState(ctx, ClientConnected): 114 | if err != nil { 115 | return nil, nil, nil, nil, nil, err 116 | } 117 | case <-ctx.Done(): 118 | return nil, nil, nil, nil, nil, ctx.Err() 119 | } 120 | } 121 | return server, client, receiver, srvConn, cliConn, nil 122 | } 123 | 124 | func makePipeClientsAndReceivers() ([]Client, []*SimpleReceiver, context.CancelFunc) { 125 | cliConn := make([]*pipeConnection, 3) 126 | srvConn := make([]*pipeConnection, 3) 127 | for i := 0; i < 3; i++ { 128 | cliConn[i], srvConn[i] = newClientServerConnections() 129 | cliConn[i].SetConnectionID(fmt.Sprint(i)) 130 | srvConn[i].SetConnectionID(fmt.Sprint(i)) 131 | } 132 | ctx, cancel := context.WithCancel(context.Background()) 133 | server, _ := NewServer(ctx, SimpleHubFactory(&contextHub{}), testLoggerOption()) 134 | var wg sync.WaitGroup 135 | wg.Add(3) 136 | for i := 0; i < 3; i++ { 137 | go func(i int) { 138 | wg.Done() 139 | _ = server.Serve(srvConn[i]) 140 | }(i) 141 | } 142 | wg.Wait() 143 | client := make([]Client, 3) 144 | receiver := make([]*SimpleReceiver, 3) 145 | for i := 0; i < 3; i++ { 146 | receiver[i] = &SimpleReceiver{ch: make(chan struct{})} 147 | client[i], _ = NewClient(ctx, WithConnection(cliConn[i]), WithReceiver(receiver[i]), testLoggerOption()) 148 | client[i].Start() 149 | <-client[i].WaitForState(ctx, ClientConnected) 150 | } 151 | return client, receiver, cancel 152 | } 153 | 154 | var _ = Describe("HubContext", func() { 155 | for i := 0; i < 10; i++ { 156 | Context("Clients().All()", func() { 157 | It("should invoke all clients", func() { 158 | client, receiver, cancel := makePipeClientsAndReceivers() 159 | r := <-client[0].Invoke("CallAll") 160 | Expect(r.Error).NotTo(HaveOccurred()) 161 | result := 0 162 | for result < 3 { 163 | select { 164 | case <-receiver[0].ch: 165 | result++ 166 | case <-receiver[1].ch: 167 | result++ 168 | case <-receiver[2].ch: 169 | result++ 170 | case <-time.After(2 * time.Second): 171 | Fail("timeout waiting for clients getting results") 172 | } 173 | } 174 | cancel() 175 | }) 176 | }) 177 | Context("Clients().Caller()", func() { 178 | It("should invoke only the caller", func() { 179 | client, receiver, cancel := makePipeClientsAndReceivers() 180 | r := <-client[0].Invoke("CallCaller") 181 | Expect(r.Error).NotTo(HaveOccurred()) 182 | select { 183 | case <-receiver[0].ch: 184 | case <-receiver[1].ch: 185 | Fail("Wrong client received message") 186 | case <-receiver[2].ch: 187 | Fail("Wrong client received message") 188 | case <-time.After(2 * time.Second): 189 | Fail("timeout waiting for clients getting results") 190 | } 191 | cancel() 192 | }) 193 | }) 194 | Context("Clients().Client()", func() { 195 | It("should invoke only the client which was addressed", func() { 196 | client, receiver, cancel := makePipeClientsAndReceivers() 197 | r := <-client[0].Invoke("CallClient", "1") 198 | Expect(r.Error).NotTo(HaveOccurred()) 199 | select { 200 | case <-receiver[0].ch: 201 | Fail("Wrong client received message") 202 | case <-receiver[1].ch: 203 | case <-receiver[2].ch: 204 | Fail("Wrong client received message") 205 | case <-time.After(2 * time.Second): 206 | Fail("timeout waiting for clients getting results") 207 | } 208 | cancel() 209 | }) 210 | }) 211 | } 212 | }) 213 | 214 | func TestGroupShouldInvokeOnlyTheClientsInTheGroup(t *testing.T) { 215 | ctx, cancel := context.WithCancel(context.Background()) 216 | defer cancel() 217 | _, client, receiver, srvConn, _, err := makeTCPServerAndClients(ctx, 3) 218 | assert.NoError(t, err) 219 | select { 220 | case ir := <-client[0].Invoke("buildgroup", srvConn[1].ConnectionID(), srvConn[2].ConnectionID()): 221 | assert.NoError(t, ir.Error) 222 | case <-time.After(100 * time.Millisecond): 223 | assert.Fail(t, "timeout in invoke") 224 | } 225 | select { 226 | case ir := <-client[0].Invoke("callgroup"): 227 | assert.NoError(t, ir.Error) 228 | case <-time.After(100 * time.Millisecond): 229 | assert.Fail(t, "timeout in invoke") 230 | } 231 | gotCalled := 0 232 | select { 233 | case <-receiver[0].ch: 234 | assert.Fail(t, "client 1 received message for client 2, 3") 235 | case <-receiver[1].ch: 236 | gotCalled++ 237 | case <-receiver[2].ch: 238 | gotCalled++ 239 | case <-time.After(100 * time.Millisecond): 240 | if gotCalled < 2 { 241 | assert.Fail(t, "timeout without client 2 and 3 got called") 242 | } 243 | } 244 | } 245 | 246 | func TestRemoveClientsShouldRemoveClientsFromTheGroup(t *testing.T) { 247 | ctx, cancel := context.WithCancel(context.Background()) 248 | defer cancel() 249 | _, client, receiver, srvConn, _, err := makeTCPServerAndClients(ctx, 3) 250 | assert.NoError(t, err) 251 | select { 252 | case ir := <-client[0].Invoke("buildgroup", srvConn[1].ConnectionID(), srvConn[2].ConnectionID()): 253 | assert.NoError(t, ir.Error) 254 | case <-time.After(100 * time.Millisecond): 255 | assert.Fail(t, "timeout in invoke") 256 | } 257 | select { 258 | case ir := <-client[0].Invoke("removefromgroup", srvConn[2].ConnectionID()): 259 | assert.NoError(t, ir.Error) 260 | case <-time.After(100 * time.Millisecond): 261 | assert.Fail(t, "timeout in invoke") 262 | } 263 | select { 264 | case ir := <-client[0].Invoke("callgroup"): 265 | assert.NoError(t, ir.Error) 266 | case <-time.After(100 * time.Millisecond): 267 | assert.Fail(t, "timeout in invoke") 268 | } 269 | gotCalled := false 270 | select { 271 | case <-receiver[0].ch: 272 | assert.Fail(t, "client 1 received message for client 2") 273 | case <-receiver[1].ch: 274 | gotCalled = true 275 | case <-receiver[2].ch: 276 | assert.Fail(t, "client 3 received message for client 2") 277 | case <-time.After(100 * time.Millisecond): 278 | if !gotCalled { 279 | assert.Fail(t, "timeout without client 3 got called") 280 | } 281 | } 282 | } 283 | 284 | func TestItemsShouldHoldItemsConnectionWise(t *testing.T) { 285 | ctx, cancel := context.WithCancel(context.Background()) 286 | defer cancel() 287 | _, client, _, _, _, err := makeTCPServerAndClients(ctx, 2) 288 | assert.NoError(t, err) 289 | select { 290 | case ir := <-client[0].Invoke("additem", "first", 1): 291 | assert.NoError(t, ir.Error) 292 | case <-time.After(100 * time.Millisecond): 293 | assert.Fail(t, "timeout in invoke") 294 | } 295 | select { 296 | case ir := <-client[0].Invoke("getitem", "first"): 297 | assert.NoError(t, ir.Error) 298 | assert.Equal(t, ir.Value, 1.0) 299 | case <-time.After(100 * time.Millisecond): 300 | assert.Fail(t, "timeout in invoke") 301 | } 302 | select { 303 | case ir := <-client[1].Invoke("getitem", "first"): 304 | assert.NoError(t, ir.Error) 305 | assert.Equal(t, ir.Value, nil) 306 | case <-time.After(100 * time.Millisecond): 307 | assert.Fail(t, "timeout in invoke") 308 | } 309 | } 310 | 311 | func TestAbortShouldAbortTheConnectionOfTheCurrentCaller(t *testing.T) { 312 | ctx, cancel := context.WithCancel(context.Background()) 313 | defer cancel() 314 | _, client, _, _, _, err := makeTCPServerAndClients(ctx, 2) 315 | assert.NoError(t, err) 316 | select { 317 | case ir := <-client[0].Invoke("abort"): 318 | assert.Error(t, ir.Error) 319 | select { 320 | case err := <-client[0].WaitForState(ctx, ClientClosed): 321 | assert.NoError(t, err) 322 | case <-time.After(500 * time.Millisecond): 323 | assert.Fail(t, "timeout waiting for client close") 324 | } 325 | case <-time.After(100 * time.Millisecond): 326 | assert.Fail(t, "timeout in invoke") 327 | } 328 | select { 329 | case ir := <-client[1].Invoke("additem", "first", 2): 330 | assert.NoError(t, ir.Error) 331 | case <-time.After(100 * time.Millisecond): 332 | assert.Fail(t, "timeout in invoke") 333 | } 334 | select { 335 | case ir := <-client[1].Invoke("getitem", "first"): 336 | assert.NoError(t, ir.Error) 337 | assert.Equal(t, ir.Value, 2.0) 338 | case <-time.After(100 * time.Millisecond): 339 | assert.Fail(t, "timeout in invoke") 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /hublifetimemanager.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/go-kit/log" 7 | ) 8 | 9 | // HubLifetimeManager is a lifetime manager abstraction for hub instances 10 | // OnConnected() is called when a connection is started 11 | // OnDisconnected() is called when a connection is finished 12 | // InvokeAll() sends an invocation message to all hub connections 13 | // InvokeClient() sends an invocation message to a specified hub connection 14 | // InvokeGroup() sends an invocation message to a specified group of hub connections 15 | // AddToGroup() adds a connection to the specified group 16 | // RemoveFromGroup() removes a connection from the specified group 17 | type HubLifetimeManager interface { 18 | OnConnected(conn hubConnection) 19 | OnDisconnected(conn hubConnection) 20 | InvokeAll(target string, args []interface{}) 21 | InvokeClient(connectionID string, target string, args []interface{}) 22 | InvokeGroup(groupName string, target string, args []interface{}) 23 | AddToGroup(groupName, connectionID string) 24 | RemoveFromGroup(groupName, connectionID string) 25 | } 26 | 27 | func newLifeTimeManager(info StructuredLogger) defaultHubLifetimeManager { 28 | return defaultHubLifetimeManager{ 29 | info: log.WithPrefix(info, "ts", log.DefaultTimestampUTC, 30 | "class", "lifeTimeManager"), 31 | } 32 | } 33 | 34 | type defaultHubLifetimeManager struct { 35 | clients sync.Map 36 | groups sync.Map 37 | info StructuredLogger 38 | } 39 | 40 | func (d *defaultHubLifetimeManager) OnConnected(conn hubConnection) { 41 | d.clients.Store(conn.ConnectionID(), conn) 42 | } 43 | 44 | func (d *defaultHubLifetimeManager) OnDisconnected(conn hubConnection) { 45 | d.clients.Delete(conn.ConnectionID()) 46 | } 47 | 48 | func (d *defaultHubLifetimeManager) InvokeAll(target string, args []interface{}) { 49 | d.clients.Range(func(key, value interface{}) bool { 50 | go func() { 51 | _ = value.(hubConnection).SendInvocation("", target, args) 52 | }() 53 | return true 54 | }) 55 | } 56 | 57 | func (d *defaultHubLifetimeManager) InvokeClient(connectionID string, target string, args []interface{}) { 58 | if client, ok := d.clients.Load(connectionID); ok { 59 | go func() { 60 | _ = client.(hubConnection).SendInvocation("", target, args) 61 | }() 62 | } 63 | } 64 | 65 | func (d *defaultHubLifetimeManager) InvokeGroup(groupName string, target string, args []interface{}) { 66 | if groups, ok := d.groups.Load(groupName); ok { 67 | for _, v := range groups.(map[string]hubConnection) { 68 | conn := v 69 | go func() { 70 | _ = conn.SendInvocation("", target, args) 71 | }() 72 | } 73 | } 74 | } 75 | 76 | func (d *defaultHubLifetimeManager) AddToGroup(groupName string, connectionID string) { 77 | if client, ok := d.clients.Load(connectionID); ok { 78 | groups, _ := d.groups.LoadOrStore(groupName, make(map[string]hubConnection)) 79 | groups.(map[string]hubConnection)[connectionID] = client.(hubConnection) 80 | } 81 | } 82 | 83 | func (d *defaultHubLifetimeManager) RemoveFromGroup(groupName string, connectionID string) { 84 | if groups, ok := d.groups.Load(groupName); ok { 85 | delete(groups.(map[string]hubConnection), connectionID) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /hubprotocol.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | ) 7 | 8 | // hubProtocol interface 9 | // ParseMessages() parses messages from an io.Reader and stores unparsed bytes in remainBuf. 10 | // If buf does not contain the whole message, it returns a nil message and complete false 11 | // WriteMessage writes a message to the specified writer 12 | // UnmarshalArgument() unmarshals a raw message depending of the specified value type into a destination value 13 | type hubProtocol interface { 14 | ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) ([]interface{}, error) 15 | WriteMessage(message interface{}, writer io.Writer) error 16 | UnmarshalArgument(src interface{}, dst interface{}) error 17 | setDebugLogger(dbg StructuredLogger) 18 | transferMode() TransferMode 19 | } 20 | 21 | //easyjson:json 22 | type hubMessage struct { 23 | Type int `json:"type"` 24 | } 25 | 26 | // easyjson:json 27 | type invocationMessage struct { 28 | Type int `json:"type"` 29 | Target string `json:"target"` 30 | InvocationID string `json:"invocationId,omitempty"` 31 | Arguments []interface{} `json:"arguments"` 32 | StreamIds []string `json:"streamIds,omitempty"` 33 | } 34 | 35 | //easyjson:json 36 | type completionMessage struct { 37 | Type int `json:"type"` 38 | InvocationID string `json:"invocationId"` 39 | Result interface{} `json:"result,omitempty"` 40 | Error string `json:"error,omitempty"` 41 | } 42 | 43 | //easyjson:json 44 | type streamItemMessage struct { 45 | Type int `json:"type"` 46 | InvocationID string `json:"invocationId"` 47 | Item interface{} `json:"item"` 48 | } 49 | 50 | //easyjson:json 51 | type cancelInvocationMessage struct { 52 | Type int `json:"type"` 53 | InvocationID string `json:"invocationId"` 54 | } 55 | 56 | //easyjson:json 57 | type closeMessage struct { 58 | Type int `json:"type"` 59 | Error string `json:"error"` 60 | AllowReconnect bool `json:"allowReconnect"` 61 | } 62 | 63 | //easyjson:json 64 | type handshakeRequest struct { 65 | Protocol string `json:"protocol"` 66 | Version int `json:"version"` 67 | } 68 | 69 | //easyjson:json 70 | type handshakeResponse struct { 71 | Error string `json:"error,omitempty"` 72 | } 73 | -------------------------------------------------------------------------------- /invokeclient.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type invokeClient struct { 11 | mx sync.Mutex 12 | resultChans map[string]invocationResultChans 13 | protocol hubProtocol 14 | chanReceiveTimeout time.Duration 15 | } 16 | 17 | func newInvokeClient(protocol hubProtocol, chanReceiveTimeout time.Duration) *invokeClient { 18 | return &invokeClient{ 19 | mx: sync.Mutex{}, 20 | resultChans: make(map[string]invocationResultChans), 21 | protocol: protocol, 22 | chanReceiveTimeout: chanReceiveTimeout, 23 | } 24 | } 25 | 26 | type invocationResultChans struct { 27 | resultChan chan interface{} 28 | errChan chan error 29 | } 30 | 31 | func (i *invokeClient) newInvocation(id string) (chan interface{}, chan error) { 32 | i.mx.Lock() 33 | r := invocationResultChans{ 34 | resultChan: make(chan interface{}, 1), 35 | errChan: make(chan error, 1), 36 | } 37 | i.resultChans[id] = r 38 | i.mx.Unlock() 39 | return r.resultChan, r.errChan 40 | } 41 | 42 | func (i *invokeClient) deleteInvocation(id string) { 43 | i.mx.Lock() 44 | if r, ok := i.resultChans[id]; ok { 45 | delete(i.resultChans, id) 46 | close(r.resultChan) 47 | close(r.errChan) 48 | } 49 | i.mx.Unlock() 50 | } 51 | 52 | func (i *invokeClient) cancelAllInvokes() { 53 | i.mx.Lock() 54 | for _, r := range i.resultChans { 55 | close(r.resultChan) 56 | go func(errChan chan error) { 57 | errChan <- errors.New("message loop ended") 58 | close(errChan) 59 | }(r.errChan) 60 | } 61 | // Clear map 62 | i.resultChans = make(map[string]invocationResultChans) 63 | i.mx.Unlock() 64 | } 65 | 66 | func (i *invokeClient) handlesInvocationID(invocationID string) bool { 67 | i.mx.Lock() 68 | _, ok := i.resultChans[invocationID] 69 | i.mx.Unlock() 70 | return ok 71 | } 72 | 73 | func (i *invokeClient) receiveCompletionItem(completion completionMessage) error { 74 | defer i.deleteInvocation(completion.InvocationID) 75 | i.mx.Lock() 76 | ir, ok := i.resultChans[completion.InvocationID] 77 | i.mx.Unlock() 78 | if ok { 79 | if completion.Error != "" { 80 | done := make(chan struct{}) 81 | go func() { 82 | ir.errChan <- errors.New(completion.Error) 83 | done <- struct{}{} 84 | }() 85 | select { 86 | case <-done: 87 | return nil 88 | case <-time.After(i.chanReceiveTimeout): 89 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client sent error", i.chanReceiveTimeout)} 90 | } 91 | } 92 | if completion.Result != nil { 93 | var result interface{} 94 | if err := i.protocol.UnmarshalArgument(completion.Result, &result); err != nil { 95 | return err 96 | } 97 | done := make(chan struct{}) 98 | go func() { 99 | ir.resultChan <- result 100 | if completion.Error != "" { 101 | ir.errChan <- errors.New(completion.Error) 102 | } else { 103 | ir.errChan <- nil 104 | } 105 | close(done) 106 | }() 107 | select { 108 | case <-done: 109 | return nil 110 | case <-time.After(i.chanReceiveTimeout): 111 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client sent value", i.chanReceiveTimeout)} 112 | } 113 | } 114 | return nil 115 | } 116 | return fmt.Errorf(`unknown completion id "%v"`, completion.InvocationID) 117 | } 118 | -------------------------------------------------------------------------------- /jsonhubprotocol.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "reflect" 10 | 11 | "github.com/go-kit/log" 12 | ) 13 | 14 | // jsonHubProtocol is the JSON based SignalR protocol 15 | type jsonHubProtocol struct { 16 | dbg log.Logger 17 | } 18 | 19 | // Protocol specific messages for correct unmarshaling of arguments or results. 20 | // jsonInvocationMessage is only used in ParseMessages, not in WriteMessage 21 | type jsonInvocationMessage struct { 22 | Type int `json:"type"` 23 | Target string `json:"target"` 24 | InvocationID string `json:"invocationId"` 25 | Arguments []json.RawMessage `json:"arguments"` 26 | StreamIds []string `json:"streamIds,omitempty"` 27 | } 28 | 29 | type jsonStreamItemMessage struct { 30 | Type int `json:"type"` 31 | InvocationID string `json:"invocationId"` 32 | Item json.RawMessage `json:"item"` 33 | } 34 | 35 | type jsonCompletionMessage struct { 36 | Type int `json:"type"` 37 | InvocationID string `json:"invocationId"` 38 | Result json.RawMessage `json:"result,omitempty"` 39 | Error string `json:"error,omitempty"` 40 | } 41 | 42 | type jsonError struct { 43 | raw string 44 | err error 45 | } 46 | 47 | func (j *jsonError) Error() string { 48 | return fmt.Sprintf("%v (source: %v)", j.err, j.raw) 49 | } 50 | 51 | // UnmarshalArgument unmarshals a json.RawMessage depending on the specified value type into value 52 | func (j *jsonHubProtocol) UnmarshalArgument(src interface{}, dst interface{}) error { 53 | rawSrc, ok := src.(json.RawMessage) 54 | if !ok { 55 | return fmt.Errorf("invalid source %#v for UnmarshalArgument", src) 56 | } 57 | if err := json.Unmarshal(rawSrc, dst); err != nil { 58 | return &jsonError{string(rawSrc), err} 59 | } 60 | _ = j.dbg.Log(evt, "UnmarshalArgument", 61 | "argument", string(rawSrc), 62 | "value", fmt.Sprintf("%v", reflect.ValueOf(dst).Elem())) 63 | return nil 64 | } 65 | 66 | // ParseMessages reads all messages from the reader and puts the remaining bytes into remainBuf 67 | func (j *jsonHubProtocol) ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) (messages []interface{}, err error) { 68 | frames, err := readJSONFrames(reader, remainBuf) 69 | if err != nil { 70 | return nil, err 71 | } 72 | message := hubMessage{} 73 | messages = make([]interface{}, 0) 74 | for _, frame := range frames { 75 | err = json.Unmarshal(frame, &message) 76 | _ = j.dbg.Log(evt, "read", msg, string(frame)) 77 | if err != nil { 78 | return nil, &jsonError{string(frame), err} 79 | } 80 | typedMessage, err := j.parseMessage(message.Type, frame) 81 | if err != nil { 82 | return nil, err 83 | } 84 | // No specific type (aka Ping), use hubMessage 85 | if typedMessage == nil { 86 | typedMessage = message 87 | } 88 | messages = append(messages, typedMessage) 89 | } 90 | return messages, nil 91 | } 92 | 93 | func (j *jsonHubProtocol) parseMessage(messageType int, text []byte) (message interface{}, err error) { 94 | switch messageType { 95 | case 1, 4: 96 | jsonInvocation := jsonInvocationMessage{} 97 | if err = json.Unmarshal(text, &jsonInvocation); err != nil { 98 | err = &jsonError{string(text), err} 99 | } 100 | arguments := make([]interface{}, len(jsonInvocation.Arguments)) 101 | for i, a := range jsonInvocation.Arguments { 102 | arguments[i] = a 103 | } 104 | return invocationMessage{ 105 | Type: jsonInvocation.Type, 106 | Target: jsonInvocation.Target, 107 | InvocationID: jsonInvocation.InvocationID, 108 | Arguments: arguments, 109 | StreamIds: jsonInvocation.StreamIds, 110 | }, err 111 | case 2: 112 | jsonStreamItem := jsonStreamItemMessage{} 113 | if err = json.Unmarshal(text, &jsonStreamItem); err != nil { 114 | err = &jsonError{string(text), err} 115 | } 116 | return streamItemMessage{ 117 | Type: jsonStreamItem.Type, 118 | InvocationID: jsonStreamItem.InvocationID, 119 | Item: jsonStreamItem.Item, 120 | }, err 121 | case 3: 122 | jsonCompletion := jsonCompletionMessage{} 123 | if err = json.Unmarshal(text, &jsonCompletion); err != nil { 124 | err = &jsonError{string(text), err} 125 | } 126 | completion := completionMessage{ 127 | Type: jsonCompletion.Type, 128 | InvocationID: jsonCompletion.InvocationID, 129 | Error: jsonCompletion.Error, 130 | } 131 | // Only assign Result when non nil. setting interface{} Result to (json.RawMessage)(nil) 132 | // will produce a value which can not compared to nil even if it is pointing towards nil! 133 | // See https://www.calhoun.io/when-nil-isnt-equal-to-nil/ for explanation 134 | if jsonCompletion.Result != nil { 135 | completion.Result = jsonCompletion.Result 136 | } 137 | return completion, err 138 | case 5: 139 | invocation := cancelInvocationMessage{} 140 | if err = json.Unmarshal(text, &invocation); err != nil { 141 | err = &jsonError{string(text), err} 142 | } 143 | return invocation, err 144 | case 7: 145 | cm := closeMessage{} 146 | if err = json.Unmarshal(text, &cm); err != nil { 147 | err = &jsonError{string(text), err} 148 | } 149 | return cm, err 150 | default: 151 | return nil, nil 152 | } 153 | } 154 | 155 | // readJSONFrames reads all complete frames (delimited by 0x1e) from the reader and puts the remaining bytes into remainBuf 156 | func readJSONFrames(reader io.Reader, remainBuf *bytes.Buffer) ([][]byte, error) { 157 | p := make([]byte, 1<<15) 158 | buf := &bytes.Buffer{} 159 | _, _ = buf.ReadFrom(remainBuf) 160 | // Try getting data until at least one frame is available 161 | for { 162 | n, err := reader.Read(p) 163 | // Some reader implementations return io.EOF additionally to n=0 if no data could be read 164 | if err != nil && !errors.Is(err, io.EOF) { 165 | return nil, err 166 | } 167 | if n > 0 { 168 | _, _ = buf.Write(p[:n]) 169 | frames, err := parseJSONFrames(buf) 170 | if err != nil { 171 | return nil, err 172 | } 173 | if len(frames) > 0 { 174 | _, _ = remainBuf.ReadFrom(buf) 175 | return frames, nil 176 | } 177 | } 178 | } 179 | } 180 | 181 | func parseJSONFrames(buf *bytes.Buffer) ([][]byte, error) { 182 | frames := make([][]byte, 0) 183 | for { 184 | frame, err := buf.ReadBytes(0x1e) 185 | if errors.Is(err, io.EOF) { 186 | // Restore incomplete frame in buffer 187 | _, _ = buf.Write(frame) 188 | break 189 | } 190 | if err != nil { 191 | return nil, err 192 | } 193 | frames = append(frames, frame[:len(frame)-1]) 194 | } 195 | return frames, nil 196 | } 197 | 198 | // WriteMessage writes a message as JSON to the specified writer 199 | func (j *jsonHubProtocol) WriteMessage(message interface{}, writer io.Writer) error { 200 | var b []byte 201 | var err error 202 | if marshaler, ok := message.(json.Marshaler); ok { 203 | b, err = marshaler.MarshalJSON() 204 | } else { 205 | b, err = json.Marshal(message) 206 | } 207 | if err != nil { 208 | return err 209 | } 210 | b = append(b, 0x1e) 211 | _ = j.dbg.Log(evt, "write", msg, string(b)) 212 | _, err = writer.Write(b) 213 | return err 214 | } 215 | 216 | func (j *jsonHubProtocol) transferMode() TransferMode { 217 | return TextTransferMode 218 | } 219 | 220 | func (j *jsonHubProtocol) setDebugLogger(dbg StructuredLogger) { 221 | j.dbg = log.WithPrefix(dbg, "ts", log.DefaultTimestampUTC, "protocol", "JSON") 222 | } 223 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io" 7 | "os" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/go-kit/log" 13 | ) 14 | 15 | type loggerConfig struct { 16 | Enabled bool 17 | Debug bool 18 | } 19 | 20 | var lConf loggerConfig 21 | 22 | var tLog StructuredLogger 23 | 24 | func testLoggerOption() func(Party) error { 25 | testLogger() 26 | return Logger(tLog, lConf.Debug) 27 | } 28 | 29 | func testLogger() StructuredLogger { 30 | if tLog == nil { 31 | lConf = loggerConfig{Enabled: false, Debug: false} 32 | b, err := os.ReadFile("testLogConf.json") 33 | if err == nil { 34 | err = json.Unmarshal(b, &lConf) 35 | if err != nil { 36 | lConf = loggerConfig{Enabled: false, Debug: false} 37 | } 38 | } 39 | writer := io.Discard 40 | if lConf.Enabled { 41 | writer = os.Stderr 42 | } 43 | tLog = log.NewLogfmtLogger(writer) 44 | } 45 | return tLog 46 | } 47 | 48 | type panicLogger struct { 49 | log log.Logger 50 | } 51 | 52 | func (p *panicLogger) Log(keyVals ...interface{}) error { 53 | _ = p.log.Log(keyVals...) 54 | panic("panic as expected") 55 | } 56 | 57 | type testLogWriter struct { 58 | mx sync.Mutex 59 | p []byte 60 | t *testing.T 61 | } 62 | 63 | func (t *testLogWriter) Write(p []byte) (n int, err error) { 64 | t.mx.Lock() 65 | defer t.mx.Unlock() 66 | t.p = append(t.p, p...) 67 | if len(p) > 0 && p[len(p)-1] == 10 { // Will not work on Windows, but doesn't matter. This is only to check if the logger output still looks as expected 68 | t.t.Log(string(t.p)) 69 | t.p = nil 70 | } 71 | return len(p), nil 72 | } 73 | 74 | func Test_PanicLogger(t *testing.T) { 75 | defer func() { 76 | if err := recover(); err != nil { 77 | t.Errorf("panic in logger: '%v'", err) 78 | } 79 | }() 80 | ctx, cancel := context.WithCancel(context.Background()) 81 | server, _ := NewServer(ctx, SimpleHubFactory(&simpleHub{}), 82 | Logger(&panicLogger{log: log.NewLogfmtLogger(&testLogWriter{t: t})}, true), 83 | ChanReceiveTimeout(200*time.Millisecond), 84 | StreamBufferCapacity(5)) 85 | // Create both ends of the connection 86 | cliConn, srvConn := newClientServerConnections() 87 | // Start the server 88 | go func() { _ = server.Serve(srvConn) }() 89 | // Create the Client 90 | client, _ := NewClient(ctx, WithConnection(cliConn), Logger(&panicLogger{log: log.NewLogfmtLogger(&testLogWriter{t: t})}, true)) 91 | // Start it 92 | client.Start() 93 | // Do something 94 | <-client.Send("InvokeMe") 95 | cancel() 96 | } 97 | -------------------------------------------------------------------------------- /messagepackhubprotocol_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | 7 | . "github.com/onsi/ginkgo" 8 | . "github.com/onsi/gomega" 9 | ) 10 | 11 | var _ = Describe("MessagePackHubProtocol", func() { 12 | protocol := messagePackHubProtocol{} 13 | protocol.setDebugLogger(testLogger()) 14 | Context("ParseMessages", func() { 15 | It("should encode/decode an InvocationMessage", func() { 16 | message := invocationMessage{ 17 | Type: 4, 18 | Target: "target", 19 | InvocationID: "1", 20 | // because DecodeSlice below decodes ints to the smallest type and arrays always to []interface{}, we need to be very specific 21 | Arguments: []interface{}{"1", int8(1), []interface{}{int8(7), int8(3)}}, 22 | StreamIds: []string{"0"}, 23 | } 24 | buf := bytes.Buffer{} 25 | err := protocol.WriteMessage(message, &buf) 26 | Expect(err).NotTo(HaveOccurred()) 27 | remainBuf := bytes.Buffer{} 28 | got, err := protocol.ParseMessages(&buf, &remainBuf) 29 | Expect(err).NotTo(HaveOccurred()) 30 | Expect(remainBuf.Len()).To(Equal(0)) 31 | Expect(len(got)).To(Equal(1)) 32 | Expect(got[0]).To(BeAssignableToTypeOf(invocationMessage{})) 33 | gotMsg := got[0].(invocationMessage) 34 | Expect(gotMsg.Type).To(Equal(message.Type)) 35 | Expect(gotMsg.Target).To(Equal(message.Target)) 36 | Expect(gotMsg.InvocationID).To(Equal(message.InvocationID)) 37 | Expect(gotMsg.StreamIds).To(Equal(message.StreamIds)) 38 | for i, gotArg := range gotMsg.Arguments { 39 | // We can not directly compare gotArg and want.Arguments[i] 40 | // because msgpack serializes numbers to the shortest possible type 41 | t := reflect.TypeOf(message.Arguments[i]) 42 | value := reflect.New(t) 43 | Expect(protocol.UnmarshalArgument(gotArg, value.Interface())).NotTo(HaveOccurred()) 44 | Expect(reflect.Indirect(value).Interface()).To(Equal(message.Arguments[i])) 45 | } 46 | }) 47 | }) 48 | }) 49 | -------------------------------------------------------------------------------- /negotiateresponse.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | type availableTransport struct { 4 | Transport string `json:"transport"` 5 | TransferFormats []string `json:"transferFormats"` 6 | } 7 | 8 | type negotiateResponse struct { 9 | ConnectionToken string `json:"connectionToken,omitempty"` 10 | ConnectionID string `json:"connectionId"` 11 | NegotiateVersion int `json:"negotiateVersion,omitempty"` 12 | AvailableTransports []availableTransport `json:"availableTransports"` 13 | } 14 | 15 | func (nr *negotiateResponse) getTransferFormats(transportType string) []string { 16 | for _, transport := range nr.AvailableTransports { 17 | if transport.Transport == transportType { 18 | return transport.TransferFormats 19 | } 20 | } 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /netconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "encoding/base64" 7 | "fmt" 8 | "net" 9 | "time" 10 | ) 11 | 12 | type netConnection struct { 13 | ConnectionBase 14 | conn net.Conn 15 | } 16 | 17 | // NewNetConnection wraps net.Conn into a Connection 18 | func NewNetConnection(ctx context.Context, conn net.Conn) Connection { 19 | netConn := &netConnection{ 20 | ConnectionBase: *NewConnectionBase(ctx, getConnectionID()), 21 | conn: conn, 22 | } 23 | go func() { 24 | <-ctx.Done() 25 | _ = conn.Close() 26 | }() 27 | return netConn 28 | } 29 | 30 | func (nc *netConnection) Write(p []byte) (n int, err error) { 31 | n, err = ReadWriteWithContext(nc.Context(), 32 | func() (int, error) { return nc.conn.Write(p) }, 33 | func() { _ = nc.conn.SetWriteDeadline(time.Now()) }) 34 | if err != nil { 35 | err = fmt.Errorf("%T: %w", nc, err) 36 | } 37 | return n, err 38 | } 39 | 40 | func (nc *netConnection) Read(p []byte) (n int, err error) { 41 | n, err = ReadWriteWithContext(nc.Context(), 42 | func() (int, error) { return nc.conn.Read(p) }, 43 | func() { _ = nc.conn.SetReadDeadline(time.Now()) }) 44 | if err != nil { 45 | err = fmt.Errorf("%T: %w", nc, err) 46 | } 47 | return n, err 48 | } 49 | 50 | func getConnectionID() string { 51 | bytes := make([]byte, 16) 52 | _, _ = rand.Read(bytes) 53 | return base64.StdEncoding.EncodeToString(bytes) 54 | } 55 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/go-kit/log" 9 | "github.com/go-kit/log/level" 10 | ) 11 | 12 | // TimeoutInterval is the interval one Party will consider the other Party disconnected 13 | // if it hasn't received a message (including keep-alive) in it. 14 | // The recommended value is double the KeepAliveInterval value. 15 | // Default is 30 seconds. 16 | func TimeoutInterval(timeout time.Duration) func(Party) error { 17 | return func(p Party) error { 18 | p.setTimeout(timeout) 19 | return nil 20 | } 21 | } 22 | 23 | // HandshakeTimeout is the interval if the other Party doesn't send an initial handshake message within, 24 | // the connection is closed. This is an advanced setting that should only be modified 25 | // if handshake timeout errors are occurring due to severe network latency. 26 | // For more detail on the handshake process, 27 | // see https://github.com/dotnet/aspnetcore/blob/master/src/SignalR/docs/specs/HubProtocol.md 28 | func HandshakeTimeout(timeout time.Duration) func(Party) error { 29 | return func(p Party) error { 30 | p.setHandshakeTimeout(timeout) 31 | return nil 32 | } 33 | } 34 | 35 | // KeepAliveInterval is the interval if the Party hasn't sent a message within, 36 | // a ping message is sent automatically to keep the connection open. 37 | // When changing KeepAliveInterval, change the Timeout setting on the other Party. 38 | // The recommended Timeout value is double the KeepAliveInterval value. 39 | // Default is 15 seconds. 40 | func KeepAliveInterval(interval time.Duration) func(Party) error { 41 | return func(p Party) error { 42 | p.setKeepAliveInterval(interval) 43 | return nil 44 | } 45 | } 46 | 47 | // StreamBufferCapacity is the maximum number of items that can be buffered for client upload streams. 48 | // If this limit is reached, the processing of invocations is blocked until the server processes stream items. 49 | // Default is 10. 50 | func StreamBufferCapacity(capacity uint) func(Party) error { 51 | return func(p Party) error { 52 | if capacity == 0 { 53 | return errors.New("unsupported StreamBufferCapacity 0") 54 | } 55 | p.setStreamBufferCapacity(capacity) 56 | return nil 57 | } 58 | } 59 | 60 | // MaximumReceiveMessageSize is the maximum size in bytes of a single incoming hub message. 61 | // Default is 32768 bytes (32KB) 62 | func MaximumReceiveMessageSize(sizeInBytes uint) func(Party) error { 63 | return func(p Party) error { 64 | if sizeInBytes == 0 { 65 | return errors.New("unsupported maximumReceiveMessageSize 0") 66 | } 67 | p.setMaximumReceiveMessageSize(sizeInBytes) 68 | return nil 69 | } 70 | } 71 | 72 | // ChanReceiveTimeout is the timeout for processing stream items from the client, after StreamBufferCapacity was reached 73 | // If the hub method is not able to process a stream item during the timeout duration, 74 | // the server will send a completion with error. 75 | // Default is 5 seconds. 76 | func ChanReceiveTimeout(timeout time.Duration) func(Party) error { 77 | return func(p Party) error { 78 | p.setChanReceiveTimeout(timeout) 79 | return nil 80 | } 81 | } 82 | 83 | // EnableDetailedErrors If true, detailed exception messages are returned to the other 84 | // Party when an exception is thrown in a Hub method. 85 | // The default is false, as these exception messages can contain sensitive information. 86 | func EnableDetailedErrors(enable bool) func(Party) error { 87 | return func(p Party) error { 88 | p.setEnableDetailedErrors(enable) 89 | return nil 90 | } 91 | } 92 | 93 | // StructuredLogger is the simplest logging interface for structured logging. 94 | // See github.com/go-kit/log 95 | type StructuredLogger interface { 96 | Log(keyVals ...interface{}) error 97 | } 98 | 99 | // Logger sets the logger used by the Party to log info events. 100 | // If debug is true, debug log event are generated, too 101 | func Logger(logger StructuredLogger, debug bool) func(Party) error { 102 | return func(p Party) error { 103 | i, d := buildInfoDebugLogger(logger, debug) 104 | p.setLoggers(i, d) 105 | return nil 106 | } 107 | } 108 | 109 | type recoverLogger struct { 110 | logger log.Logger 111 | } 112 | 113 | func (r *recoverLogger) Log(keyVals ...interface{}) error { 114 | defer func() { 115 | if err := recover(); err != nil { 116 | fmt.Printf("recovering from panic in logger: %v\n", err) 117 | } 118 | }() 119 | return r.logger.Log(keyVals...) 120 | } 121 | 122 | func buildInfoDebugLogger(logger log.Logger, debug bool) (log.Logger, log.Logger) { 123 | if debug { 124 | logger = level.NewFilter(logger, level.AllowDebug()) 125 | } else { 126 | logger = level.NewFilter(logger, level.AllowInfo()) 127 | } 128 | infoLogger := &recoverLogger{level.Info(logger)} 129 | debugLogger := log.With(&recoverLogger{level.Debug(logger)}, "caller", log.DefaultCaller) 130 | return infoLogger, debugLogger 131 | } 132 | -------------------------------------------------------------------------------- /party.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/go-kit/log" 8 | ) 9 | 10 | // Party is the common base of Server and Client. The Party methods are only used internally, 11 | // but the interface is public to allow using Options on Party as parameters for external functions 12 | type Party interface { 13 | context() context.Context 14 | cancel() 15 | 16 | onConnected(hc hubConnection) 17 | onDisconnected(hc hubConnection) 18 | 19 | invocationTarget(hc hubConnection) interface{} 20 | 21 | timeout() time.Duration 22 | setTimeout(timeout time.Duration) 23 | 24 | setHandshakeTimeout(timeout time.Duration) 25 | 26 | keepAliveInterval() time.Duration 27 | setKeepAliveInterval(interval time.Duration) 28 | 29 | insecureSkipVerify() bool 30 | setInsecureSkipVerify(skip bool) 31 | 32 | originPatterns() [] string 33 | setOriginPatterns(orgs []string) 34 | 35 | chanReceiveTimeout() time.Duration 36 | setChanReceiveTimeout(interval time.Duration) 37 | 38 | streamBufferCapacity() uint 39 | setStreamBufferCapacity(capacity uint) 40 | 41 | allowReconnect() bool 42 | 43 | enableDetailedErrors() bool 44 | setEnableDetailedErrors(enable bool) 45 | 46 | loggers() (info StructuredLogger, dbg StructuredLogger) 47 | setLoggers(info StructuredLogger, dbg StructuredLogger) 48 | 49 | prefixLoggers(connectionID string) (info StructuredLogger, dbg StructuredLogger) 50 | 51 | maximumReceiveMessageSize() uint 52 | setMaximumReceiveMessageSize(size uint) 53 | } 54 | 55 | func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase { 56 | ctx, cancelFunc := context.WithCancel(parentContext) 57 | return partyBase{ 58 | ctx: ctx, 59 | cancelFunc: cancelFunc, 60 | _timeout: time.Second * 30, 61 | _handshakeTimeout: time.Second * 15, 62 | _keepAliveInterval: time.Second * 5, 63 | _chanReceiveTimeout: time.Second * 5, 64 | _streamBufferCapacity: 10, 65 | _maximumReceiveMessageSize: 1 << 15, // 32KB 66 | _enableDetailedErrors: false, 67 | _insecureSkipVerify: false, 68 | _originPatterns: nil, 69 | info: info, 70 | dbg: dbg, 71 | } 72 | } 73 | 74 | type partyBase struct { 75 | ctx context.Context 76 | cancelFunc context.CancelFunc 77 | _timeout time.Duration 78 | _handshakeTimeout time.Duration 79 | _keepAliveInterval time.Duration 80 | _chanReceiveTimeout time.Duration 81 | _streamBufferCapacity uint 82 | _maximumReceiveMessageSize uint 83 | _enableDetailedErrors bool 84 | _insecureSkipVerify bool 85 | _originPatterns []string 86 | info StructuredLogger 87 | dbg StructuredLogger 88 | } 89 | 90 | func (p *partyBase) context() context.Context { 91 | return p.ctx 92 | } 93 | 94 | func (p *partyBase) cancel() { 95 | p.cancelFunc() 96 | } 97 | 98 | func (p *partyBase) timeout() time.Duration { 99 | return p._timeout 100 | } 101 | 102 | func (p *partyBase) setTimeout(timeout time.Duration) { 103 | p._timeout = timeout 104 | } 105 | 106 | func (p *partyBase) HandshakeTimeout() time.Duration { 107 | return p._handshakeTimeout 108 | } 109 | 110 | func (p *partyBase) setHandshakeTimeout(timeout time.Duration) { 111 | p._handshakeTimeout = timeout 112 | } 113 | 114 | func (p *partyBase) keepAliveInterval() time.Duration { 115 | return p._keepAliveInterval 116 | } 117 | 118 | func (p *partyBase) setKeepAliveInterval(interval time.Duration) { 119 | p._keepAliveInterval = interval 120 | } 121 | 122 | func (p *partyBase) insecureSkipVerify() bool { 123 | return p._insecureSkipVerify 124 | } 125 | func (p *partyBase) setInsecureSkipVerify(skip bool) { 126 | p._insecureSkipVerify = skip 127 | } 128 | 129 | func (p *partyBase) originPatterns() []string { 130 | return p._originPatterns 131 | } 132 | func (p *partyBase) setOriginPatterns(origins []string) { 133 | p._originPatterns = origins 134 | } 135 | 136 | func (p *partyBase) chanReceiveTimeout() time.Duration { 137 | return p._chanReceiveTimeout 138 | } 139 | 140 | func (p *partyBase) setChanReceiveTimeout(interval time.Duration) { 141 | p._chanReceiveTimeout = interval 142 | } 143 | 144 | func (p *partyBase) streamBufferCapacity() uint { 145 | return p._streamBufferCapacity 146 | } 147 | 148 | func (p *partyBase) setStreamBufferCapacity(capacity uint) { 149 | p._streamBufferCapacity = capacity 150 | } 151 | 152 | func (p *partyBase) maximumReceiveMessageSize() uint { 153 | return p._maximumReceiveMessageSize 154 | } 155 | 156 | func (p *partyBase) setMaximumReceiveMessageSize(size uint) { 157 | p._maximumReceiveMessageSize = size 158 | } 159 | 160 | func (p *partyBase) enableDetailedErrors() bool { 161 | return p._enableDetailedErrors 162 | } 163 | 164 | func (p *partyBase) setEnableDetailedErrors(enable bool) { 165 | p._enableDetailedErrors = enable 166 | } 167 | 168 | func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) { 169 | p.info = info 170 | p.dbg = dbg 171 | } 172 | 173 | func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) { 174 | return p.info, p.dbg 175 | } 176 | -------------------------------------------------------------------------------- /receiver.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | // ReceiverInterface allows receivers to interact with the server directly from the receiver methods 4 | // Init(Client) 5 | // Init is used by the Client to connect the receiver to the server. 6 | // Server() Client 7 | // Server can be used inside receiver methods to call Client methods, 8 | // e.g. Client.Send, Client.Invoke, Client.PullStream and Client.PushStreams 9 | type ReceiverInterface interface { 10 | Init(Client) 11 | Server() Client 12 | } 13 | 14 | // Receiver is a base class for receivers in the client. 15 | // It implements ReceiverInterface 16 | type Receiver struct { 17 | client Client 18 | } 19 | 20 | // Init is used by the Client to connect the receiver to the server. 21 | func (ch *Receiver) Init(client Client) { 22 | ch.client = client 23 | } 24 | 25 | // Server can be used inside receiver methods to call Client methods, 26 | func (ch *Receiver) Server() Client { 27 | return ch.client 28 | } 29 | -------------------------------------------------------------------------------- /router/Makefile: -------------------------------------------------------------------------------- 1 | PROJECT_NAME := "signalr/router" 2 | PKG := "github.com/philippseith/$(PROJECT_NAME)" 3 | PKG_LIST := $(shell go list ${PKG}/... | grep -v /vendor/) 4 | GO_FILES := $(shell find . -name '*.go' | grep -v /vendor/ | grep -v _test.go) 5 | 6 | .PHONY: all dep lint vet test test-coverage build clean 7 | 8 | all: build 9 | 10 | dep: ## Get the dependencies 11 | @go mod download 12 | 13 | lint: ## Lint Golang files 14 | @golint -set_exit_status ${PKG_LIST} 15 | 16 | vet: ## Run go vet 17 | @go vet ${PKG_LIST} 18 | 19 | test: ## Run unittests 20 | @go test -short ${PKG_LIST} 21 | 22 | test-coverage: ## Run tests with coverage 23 | @go test -short -coverpkg=. -coverprofile cover.out -covermode=atomic ${PKG_LIST} 24 | @cat cover.out >> coverage.txt 25 | 26 | build: dep ## Build the binary file 27 | @go build -i -o build/main $(PKG) 28 | 29 | clean: ## Remove previous build 30 | @rm -f $(PROJECT_NAME)/build 31 | 32 | help: ## Display this help screen 33 | @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 34 | -------------------------------------------------------------------------------- /router/chirouter.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/go-chi/chi/v5" 7 | "github.com/philippseith/signalr" 8 | ) 9 | 10 | // WithChiRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP 11 | // which converts a chi.Router to a signalr.MappableRouter. 12 | func WithChiRouter(r chi.Router) func() signalr.MappableRouter { 13 | return func() signalr.MappableRouter { 14 | return &chiRouter{r: r} 15 | } 16 | } 17 | 18 | type chiRouter struct { 19 | r chi.Router 20 | } 21 | 22 | func (j *chiRouter) HandleFunc(path string, handler func(w http.ResponseWriter, r *http.Request)) { 23 | j.r.HandleFunc(path, handler) 24 | } 25 | 26 | func (j *chiRouter) Handle(pattern string, handler http.Handler) { 27 | j.r.Handle(pattern, handler) 28 | } 29 | -------------------------------------------------------------------------------- /router/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package router contains signalr.MappableRouter factories for some popular http routers. 3 | (Listed in https://www.alexedwards.net/blog/which-go-router-should-i-use) 4 | The factories can be used to integrate the signalR server from github.com/philippseith/signalr with the router. 5 | If you don't like to reference router modules you aren't using, just copy the source code 6 | for the factory for your router. 7 | */ 8 | package router 9 | -------------------------------------------------------------------------------- /router/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippseith/signalr/router 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/go-chi/chi/v5 v5.0.5 7 | github.com/go-kit/log v0.2.0 8 | github.com/gorilla/mux v1.8.0 9 | github.com/julienschmidt/httprouter v1.3.0 10 | github.com/onsi/ginkgo v1.12.1 11 | github.com/onsi/gomega v1.11.0 12 | github.com/philippseith/signalr v0.4.2-0.20211029170321-9b04bbc12782 13 | ) 14 | 15 | require ( 16 | github.com/fsnotify/fsnotify v1.4.9 // indirect 17 | github.com/go-logfmt/logfmt v0.5.1 // indirect 18 | github.com/klauspost/compress v1.10.3 // indirect 19 | github.com/nxadm/tail v1.4.8 // indirect 20 | github.com/teivah/onecontext v1.3.0 // indirect 21 | github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect 22 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 23 | golang.org/x/net v0.7.0 // indirect 24 | golang.org/x/sys v0.5.0 // indirect 25 | golang.org/x/text v0.7.0 // indirect 26 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect 27 | gopkg.in/yaml.v2 v2.4.0 // indirect 28 | nhooyr.io/websocket v1.8.7 // indirect 29 | ) 30 | -------------------------------------------------------------------------------- /router/gorillarouter.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gorilla/mux" 7 | "github.com/philippseith/signalr" 8 | ) 9 | 10 | // WithGorillaRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP 11 | // which converts a mux.Router to a signalr.MappableRouter. 12 | func WithGorillaRouter(r *mux.Router) func() signalr.MappableRouter { 13 | return func() signalr.MappableRouter { 14 | return &gorillaMappableRouter{r: r} 15 | } 16 | } 17 | 18 | type gorillaMappableRouter struct { 19 | r *mux.Router 20 | } 21 | 22 | func (g *gorillaMappableRouter) Handle(path string, handler http.Handler) { 23 | g.r.Handle(path, handler) 24 | } 25 | 26 | func (g *gorillaMappableRouter) HandleFunc(path string, handleFunc func(w http.ResponseWriter, r *http.Request)) { 27 | g.r.HandleFunc(path, handleFunc) 28 | } 29 | -------------------------------------------------------------------------------- /router/httprouter.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/julienschmidt/httprouter" 7 | "github.com/philippseith/signalr" 8 | ) 9 | 10 | // WithHttpRouter is a signalr.MappableRouter factory for signalr.Server.MapHTTP 11 | // which converts a httprouter.Router to a signalr.MappableRouter. 12 | func WithHttpRouter(r *httprouter.Router) func() signalr.MappableRouter { 13 | return func() signalr.MappableRouter { 14 | return &julienRouter{r: r} 15 | } 16 | } 17 | 18 | type julienRouter struct { 19 | r *httprouter.Router 20 | } 21 | 22 | func (j *julienRouter) HandleFunc(path string, handler func(w http.ResponseWriter, r *http.Request)) { 23 | j.r.HandlerFunc("POST", path, handler) 24 | } 25 | 26 | func (j *julienRouter) Handle(pattern string, handler http.Handler) { 27 | j.r.Handler("POST", pattern, handler) 28 | j.r.Handler("GET", pattern, handler) 29 | } 30 | -------------------------------------------------------------------------------- /router/router_test/router_test.go: -------------------------------------------------------------------------------- 1 | package router_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io/ioutil" 10 | "net" 11 | "net/http" 12 | "testing" 13 | "time" 14 | 15 | "github.com/go-kit/log" 16 | 17 | "github.com/go-chi/chi/v5" 18 | 19 | "github.com/julienschmidt/httprouter" 20 | 21 | "github.com/gorilla/mux" 22 | 23 | "github.com/philippseith/signalr" 24 | "github.com/philippseith/signalr/router" 25 | 26 | . "github.com/onsi/ginkgo" 27 | . "github.com/onsi/gomega" 28 | ) 29 | 30 | func TestRouter(t *testing.T) { 31 | RegisterFailHandler(Fail) 32 | RunSpecs(t, "signalr/mux Suite") 33 | } 34 | 35 | func initGorillaRouter(server signalr.Server, port int) { 36 | r := mux.NewRouter() 37 | server.MapHTTP(router.WithGorillaRouter(r), "/hub") 38 | go func() { 39 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r) 40 | }() 41 | } 42 | 43 | func initHttpRouter(server signalr.Server, port int) { 44 | r := httprouter.New() 45 | server.MapHTTP(router.WithHttpRouter(r), "/hub") 46 | go func() { 47 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r) 48 | }() 49 | } 50 | 51 | func initChiRouter(server signalr.Server, port int) { 52 | r := chi.NewRouter() 53 | server.MapHTTP(router.WithChiRouter(r), "/hub") 54 | go func() { 55 | _ = http.ListenAndServe(fmt.Sprintf("127.0.0.1:%v", port), r) 56 | }() 57 | } 58 | 59 | var _ = Describe("Router", func() { 60 | for i, initFunc := range []func(server signalr.Server, port int){ 61 | initGorillaRouter, 62 | initHttpRouter, 63 | initChiRouter, 64 | } { 65 | routerNames := []string{ 66 | "gorilla/mux.Router", 67 | "julienschmidt/httprouter", 68 | "chi/Router", 69 | } 70 | Context(fmt.Sprintf("With %v", routerNames[i]), func() { 71 | Context("A correct negotiation request is sent", func() { 72 | It("should send a correct negotiation response", func(done Done) { 73 | // Start server 74 | ctx, serverCancel := context.WithCancel(context.Background()) 75 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&addHub{}), signalr.HTTPTransports("WebSockets")) 76 | Expect(err).NotTo(HaveOccurred()) 77 | port := freePort() 78 | initFunc(server, port) 79 | // Negotiate 80 | negResp := negotiateWebSocketTestServer(port) 81 | Expect(negResp["connectionId"]).NotTo(BeNil()) 82 | Expect(negResp["availableTransports"]).To(BeAssignableToTypeOf([]interface{}{})) 83 | avt := negResp["availableTransports"].([]interface{}) 84 | Expect(len(avt)).To(BeNumerically(">", 0)) 85 | Expect(avt[0]).To(BeAssignableToTypeOf(map[string]interface{}{})) 86 | avtVal := avt[0].(map[string]interface{}) 87 | Expect(avtVal["transferFormats"]).To(BeAssignableToTypeOf([]interface{}{})) 88 | tf := avtVal["transferFormats"].([]interface{}) 89 | Expect(tf).To(ContainElement("Text")) 90 | Expect(tf).To(ContainElement("Binary")) 91 | serverCancel() 92 | close(done) 93 | }) 94 | }) 95 | Context("Connection with client", func() { 96 | It("should successfully handle an Invoke call", func(done Done) { 97 | // Start server 98 | server, err := signalr.NewServer(context.Background(), 99 | signalr.SimpleHubFactory(&addHub{}), 100 | signalr.Logger(log.NewNopLogger(), false), 101 | signalr.HTTPTransports("WebSockets")) 102 | Expect(err).NotTo(HaveOccurred()) 103 | port := freePort() 104 | initFunc(server, port) 105 | waitForPort(port) 106 | // Start client 107 | conn, err := signalr.NewHTTPConnection(context.Background(), fmt.Sprintf("http://127.0.0.1:%v/hub", port)) 108 | Expect(err).NotTo(HaveOccurred()) 109 | client, err := signalr.NewClient(context.Background(), 110 | signalr.WithConnection(conn), 111 | signalr.Logger(log.NewNopLogger(), false), 112 | signalr.TransferFormat("Text")) 113 | Expect(err).NotTo(HaveOccurred()) 114 | Expect(client).NotTo(BeNil()) 115 | errCh := client.Start() 116 | Expect(client.WaitConnected(context.Background())).NotTo(HaveOccurred()) 117 | go func() { 118 | defer GinkgoRecover() 119 | Expect(errors.Is(<-errCh, context.Canceled)).To(BeTrue()) 120 | }() 121 | Expect(err).NotTo(HaveOccurred()) 122 | result := <-client.Invoke("Add2", 1) 123 | Expect(result.Error).NotTo(HaveOccurred()) 124 | Expect(result.Value).To(BeEquivalentTo(3)) 125 | close(done) 126 | }, 10.0) 127 | }) 128 | }) 129 | } 130 | }) 131 | 132 | type addHub struct { 133 | signalr.Hub 134 | } 135 | 136 | func (w *addHub) Add2(i int) int { 137 | return i + 2 138 | } 139 | 140 | func negotiateWebSocketTestServer(port int) map[string]interface{} { 141 | waitForPort(port) 142 | buf := bytes.Buffer{} 143 | resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%v/hub/negotiate", port), "text/plain;charset=UTF-8", &buf) 144 | Expect(err).To(BeNil()) 145 | Expect(resp).ToNot(BeNil()) 146 | defer func() { 147 | _ = resp.Body.Close() 148 | }() 149 | var body []byte 150 | body, err = ioutil.ReadAll(resp.Body) 151 | Expect(err).To(BeNil()) 152 | response := make(map[string]interface{}) 153 | err = json.Unmarshal(body, &response) 154 | Expect(err).To(BeNil()) 155 | return response 156 | } 157 | 158 | func freePort() int { 159 | if addr, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { 160 | if listener, err := net.ListenTCP("tcp", addr); err == nil { 161 | defer func() { 162 | _ = listener.Close() 163 | }() 164 | return listener.Addr().(*net.TCPAddr).Port 165 | } 166 | } 167 | return 0 168 | } 169 | 170 | func waitForPort(port int) { 171 | for { 172 | if _, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", port)); err == nil { 173 | return 174 | } 175 | time.Sleep(100 * time.Millisecond) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "net/http" 10 | "os" 11 | "reflect" 12 | "runtime/debug" 13 | 14 | "github.com/go-kit/log" 15 | ) 16 | 17 | // Server is a SignalR server for one type of hub. 18 | // 19 | // MapHTTP(mux *http.ServeMux, path string) 20 | // maps the servers' hub to a path on a http.ServeMux. 21 | // 22 | // Serve(conn Connection) 23 | // serves the hub of the server on one connection. 24 | // The same server might serve different connections in parallel. Serve does not return until the connection is closed 25 | // or the servers' context is canceled. 26 | // 27 | // HubClients() 28 | // allows to call all HubClients of the server from server-side, non-hub code. 29 | // Note that HubClients.Caller() returns nil, because there is no real caller which can be reached over a HubConnection. 30 | type Server interface { 31 | Party 32 | MapHTTP(routerFactory func() MappableRouter, path string) 33 | Serve(conn Connection) error 34 | HubClients() HubClients 35 | availableTransports() []string 36 | } 37 | 38 | type server struct { 39 | partyBase 40 | newHub func() HubInterface 41 | lifetimeManager HubLifetimeManager 42 | defaultHubClients *defaultHubClients 43 | groupManager GroupManager 44 | reconnectAllowed bool 45 | transports []string 46 | } 47 | 48 | // NewServer creates a new server for one type of hub. The hub type is set by one of the 49 | // options UseHub, HubFactory or SimpleHubFactory 50 | func NewServer(ctx context.Context, options ...func(Party) error) (Server, error) { 51 | info, dbg := buildInfoDebugLogger(log.NewLogfmtLogger(os.Stderr), false) 52 | lifetimeManager := newLifeTimeManager(info) 53 | server := &server{ 54 | lifetimeManager: &lifetimeManager, 55 | defaultHubClients: &defaultHubClients{ 56 | lifetimeManager: &lifetimeManager, 57 | allCache: allClientProxy{lifetimeManager: &lifetimeManager}, 58 | }, 59 | groupManager: &defaultGroupManager{ 60 | lifetimeManager: &lifetimeManager, 61 | }, 62 | partyBase: newPartyBase(ctx, info, dbg), 63 | reconnectAllowed: true, 64 | } 65 | for _, option := range options { 66 | if option != nil { 67 | if err := option(server); err != nil { 68 | return nil, err 69 | } 70 | } 71 | } 72 | if server.transports == nil { 73 | server.transports = []string{"WebSockets", "ServerSentEvents"} 74 | } 75 | if server.newHub == nil { 76 | return server, errors.New("cannot determine hub type. Neither UseHub, HubFactory or SimpleHubFactory given as option") 77 | } 78 | return server, nil 79 | } 80 | 81 | // MappableRouter encapsulates the methods used by server.MapHTTP to configure the 82 | // handlers required by the signalr protocol. this abstraction removes the explicit 83 | // binding to http.ServerMux and allows use of any mux which implements those basic 84 | // Handle and HandleFunc methods. 85 | type MappableRouter interface { 86 | HandleFunc(string, func(w http.ResponseWriter, r *http.Request)) 87 | Handle(string, http.Handler) 88 | } 89 | 90 | // WithHTTPServeMux is a MappableRouter factory for MapHTTP which converts a 91 | // http.ServeMux to a MappableRouter. 92 | // For factories for other routers, see github.com/philippseith/signalr/router 93 | func WithHTTPServeMux(serveMux *http.ServeMux) func() MappableRouter { 94 | return func() MappableRouter { 95 | return serveMux 96 | } 97 | } 98 | 99 | // MapHTTP maps the servers' hub to a path in a MappableRouter 100 | func (s *server) MapHTTP(routerFactory func() MappableRouter, path string) { 101 | httpMux := newHTTPMux(s) 102 | router := routerFactory() 103 | router.HandleFunc(fmt.Sprintf("%s/negotiate", path), httpMux.negotiate) 104 | router.Handle(path, httpMux) 105 | } 106 | 107 | // Serve serves the hub of the server on one connection. 108 | // The same server might serve different connections in parallel. Serve does not return until the connection is closed 109 | // or the servers' context is canceled. 110 | func (s *server) Serve(conn Connection) error { 111 | 112 | protocol, err := s.processHandshake(conn) 113 | if err != nil { 114 | info, _ := s.prefixLoggers("") 115 | _ = info.Log(evt, "processHandshake", "connectionId", conn.ConnectionID(), "error", err, react, "do not connect") 116 | return err 117 | } 118 | 119 | return newLoop(s, conn, protocol).Run(make(chan struct{}, 1)) 120 | } 121 | 122 | func (s *server) HubClients() HubClients { 123 | return s.defaultHubClients 124 | } 125 | 126 | func (s *server) availableTransports() []string { 127 | return s.transports 128 | } 129 | 130 | func (s *server) onConnected(hc hubConnection) { 131 | s.lifetimeManager.OnConnected(hc) 132 | go func() { 133 | defer s.recoverHubLifeCyclePanic() 134 | s.invocationTarget(hc).(HubInterface).OnConnected(hc.ConnectionID()) 135 | }() 136 | } 137 | 138 | func (s *server) onDisconnected(hc hubConnection) { 139 | go func() { 140 | defer s.recoverHubLifeCyclePanic() 141 | s.invocationTarget(hc).(HubInterface).OnDisconnected(hc.ConnectionID()) 142 | }() 143 | s.lifetimeManager.OnDisconnected(hc) 144 | 145 | } 146 | 147 | func (s *server) invocationTarget(conn hubConnection) interface{} { 148 | hub := s.newHub() 149 | hub.Initialize(s.newConnectionHubContext(conn)) 150 | return hub 151 | } 152 | 153 | func (s *server) allowReconnect() bool { 154 | return s.reconnectAllowed 155 | } 156 | 157 | func (s *server) recoverHubLifeCyclePanic() { 158 | if err := recover(); err != nil { 159 | s.reconnectAllowed = false 160 | info, dbg := s.prefixLoggers("") 161 | _ = info.Log(evt, "panic in hub lifecycle", "error", err, react, "close connection, allow no reconnect") 162 | _ = dbg.Log(evt, "panic in hub lifecycle", "error", err, react, "close connection, allow no reconnect", "stack", string(debug.Stack())) 163 | s.cancel() 164 | } 165 | } 166 | 167 | func (s *server) prefixLoggers(connectionID string) (info StructuredLogger, dbg StructuredLogger) { 168 | return log.WithPrefix(s.info, "ts", log.DefaultTimestampUTC, 169 | "class", "Server", 170 | "connection", connectionID, 171 | "hub", reflect.ValueOf(s.newHub()).Elem().Type()), 172 | log.WithPrefix(s.dbg, "ts", log.DefaultTimestampUTC, 173 | "class", "Server", 174 | "connection", connectionID, 175 | "hub", reflect.ValueOf(s.newHub()).Elem().Type()) 176 | } 177 | 178 | func (s *server) newConnectionHubContext(hubConn hubConnection) HubContext { 179 | return &connectionHubContext{ 180 | abort: hubConn.Abort, 181 | clients: &callerHubClients{ 182 | defaultHubClients: s.defaultHubClients, 183 | connectionID: hubConn.ConnectionID(), 184 | }, 185 | groups: s.groupManager, 186 | connection: hubConn, 187 | info: s.info, 188 | dbg: s.dbg, 189 | } 190 | } 191 | 192 | func (s *server) processHandshake(conn Connection) (hubProtocol, error) { 193 | if request, err := s.receiveHandshakeRequest(conn); err != nil { 194 | return nil, err 195 | } else { 196 | return s.sendHandshakeResponse(conn, request) 197 | } 198 | } 199 | 200 | func (s *server) receiveHandshakeRequest(conn Connection) (handshakeRequest, error) { 201 | _, dbg := s.prefixLoggers(conn.ConnectionID()) 202 | ctx, cancelRead := context.WithTimeout(s.context(), s.HandshakeTimeout()) 203 | defer cancelRead() 204 | readJSONFramesChan := make(chan []interface{}, 1) 205 | go func() { 206 | var remainBuf bytes.Buffer 207 | rawHandshake, err := readJSONFrames(conn, &remainBuf) 208 | readJSONFramesChan <- []interface{}{rawHandshake, err} 209 | }() 210 | request := handshakeRequest{} 211 | select { 212 | case result := <-readJSONFramesChan: 213 | if result[1] != nil { 214 | return request, result[1].(error) 215 | } 216 | rawHandshake := result[0].([][]byte) 217 | _ = dbg.Log(evt, "handshake received", "msg", string(rawHandshake[0])) 218 | return request, json.Unmarshal(rawHandshake[0], &request) 219 | case <-ctx.Done(): 220 | return request, ctx.Err() 221 | } 222 | } 223 | 224 | func (s *server) sendHandshakeResponse(conn Connection, request handshakeRequest) (protocol hubProtocol, err error) { 225 | info, dbg := s.prefixLoggers(conn.ConnectionID()) 226 | ctx, cancelWrite := context.WithTimeout(s.context(), s.HandshakeTimeout()) 227 | defer cancelWrite() 228 | var ok bool 229 | if protocol, ok = protocolMap[request.Protocol]; ok { 230 | // Send the handshake response 231 | const handshakeResponse = "{}\u001e" 232 | if _, err = ReadWriteWithContext(ctx, 233 | func() (int, error) { 234 | return conn.Write([]byte(handshakeResponse)) 235 | }, func() {}); err != nil { 236 | _ = dbg.Log(evt, "handshake sent", "error", err) 237 | } else { 238 | _ = dbg.Log(evt, "handshake sent", "msg", handshakeResponse) 239 | } 240 | } else { 241 | err = fmt.Errorf("protocol %v not supported", request.Protocol) 242 | _ = info.Log(evt, "protocol requested", "error", err) 243 | if _, respErr := ReadWriteWithContext(ctx, 244 | func() (int, error) { 245 | const errorHandshakeResponse = "{\"error\":\"%s\"}\u001e" 246 | return conn.Write([]byte(fmt.Sprintf(errorHandshakeResponse, err))) 247 | }, func() {}); respErr != nil { 248 | _ = dbg.Log(evt, "handshake sent", "error", respErr) 249 | err = respErr 250 | } 251 | } 252 | return protocol, err 253 | } 254 | 255 | var protocolMap = map[string]hubProtocol{ 256 | "json": &jsonHubProtocol{}, 257 | "messagepack": &messagePackHubProtocol{}, 258 | } 259 | 260 | // const for logging 261 | const evt string = "event" 262 | const msgRecv string = "message received" 263 | const msgSend string = "message send" 264 | const msg string = "message" 265 | const react string = "reaction" 266 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("Server.HubClients", func() { 13 | Context("All().Send()", func() { 14 | j := 1 15 | It(fmt.Sprintf("should send clients %v", j), func(done Done) { 16 | // Create a simple server 17 | server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}), 18 | testLoggerOption(), 19 | ChanReceiveTimeout(200*time.Millisecond), 20 | StreamBufferCapacity(5)) 21 | Expect(err).NotTo(HaveOccurred()) 22 | Expect(server).NotTo(BeNil()) 23 | // Create both ends of the connection 24 | cliConn, srvConn := newClientServerConnections() 25 | // Start the server 26 | go func() { _ = server.Serve(srvConn) }() 27 | // Give the server some time. In contrast to the client, we have not connected state to query 28 | <-time.After(100 * time.Millisecond) 29 | // Create the Client 30 | receiver := &simpleReceiver{ch: make(chan string, 1)} 31 | ctx, cancelClient := context.WithCancel(context.Background()) 32 | client, _ := NewClient(ctx, 33 | WithConnection(cliConn), 34 | WithReceiver(receiver), 35 | testLoggerOption(), 36 | TransferFormat("Text")) 37 | Expect(client).NotTo(BeNil()) 38 | // Start it 39 | client.Start() 40 | // Wait for client running 41 | Expect(<-client.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) 42 | // Send from the server to "all" clients 43 | <-time.After(100 * time.Millisecond) 44 | server.HubClients().All().Send("OnCallback", fmt.Sprintf("All%v", j)) 45 | // Did the receiver get what we did send? 46 | Expect(<-receiver.ch).To(Equal(fmt.Sprintf("All%v", j))) 47 | cancelClient() 48 | server.cancel() 49 | close(done) 50 | }, 1.0) 51 | }) 52 | 53 | Context("Caller()", func() { 54 | It("should return nil", func() { 55 | server, _ := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}), 56 | testLoggerOption(), 57 | ChanReceiveTimeout(200*time.Millisecond), 58 | StreamBufferCapacity(5)) 59 | Expect(server.HubClients().Caller()).To(BeNil()) 60 | }) 61 | }) 62 | }) 63 | -------------------------------------------------------------------------------- /serveroptions.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | // UseHub sets the hub instance used by the server 10 | func UseHub(hub HubInterface) func(Party) error { 11 | return func(p Party) error { 12 | if s, ok := p.(*server); ok { 13 | s.newHub = func() HubInterface { return hub } 14 | return nil 15 | } 16 | return errors.New("option UseHub is server only") 17 | } 18 | } 19 | 20 | // HubFactory sets the function which returns the hub instance for every hub method invocation 21 | // The function might create a new hub instance on every invocation. 22 | // If hub instances should be created and initialized by a DI framework, 23 | // the frameworks' factory method can be called here. 24 | func HubFactory(factory func() HubInterface) func(Party) error { 25 | return func(p Party) error { 26 | if s, ok := p.(*server); ok { 27 | s.newHub = factory 28 | return nil 29 | } 30 | return errors.New("option HubFactory is server only") 31 | } 32 | } 33 | 34 | // SimpleHubFactory sets a HubFactory which creates a new hub with the underlying type 35 | // of hubProto on each hub method invocation. 36 | func SimpleHubFactory(hubProto HubInterface) func(Party) error { 37 | return HubFactory( 38 | func() HubInterface { 39 | return reflect.New(reflect.ValueOf(hubProto).Elem().Type()).Interface().(HubInterface) 40 | }) 41 | } 42 | 43 | // HTTPTransports sets the list of available transports for http connections. Allowed transports are 44 | // "WebSockets", "ServerSentEvents". Default is both transports are available. 45 | func HTTPTransports(transports ...string) func(Party) error { 46 | return func(p Party) error { 47 | if s, ok := p.(*server); ok { 48 | for _, transport := range transports { 49 | switch transport { 50 | case "WebSockets", "ServerSentEvents": 51 | s.transports = append(s.transports, transport) 52 | default: 53 | return fmt.Errorf("unsupported transport: %v", transport) 54 | } 55 | } 56 | return nil 57 | } 58 | return errors.New("option Transports is server only") 59 | } 60 | } 61 | 62 | // InsecureSkipVerify disables Accepts origin verification behaviour which is used to avoid same origin strategy. 63 | // See https://pkg.go.dev/nhooyr.io/websocket#AcceptOptions 64 | func InsecureSkipVerify(skip bool) func(Party) error { 65 | return func(p Party) error { 66 | p.setInsecureSkipVerify(skip) 67 | return nil 68 | } 69 | } 70 | 71 | // AllowOriginPatterns lists the host patterns for authorized origins which is used for avoid same origin strategy. 72 | // See https://pkg.go.dev/nhooyr.io/websocket#AcceptOptions 73 | func AllowOriginPatterns(origins []string) func(Party) error { 74 | return func(p Party) error { 75 | p.setOriginPatterns(origins) 76 | return nil 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /serversseconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | type serverSSEConnection struct { 14 | ConnectionBase 15 | mx sync.Mutex 16 | postWriting bool 17 | postWriter io.Writer 18 | postReader io.Reader 19 | jobChan chan []byte 20 | jobResultChan chan RWJobResult 21 | } 22 | 23 | func newServerSSEConnection(ctx context.Context, connectionID string) (*serverSSEConnection, <-chan []byte, chan RWJobResult, error) { 24 | s := serverSSEConnection{ 25 | ConnectionBase: *NewConnectionBase(ctx, connectionID), 26 | jobChan: make(chan []byte, 1), 27 | jobResultChan: make(chan RWJobResult, 1), 28 | } 29 | s.postReader, s.postWriter = io.Pipe() 30 | go func() { 31 | <-s.Context().Done() 32 | s.mx.Lock() 33 | close(s.jobChan) 34 | s.mx.Unlock() 35 | }() 36 | return &s, s.jobChan, s.jobResultChan, nil 37 | } 38 | 39 | func (s *serverSSEConnection) consumeRequest(request *http.Request) int { 40 | if err := s.Context().Err(); err != nil { 41 | return http.StatusGone // 410 42 | } 43 | s.mx.Lock() 44 | if s.postWriting { 45 | s.mx.Unlock() 46 | return http.StatusConflict // 409 47 | } 48 | s.postWriting = true 49 | s.mx.Unlock() 50 | defer func() { 51 | _ = request.Body.Close() 52 | }() 53 | body, err := io.ReadAll(request.Body) 54 | if err != nil { 55 | return http.StatusBadRequest // 400 56 | } else if _, err := s.postWriter.Write(body); err != nil { 57 | return http.StatusInternalServerError // 500 58 | } 59 | s.mx.Lock() 60 | s.postWriting = false 61 | s.mx.Unlock() 62 | <-time.After(50 * time.Millisecond) 63 | return http.StatusOK // 200 64 | } 65 | 66 | func (s *serverSSEConnection) Read(p []byte) (n int, err error) { 67 | n, err = ReadWriteWithContext(s.Context(), 68 | func() (int, error) { return s.postReader.Read(p) }, 69 | func() { _, _ = s.postWriter.Write([]byte("\n")) }) 70 | if err != nil { 71 | err = fmt.Errorf("%T: %w", s, err) 72 | } 73 | return n, err 74 | } 75 | 76 | func (s *serverSSEConnection) Write(p []byte) (n int, err error) { 77 | if err := s.Context().Err(); err != nil { 78 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 79 | } 80 | payload := "" 81 | for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") { 82 | payload = payload + "data: " + line + "\n" 83 | } 84 | // prevent race with goroutine closing the jobChan 85 | s.mx.Lock() 86 | if s.Context().Err() == nil { 87 | s.jobChan <- []byte(payload + "\n") 88 | } else { 89 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 90 | } 91 | s.mx.Unlock() 92 | select { 93 | case <-s.Context().Done(): 94 | return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 95 | case r := <-s.jobResultChan: 96 | return r.n, r.err 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /signalr_suite_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | func TestSignalR(t *testing.T) { 13 | RegisterFailHandler(Fail) 14 | RunSpecs(t, "SignalR Suite") 15 | } 16 | 17 | func connect(hubProto HubInterface) (Server, *testingConnection) { 18 | server, err := NewServer(context.TODO(), SimpleHubFactory(hubProto), 19 | testLoggerOption(), 20 | ChanReceiveTimeout(200*time.Millisecond), 21 | StreamBufferCapacity(5)) 22 | if err != nil { 23 | Fail(err.Error()) 24 | return nil, nil 25 | } 26 | conn := newTestingConnectionForServer() 27 | go func() { _ = server.Serve(conn) }() 28 | return server, conn 29 | } 30 | -------------------------------------------------------------------------------- /signalr_test/logger_test.go: -------------------------------------------------------------------------------- 1 | package signalr_test 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "os" 7 | 8 | "github.com/go-kit/log" 9 | "github.com/philippseith/signalr" 10 | ) 11 | 12 | type loggerConfig struct { 13 | Enabled bool 14 | Debug bool 15 | } 16 | 17 | var lConf loggerConfig 18 | 19 | var tLog signalr.StructuredLogger 20 | 21 | func testLoggerOption() func(signalr.Party) error { 22 | testLogger() 23 | return signalr.Logger(tLog, lConf.Debug) 24 | } 25 | 26 | func testLogger() signalr.StructuredLogger { 27 | if tLog == nil { 28 | lConf = loggerConfig{Enabled: false, Debug: false} 29 | b, err := ioutil.ReadFile("../testLogConf.json") 30 | if err == nil { 31 | err = json.Unmarshal(b, &lConf) 32 | if err != nil { 33 | lConf = loggerConfig{Enabled: false, Debug: false} 34 | } 35 | } 36 | writer := ioutil.Discard 37 | if lConf.Enabled { 38 | writer = os.Stderr 39 | } 40 | tLog = log.NewLogfmtLogger(writer) 41 | } 42 | return tLog 43 | } 44 | -------------------------------------------------------------------------------- /signalr_test/netconnection_test.go: -------------------------------------------------------------------------------- 1 | package signalr_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | . "github.com/onsi/ginkgo" 11 | . "github.com/onsi/gomega" 12 | "github.com/philippseith/signalr" 13 | ) 14 | 15 | func TestSignalR(t *testing.T) { 16 | RegisterFailHandler(Fail) 17 | RunSpecs(t, "SignalR external Suite") 18 | } 19 | 20 | type NetHub struct { 21 | signalr.Hub 22 | } 23 | 24 | func (n *NetHub) Smoke() string { 25 | return "no smoke!" 26 | } 27 | 28 | func (n *NetHub) ContinuousSmoke() chan string { 29 | ch := make(chan string, 1) 30 | go func() { 31 | loop: 32 | for i := 0; i < 5; i++ { 33 | select { 34 | case ch <- "smoke...": 35 | case <-n.Context().Done(): 36 | break loop 37 | } 38 | <-time.After(100 * time.Millisecond) 39 | } 40 | close(ch) 41 | }() 42 | return ch 43 | } 44 | 45 | var _ = Describe("NetConnection", func() { 46 | Context("Smoke", func() { 47 | It("should transport a simple invocation over raw rcp", func(done Done) { 48 | ctx, cancel := context.WithCancel(context.Background()) 49 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&NetHub{}), testLoggerOption()) 50 | Expect(err).NotTo(HaveOccurred()) 51 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 52 | Expect(err).NotTo(HaveOccurred()) 53 | listener, err := net.ListenTCP("tcp", addr) 54 | Expect(err).NotTo(HaveOccurred()) 55 | go func() { 56 | for { 57 | tcpConn, err := listener.Accept() 58 | Expect(err).NotTo(HaveOccurred()) 59 | go func() { _ = server.Serve(signalr.NewNetConnection(ctx, tcpConn)) }() 60 | break 61 | } 62 | }() 63 | var client signalr.Client 64 | for { 65 | if clientConn, err := net.Dial("tcp", 66 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port)); err == nil { 67 | client, err = signalr.NewClient(ctx, signalr.WithConnection(signalr.NewNetConnection(ctx, clientConn)), testLoggerOption()) 68 | Expect(err).NotTo(HaveOccurred()) 69 | break 70 | } 71 | time.Sleep(100 * time.Millisecond) 72 | } 73 | client.Start() 74 | result := <-client.Invoke("smoke") 75 | Expect(result.Value).To(Equal("no smoke!")) 76 | cancel() 77 | close(done) 78 | }) 79 | }) 80 | Context("Stream and Timeout", func() { 81 | It("Client and Server should timeout when no messages are exchanged, but message exchange should prevent timeout", func(done Done) { 82 | ctx, cancel := context.WithCancel(context.Background()) 83 | server, err := signalr.NewServer(ctx, signalr.SimpleHubFactory(&NetHub{}), testLoggerOption(), 84 | // Set KeepAlive and Timeout so KeepAlive can't keep it alive 85 | signalr.TimeoutInterval(500*time.Millisecond), signalr.KeepAliveInterval(2*time.Second)) 86 | Expect(err).NotTo(HaveOccurred()) 87 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 88 | Expect(err).NotTo(HaveOccurred()) 89 | listener, err := net.ListenTCP("tcp", addr) 90 | Expect(err).NotTo(HaveOccurred()) 91 | serverDone := make(chan struct{}, 1) 92 | go func() { 93 | tcpConn, err := listener.Accept() 94 | Expect(err).NotTo(HaveOccurred()) 95 | go func() { 96 | _ = server.Serve(signalr.NewNetConnection(ctx, tcpConn)) 97 | serverDone <- struct{}{} 98 | }() 99 | }() 100 | var client signalr.Client 101 | for { 102 | if clientConn, err := net.Dial("tcp", 103 | fmt.Sprintf("localhost:%v", listener.Addr().(*net.TCPAddr).Port)); err == nil { 104 | client, err = signalr.NewClient(ctx, signalr.WithConnection(signalr.NewNetConnection(ctx, clientConn)), testLoggerOption(), 105 | // Set KeepAlive and Timeout so KeepAlive can't keep it alive 106 | signalr.TimeoutInterval(500*time.Millisecond), signalr.KeepAliveInterval(2*time.Second)) 107 | Expect(err).NotTo(HaveOccurred()) 108 | break 109 | } 110 | time.Sleep(100 * time.Millisecond) 111 | } 112 | client.Start() 113 | // The Server will send values each 100ms, so this should keep the connection alive 114 | i := 0 115 | for range client.PullStream("continuoussmoke") { 116 | i++ 117 | } 118 | // some smoke messages and one error message when timed out 119 | Expect(i).To(BeNumerically(">", 1)) 120 | // Wait for client and server to timeout 121 | <-time.After(time.Second) 122 | select { 123 | case <-serverDone: 124 | Expect(client.State() == signalr.ClientClosed) 125 | case <-time.After(10 * time.Millisecond): 126 | Fail("server not closed") 127 | } 128 | cancel() 129 | close(done) 130 | }, 5.0) 131 | }) 132 | Context("SetConnectionID", func() { 133 | It("should change the ConnectionID", func() { 134 | clientConn, _ := net.Pipe() 135 | conn := signalr.NewNetConnection(context.Background(), clientConn) 136 | id := conn.ConnectionID() 137 | conn.SetConnectionID("Other" + id) 138 | Expect(conn.ConnectionID()).To(Equal("Other" + id)) 139 | }) 140 | }) 141 | Context("Cancel", func() { 142 | It("should cancel the connection", func(done Done) { 143 | clientConn, serverConn := net.Pipe() 144 | ctx, cancel := context.WithCancel(context.Background()) 145 | conn := signalr.NewNetConnection(ctx, clientConn) 146 | // Server loop 147 | go func() { 148 | b := make([]byte, 1024) 149 | for { 150 | if _, err := serverConn.Read(b); err != nil { 151 | break 152 | } 153 | } 154 | }() 155 | go func() { 156 | time.Sleep(500 * time.Millisecond) 157 | // cancel the connection 158 | cancel() 159 | }() 160 | for { 161 | if _, err := conn.Write([]byte("foobar")); err != nil { 162 | // This will never happen if the connection is not canceled 163 | break 164 | } 165 | } 166 | close(done) 167 | }, 2.0) 168 | }) 169 | }) 170 | -------------------------------------------------------------------------------- /signalr_test/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "signalr_test", 3 | "version": "0.0.0", 4 | "description": "e2e test for signalr server", 5 | "private": true, 6 | "license": "UNLICENSED", 7 | "scripts": { 8 | "test": "run-p test:wsgo test:jest --race", 9 | "test:wsgo": "go test -v -run TestServerWebSockets", 10 | "test:jest": "npx jest -t 'e2e test with microsoft/signalr client should work'" 11 | }, 12 | "dependencies": { 13 | "@microsoft/signalr": "^7.0.7", 14 | "@microsoft/signalr-protocol-msgpack": "^5.0.9", 15 | "@types/jest": "^27.0.2", 16 | "@types/node": "^12.20.4", 17 | "jest": "^27.2.5", 18 | "jest-preset-typescript": "^1.2.0", 19 | "path-parse": ">=1.0.7", 20 | "rxjs": "^6.6.6", 21 | "rxjs-for-await": "^0.0.2", 22 | "set-value": ">=4.0.1", 23 | "ts-jest": "^27.0.7", 24 | "typescript": "^4.2.2" 25 | }, 26 | "devDependencies": { 27 | "npm-run-all": "^4.1.5" 28 | }, 29 | "jest": { 30 | "testEnvironment": "node", 31 | "testTimeout": 10000, 32 | "preset": "jest-preset-typescript", 33 | "setupFilesAfterEnv": [ 34 | "/setupJest.ts" 35 | ], 36 | "transformIgnorePatterns": [ 37 | "node_modules/(?!(@aspnet/signalr)/)" 38 | ] 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /signalr_test/server_test.go: -------------------------------------------------------------------------------- 1 | package signalr_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net/http" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | "strings" 13 | "testing" 14 | "time" 15 | 16 | "github.com/philippseith/signalr" 17 | ) 18 | 19 | func TestMain(m *testing.M) { 20 | npmInstall := exec.Command("npm", "install") 21 | stdout, err := npmInstall.StdoutPipe() 22 | if err != nil { 23 | println(err.Error()) 24 | os.Exit(120) 25 | } 26 | stderr, err := npmInstall.StderrPipe() 27 | if err != nil { 28 | println(err.Error()) 29 | os.Exit(121) 30 | } 31 | if err := npmInstall.Start(); err != nil { 32 | println(err.Error()) 33 | os.Exit(122) 34 | } 35 | outSlurp, _ := io.ReadAll(stdout) 36 | errSlurp, _ := io.ReadAll(stderr) 37 | err = npmInstall.Wait() 38 | if err != nil { 39 | println(err.Error()) 40 | fmt.Println(string(outSlurp)) 41 | fmt.Println(string(errSlurp)) 42 | os.Exit(123) 43 | } 44 | os.Exit(m.Run()) 45 | } 46 | 47 | func TestServerSmoke(t *testing.T) { 48 | testServer(t, "^smoke", signalr.HTTPTransports("WebSockets")) 49 | } 50 | 51 | func TestServerJsonWebSockets(t *testing.T) { 52 | testServer(t, "^JSON", signalr.HTTPTransports("WebSockets")) 53 | } 54 | 55 | func TestServerJsonSSE(t *testing.T) { 56 | testServer(t, "^JSON", signalr.HTTPTransports("ServerSentEvents")) 57 | } 58 | 59 | func TestServerMessagePack(t *testing.T) { 60 | testServer(t, "^MessagePack", signalr.HTTPTransports("WebSockets")) 61 | } 62 | 63 | func testServer(t *testing.T, testNamePattern string, transports func(signalr.Party) error) { 64 | serverIsUp := make(chan struct{}, 1) 65 | quitServer := make(chan struct{}, 1) 66 | serverIsDown := make(chan struct{}, 1) 67 | go func() { 68 | runServer(t, serverIsUp, quitServer, transports) 69 | serverIsDown <- struct{}{} 70 | }() 71 | <-serverIsUp 72 | runJest(t, testNamePattern, quitServer) 73 | <-serverIsDown 74 | } 75 | 76 | func runJest(t *testing.T, testNamePattern string, quitServer chan struct{}) { 77 | defer func() { quitServer <- struct{}{} }() 78 | var jest = exec.Command(filepath.FromSlash("node_modules/.bin/jest"), fmt.Sprintf("--testNamePattern=%v", testNamePattern)) 79 | stdout, err := jest.StdoutPipe() 80 | if err != nil { 81 | t.Error(err) 82 | } 83 | stderr, err := jest.StderrPipe() 84 | if err != nil { 85 | t.Error(err) 86 | } 87 | if err := jest.Start(); err != nil { 88 | t.Error(err) 89 | } 90 | outSlurp, _ := io.ReadAll(stdout) 91 | errSlurp, _ := io.ReadAll(stderr) 92 | err = jest.Wait() 93 | if err != nil { 94 | t.Error(err, fmt.Sprintf("\n%s\n%s", outSlurp, errSlurp)) 95 | } else { 96 | // Strange: Jest reports test results to stderr 97 | t.Log(fmt.Sprintf("\n%s", errSlurp)) 98 | } 99 | } 100 | 101 | func runServer(t *testing.T, serverIsUp chan struct{}, quitServer chan struct{}, transports func(signalr.Party) error) { 102 | // Install a handler to cancel the server 103 | doneQuit := make(chan struct{}, 1) 104 | ctx, cancelSignalRServer := context.WithCancel(context.Background()) 105 | sRServer, _ := signalr.NewServer(ctx, signalr.SimpleHubFactory(&hub{}), 106 | signalr.KeepAliveInterval(2*time.Second), 107 | transports, 108 | testLoggerOption()) 109 | router := http.NewServeMux() 110 | sRServer.MapHTTP(signalr.WithHTTPServeMux(router), "/hub") 111 | 112 | server := &http.Server{ 113 | Addr: "127.0.0.1:5001", 114 | Handler: router, 115 | ReadTimeout: 2 * time.Second, 116 | WriteTimeout: 5 * time.Second, 117 | IdleTimeout: 10 * time.Second, 118 | } 119 | // wait for someone triggering quitServer 120 | go func() { 121 | <-quitServer 122 | // Cancel the signalR server and all its connections 123 | cancelSignalRServer() 124 | // Now shutdown the http server 125 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 126 | // If it does not Shutdown during 10s, try to end it by canceling the context 127 | defer cancel() 128 | server.SetKeepAlivesEnabled(false) 129 | _ = server.Shutdown(ctx) 130 | doneQuit <- struct{}{} 131 | }() 132 | // alternate method for quiting 133 | router.HandleFunc("/quit", func(http.ResponseWriter, *http.Request) { 134 | quitServer <- struct{}{} 135 | }) 136 | t.Logf("Server %v is up", server.Addr) 137 | serverIsUp <- struct{}{} 138 | // Run the server 139 | _ = server.ListenAndServe() 140 | <-doneQuit 141 | } 142 | 143 | type hub struct { 144 | signalr.Hub 145 | } 146 | 147 | func (h *hub) Ping() string { 148 | return "Pong!" 149 | } 150 | 151 | func (h *hub) Touch() { 152 | h.Clients().Caller().Send("touched") 153 | } 154 | 155 | func (h *hub) TriumphantTriple(club string) []string { 156 | if strings.Contains(club, "FC Bayern") { 157 | return []string{"German Championship", "DFB Cup", "Champions League"} 158 | } 159 | return []string{} 160 | } 161 | 162 | type AlcoholicContent struct { 163 | Drink string `json:"drink"` 164 | Strength float32 `json:"strength"` 165 | } 166 | 167 | func (h *hub) AlcoholicContents() []AlcoholicContent { 168 | return []AlcoholicContent{ 169 | { 170 | Drink: "Brunello", 171 | Strength: 13.5, 172 | }, 173 | { 174 | Drink: "Beer", 175 | Strength: 4.9, 176 | }, 177 | { 178 | Drink: "Lagavulin Cask Strength", 179 | Strength: 56.2, 180 | }, 181 | } 182 | } 183 | 184 | func (h *hub) AlcoholicContentMap() map[string]float64 { 185 | return map[string]float64{ 186 | "Brunello": 13.5, 187 | "Beer": 4.9, 188 | "Lagavulin Cask Strength": 56.2, 189 | } 190 | } 191 | 192 | func (h *hub) LargeCompressableContent() string { 193 | return strings.Repeat("data_", 10000) 194 | } 195 | 196 | func (h *hub) LargeUncompressableContent() string { 197 | return randString(20000) 198 | } 199 | 200 | var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") 201 | 202 | func randString(n int) string { 203 | b := make([]rune, n) 204 | for i := range b { 205 | b[i] = letters[rand.Intn(len(letters))] 206 | } 207 | return string(b) 208 | } 209 | 210 | func (h *hub) FiveDates() <-chan string { 211 | r := make(chan string) 212 | go func() { 213 | for i := 0; i < 5; i++ { 214 | r <- fmt.Sprint(time.Now().Nanosecond()) 215 | time.Sleep(time.Millisecond * 100) 216 | } 217 | close(r) 218 | }() 219 | return r 220 | } 221 | 222 | func (h *hub) UploadStream(upload <-chan int) { 223 | var allUs []int 224 | for u := range upload { 225 | allUs = append(allUs, u) 226 | } 227 | h.Clients().Caller().Send("OnUploadComplete", allUs) 228 | } 229 | -------------------------------------------------------------------------------- /signalr_test/setupJest.ts: -------------------------------------------------------------------------------- 1 | import 'jest-preset-typescript'; 2 | -------------------------------------------------------------------------------- /signalr_test/spec/server.spec.ts: -------------------------------------------------------------------------------- 1 | import * as signalR from '@microsoft/signalr'; 2 | import { Subject, from } from 'rxjs' 3 | import { eachValueFrom } from 'rxjs-for-await'; 4 | import {MessagePackHubProtocol} from "@microsoft/signalr-protocol-msgpack"; 5 | 6 | // IMPORTANT: When a proxy (e.g. px) is in use, the server will get the request, 7 | // but the client will not get the response 8 | // So disable the proxy for this process. 9 | process.env.http_proxy = ""; 10 | 11 | const builder: signalR.HubConnectionBuilder = 12 | new signalR.HubConnectionBuilder().configureLogging(signalR.LogLevel.Debug); 13 | 14 | const hubUrl = "http://127.0.0.1:5001/hub"; 15 | 16 | 17 | describe("smoke test", () => { 18 | it("should connect on a clients request for connection and answer a simple request", 19 | async () => { 20 | const connection: signalR.HubConnection = builder 21 | .withUrl(hubUrl) 22 | .build(); 23 | await connection.start(); 24 | const pong = await connection.invoke("ping"); 25 | expect(pong).toEqual("Pong!"); 26 | await connection.stop(); 27 | }); 28 | }); 29 | 30 | describe("MessagePack smoke test", () => { 31 | it("should connect on a clients request for connection and answer a simple request", 32 | async () => { 33 | const connection: signalR.HubConnection = builder 34 | .withUrl(hubUrl) 35 | .withHubProtocol(new MessagePackHubProtocol()) 36 | .build(); 37 | await connection.start(); 38 | const pong = await connection.invoke("ping"); 39 | expect(pong).toEqual("Pong!"); 40 | await connection.stop(); 41 | }); 42 | }); 43 | 44 | 45 | class AlcoholicContent { 46 | drink: string 47 | strength: number 48 | } 49 | 50 | function runE2E(protocol: signalR.IHubProtocol) { 51 | let connection: signalR.HubConnection; 52 | beforeEach(async() => { 53 | connection = builder 54 | .withUrl(hubUrl) 55 | .withHubProtocol(protocol) 56 | .build(); 57 | await connection.start(); 58 | }) 59 | afterEach(async() => { 60 | await connection.stop(); 61 | }) 62 | it("should answer a simple request", async () => { 63 | const pong = await connection.invoke("ping"); 64 | expect(pong).toEqual("Pong!"); 65 | }) 66 | it("should send correct ping messages", async () => { 67 | const pong = await connection.invoke("ping"); 68 | expect(pong).toEqual("Pong!"); 69 | // Wait for a ping 70 | await new Promise(r => setTimeout(r, 5000)); 71 | }) 72 | it("should answer a simple request with multiple results", async () => { 73 | const triple = await connection.invoke("triumphantTriple", "1.FC Bayern München"); 74 | expect(triple).toEqual(["German Championship", "DFB Cup", "Champions League"]); 75 | }) 76 | it("should answer a request with an resulting array of structs", async () => { 77 | const contents = await connection.invoke("alcoholicContents"); 78 | expect(contents.length).toEqual(3); 79 | expect(contents[0].drink).toEqual('Brunello'); 80 | expect(Math.abs(contents[2].strength- 56.2)).toBeLessThan(0.0001); 81 | }) 82 | it("should answer a request with an resulting map", async () => { 83 | const contents = await connection.invoke("alcoholicContentMap"); 84 | expect(contents["Beer"]).toEqual(4.9); 85 | expect(Math.abs(contents["Lagavulin Cask Strength"] - 56.2)).toBeLessThan(0.0001); 86 | }) 87 | it("should answer a request with a large amount of compressable data", async () => { 88 | const data = await connection.invoke("largeCompressableContent"); 89 | expect(data.length).toEqual(50000); 90 | }) 91 | 92 | it("should answer a request with a large amount of uncompressable data", async () => { 93 | const data = await connection.invoke("largeUncompressableContent"); 94 | expect(data.length).toEqual(20000); 95 | }) 96 | it("should receive a stream", async () => { 97 | const fiveDates: Subject = new Subject(); 98 | connection.stream("FiveDates").subscribe(fiveDates); 99 | let i = 0; 100 | let lastValue = ''; 101 | for await (const value of eachValueFrom(fiveDates)) { 102 | expect(value).toBeDefined(); 103 | expect(value).not.toEqual(lastValue); 104 | lastValue = value; 105 | i++; 106 | } 107 | expect(i).toEqual(5) 108 | }) 109 | it("should upload a stream", async() =>{ 110 | const receive = new Promise(resolve => { 111 | connection.on("onUploadComplete", (r: number[]) => { 112 | resolve(r); 113 | }); 114 | }); 115 | await connection.send("uploadStream", from([2, 0, 7])); 116 | expect(await receive).toEqual([2, 0, 7]) 117 | }) 118 | it("should receive subsequent sends without await", async() => { 119 | let or: (value?: unknown) => void; 120 | const p = new Promise((r, rj) => { 121 | or = r; 122 | }); 123 | let tc = 0 124 | connection.on("touched", () => { 125 | tc++; 126 | if (tc == 2) { 127 | or(); 128 | } 129 | }) 130 | connection.send("touch"); 131 | connection.send("touch"); 132 | await p; 133 | }) 134 | } 135 | 136 | describe("JSON e2e test with microsoft/signalr client", () => { 137 | runE2E(new signalR.JsonHubProtocol()); 138 | }) 139 | 140 | describe("MessagePack e2e test with microsoft/signalr client", () => { 141 | runE2E(new MessagePackHubProtocol()); 142 | }) 143 | -------------------------------------------------------------------------------- /signalr_test/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "moduleResolution": "node", 5 | "target": "ES5", 6 | "typeRoots": [ 7 | "node_modules/@types" 8 | ], 9 | "sourceMap": true, 10 | "noImplicitAny": true, 11 | "allowJs": true 12 | }, 13 | "exclude": [ 14 | "node_modules" 15 | ] 16 | } -------------------------------------------------------------------------------- /signalr_test/tsconfig.spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "compilerOptions": { 4 | "outDir": "./out-tsc/spec", 5 | "types": [ 6 | "node", 7 | "webpack-env" 8 | ] 9 | }, 10 | "include": [ 11 | "**/*.spec.ts", 12 | "**/*.d.ts" 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /streamclient.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | func newStreamClient(protocol hubProtocol, chanReceiveTimeout time.Duration, streamBufferCapacity uint) *streamClient { 11 | return &streamClient{ 12 | mx: sync.Mutex{}, 13 | upstreamChannels: make(map[string]reflect.Value), 14 | runningStreams: make(map[string]bool), 15 | chanReceiveTimeout: chanReceiveTimeout, 16 | streamBufferCapacity: streamBufferCapacity, 17 | protocol: protocol, 18 | } 19 | } 20 | 21 | type streamClient struct { 22 | mx sync.Mutex 23 | upstreamChannels map[string]reflect.Value 24 | runningStreams map[string]bool 25 | chanReceiveTimeout time.Duration 26 | streamBufferCapacity uint 27 | protocol hubProtocol 28 | } 29 | 30 | func (c *streamClient) buildChannelArgument(invocation invocationMessage, argType reflect.Type, chanCount int) (arg reflect.Value, canClientStreaming bool, err error) { 31 | c.mx.Lock() 32 | defer c.mx.Unlock() 33 | if argType.Kind() != reflect.Chan || argType.ChanDir() == reflect.SendDir { 34 | return reflect.Value{}, false, nil 35 | } else if len(invocation.StreamIds) > chanCount { 36 | // MakeChan does only accept bidirectional channels, and we need to Send to this channel anyway 37 | arg = reflect.MakeChan(reflect.ChanOf(reflect.BothDir, argType.Elem()), int(c.streamBufferCapacity)) 38 | c.upstreamChannels[invocation.StreamIds[chanCount]] = arg 39 | return arg, true, nil 40 | } else { 41 | // To many channel parameters arguments this method. The client will not send streamItems for these 42 | return reflect.Value{}, true, fmt.Errorf("method %s has more chan parameters than the client will stream", invocation.Target) 43 | } 44 | } 45 | 46 | func (c *streamClient) newUpstreamChannel(invocationID string) <-chan interface{} { 47 | c.mx.Lock() 48 | defer c.mx.Unlock() 49 | upChan := make(chan interface{}, c.streamBufferCapacity) 50 | c.upstreamChannels[invocationID] = reflect.ValueOf(upChan) 51 | return upChan 52 | } 53 | 54 | func (c *streamClient) deleteUpstreamChannel(invocationID string) { 55 | c.mx.Lock() 56 | if upChan, ok := c.upstreamChannels[invocationID]; ok { 57 | upChan.Close() 58 | delete(c.upstreamChannels, invocationID) 59 | } 60 | c.mx.Unlock() 61 | } 62 | 63 | func (c *streamClient) receiveStreamItem(streamItem streamItemMessage) error { 64 | c.mx.Lock() 65 | defer c.mx.Unlock() 66 | if upChan, ok := c.upstreamChannels[streamItem.InvocationID]; ok { 67 | // Mark the stream as running to detect illegal completion with result on this id 68 | c.runningStreams[streamItem.InvocationID] = true 69 | chanVal := reflect.New(upChan.Type().Elem()) 70 | err := c.protocol.UnmarshalArgument(streamItem.Item, chanVal.Interface()) 71 | if err != nil { 72 | return err 73 | } 74 | return c.sendChanValSave(upChan, chanVal.Elem()) 75 | } 76 | return fmt.Errorf(`unknown stream id "%v"`, streamItem.InvocationID) 77 | } 78 | 79 | func (c *streamClient) sendChanValSave(upChan reflect.Value, chanVal reflect.Value) error { 80 | done := make(chan error) 81 | go func() { 82 | defer func() { 83 | if r := recover(); r != nil { 84 | done <- fmt.Errorf("%v", r) 85 | } 86 | }() 87 | upChan.Send(chanVal) 88 | done <- nil 89 | }() 90 | select { 91 | case err := <-done: 92 | return err 93 | case <-time.After(c.chanReceiveTimeout): 94 | return &hubChanTimeoutError{fmt.Sprintf("timeout (%v) waiting for hub to receive client streamed value", c.chanReceiveTimeout)} 95 | } 96 | } 97 | 98 | type hubChanTimeoutError struct { 99 | msg string 100 | } 101 | 102 | func (h *hubChanTimeoutError) Error() string { 103 | return h.msg 104 | } 105 | 106 | func (c *streamClient) handlesInvocationID(invocationID string) bool { 107 | c.mx.Lock() 108 | defer c.mx.Unlock() 109 | _, ok := c.upstreamChannels[invocationID] 110 | return ok 111 | } 112 | 113 | func (c *streamClient) receiveCompletionItem(completion completionMessage, invokeClient *invokeClient) error { 114 | c.mx.Lock() 115 | channel, ok := c.upstreamChannels[completion.InvocationID] 116 | c.mx.Unlock() 117 | if ok { 118 | var err error 119 | if completion.Error != "" { 120 | // Push error to the error channel 121 | err = invokeClient.receiveCompletionItem(completion) 122 | } else { 123 | if completion.Result != nil { 124 | c.mx.Lock() 125 | running := c.runningStreams[completion.InvocationID] 126 | c.mx.Unlock() 127 | if running { 128 | err = fmt.Errorf("client side streaming: received completion with result %v", completion) 129 | } else { 130 | // handle result like a stream item 131 | err = c.receiveStreamItem(streamItemMessage{ 132 | Type: 2, 133 | InvocationID: completion.InvocationID, 134 | Item: completion.Result, 135 | }) 136 | } 137 | } 138 | } 139 | channel.Close() 140 | // Close error channel 141 | invokeClient.deleteInvocation(completion.InvocationID) 142 | c.mx.Lock() 143 | delete(c.upstreamChannels, completion.InvocationID) 144 | delete(c.runningStreams, completion.InvocationID) 145 | c.mx.Unlock() 146 | return err 147 | } 148 | return fmt.Errorf("received completion with unknown id %v", completion.InvocationID) 149 | } 150 | -------------------------------------------------------------------------------- /streamer.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | ) 7 | 8 | type streamer struct { 9 | cancels sync.Map 10 | conn hubConnection 11 | } 12 | 13 | func (s *streamer) Start(invocationID string, reflectedChannel reflect.Value) { 14 | go func() { 15 | loop: 16 | for { 17 | // Waits for channel, so might hang 18 | if chanResult, ok := reflectedChannel.Recv(); ok { 19 | if _, ok := s.cancels.Load(invocationID); ok { 20 | s.cancels.Delete(invocationID) 21 | _ = s.conn.Completion(invocationID, nil, "") 22 | break loop 23 | } 24 | if s.conn.Context().Err() != nil { 25 | break loop 26 | } 27 | _ = s.conn.StreamItem(invocationID, chanResult.Interface()) 28 | } else { 29 | if s.conn.Context().Err() == nil { 30 | _ = s.conn.Completion(invocationID, nil, "") 31 | } 32 | break loop 33 | } 34 | } 35 | }() 36 | } 37 | 38 | func (s *streamer) Stop(invocationID string) { 39 | s.cancels.Store(invocationID, struct{}{}) 40 | } 41 | -------------------------------------------------------------------------------- /streaminvocation_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "time" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | var streamInvocationQueue = make(chan string, 20) 11 | 12 | type streamHub struct { 13 | Hub 14 | } 15 | 16 | func (s *streamHub) SimpleStream() <-chan int { 17 | r := make(chan int) 18 | go func() { 19 | defer close(r) 20 | for i := 1; i < 4; i++ { 21 | r <- i 22 | } 23 | }() 24 | streamInvocationQueue <- "SimpleStream()" 25 | return r 26 | } 27 | 28 | func (s *streamHub) EndlessStream() <-chan int { 29 | r := make(chan int) 30 | go func() { 31 | defer close(r) 32 | for i := 1; ; i++ { 33 | r <- i 34 | } 35 | }() 36 | streamInvocationQueue <- "EndlessStream()" 37 | return r 38 | } 39 | 40 | func (s *streamHub) SliceStream() <-chan []int { 41 | r := make(chan []int) 42 | go func() { 43 | defer close(r) 44 | for i := 1; i < 4; i++ { 45 | s := make([]int, 2) 46 | s[0] = i 47 | s[1] = i * 2 48 | r <- s 49 | } 50 | }() 51 | streamInvocationQueue <- "SliceStream()" 52 | return r 53 | } 54 | 55 | func (s *streamHub) SimpleInt() int { 56 | streamInvocationQueue <- "SimpleInt()" 57 | return -1 58 | } 59 | 60 | var _ = Describe("StreamInvocation", func() { 61 | 62 | Describe("Simple stream invocation", func() { 63 | var server Server 64 | var conn *testingConnection 65 | BeforeEach(func(done Done) { 66 | server, conn = connect(&streamHub{}) 67 | close(done) 68 | }) 69 | AfterEach(func(done Done) { 70 | server.cancel() 71 | close(done) 72 | }) 73 | Context("When invoked by the client", func() { 74 | It("should be invoked on the server, return stream items and a final completion without result", func(done Done) { 75 | p := &jsonHubProtocol{dbg: testLogger()} 76 | conn.ClientSend(`{"type":4,"invocationId": "zzz","target":"simplestream"}`) 77 | Expect(<-streamInvocationQueue).To(Equal("SimpleStream()")) 78 | for i := 1; i < 4; i++ { 79 | recv := (<-conn.received).(streamItemMessage) 80 | Expect(recv).NotTo(BeNil()) 81 | Expect(recv.InvocationID).To(Equal("zzz")) 82 | var f float64 83 | Expect(p.UnmarshalArgument(recv.Item, &f)).NotTo(HaveOccurred()) 84 | Expect(f).To(Equal(float64(i))) 85 | } 86 | recv := (<-conn.received).(completionMessage) 87 | Expect(recv).NotTo(BeNil()) 88 | Expect(recv.InvocationID).To(Equal("zzz")) 89 | Expect(recv.Result).To(BeNil()) 90 | Expect(recv.Error).To(Equal("")) 91 | close(done) 92 | }) 93 | }) 94 | }) 95 | 96 | Describe("Slice stream invocation", func() { 97 | var server Server 98 | var conn *testingConnection 99 | BeforeEach(func(done Done) { 100 | server, conn = connect(&streamHub{}) 101 | close(done) 102 | }) 103 | AfterEach(func(done Done) { 104 | server.cancel() 105 | close(done) 106 | }) 107 | Context("When invoked by the client", func() { 108 | It("should be invoked on the server, return stream items and a final completion without result", func(done Done) { 109 | protocol := jsonHubProtocol{dbg: testLogger()} 110 | conn.ClientSend(`{"type":4,"invocationId": "slice","target":"slicestream"}`) 111 | Expect(<-streamInvocationQueue).To(Equal("SliceStream()")) 112 | for i := 1; i < 4; i++ { 113 | recv := (<-conn.received).(streamItemMessage) 114 | Expect(recv).NotTo(BeNil()) 115 | Expect(recv.InvocationID).To(Equal("slice")) 116 | exp := make([]int, 0, 2) 117 | exp = append(exp, i) 118 | exp = append(exp, i*2) 119 | var got []int 120 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred()) 121 | Expect(got).To(Equal(exp)) 122 | } 123 | recv := (<-conn.received).(completionMessage) 124 | Expect(recv).NotTo(BeNil()) 125 | Expect(recv.InvocationID).To(Equal("slice")) 126 | Expect(recv.Result).To(BeNil()) 127 | Expect(recv.Error).To(Equal("")) 128 | close(done) 129 | }) 130 | }) 131 | }) 132 | 133 | Describe("Stop simple stream invocation", func() { 134 | var server Server 135 | var conn *testingConnection 136 | BeforeEach(func(done Done) { 137 | server, conn = connect(&streamHub{}) 138 | close(done) 139 | }) 140 | AfterEach(func(done Done) { 141 | server.cancel() 142 | close(done) 143 | }) 144 | Context("When invoked by the client and stop after one result", func() { 145 | It("should be invoked on the server, return stream one item and a final completion without result", func(done Done) { 146 | protocol := jsonHubProtocol{dbg: testLogger()} 147 | conn.ClientSend(`{"type":4,"invocationId": "xxx","target":"endlessstream"}`) 148 | Expect(<-streamInvocationQueue).To(Equal("EndlessStream()")) 149 | recv := (<-conn.received).(streamItemMessage) 150 | Expect(recv).NotTo(BeNil()) 151 | Expect(recv.InvocationID).To(Equal("xxx")) 152 | var got int 153 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred()) 154 | Expect(got).To(Equal(1)) 155 | // stop it 156 | conn.ClientSend(`{"type":5,"invocationId": "xxx"}`) 157 | loop: 158 | for { 159 | recv := <-conn.received 160 | Expect(recv).NotTo(BeNil()) 161 | switch recv := recv.(type) { 162 | case streamItemMessage: 163 | Expect(recv.InvocationID).To(Equal("xxx")) 164 | case completionMessage: 165 | Expect(recv.InvocationID).To(Equal("xxx")) 166 | Expect(recv.Result).To(BeNil()) 167 | Expect(recv.Error).To(Equal("")) 168 | break loop 169 | } 170 | } 171 | close(done) 172 | }) 173 | }) 174 | }) 175 | 176 | Describe("Invalid CancelInvocation", func() { 177 | var server Server 178 | var conn *testingConnection 179 | BeforeEach(func(done Done) { 180 | server, conn = connect(&streamHub{}) 181 | close(done) 182 | }) 183 | AfterEach(func(done Done) { 184 | server.cancel() 185 | close(done) 186 | }) 187 | Context("When invoked by the client and receiving an invalid CancelInvocation", func() { 188 | It("should close the connection with an error", func(done Done) { 189 | protocol := &jsonHubProtocol{dbg: testLogger()} 190 | conn.ClientSend(`{"type":4,"invocationId": "xyz","target":"endlessstream"}`) 191 | Expect(<-streamInvocationQueue).To(Equal("EndlessStream()")) 192 | recv := (<-conn.received).(streamItemMessage) 193 | Expect(recv).NotTo(BeNil()) 194 | Expect(recv.InvocationID).To(Equal("xyz")) 195 | var got int 196 | Expect(protocol.UnmarshalArgument(recv.Item, &got)).NotTo(HaveOccurred()) 197 | Expect(got).To(Equal(1)) 198 | // try to stop it, but do not get it right 199 | conn.ClientSend(`{"type":5,"invocationId":1}`) 200 | loop: 201 | for { 202 | message := <-conn.received 203 | switch message := message.(type) { 204 | case closeMessage: 205 | Expect(message.Error).NotTo(BeNil()) 206 | break loop 207 | default: 208 | } 209 | } 210 | close(done) 211 | }) 212 | }) 213 | }) 214 | 215 | Describe("Stream invocation of method with no stream result", func() { 216 | var server Server 217 | var conn *testingConnection 218 | BeforeEach(func(done Done) { 219 | server, conn = connect(&streamHub{}) 220 | close(done) 221 | }) 222 | AfterEach(func(done Done) { 223 | server.cancel() 224 | close(done) 225 | }) 226 | Context("When invoked by the client", func() { 227 | It("should be invoked on the server, return one stream item with the \"no stream\" result and a final completion without result", func(done Done) { 228 | protocol := &jsonHubProtocol{dbg: testLogger()} 229 | conn.ClientSend(`{"type":4,"invocationId": "yyy","target":"simpleint"}`) 230 | Expect(<-streamInvocationQueue).To(Equal("SimpleInt()")) 231 | sRecv := (<-conn.received).(streamItemMessage) 232 | Expect(sRecv).NotTo(BeNil()) 233 | Expect(sRecv.InvocationID).To(Equal("yyy")) 234 | var got int 235 | Expect(protocol.UnmarshalArgument(sRecv.Item, &got)).NotTo(HaveOccurred()) 236 | Expect(got).To(Equal(-1)) 237 | cRecv := (<-conn.received).(completionMessage) 238 | Expect(cRecv).NotTo(BeNil()) 239 | Expect(cRecv.InvocationID).To(Equal("yyy")) 240 | Expect(cRecv.Result).To(BeNil()) 241 | Expect(cRecv.Error).To(Equal("")) 242 | close(done) 243 | }) 244 | }) 245 | }) 246 | 247 | Describe("invalid messages", func() { 248 | var server Server 249 | var conn *testingConnection 250 | BeforeEach(func(done Done) { 251 | server, conn = connect(&streamHub{}) 252 | close(done) 253 | }) 254 | AfterEach(func(done Done) { 255 | server.cancel() 256 | close(done) 257 | }) 258 | Context("When an invalid stream invocation message is sent", func() { 259 | It("should return a completion with error", func(done Done) { 260 | conn.ClientSend(`{"type":4}`) 261 | select { 262 | case message := <-conn.received: 263 | completionMessage := message.(completionMessage) 264 | Expect(completionMessage).NotTo(BeNil()) 265 | Expect(completionMessage.Error).NotTo(BeNil()) 266 | case <-time.After(100 * time.Millisecond): 267 | } 268 | close(done) 269 | }) 270 | }) 271 | }) 272 | 273 | }) 274 | -------------------------------------------------------------------------------- /testLogConf.json: -------------------------------------------------------------------------------- 1 | { 2 | "Enabled": false, 3 | "Debug": false 4 | } -------------------------------------------------------------------------------- /testingconnection_test.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "sync" 11 | "time" 12 | 13 | "github.com/onsi/ginkgo" 14 | ) 15 | 16 | type testingConnection struct { 17 | timeout time.Duration 18 | connectionID string 19 | srvWriter io.Writer 20 | srvReader io.Reader 21 | cliWriter io.Writer 22 | cliReader io.Reader 23 | received chan interface{} 24 | cnMutex sync.Mutex 25 | connected bool 26 | cliSendChan chan string 27 | srvSendChan chan []byte 28 | failRead string 29 | failWrite string 30 | failMx sync.Mutex 31 | } 32 | 33 | func (t *testingConnection) Context() context.Context { 34 | return context.TODO() 35 | } 36 | 37 | var connNum = 0 38 | var connNumMx sync.Mutex 39 | 40 | func (t *testingConnection) SetTimeout(timeout time.Duration) { 41 | t.timeout = timeout 42 | } 43 | 44 | func (t *testingConnection) Timeout() time.Duration { 45 | return t.timeout 46 | } 47 | 48 | func (t *testingConnection) ConnectionID() string { 49 | connNumMx.Lock() 50 | defer connNumMx.Unlock() 51 | if t.connectionID == "" { 52 | connNum++ 53 | t.connectionID = fmt.Sprintf("test%v", connNum) 54 | } 55 | return t.connectionID 56 | } 57 | 58 | func (t *testingConnection) SetConnectionID(id string) { 59 | t.connectionID = id 60 | } 61 | 62 | func (t *testingConnection) Read(b []byte) (n int, err error) { 63 | if fr := t.FailRead(); fr != "" { 64 | defer func() { t.SetFailRead("") }() 65 | return 0, errors.New(fr) 66 | } 67 | timer := make(<-chan time.Time) 68 | if t.Timeout() > 0 { 69 | timer = time.After(t.Timeout()) 70 | } 71 | nch := make(chan int) 72 | go func() { 73 | n, _ := t.srvReader.Read(b) 74 | nch <- n 75 | }() 76 | select { 77 | case n := <-nch: 78 | return n, nil 79 | case <-timer: 80 | return 0, fmt.Errorf("timeout %v", t.Timeout()) 81 | } 82 | } 83 | 84 | func (t *testingConnection) Write(b []byte) (n int, err error) { 85 | if fw := t.FailWrite(); fw != "" { 86 | defer func() { t.SetFailWrite("") }() 87 | return 0, errors.New(fw) 88 | } 89 | t.srvSendChan <- b 90 | return len(b), nil 91 | } 92 | 93 | func (t *testingConnection) Connected() bool { 94 | t.cnMutex.Lock() 95 | defer t.cnMutex.Unlock() 96 | return t.connected 97 | } 98 | 99 | func (t *testingConnection) SetConnected(connected bool) { 100 | t.cnMutex.Lock() 101 | defer t.cnMutex.Unlock() 102 | t.connected = connected 103 | } 104 | 105 | func (t *testingConnection) FailRead() string { 106 | defer t.failMx.Unlock() 107 | t.failMx.Lock() 108 | return t.failRead 109 | } 110 | 111 | func (t *testingConnection) FailWrite() string { 112 | defer t.failMx.Unlock() 113 | t.failMx.Lock() 114 | return t.failWrite 115 | } 116 | 117 | func (t *testingConnection) SetFailRead(fail string) { 118 | defer t.failMx.Unlock() 119 | t.failMx.Lock() 120 | t.failRead = fail 121 | } 122 | 123 | func (t *testingConnection) SetFailWrite(fail string) { 124 | defer t.failMx.Unlock() 125 | t.failMx.Lock() 126 | t.failWrite = fail 127 | } 128 | 129 | // newTestingConnectionForServer builds a testingConnection with an sent (but not yet received) handshake for testing a server 130 | func newTestingConnectionForServer() *testingConnection { 131 | conn := newTestingConnection() 132 | // client receive loop 133 | go receiveLoop(conn)() 134 | // Send initial Handshake 135 | conn.ClientSend(`{"protocol": "json","version": 1}`) 136 | conn.SetConnected(true) 137 | return conn 138 | } 139 | 140 | func newTestingConnection() *testingConnection { 141 | cliReader, srvWriter := io.Pipe() 142 | srvReader, cliWriter := io.Pipe() 143 | conn := testingConnection{ 144 | srvWriter: srvWriter, 145 | srvReader: srvReader, 146 | cliWriter: cliWriter, 147 | cliReader: cliReader, 148 | received: make(chan interface{}, 20), 149 | cliSendChan: make(chan string, 20), 150 | srvSendChan: make(chan []byte, 20), 151 | timeout: time.Second * 5, 152 | } 153 | // client send loop 154 | go func() { 155 | for { 156 | _, _ = conn.cliWriter.Write(append([]byte(<-conn.cliSendChan), 30)) 157 | } 158 | }() 159 | // server send loop 160 | go func() { 161 | for { 162 | _, _ = conn.srvWriter.Write(<-conn.srvSendChan) 163 | } 164 | }() 165 | return &conn 166 | } 167 | 168 | func (t *testingConnection) ClientSend(message string) { 169 | t.cliSendChan <- message 170 | } 171 | 172 | func (t *testingConnection) ClientReceive() (string, error) { 173 | var buf bytes.Buffer 174 | var data = make([]byte, 1<<15) // 32K 175 | var nn int 176 | for { 177 | if message, err := buf.ReadString(30); err != nil { 178 | buf.Write(data[:nn]) 179 | if n, err := t.cliReader.Read(data[nn:]); err == nil { 180 | buf.Write(data[nn : nn+n]) 181 | nn = nn + n 182 | } else { 183 | return "", err 184 | } 185 | } else { 186 | return message[:len(message)-1], nil 187 | } 188 | } 189 | } 190 | 191 | func (t *testingConnection) ReceiveChan() chan interface{} { 192 | return t.received 193 | } 194 | 195 | type clientReceiver interface { 196 | ClientReceive() (string, error) 197 | ReceiveChan() chan interface{} 198 | SetConnected(bool) 199 | } 200 | 201 | func receiveLoop(conn clientReceiver) func() { 202 | return func() { 203 | defer ginkgo.GinkgoRecover() 204 | errorHandler := func(err error) { ginkgo.Fail(fmt.Sprintf("received invalid message from server %v", err.Error())) } 205 | for { 206 | if message, err := conn.ClientReceive(); err == nil { 207 | var hubMessage hubMessage 208 | if err = json.Unmarshal([]byte(message), &hubMessage); err == nil { 209 | switch hubMessage.Type { 210 | case 1, 4: 211 | var invocationMessage invocationMessage 212 | if err = json.Unmarshal([]byte(message), &invocationMessage); err == nil { 213 | conn.ReceiveChan() <- invocationMessage 214 | } else { 215 | errorHandler(err) 216 | } 217 | case 2: 218 | var jsonStreamItemMessage jsonStreamItemMessage 219 | if err = json.Unmarshal([]byte(message), &jsonStreamItemMessage); err == nil { 220 | 221 | conn.ReceiveChan() <- streamItemMessage{ 222 | Type: jsonStreamItemMessage.Type, 223 | InvocationID: jsonStreamItemMessage.InvocationID, 224 | Item: jsonStreamItemMessage.Item, 225 | } 226 | } else { 227 | errorHandler(err) 228 | } 229 | case 3: 230 | var completionMessage completionMessage 231 | if err = json.Unmarshal([]byte(message), &completionMessage); err == nil { 232 | conn.ReceiveChan() <- completionMessage 233 | } else { 234 | errorHandler(err) 235 | } 236 | case 7: 237 | var closeMessage closeMessage 238 | if err = json.Unmarshal([]byte(message), &closeMessage); err == nil { 239 | conn.SetConnected(false) 240 | conn.ReceiveChan() <- closeMessage 241 | } else { 242 | errorHandler(err) 243 | } 244 | } 245 | } else { 246 | errorHandler(err) 247 | } 248 | } 249 | } 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /websocketconnection.go: -------------------------------------------------------------------------------- 1 | package signalr 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | 8 | "nhooyr.io/websocket" 9 | ) 10 | 11 | type webSocketConnection struct { 12 | ConnectionBase 13 | conn *websocket.Conn 14 | transferMode TransferMode 15 | } 16 | 17 | func newWebSocketConnection(ctx context.Context, connectionID string, conn *websocket.Conn) *webSocketConnection { 18 | w := &webSocketConnection{ 19 | conn: conn, 20 | ConnectionBase: *NewConnectionBase(ctx, connectionID), 21 | } 22 | return w 23 | } 24 | 25 | func (w *webSocketConnection) Write(p []byte) (n int, err error) { 26 | messageType := websocket.MessageText 27 | if w.transferMode == BinaryTransferMode { 28 | messageType = websocket.MessageBinary 29 | } 30 | n, err = ReadWriteWithContext(w.Context(), 31 | func() (int, error) { 32 | err := w.conn.Write(w.Context(), messageType, p) 33 | if err != nil { 34 | return 0, err 35 | } 36 | return len(p), nil 37 | }, 38 | func() {}) 39 | if err != nil { 40 | err = fmt.Errorf("%T: %w", w, err) 41 | _ = w.conn.Close(1000, err.Error()) 42 | } 43 | return n, err 44 | } 45 | 46 | func (w *webSocketConnection) Read(p []byte) (n int, err error) { 47 | n, err = ReadWriteWithContext(w.Context(), 48 | func() (int, error) { 49 | _, data, err := w.conn.Read(w.Context()) 50 | if err != nil { 51 | return 0, err 52 | } 53 | return bytes.NewReader(data).Read(p) 54 | }, 55 | func() {}) 56 | if err != nil { 57 | err = fmt.Errorf("%T: %w", w, err) 58 | _ = w.conn.Close(1000, err.Error()) 59 | } 60 | return n, err 61 | } 62 | 63 | func (w *webSocketConnection) TransferMode() TransferMode { 64 | return w.transferMode 65 | } 66 | 67 | func (w *webSocketConnection) SetTransferMode(transferMode TransferMode) { 68 | w.transferMode = transferMode 69 | } 70 | --------------------------------------------------------------------------------