├── .github └── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation-improvement.md │ ├── feature-request.md │ └── generic-issue.md ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Gopkg.lock ├── Gopkg.toml ├── LICENSE ├── README.md ├── connection.go ├── connectionOptions.go ├── defaultSessionManager.go ├── deregisterHandler.go ├── docs ├── img │ ├── webwire_logo.svg │ ├── webwire_logo_black.svg │ ├── wwr_msgproto_diagram.svg │ └── wwr_msgproto_diagram.xml ├── protocol-sequences.svg └── protocol-sequences.txt ├── errors.go ├── failMsg.go ├── fulfillMsg.go ├── genericSessionInfo.go ├── genericSessionInfo_test.go ├── handleConnection.go ├── handleMessage.go ├── handleRequest.go ├── handleSessionClosure.go ├── handleSessionRestore.go ├── handleSignal.go ├── interfaces.go ├── isShuttingDown.go ├── message ├── bench_test.go ├── buffer.go ├── calcMsgLen.go ├── message.go ├── messageParts.go ├── message_test.go ├── newAcceptConfMessage.go ├── newMessage.go ├── parse.go ├── parseAcceptConf.go ├── parseAcceptConf_test.go ├── parseCloseSession.go ├── parseErrorReply.go ├── parseHeartbeat.go ├── parseReply.go ├── parseReplyUtf16.go ├── parseRequest.go ├── parseRequestUtf16.go ├── parseRestoreSession.go ├── parseSessionClosed.go ├── parseSessionCreated.go ├── parseSignal.go ├── parseSignalUtf16.go ├── parseSpecialReplyMessage.go ├── parser_corruptnamelen_test.go ├── parser_corruptpayload_test.go ├── parser_invtoolong_test.go ├── parser_invtooshort_test.go ├── parser_test.go ├── pool.go ├── read.go ├── syncPool.go ├── testutil_test.go ├── writeMsgErrorReply.go ├── writeMsgHeartbeat.go ├── writeMsgNamelessRequest.go ├── writeMsgReply.go ├── writeMsgRequest.go ├── writeMsgSessionClosed.go ├── writeMsgSessionCreated.go ├── writeMsgSignal.go ├── writeMsgSpecialRequestReply.go ├── write_corruptUtf16Payload_test.go ├── write_test.go └── write_unexpparams_test.go ├── newServer.go ├── payload.go ├── payload ├── encoding.go ├── encoding_test.go └── payload.go ├── registerHandler.go ├── requestManager ├── reply.go ├── request.go ├── requestManager.go └── requestManager_test.go ├── server.go ├── serverOptions.go ├── session.go ├── sessionInfoToVarMap.go ├── sessionInfoToVarMap_test.go ├── sessionLookupResult.go ├── sessionRegistry.go ├── sessionRegistry_test.go ├── socket.go ├── test ├── activeSessionRegistry_test.go ├── benchmark_test.go ├── clientInitiatedSessionDestruction_test.go ├── connIsConnected_test.go ├── connSessionNoOverride_test.go ├── connSignalBufferOverflow_test.go ├── connSignalNoNameNoPayload_test.go ├── connectionInfo_test.go ├── connectionSessionGetters_test.go ├── customSessKeyGenInvalid_test.go ├── customSessKeyGen_test.go ├── disabledSessions_test.go ├── emptyReplyUtf16_test.go ├── emptyReplyUtf8_test.go ├── emptyReply_test.go ├── gracefulShutdown_test.go ├── handshake_test.go ├── maxConcSessConn_test.go ├── namedRequest_test.go ├── namedSignal_test.go ├── protocolViolation_test.go ├── refuseConnections_test.go ├── requestError_test.go ├── requestInternalError_test.go ├── requestNameOnly_test.go ├── requestNoNameNoPayload_test.go ├── requestPayloadOnly_test.go ├── requestUtf16_test.go ├── requestUtf8_test.go ├── serverImpl.go ├── serverInitiatedSessionDestruction_test.go ├── serverSignal_test.go ├── sessionCreationOnClosedConn_test.go ├── sessionKeyGen.go ├── sessionManagers.go ├── sessionNotFound_test.go ├── sessionRestoration_test.go ├── sessionStatus_test.go ├── signalUtf16_test.go ├── signalUtf8_test.go ├── simpleShutdown_test.go └── util.go ├── translateContextError.go ├── transport.go ├── transport └── memchan │ ├── buffer.go │ ├── buffer_test.go │ ├── clientTransport.go │ ├── entangleSockets.go │ ├── memchan_test.go │ ├── newEntangledSockets.go │ ├── newSocket.go │ ├── remoteAddress.go │ ├── sockReadErr.go │ ├── socket.go │ └── transport.go └── vendor ├── github.com ├── davecgh │ └── go-spew │ │ └── spew │ │ ├── bypass.go │ │ ├── bypasssafe.go │ │ ├── common.go │ │ ├── config.go │ │ ├── doc.go │ │ ├── dump.go │ │ ├── format.go │ │ └── spew.go ├── pmezard │ └── go-difflib │ │ └── difflib │ │ └── difflib.go └── stretchr │ └── testify │ ├── assert │ ├── assertion_format.go │ ├── assertion_format.go.tmpl │ ├── assertion_forward.go │ ├── assertion_forward.go.tmpl │ ├── assertions.go │ ├── doc.go │ ├── errors.go │ ├── forward_assertions.go │ └── http_assertions.go │ └── require │ ├── doc.go │ ├── forward_requirements.go │ ├── require.go │ ├── require.go.tmpl │ ├── require_forward.go │ ├── require_forward.go.tmpl │ └── requirements.go └── golang.org └── x ├── net └── context │ ├── context.go │ ├── go17.go │ ├── go19.go │ ├── pre_go17.go │ └── pre_go19.go └── sync └── semaphore └── semaphore.go /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Report a bug 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | # Bug Report 11 | ## Expected Behavior 12 | _Describe how you expect either the server or the client implementation to behave in what case._ 13 | 14 | ## Actual Behavior 15 | _Describe how either the server or the client implementation actually behaves right now._ 16 | 17 | ## Reproduction Steps 18 | _List and describe the steps to reproduce the actual behavior._ 19 | - step 1 20 | - step 2 21 | - step 3 22 | 23 | ## Environment 24 | - **OS**: _Windows 10, Ubuntu 16.04, Mac OS 11.5 ..._ 25 | - **Go version**: _Go 1.10.3_ 26 | - **webwire-go version**: _v1.0.0 rc1 / git revision_ 27 | - **System**: _CPU, RAM, Disk, Network (when necessary)_ 28 | 29 | ## Additional information 30 | _**Optional section, remove when empty**_
31 | _Optionally add any other relevant information if necessary._ 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-improvement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation Improvement 3 | about: Propose a documentation improvement 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | # Documentation Improvement 11 | 12 | ## Proposed Improvement 13 | _Describe your improvement proposal._ 14 | 15 | ## Current State 16 | _Describe the current state of the documentation part that you propose to improve._ 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Propose a new feature for the webwire protocol or server implementation 4 | title: '' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | # Feature Request 11 | _Describe the feature you propose and how you expect the webwire server to behave in what case._ 12 | 13 | ## Problem 14 | _Describe how the webwire server actually behaves right now 15 | and what exactly made you create this feature request._ 16 | 17 | ## Proposed API 18 | _**Optional section, remove when empty**_
19 | _If your feature requires API changes and you already have ideas about how they could look, 20 | then please describe them._ 21 | 22 | ## Proposed Implementation 23 | _**Optional section, remove when empty**_
24 | _If you already have ideas on how to implement your feature then please describe it._ 25 | 26 | ## Additional information 27 | _**Optional section, remove when empty**_
28 | _Optionally add any other relevant information if necessary._ 29 | 30 | ### Environment 31 | - **OS**: _Windows 10, Ubuntu 16.04, Mac OS 11.5 ..._ 32 | - **Go version**: _Go 1.11.2_ 33 | - **webwire-go version**: _v1.0.0 rc1 / git revision_ 34 | - **System**: _CPU, RAM, Disk, Network (when necessary)_ 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/generic-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Generic Issue 3 | about: Use this template in exceptional cases, when neither of the other templates 4 | suit your issue 5 | title: '' 6 | labels: '' 7 | assignees: '' 8 | 9 | --- 10 | 11 | ## Generic Issue 12 | _Describe the problem you're facing using webwire-go or the improvement you're proposing._ 13 | _Please use this template in exceptional cases when neither of the other templates suits your issue_ 14 | 15 | ## Environment 16 | _**Optional section, remove when empty**_
17 | - **OS**: _Windows 10, Ubuntu 16.04, Mac OS 11.5 ..._ 18 | - **Go version**: _Go 1.10.3_ 19 | - **webwire-go version**: _v1.0.0 rc1 / git revision_ 20 | - **System**: _CPU, RAM, Disk, Network (when necessary)_ 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore extensionless files 2 | * 3 | !*/ 4 | !*.* 5 | 6 | # Binaries for programs and plugins 7 | *.exe 8 | *.dll 9 | *.so 10 | *.dylib 11 | 12 | # Test binary, build with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 19 | .glide/ 20 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - master 5 | - "1.11" 6 | 7 | install: true 8 | 9 | matrix: 10 | allow_failures: 11 | - go: master 12 | fast_finish: true 13 | 14 | notifications: 15 | email: true 16 | 17 | before_script: 18 | - GO_FILES=$(find . -iname '*.go' -type f | grep -v /vendor/) # All the .go files, excluding vendor/ 19 | - go get golang.org/x/lint/golint # Linter 20 | - go get honnef.co/go/tools/cmd/megacheck # Badass static analyzer/linter 21 | - go get golang.org/x/tools/cmd/cover 22 | - go get github.com/mattn/goveralls 23 | - go get github.com/go-playground/overalls 24 | 25 | script: 26 | # Run all the tests with the race detector enabled 27 | - overalls -project=github.com/qbeon/webwire-go -covermode=atomic -debug -- -race -v -coverpkg=./... 28 | - $HOME/gopath/bin/goveralls -coverprofile=overalls.coverprofile -service=travis-ci -repotoken=$COVERALLS_TOKEN 29 | 30 | # go vet is the official Go static analyzer 31 | - go vet ./... 32 | 33 | # "go vet on steroids" + linter 34 | - megacheck ./... 35 | 36 | # one last linter 37 | - golint -set_exit_status $(go list ./...) 38 | 39 | after_success: 40 | - "curl -s -X POST 41 | -H \"Content-Type: application/json\" 42 | -H \"Accept: application/json\" 43 | -H \"Travis-API-Version: 3\" 44 | -H \"Authorization: token $TRAVIS_API_TOKEN\" 45 | -d '{\"request\": {\"branch\":\"master\"}}' 46 | https://api.travis-ci.org/repo/qbeon%2Fwebwire-go-gorilla/requests" 47 | - "curl -s -X POST 48 | -H \"Content-Type: application/json\" 49 | -H \"Accept: application/json\" 50 | -H \"Travis-API-Version: 3\" 51 | -H \"Authorization: token $TRAVIS_API_TOKEN\" 52 | -d '{\"request\": {\"branch\":\"master\"}}' 53 | https://api.travis-ci.org/repo/qbeon%2Fwebwire-go-fasthttp/requests" 54 | - "curl -s -X POST 55 | -H \"Content-Type: application/json\" 56 | -H \"Accept: application/json\" 57 | -H \"Travis-API-Version: 3\" 58 | -H \"Authorization: token $TRAVIS_API_TOKEN\" 59 | -d '{\"request\": {\"branch\":\"master\"}}' 60 | https://api.travis-ci.org/repo/qbeon%2Fwebwire-go-examples/requests" 61 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | digest = "1:ffe9824d294da03b391f44e1ae8281281b4afc1bdaa9588c9097785e3af10cec" 6 | name = "github.com/davecgh/go-spew" 7 | packages = ["spew"] 8 | pruneopts = "UT" 9 | revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" 10 | version = "v1.1.1" 11 | 12 | [[projects]] 13 | digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" 14 | name = "github.com/pmezard/go-difflib" 15 | packages = ["difflib"] 16 | pruneopts = "UT" 17 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 18 | version = "v1.0.0" 19 | 20 | [[projects]] 21 | digest = "1:c40d65817cdd41fac9aa7af8bed56927bb2d6d47e4fea566a74880f5c2b1c41e" 22 | name = "github.com/stretchr/testify" 23 | packages = [ 24 | "assert", 25 | "require", 26 | ] 27 | pruneopts = "UT" 28 | revision = "f35b8ab0b5a2cef36673838d662e249dd9c94686" 29 | version = "v1.2.2" 30 | 31 | [[projects]] 32 | branch = "master" 33 | digest = "1:76ee51c3f468493aff39dbacc401e8831fbb765104cbf613b89bef01cf4bad70" 34 | name = "golang.org/x/net" 35 | packages = ["context"] 36 | pruneopts = "UT" 37 | revision = "c10e9556a7bc0e7c942242b606f0acf024ad5d6a" 38 | 39 | [[projects]] 40 | branch = "master" 41 | digest = "1:e0140c0c868c6e0f01c0380865194592c011fe521d6e12d78bfd33e756fe018a" 42 | name = "golang.org/x/sync" 43 | packages = ["semaphore"] 44 | pruneopts = "UT" 45 | revision = "1d60e4601c6fd243af51cc01ddf169918a5407ca" 46 | 47 | [solve-meta] 48 | analyzer-name = "dep" 49 | analyzer-version = 1 50 | input-imports = [ 51 | "github.com/stretchr/testify/assert", 52 | "github.com/stretchr/testify/require", 53 | "golang.org/x/sync/semaphore", 54 | ] 55 | solver-name = "gps-cdcl" 56 | solver-version = 1 57 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | [prune] 2 | go-tests = true 3 | unused-packages = true 4 | 5 | [[constraint]] 6 | name = "github.com/stretchr/testify" 7 | version = "1.2.2" 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Roman Sharkov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /connectionOptions.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | // ConnectionAcceptance defines whether a connection is to be accepted 4 | type ConnectionAcceptance byte 5 | 6 | const ( 7 | // Accept instructs the server to accept the incoming connection 8 | Accept ConnectionAcceptance = iota 9 | 10 | // Refuse instructs the server to refuse the incoming connection 11 | Refuse 12 | ) 13 | 14 | // ConnectionOptions represents the options applied to an individual connection 15 | // during accept 16 | type ConnectionOptions struct { 17 | // Info stores arbitrary connection information provided by the transport 18 | // layer implementation 19 | Info map[int]interface{} 20 | 21 | // Connection refuses the incoming connection when explicitly set to 22 | // wwr.Refuse. It's set to wwr.Accept by default. 23 | Connection ConnectionAcceptance 24 | 25 | // ConcurrencyLimit defines the maximum number of operations to be processed 26 | // concurrently for this particular client connection. If ConcurrencyLimit 27 | // is 0 (which it is by default) then the number of concurrent operations 28 | // for this particular connection will be limited to 1. Anything below 0 29 | // will lift the limitation entirely while everything above 0 will set the 30 | // limit to the specified number of handlers 31 | ConcurrencyLimit int 32 | } 33 | -------------------------------------------------------------------------------- /deregisterHandler.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | // deregisterHandler decrements the number of currently executed handlers 4 | // and shuts down the server if scheduled and no more operations are left 5 | func (srv *server) deregisterHandler(con *connection) { 6 | srv.opsLock.Lock() 7 | srv.currentOps-- 8 | if srv.shutdown && srv.currentOps < 1 { 9 | close(srv.shutdownRdy) 10 | } 11 | srv.opsLock.Unlock() 12 | 13 | con.deregisterTask() 14 | 15 | // Release a handler slot 16 | if con.options.ConcurrencyLimit > 1 { 17 | con.handlerSlots.Release(1) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /docs/img/webwire_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | image/svg+xml 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/img/webwire_logo_black.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | image/svg+xml 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/img/wwr_msgproto_diagram.xml: -------------------------------------------------------------------------------- 1 | 7V1bc5s4FP41fuwOQlwfbRK33Wl3Z+p29pkYxWaKwcWkifvrV+LiYCRibCMsxycPDhxAt3O+T5ejywh7q5ePqb9efk0CEo10LXgZ4buRriNs2PQfk2xLiaWhQrJIw6CUvQpm4R9SCrVS+hQGZLP3YpYkURau94XzJI7JPNuT+WmaPO+/9phE+7Gu/QXhBLO5H/HS/8IgW1apY3+vjz6RcLHM+GcP/vznIk2e4jLSkY4f87/i8cqvAizf3yz9IHmuifD9CHtpkmTF1erFIxEr4arsiu+mLU93iU9JnHX6wNSLT3770ROp0pynLNtWRZLnh7AvtBGePC/DjMzW/pw9faZWQGXLbBXRO0Qvdzli7y4if7MprzdZmvwkXhIlKZXESUx2wqqg2fePYRQ1XlqkfhDSDFViWqgYe/b9HXs9ibOpvwojZm+fSPSbZOHcLx+U1qVr5X0jXD8KFzG9ndOwCZVP4uSL/0CqnJQFQ9KMvLQWL9opjUKCJCuSpVv6SvmBXdlFhQbHKQXPNdsyK+GyblVIc0qrLi16sQv+VaH0otRpi36RQL1WROOYrPeUbP16YjY3icKYfKjSMaavIEwTZ76+QK8W5f88GFauwpDYgw+bXAN5OM76hQ8l265JFRLNSxHYfgQPaSVBR8dpieJkMN1m3aOtvbh+lbViBJ1m69Sop7bhmdqpRs0CyP9qlh2RR5baDUVrGC9KIDKLpsFF4/KlVRgEEWlCN2JImCRpQNImbNJ5mYY8sWXgX/Ko7nSjAXsBeKemMXb12rO7MKVUHiYsNcTf1JL8rQREHmwp+56sy7CXSRr+oUXgV5Ct8ZCIpnpAtGs2EO0aAkRjAaArkJ8DZ8Ph4PyN/HoitMiaBkmzmLVZooBjayUnMrJZGTB6y34qq0uLTE8amGirGGTwrEgruiytmCanlRktCWqVoJSaUmjtN6BSLI1Tyo849lc067oG6hGoB2tDqsd4Qz3AaSL9aNaQ+uE5TamGI+0kxln4GJK0QzvucCRUMcKW4nFJFYfiHEhhL63ZzQDFcEQD+FAn8bT28dTxsPv+28fkcU7yfJ7bPu6BhZDmdqEhSxYNue0NXlpNrKMtVBJ72jH0ASsJmx9dYMEewzhidj/ALgN0r5HTgT52A0cDEkjTiPrpX9fiUap7rdsO5uxZRDZmH+YsGgvto7EyppnW7m36g9kPY1QNsZgci/5Mc9m4lwr8JDhdvrK+ZrSdUFsL0dZLZa13gM8OK73jh+/Tafha7de2XcfzbsB+d66UN+y3D9vctTguQu18d5axfY4XjT0sriwms6ZMll+xB2ykQsPTa7VjbHu6cws8PDHvEFWYEjxsdmnG6K4sY+c7TZX10tzGp9uvptUaKUVQ191IsdAYj50bAIftepO8ACSTvHFJknftVrtXYsxy7W+jxA/orb9i5hs/bIpuUx1WZ40QGqJoJ2Hss4LkYv3xfeqIxcgaYDgxrARxLbLw5Lzf+ujoURX9FY+OnlDPyxsdNTs5aWSNjiLNUpvxwEtzezwEXpoL8FDTSyOeACOPhxRveQEPAQ8BD13CWyz0R8rjofaRDyV4CKY7w3Tnq57uLMSztOnOqFoUBHgGPAOeZeBZNOlXIp5hORLgGfAsEc+iNRYS8dw+YwrwDHgGPJ+NZ5E/QSKesdp4Bg8qeFDBg/quPKgu2me8gT0XiJ+zCowHjAeMB4wni/Eso0OfTSLjKb6yFxgPGA8Y710x3m4B6ptbTchjPMVnyQHjAeMB470rxuPmBQ87D0ZXfFyebfQTkVgi+gYfzL92mjhqvesV00SX5a6D0QS2Ogx+yRvu15Uf7g8CZjqqz9lFB1J4Xin02A4UclxBcNSkpwMU9KVp7qgVn6fRnCOb5mgrXJkFn7Z9WQZTfMIvNHSUYwBo6CjQ0Bl4XkMVMNAE0ATQhLo0wS0fGnY6I1Z8ejLQBNAE0ASlCcf9yzzsUZJIFFcwvjo60Z20L+u8YZTQxWR/eAhZNsYz7/PnWuxte0f1OShThRKOWtmTcyq1FMwR5XUubzbSvK+0nbir76pr7i5LochA3uT9U6hrWe50ogiFNh1UA0+7xIqPPAOBAoECgQKBHjGnadhZnFj5eevguuteZYDrDlx3g7vuuDlKAw+2KT8PHRgMGAwYTGEG47yKw+5OhPnTWKUwmKhDWcgsdrBMcWgBrhl5t97hzfRmFN2T/YTuTC2env3zXQ4sQLoASX1s3I7wqScWPNfAFCfpih3qWpSJ1qlmuH/xV+uI1UDVsXJ43AqkBhZu+6A523L47q40qq0GJ4XnAE79MHpKCShoT0HmkEdgU85WuzUPOwDf3son2AG4tSEx4BQeIQ9JGxc1FPfMw45ksCPZde9IJsKzvD62obijmKRpwoDT8BRz971U6EIXcbFaW/t79u8/A9TpsDgbFmdf++Jsbthw4CYK7/jwtOOMGk5Nh1PT36qjHWu4UxeR0b69ynl1Kxyb/o7hpvSx6SL8SDs2HRm8Hyk/kjo/g9owRtVB1MzqLZf93FdX9utTj926owMnWfN9PcUbPdds4yodSc0dvy6ycUPWkdTIbB8fHs7Ds2aGBv6dTu4Dx5Dn36G3acK0tHv2kVLL8msSEPbG/w== -------------------------------------------------------------------------------- /docs/protocol-sequences.txt: -------------------------------------------------------------------------------- 1 | title webwire binary protocol 2 | 3 | # Connection establishment 4 | group connection establishment 5 | Client-->Server: connect 6 | Client<-Server: AcceptConf 7 | end 8 | 9 | # Request 10 | group request 11 | 12 | alt binary payload 13 | Client->Server: RequestBinary 14 | else UTF8 encoded payload 15 | Client->Server: RequestUtf8 16 | else UTF16 encoded payload 17 | Client->Server: RequestUtf16 18 | end 19 | 20 | alt success 21 | alt binary reply payload 22 | Client<-Server: ReplyBinary 23 | else UTF8 encoded reply payload 24 | Client<-Server: ReplyUtf8 25 | else UTF16 encoded reply payload 26 | Client<-Server: ReplyUtf16 27 | end 28 | else failure 29 | alt error 30 | Client<-Server: ReplyError 31 | else internal error 32 | Client<-Server: ReplyInternalError 33 | else server shutting down 34 | Client<-Server: ReplyShutdown 35 | end 36 | end 37 | 38 | end 39 | 40 | # Client-side signal 41 | group client-side signal 42 | alt binary payload 43 | Client->Server: SignalBinary 44 | else UTF8 encoded payload 45 | Client->Server: SignalUtf8 46 | else UTF16 encoded payload 47 | Client->Server: SignalUtf16 48 | end 49 | end 50 | 51 | # Server-side signal 52 | group server-side signal 53 | alt binary payload 54 | Client<-Server: SignalBinary 55 | else UTF8 encoded payload 56 | Client<-Server: SignalUtf8 57 | else UTF16 encoded payload 58 | Client<-Server: SignalUtf16 59 | end 60 | end 61 | 62 | # Session restoration request 63 | group restore session 64 | Client->Server: RequestRestoreSession 65 | 66 | alt success 67 | 68 | Client<-Server: ReplyBinary 69 | else failure 70 | alt session not found 71 | Client<-Server: ReplySessionNotFound 72 | else session connections limit reached 73 | Client<-Server: ReplyMaxSessConnsReached 74 | else server shutting down 75 | Client<-Server: ReplyShutdown 76 | else sessions disabled 77 | Client<-Server: ReplySessionsDisabled 78 | else internal error 79 | Client<-Server: ReplyInternalError 80 | end 81 | end 82 | end 83 | 84 | # Session closure request 85 | group session closure request 86 | Client->Server: RequestCloseSession 87 | box over Server: close active session 88 | Client<-Server: ReplyBinary 89 | end 90 | 91 | # Session creation notification 92 | group session creation notification 93 | box over Server: session created 94 | Client<-Server: NotifySessionCreated 95 | end 96 | 97 | # Session closure notification 98 | group session creation notification 99 | box over Server: session closed 100 | Client<-Server: NotifySessionClosed 101 | end 102 | 103 | # Heartbeat 104 | group heartbeat 105 | Client-->Server: Heartbeat 106 | end 107 | -------------------------------------------------------------------------------- /failMsg.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "github.com/qbeon/webwire-go/message" 5 | ) 6 | 7 | // failMsg fails the message returning an error reply 8 | func (srv *server) failMsg( 9 | con *connection, 10 | msg *message.Message, 11 | reqErr error, 12 | ) { 13 | // Don't send any failure reply if the type of the message 14 | // doesn't expect any response 15 | if !msg.RequiresReply() { 16 | return 17 | } 18 | 19 | writer, err := con.sock.GetWriter() 20 | if err != nil { 21 | srv.errorLog.Printf( 22 | "couldn't get writer for connection %p: %s", 23 | con, 24 | err, 25 | ) 26 | return 27 | } 28 | 29 | switch err := reqErr.(type) { 30 | case ErrRequest: 31 | if err := message.WriteMsgReplyError( 32 | writer, 33 | msg.MsgIdentifierBytes, 34 | []byte(err.Code), 35 | []byte(err.Message), 36 | true, 37 | ); err != nil { 38 | srv.errorLog.Println("couldn't write error reply message: ", err) 39 | return 40 | } 41 | case *ErrRequest: 42 | if err := message.WriteMsgReplyError( 43 | writer, 44 | msg.MsgIdentifierBytes, 45 | []byte(err.Code), 46 | []byte(err.Message), 47 | true, 48 | ); err != nil { 49 | srv.errorLog.Println("couldn't write error reply message: ", err) 50 | return 51 | } 52 | case ErrMaxSessConnsReached: 53 | if err := message.WriteMsgSpecialRequestReply( 54 | writer, 55 | message.MsgReplyMaxSessConnsReached, 56 | msg.MsgIdentifierBytes, 57 | ); err != nil { 58 | srv.errorLog.Println( 59 | "couldn't write max sessions reached message: ", 60 | err, 61 | ) 62 | return 63 | } 64 | case ErrSessionNotFound: 65 | if err := message.WriteMsgSpecialRequestReply( 66 | writer, 67 | message.MsgReplySessionNotFound, 68 | msg.MsgIdentifierBytes, 69 | ); err != nil { 70 | srv.errorLog.Println( 71 | "couldn't write session not found message: ", 72 | err, 73 | ) 74 | return 75 | } 76 | case ErrSessionsDisabled: 77 | if err := message.WriteMsgSpecialRequestReply( 78 | writer, 79 | message.MsgReplySessionsDisabled, 80 | msg.MsgIdentifierBytes, 81 | ); err != nil { 82 | srv.errorLog.Println( 83 | "couldn't write sessions disabled message: ", 84 | err, 85 | ) 86 | return 87 | } 88 | default: 89 | if err := message.WriteMsgSpecialRequestReply( 90 | writer, 91 | message.MsgReplyInternalError, 92 | msg.MsgIdentifierBytes, 93 | ); err != nil { 94 | srv.errorLog.Println( 95 | "couldn't write internal error message: ", 96 | err, 97 | ) 98 | return 99 | } 100 | } 101 | } 102 | 103 | // failMsgShutdown sends request failure reply due to current server shutdown 104 | func (srv *server) failMsgShutdown(con *connection, msg *message.Message) { 105 | writer, err := con.sock.GetWriter() 106 | if err != nil { 107 | srv.errorLog.Printf( 108 | "couldn't get writer for connection %p: %s", 109 | con, 110 | err, 111 | ) 112 | } 113 | 114 | if err := message.WriteMsgSpecialRequestReply( 115 | writer, 116 | message.MsgReplyShutdown, 117 | msg.MsgIdentifierBytes, 118 | ); err != nil { 119 | srv.errorLog.Println("failed writing shutdown reply message: ", err) 120 | return 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /fulfillMsg.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import "github.com/qbeon/webwire-go/message" 4 | 5 | // fulfillMsg fulfills the message sending the reply 6 | func (srv *server) fulfillMsg( 7 | con *connection, 8 | msg *message.Message, 9 | replyPayload Payload, 10 | ) { 11 | writer, err := con.sock.GetWriter() 12 | if err != nil { 13 | srv.errorLog.Printf( 14 | "couldn't get writer for connection %p: %s", 15 | con, 16 | err, 17 | ) 18 | return 19 | } 20 | 21 | if err := message.WriteMsgReply( 22 | writer, 23 | msg.MsgIdentifierBytes, 24 | replyPayload.Encoding, 25 | replyPayload.Data, 26 | ); err != nil { 27 | srv.errorLog.Printf( 28 | "couldn't write reply message for connection %p: %s", 29 | con, 30 | err, 31 | ) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /genericSessionInfo.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "reflect" 5 | ) 6 | 7 | func deepCopy(src interface{}) interface{} { 8 | if src == nil { 9 | return nil 10 | } 11 | 12 | // Make the interface a reflect.Value 13 | original := reflect.ValueOf(src) 14 | 15 | // Make a copy of the same type as the original. 16 | cpy := reflect.New(original.Type()).Elem() 17 | 18 | // Recursively copy the original. 19 | copyRecursive(original, cpy) 20 | 21 | // Return the copy as an interface. 22 | return cpy.Interface() 23 | } 24 | 25 | func copyRecursive(original, cpy reflect.Value) { 26 | // handle according to original's Kind 27 | switch original.Kind() { 28 | case reflect.Interface: 29 | // If this is a nil, don't do anything 30 | if original.IsNil() { 31 | return 32 | } 33 | // Get the value for the interface, not the pointer. 34 | originalValue := original.Elem() 35 | 36 | // Get the value by calling Elem(). 37 | copyValue := reflect.New(originalValue.Type()).Elem() 38 | copyRecursive(originalValue, copyValue) 39 | cpy.Set(copyValue) 40 | 41 | case reflect.Slice: 42 | if original.IsNil() { 43 | return 44 | } 45 | // Make a new slice and copy each element. 46 | cpy.Set(reflect.MakeSlice( 47 | original.Type(), 48 | original.Len(), 49 | original.Cap(), 50 | )) 51 | for i := 0; i < original.Len(); i++ { 52 | copyRecursive(original.Index(i), cpy.Index(i)) 53 | } 54 | 55 | case reflect.Map: 56 | if original.IsNil() { 57 | return 58 | } 59 | cpy.Set(reflect.MakeMap(original.Type())) 60 | for _, key := range original.MapKeys() { 61 | originalValue := original.MapIndex(key) 62 | copyValue := reflect.New(originalValue.Type()).Elem() 63 | copyRecursive(originalValue, copyValue) 64 | copyKey := deepCopy(key.Interface()) 65 | cpy.SetMapIndex(reflect.ValueOf(copyKey), copyValue) 66 | } 67 | 68 | default: 69 | cpy.Set(original) 70 | } 71 | } 72 | 73 | // GenericSessionInfo defines a default webwire.SessionInfo interface 74 | // implementation type used by the client when no explicit session info 75 | // parser is used 76 | type GenericSessionInfo struct { 77 | data map[string]interface{} 78 | } 79 | 80 | // Copy implements the webwire.SessionInfo interface. 81 | // It deep-copies the object and returns it's exact clone 82 | func (sinf *GenericSessionInfo) Copy() SessionInfo { 83 | return &GenericSessionInfo{ 84 | data: deepCopy(sinf.data).(map[string]interface{}), 85 | } 86 | } 87 | 88 | // Fields implements the webwire.SessionInfo interface. 89 | // It returns a constant list of the names of all fields of the object 90 | func (sinf *GenericSessionInfo) Fields() []string { 91 | if sinf.data == nil { 92 | return make([]string, 0) 93 | } 94 | names := make([]string, len(sinf.data)) 95 | index := 0 96 | for fieldName := range sinf.data { 97 | names[index] = fieldName 98 | index++ 99 | } 100 | return names 101 | } 102 | 103 | // Value implements the webwire.SessionInfo interface. 104 | // It returns an exact deep copy of a session info field value 105 | func (sinf *GenericSessionInfo) Value(fieldName string) interface{} { 106 | if sinf.data == nil { 107 | return nil 108 | } 109 | if val, exists := sinf.data[fieldName]; exists { 110 | return deepCopy(val) 111 | } 112 | return nil 113 | } 114 | 115 | // GenericSessionInfoParser represents a default implementation of a 116 | // session info object parser. It parses the info object into a generic 117 | // session info type implementing the webwire.SessionInfo interface 118 | func GenericSessionInfoParser(data map[string]interface{}) SessionInfo { 119 | return &GenericSessionInfo{data} 120 | } 121 | -------------------------------------------------------------------------------- /genericSessionInfo_test.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | // TestGenericSessionInfoCopy tests the Copy method 10 | // of the generic session info implementation 11 | func TestGenericSessionInfoCopy(t *testing.T) { 12 | original := SessionInfo(&GenericSessionInfo{ 13 | data: map[string]interface{}{ 14 | "field1": "value1", 15 | "field2": "value2", 16 | }, 17 | }) 18 | 19 | copied := original.Copy() 20 | 21 | check := func() { 22 | require.ElementsMatch(t, []string{"field1", "field2"}, copied.Fields()) 23 | require.Equal(t, "value1", copied.Value("field1")) 24 | require.Equal(t, "value2", copied.Value("field2")) 25 | } 26 | 27 | // Verify consistency 28 | check() 29 | 30 | // Verify immutability 31 | delete(original.(*GenericSessionInfo).data, "field1") 32 | original.(*GenericSessionInfo).data["field2"] = "another_value" 33 | original.(*GenericSessionInfo).data["field3"] = "another_value" 34 | check() 35 | } 36 | 37 | // TestGenericSessionInfoValue tests the Value getter method 38 | // of the generic session info implementation 39 | func TestGenericSessionInfoValue(t *testing.T) { 40 | info := SessionInfo(&GenericSessionInfo{ 41 | data: map[string]interface{}{ 42 | "string": "stringValue", 43 | "float": 12.5, 44 | "arrayOfStrings": []string{"item1", "item2"}, 45 | }, 46 | }) 47 | 48 | // Check types 49 | require.IsType(t, float64(12.5), info.Value("float")) 50 | require.IsType(t, []string{}, info.Value("arrayOfStrings")) 51 | 52 | // Check values 53 | require.Equal(t, "stringValue", info.Value("string")) 54 | require.Equal(t, 12.5, info.Value("float")) 55 | require.Equal(t, []string{"item1", "item2"}, info.Value("arrayOfStrings")) 56 | 57 | // Check value of inexistent field 58 | require.Nil(t, info.Value("inexistent")) 59 | } 60 | 61 | // TestGenericSessionInfoEmpty tests working with an empty 62 | // generic session info instance 63 | func TestGenericSessionInfoEmpty(t *testing.T) { 64 | info := SessionInfo(&GenericSessionInfo{}) 65 | 66 | check := func(info SessionInfo) { 67 | require.Equal(t, nil, info.Value("inexistent")) 68 | require.Equal(t, make([]string, 0), info.Fields()) 69 | } 70 | 71 | // Check values 72 | check(info) 73 | 74 | copied := info.Copy() 75 | require.NotNil(t, copied) 76 | check(copied) 77 | } 78 | -------------------------------------------------------------------------------- /handleConnection.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | func (srv *server) writeConfMessage(sock Socket) error { 9 | writer, err := sock.GetWriter() 10 | if err != nil { 11 | return fmt.Errorf( 12 | "couldn't get writer for configuration message: %s", 13 | err, 14 | ) 15 | } 16 | 17 | if _, err := writer.Write(srv.configMsg); err != nil { 18 | if closeErr := writer.Close(); closeErr != nil { 19 | return fmt.Errorf( 20 | "couldn't close writer after failed conf message write: %s: %s", 21 | err, 22 | closeErr, 23 | ) 24 | } 25 | return fmt.Errorf("couldn't write configuration message: %s", err) 26 | } 27 | 28 | if err := writer.Close(); err != nil { 29 | return fmt.Errorf("couldn't close writer: %s", err) 30 | } 31 | 32 | return nil 33 | } 34 | 35 | func (srv *server) handleConnection( 36 | connectionOptions ConnectionOptions, 37 | sock Socket, 38 | ) { 39 | // Send server configuration message 40 | if err := srv.writeConfMessage(sock); err != nil { 41 | srv.errorLog.Println("couldn't write config message: ", err) 42 | if closeErr := sock.Close(); closeErr != nil { 43 | srv.errorLog.Println("couldn't close socket: ", closeErr) 44 | } 45 | return 46 | } 47 | 48 | // Register connected client 49 | connection := newConnection( 50 | sock, 51 | srv, 52 | connectionOptions, 53 | ) 54 | 55 | srv.connectionsLock.Lock() 56 | srv.connections = append(srv.connections, connection) 57 | srv.connectionsLock.Unlock() 58 | 59 | // Call hook on successful connection 60 | srv.impl.OnClientConnected(connectionOptions, connection) 61 | 62 | for { 63 | // Get a message buffer 64 | msg := srv.messagePool.Get() 65 | 66 | // Await message 67 | if err := sock.Read( 68 | msg, 69 | time.Now().Add(srv.options.ReadTimeout), // Deadline 70 | ); err != nil { 71 | msg.Close() 72 | 73 | if !err.IsCloseErr() { 74 | srv.warnLog.Printf("abnormal closure error: %s", err) 75 | } 76 | 77 | connection.Close() 78 | srv.impl.OnClientDisconnected(connection, err) 79 | break 80 | } 81 | 82 | // Parse & handle the message 83 | if err := srv.handleMessage(connection, msg); err != nil { 84 | srv.errorLog.Print("message handler failed: ", err) 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /handleMessage.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import "github.com/qbeon/webwire-go/message" 4 | 5 | // handleMessage parses and handles incoming messages 6 | func (srv *server) handleMessage( 7 | con *connection, 8 | msg *message.Message, 9 | ) (err error) { 10 | // Don't register a task handler for heartbeat messages 11 | // 12 | // TODO: probably this check should include any message type that's not 13 | // handled by handleMessage to avoid registering a handler 14 | if msg.MsgType == message.MsgHeartbeat { 15 | // Release message buffer 16 | msg.Close() 17 | return nil 18 | } 19 | 20 | if !srv.registerHandler(con, msg) { 21 | // Release message buffer 22 | msg.Close() 23 | return nil 24 | } 25 | 26 | // Message buffers are released by the individual handlers 27 | switch msg.MsgType { 28 | case message.MsgSignalBinary, 29 | message.MsgSignalUtf8, 30 | message.MsgSignalUtf16: 31 | if con.options.ConcurrencyLimit < 0 || 32 | con.options.ConcurrencyLimit > 1 { 33 | go srv.handleSignal(con, msg) 34 | } else { 35 | srv.handleSignal(con, msg) 36 | } 37 | 38 | case message.MsgRequestBinary, 39 | message.MsgRequestUtf8, 40 | message.MsgRequestUtf16: 41 | if con.options.ConcurrencyLimit < 0 || 42 | con.options.ConcurrencyLimit > 1 { 43 | go srv.handleRequest(con, msg) 44 | } else { 45 | srv.handleRequest(con, msg) 46 | } 47 | 48 | case message.MsgRequestRestoreSession: 49 | srv.handleSessionRestore(con, msg) 50 | case message.MsgRequestCloseSession: 51 | srv.handleSessionClosure(con, msg) 52 | 53 | default: 54 | // Immediately deregister handlers for unexpected message types 55 | srv.deregisterHandler(con) 56 | 57 | // Release message buffer 58 | msg.Close() 59 | } 60 | 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /handleRequest.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | ) 8 | 9 | // handleRequest handles incoming requests 10 | // and returns an error if the ongoing connection cannot be proceeded 11 | func (srv *server) handleRequest(con *connection, msg *message.Message) { 12 | // Execute user-space hook 13 | replyPayload, returnedErr := srv.impl.OnRequest( 14 | context.Background(), 15 | con, 16 | msg, 17 | ) 18 | 19 | // Handle returned error 20 | switch returnedErr.(type) { 21 | case nil: 22 | srv.fulfillMsg(con, msg, replyPayload) 23 | case ErrRequest: 24 | srv.failMsg(con, msg, returnedErr) 25 | case *ErrRequest: 26 | srv.failMsg(con, msg, returnedErr) 27 | default: 28 | srv.errorLog.Printf( 29 | "request handler internal error: %v", 30 | returnedErr, 31 | ) 32 | srv.failMsg(con, msg, nil) 33 | } 34 | 35 | srv.deregisterHandler(con) 36 | 37 | // Release message buffer 38 | msg.Close() 39 | } 40 | -------------------------------------------------------------------------------- /handleSessionClosure.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "github.com/qbeon/webwire-go/message" 5 | ) 6 | 7 | // handleSessionClosure handles session destruction requests 8 | // and returns an error if the ongoing connection cannot be proceeded 9 | func (srv *server) handleSessionClosure( 10 | con *connection, 11 | msg *message.Message, 12 | ) { 13 | finalize := func() { 14 | srv.deregisterHandler(con) 15 | 16 | // Release message buffer 17 | msg.Close() 18 | } 19 | 20 | if !srv.sessionsEnabled { 21 | srv.failMsg(con, msg, ErrSessionsDisabled{}) 22 | finalize() 23 | return 24 | } 25 | 26 | if !con.HasSession() { 27 | // Send confirmation even though no session was closed 28 | srv.fulfillMsg(con, msg, Payload{}) 29 | finalize() 30 | return 31 | } 32 | 33 | // Deregister session from active sessions registry destroying it if it's 34 | // the last connection left 35 | srv.sessionRegistry.deregister(con, true) 36 | 37 | // Reset the session on the connection 38 | con.setSession(nil) 39 | 40 | // Send confirmation 41 | srv.fulfillMsg(con, msg, Payload{}) 42 | finalize() 43 | } 44 | -------------------------------------------------------------------------------- /handleSessionRestore.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/qbeon/webwire-go/message" 8 | ) 9 | 10 | // handleSessionRestore handles session restoration (by session key) requests 11 | // and returns an error if the ongoing connection cannot be proceeded 12 | func (srv *server) handleSessionRestore( 13 | con *connection, 14 | msg *message.Message, 15 | ) { 16 | finalize := func() { 17 | srv.deregisterHandler(con) 18 | 19 | // Release message buffer 20 | msg.Close() 21 | } 22 | 23 | if !srv.sessionsEnabled { 24 | srv.failMsg(con, msg, ErrSessionsDisabled{}) 25 | finalize() 26 | return 27 | } 28 | 29 | key := string(msg.MsgPayload.Data) 30 | 31 | sessConsNum := srv.sessionRegistry.sessionConnectionsNum(key) 32 | if sessConsNum >= 0 && srv.sessionRegistry.maxConns > 0 && 33 | uint(sessConsNum+1) > srv.sessionRegistry.maxConns { 34 | srv.failMsg(con, msg, ErrMaxSessConnsReached{}) 35 | finalize() 36 | return 37 | } 38 | 39 | // Call session manager lookup hook 40 | result, err := srv.sessionManager.OnSessionLookup(key) 41 | 42 | if err != nil { 43 | // Fail message with internal error and log it in case the handler fails 44 | srv.failMsg(con, msg, nil) 45 | finalize() 46 | srv.errorLog.Printf("session search handler failed: %s", err) 47 | return 48 | } 49 | 50 | if result == nil { 51 | // Fail message with special error if the session wasn't found 52 | srv.failMsg(con, msg, ErrSessionNotFound{}) 53 | finalize() 54 | return 55 | } 56 | 57 | sessionCreation := result.Creation() 58 | sessionLastLookup := result.LastLookup() 59 | sessionInfo := result.Info() 60 | 61 | // JSON encode the session 62 | encodedSessionObj := JSONEncodedSession{ 63 | Key: key, 64 | Creation: sessionCreation, 65 | LastLookup: sessionLastLookup, 66 | Info: sessionInfo, 67 | } 68 | encodedSession, err := json.Marshal(&encodedSessionObj) 69 | if err != nil { 70 | srv.failMsg(con, msg, nil) 71 | finalize() 72 | srv.errorLog.Printf( 73 | "couldn't encode session object (%v): %s", 74 | encodedSessionObj, 75 | err, 76 | ) 77 | return 78 | } 79 | 80 | // Parse attached session info 81 | var parsedSessInfo SessionInfo 82 | if sessionInfo != nil && srv.sessionInfoParser != nil { 83 | parsedSessInfo = srv.sessionInfoParser(sessionInfo) 84 | } 85 | 86 | con.setSession(&Session{ 87 | Key: key, 88 | Creation: sessionCreation, 89 | LastLookup: sessionLastLookup, 90 | Info: parsedSessInfo, 91 | }) 92 | if err := srv.sessionRegistry.register(con); err != nil { 93 | panic(fmt.Errorf("the number of concurrent session connections was " + 94 | "unexpectedly exceeded", 95 | )) 96 | } 97 | 98 | srv.fulfillMsg( 99 | con, 100 | msg, 101 | Payload{ 102 | Encoding: EncodingUtf8, 103 | Data: encodedSession, 104 | }, 105 | ) 106 | finalize() 107 | } 108 | -------------------------------------------------------------------------------- /handleSignal.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | ) 8 | 9 | // handleSignal handles incoming signals 10 | // and returns an error if the ongoing connection cannot be proceeded 11 | func (srv *server) handleSignal(con *connection, msg *message.Message) { 12 | srv.impl.OnSignal(context.Background(), con, msg) 13 | 14 | srv.deregisterHandler(con) 15 | 16 | // Release message buffer 17 | msg.Close() 18 | } 19 | -------------------------------------------------------------------------------- /isShuttingDown.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | // isShuttingDown returns true if the server is currently shutting down, 4 | // otherwise returns false 5 | func (srv *server) isShuttingDown() bool { 6 | srv.opsLock.Lock() 7 | if srv.shutdown { 8 | srv.opsLock.Unlock() 9 | return true 10 | } 11 | srv.opsLock.Unlock() 12 | return false 13 | } 14 | -------------------------------------------------------------------------------- /message/buffer.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "io/ioutil" 7 | ) 8 | 9 | // Buffer represents a message buffer 10 | type Buffer struct { 11 | buf []byte 12 | len int 13 | } 14 | 15 | // Bytes returns a full-length slice of the buffer 16 | func (buf *Buffer) Bytes() []byte { 17 | return buf.buf 18 | } 19 | 20 | // IsEmpty returns true if the buffer is empty, otherwise returns false 21 | func (buf *Buffer) IsEmpty() bool { 22 | return buf.len < 1 23 | } 24 | 25 | // Close resets the message buffer and puts it back into the original pool 26 | func (buf *Buffer) Close() { 27 | buf.len = 0 28 | } 29 | 30 | // Read reads from the given reader until EOF or error 31 | func (buf *Buffer) Read(reader io.Reader) error { 32 | cursor := 0 33 | for { 34 | if cursor >= len(buf.buf) { 35 | // Expect EOF on full buffer 36 | _, err := reader.Read(buf.buf) 37 | if err != io.EOF { 38 | // Overflow! Discard the message that's bigger than the buffer 39 | buf.Close() 40 | io.Copy(ioutil.Discard, reader) 41 | return errors.New("message buffer overflow") 42 | } 43 | 44 | // Successfully read out the reader 45 | buf.len = cursor 46 | return nil 47 | } 48 | 49 | readBytes, err := reader.Read(buf.buf[cursor:]) 50 | cursor += readBytes 51 | 52 | if readBytes < 0 { 53 | panic("negative read len") 54 | } 55 | if err != nil { 56 | if err == io.EOF { 57 | buf.len = cursor 58 | return nil 59 | } 60 | 61 | buf.Close() 62 | return err 63 | } 64 | } 65 | } 66 | 67 | // Data returns a slice of the usable part of the buffer 68 | func (buf *Buffer) Data() []byte { 69 | return buf.buf[:buf.len] 70 | } 71 | -------------------------------------------------------------------------------- /message/calcMsgLen.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import pld "github.com/qbeon/webwire-go/payload" 4 | 5 | // CalcMsgLenSignal returns the size of a signal message with the given name and 6 | // payload 7 | func CalcMsgLenSignal( 8 | name []byte, 9 | encoding pld.Encoding, 10 | payload []byte, 11 | ) int { 12 | if encoding == pld.Utf16 && len(name)%2 != 0 { 13 | return 3 + len(name) + len(payload) 14 | } 15 | return 2 + len(name) + len(payload) 16 | } 17 | 18 | // CalcMsgLenRequest returns the size of a request message with the given name 19 | // and payload 20 | func CalcMsgLenRequest( 21 | name []byte, 22 | encoding pld.Encoding, 23 | payload []byte, 24 | ) int { 25 | if encoding == pld.Utf16 && len(name)%2 != 0 { 26 | return 11 + len(name) + len(payload) 27 | } 28 | return 10 + len(name) + len(payload) 29 | } 30 | -------------------------------------------------------------------------------- /message/messageParts.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | var msgTypeHeartbeat = []byte{MsgHeartbeat} 4 | var msgTypeSessionCreated = []byte{MsgNotifySessionCreated} 5 | var msgTypeSessionClosed = []byte{MsgNotifySessionClosed} 6 | 7 | var msgTypeSignalBinary = []byte{MsgSignalBinary} 8 | var msgTypeSignalUtf8 = []byte{MsgSignalUtf8} 9 | var msgTypeSignalUtf16 = []byte{MsgSignalUtf16} 10 | 11 | var msgTypeRequestCloseSession = []byte{MsgRequestCloseSession} 12 | var msgTypeRequestRestoreSession = []byte{MsgRequestRestoreSession} 13 | 14 | var msgTypeRequestBinary = []byte{MsgRequestBinary} 15 | var msgTypeRequestUtf8 = []byte{MsgRequestUtf8} 16 | var msgTypeRequestUtf16 = []byte{MsgRequestUtf16} 17 | 18 | var msgTypeReplyError = []byte{MsgReplyError} 19 | 20 | var msgTypeReplyBinary = []byte{MsgReplyBinary} 21 | var msgTypeReplyUtf8 = []byte{MsgReplyUtf8} 22 | var msgTypeReplyUtf16 = []byte{MsgReplyUtf16} 23 | 24 | var msgTypeReplyInternalError = []byte{MsgReplyInternalError} 25 | var msgTypeReplyMaxSessConnsReached = []byte{MsgReplyMaxSessConnsReached} 26 | var msgTypeReplySessionNotFound = []byte{MsgReplySessionNotFound} 27 | var msgTypeSessionsDisabled = []byte{MsgReplySessionsDisabled} 28 | var msgTypeReplyShutdown = []byte{MsgReplyShutdown} 29 | 30 | var msgHeaderPadding = []byte{0} 31 | 32 | var msgNameLenBytes = [256]byte{ 33 | 0, 1, 2, 3, 34 | 4, 5, 6, 7, 35 | 8, 9, 10, 11, 36 | 12, 13, 14, 15, 37 | 16, 17, 18, 19, 38 | 20, 21, 22, 23, 39 | 24, 25, 26, 27, 40 | 28, 29, 30, 31, 41 | 32, 33, 34, 35, 42 | 36, 37, 38, 39, 43 | 40, 41, 42, 43, 44 | 44, 45, 46, 47, 45 | 48, 49, 50, 51, 46 | 52, 53, 54, 55, 47 | 56, 57, 58, 59, 48 | 60, 61, 62, 63, 49 | 64, 65, 66, 67, 50 | 68, 69, 70, 71, 51 | 72, 73, 74, 75, 52 | 76, 77, 78, 79, 53 | 80, 81, 82, 83, 54 | 84, 85, 86, 87, 55 | 88, 89, 90, 91, 56 | 92, 93, 94, 95, 57 | 96, 97, 98, 99, 58 | 100, 101, 102, 103, 59 | 104, 105, 106, 107, 60 | 108, 109, 110, 111, 61 | 112, 113, 114, 115, 62 | 116, 117, 118, 119, 63 | 120, 121, 122, 123, 64 | 124, 125, 126, 127, 65 | 128, 129, 130, 131, 66 | 132, 133, 134, 135, 67 | 136, 137, 138, 139, 68 | 140, 141, 142, 143, 69 | 144, 145, 146, 147, 70 | 148, 149, 150, 151, 71 | 152, 153, 154, 155, 72 | 156, 157, 158, 159, 73 | 160, 161, 162, 163, 74 | 164, 165, 166, 167, 75 | 168, 169, 170, 171, 76 | 172, 173, 174, 175, 77 | 176, 177, 178, 179, 78 | 180, 181, 182, 183, 79 | 184, 185, 186, 187, 80 | 188, 189, 190, 191, 81 | 192, 193, 194, 195, 82 | 196, 197, 198, 199, 83 | 200, 201, 202, 203, 84 | 204, 205, 206, 207, 85 | 208, 209, 210, 211, 86 | 212, 213, 214, 215, 87 | 216, 217, 218, 219, 88 | 220, 221, 222, 223, 89 | 224, 225, 226, 227, 90 | 228, 229, 230, 231, 91 | 232, 233, 234, 235, 92 | 236, 237, 238, 239, 93 | 240, 241, 242, 243, 94 | 244, 245, 246, 247, 95 | 248, 249, 250, 251, 96 | 252, 253, 254, 255, 97 | } 98 | -------------------------------------------------------------------------------- /message/newAcceptConfMessage.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | // NewAcceptConfMessage composes a server configuration message and writes it to the 10 | // given buffer 11 | func NewAcceptConfMessage(conf ServerConfiguration) ([]byte, error) { 12 | buf := make([]byte, MinLenAcceptConf+len(conf.SubProtocolName)) 13 | 14 | buf[0] = byte(MsgAcceptConf) 15 | buf[1] = byte(conf.MajorProtocolVersion) 16 | buf[2] = byte(conf.MinorProtocolVersion) 17 | 18 | readTimeoutMs := conf.ReadTimeout / time.Millisecond 19 | if readTimeoutMs > 4294967295 { 20 | return nil, fmt.Errorf( 21 | "read timeout (milliseconds) overflow in server conf message (%s)", 22 | conf.ReadTimeout.String(), 23 | ) 24 | } else if readTimeoutMs < 0 { 25 | return nil, fmt.Errorf( 26 | "negative read timeout (milliseconds) in server conf message (%d)", 27 | readTimeoutMs, 28 | ) 29 | } 30 | 31 | binary.LittleEndian.PutUint32(buf[3:7], uint32(readTimeoutMs)) 32 | binary.LittleEndian.PutUint32(buf[7:11], conf.MessageBufferSize) 33 | 34 | copy(buf[11:], conf.SubProtocolName) 35 | 36 | return buf, nil 37 | } 38 | -------------------------------------------------------------------------------- /message/newMessage.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | // NewMessage creates a new buffered message instance 4 | func NewMessage(bufferSize uint32) *Message { 5 | return &Message{ 6 | MsgBuffer: Buffer{ 7 | buf: make([]byte, bufferSize), 8 | len: 0, 9 | }, 10 | MsgIdentifierBytes: make([]byte, 8), 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /message/parse.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | 6 | pld "github.com/qbeon/webwire-go/payload" 7 | ) 8 | 9 | // parse tries to parse the message from a byte slice. 10 | // the returned parsedMsgType is set to false if the message type 11 | // couldn't be determined, otherwise it's set to true. 12 | func (msg *Message) parse() (parsedMsgType bool, err error) { 13 | if msg.MsgBuffer.IsEmpty() { 14 | return false, nil 15 | } 16 | var payloadEncoding pld.Encoding 17 | msgType := msg.MsgBuffer.buf[0:1][0] 18 | 19 | switch msgType { 20 | 21 | // Server Configuration 22 | case MsgAcceptConf: 23 | err = msg.parseAcceptConf() 24 | 25 | // Heartbeat 26 | case MsgHeartbeat: 27 | err = msg.parseHeartbeat() 28 | 29 | // Request error reply message 30 | case MsgReplyError: 31 | err = msg.parseErrorReply() 32 | 33 | // Session creation notification message 34 | case MsgNotifySessionCreated: 35 | err = msg.parseSessionCreated() 36 | 37 | // Session closure notification message 38 | case MsgNotifySessionClosed: 39 | err = msg.parseSessionClosed() 40 | 41 | // Session destruction request message 42 | case MsgRequestCloseSession: 43 | err = msg.parseCloseSession() 44 | 45 | // Signal messages 46 | case MsgSignalBinary: 47 | payloadEncoding = pld.Binary 48 | err = msg.parseSignal() 49 | case MsgSignalUtf8: 50 | payloadEncoding = pld.Utf8 51 | err = msg.parseSignal() 52 | case MsgSignalUtf16: 53 | payloadEncoding = pld.Utf16 54 | err = msg.parseSignalUtf16() 55 | 56 | // Request messages 57 | case MsgRequestBinary: 58 | payloadEncoding = pld.Binary 59 | err = msg.parseRequest() 60 | case MsgRequestUtf8: 61 | payloadEncoding = pld.Utf8 62 | err = msg.parseRequest() 63 | case MsgRequestUtf16: 64 | payloadEncoding = pld.Utf16 65 | err = msg.parseRequestUtf16() 66 | 67 | // Reply messages 68 | case MsgReplyBinary: 69 | payloadEncoding = pld.Binary 70 | err = msg.parseReply() 71 | case MsgReplyUtf8: 72 | payloadEncoding = pld.Utf8 73 | err = msg.parseReply() 74 | case MsgReplyUtf16: 75 | payloadEncoding = pld.Utf16 76 | err = msg.parseReplyUtf16() 77 | 78 | // Session restoration request message 79 | case MsgRequestRestoreSession: 80 | err = msg.parseRestoreSession() 81 | 82 | // Special reply messages 83 | case MsgReplyShutdown: 84 | err = msg.parseSpecialReplyMessage() 85 | case MsgReplyInternalError: 86 | err = msg.parseSpecialReplyMessage() 87 | case MsgReplySessionNotFound: 88 | err = msg.parseSpecialReplyMessage() 89 | case MsgReplyMaxSessConnsReached: 90 | err = msg.parseSpecialReplyMessage() 91 | case MsgReplySessionsDisabled: 92 | err = msg.parseSpecialReplyMessage() 93 | 94 | // Ignore messages of invalid message type 95 | default: 96 | return false, errors.New("invalid message type") 97 | } 98 | 99 | msg.MsgType = msgType 100 | msg.MsgPayload.Encoding = payloadEncoding 101 | return true, err 102 | } 103 | -------------------------------------------------------------------------------- /message/parseAcceptConf.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "time" 7 | ) 8 | 9 | // parseAcceptConf parses MsgAcceptConf messages 10 | func (msg *Message) parseAcceptConf() error { 11 | if msg.MsgBuffer.len < MinLenAcceptConf { 12 | return errors.New("invalid msg length, too short") 13 | } 14 | dat := msg.MsgBuffer.Data() 15 | 16 | subProtocolName := []byte(nil) 17 | if msg.MsgBuffer.len > MinLenAcceptConf { 18 | subProtocolName = dat[11:] 19 | } 20 | 21 | msg.ServerConfiguration = ServerConfiguration{ 22 | MajorProtocolVersion: dat[1:2][0], 23 | MinorProtocolVersion: dat[2:3][0], 24 | ReadTimeout: time.Duration( 25 | binary.LittleEndian.Uint32(dat[3:7]), 26 | ) * time.Millisecond, 27 | MessageBufferSize: binary.LittleEndian.Uint32(dat[7:11]), 28 | SubProtocolName: subProtocolName, 29 | } 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /message/parseAcceptConf_test.go: -------------------------------------------------------------------------------- 1 | package message_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/qbeon/webwire-go/message" 8 | pld "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // TestMsgParseAcceptNoSubprotoConf tests parsing of server configuration 13 | // messages with no subprotocol name 14 | func TestMsgParseAcceptNoSubprotoConf(t *testing.T) { 15 | srvConf := message.ServerConfiguration{ 16 | MajorProtocolVersion: 22, 17 | MinorProtocolVersion: 33, 18 | ReadTimeout: 11 * time.Second, 19 | MessageBufferSize: 8192, 20 | SubProtocolName: []byte(nil), 21 | } 22 | 23 | // Compose encoded message 24 | buf, err := message.NewAcceptConfMessage(srvConf) 25 | require.NoError(t, err) 26 | require.True(t, len(buf) > 0) 27 | 28 | // Parse 29 | actual := tryParseNoErr(t, buf) 30 | 31 | // Compare 32 | require.NotNil(t, actual.MsgBuffer) 33 | require.Equal(t, message.MsgAcceptConf, actual.MsgType) 34 | require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, actual.MsgIdentifierBytes) 35 | require.Equal(t, [8]byte{}, actual.MsgIdentifier) 36 | require.Nil(t, actual.MsgName) 37 | require.Equal(t, pld.Payload{}, actual.MsgPayload) 38 | require.Equal(t, srvConf, actual.ServerConfiguration) 39 | } 40 | 41 | // TestMsgParseAcceptConf tests parsing of server configuration messages 42 | func TestMsgParseAcceptConf(t *testing.T) { 43 | srvConf := message.ServerConfiguration{ 44 | MajorProtocolVersion: 22, 45 | MinorProtocolVersion: 33, 46 | ReadTimeout: 11 * time.Second, 47 | MessageBufferSize: 8192, 48 | SubProtocolName: []byte("test - sub-protocol name"), 49 | } 50 | 51 | // Compose encoded message 52 | buf, err := message.NewAcceptConfMessage(srvConf) 53 | require.NoError(t, err) 54 | require.True(t, len(buf) > 0) 55 | 56 | // Parse 57 | actual := tryParseNoErr(t, buf) 58 | 59 | // Compare 60 | require.NotNil(t, actual.MsgBuffer) 61 | require.Equal(t, message.MsgAcceptConf, actual.MsgType) 62 | require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, actual.MsgIdentifierBytes) 63 | require.Equal(t, [8]byte{}, actual.MsgIdentifier) 64 | require.Nil(t, actual.MsgName) 65 | require.Equal(t, pld.Payload{}, actual.MsgPayload) 66 | require.Equal(t, srvConf, actual.ServerConfiguration) 67 | } 68 | -------------------------------------------------------------------------------- /message/parseCloseSession.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "errors" 4 | 5 | // parseCloseSession parses MsgRequestCloseSession messages 6 | func (msg *Message) parseCloseSession() error { 7 | if msg.MsgBuffer.len != MinLenDoCloseSession { 8 | return errors.New( 9 | "invalid session destruction request message, too short", 10 | ) 11 | } 12 | 13 | // Read identifier 14 | msg.MsgIdentifierBytes = msg.MsgBuffer.Data()[1:9] 15 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 16 | 17 | return nil 18 | } 19 | -------------------------------------------------------------------------------- /message/parseErrorReply.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // parseErrorReply parses MsgReplyError messages writing the error code into the 11 | // name field and the UTF8 encoded error message into the payload 12 | func (msg *Message) parseErrorReply() error { 13 | if msg.MsgBuffer.len < MinLenReplyError { 14 | return errors.New("invalid error reply message, too short") 15 | } 16 | 17 | dat := msg.MsgBuffer.Data() 18 | 19 | // Read identifier 20 | msg.MsgIdentifierBytes = dat[1:9] 21 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 22 | 23 | // Read error code length flag 24 | errCodeLen := int(byte(dat[9:10][0])) 25 | errMessageOffset := 10 + errCodeLen 26 | 27 | // Verify error code length (must be at least 1 character long) 28 | if errCodeLen < 1 { 29 | return errors.New( 30 | "invalid error reply message, error code length flag is zero", 31 | ) 32 | } 33 | 34 | // Verify total message size to prevent segmentation faults 35 | // caused by inconsistent flags. This could happen if the specified 36 | // error code length doesn't correspond to the actual length 37 | // of the provided error code. 38 | // Subtract 1 character already taken into account by MinLenReplyError 39 | if msg.MsgBuffer.len < MinLenReplyError+errCodeLen-1 { 40 | return fmt.Errorf( 41 | "invalid error reply message, "+ 42 | "too short for specified code length (%d)", 43 | errCodeLen, 44 | ) 45 | } 46 | 47 | // Read UTF8 encoded error message into the payload 48 | msg.MsgName = dat[10 : 10+errCodeLen] 49 | msg.MsgPayload = pld.Payload{ 50 | Encoding: pld.Utf8, 51 | Data: dat[errMessageOffset:], 52 | } 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /message/parseHeartbeat.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "fmt" 4 | 5 | // parseHeartbeat parses heartbeat messages 6 | func (msg *Message) parseHeartbeat() error { 7 | if msg.MsgBuffer.len != 1 { 8 | return fmt.Errorf( 9 | "invalid heartbeat message (len: %d)", 10 | msg.MsgBuffer.len, 11 | ) 12 | } 13 | return nil 14 | } 15 | -------------------------------------------------------------------------------- /message/parseReply.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | 6 | pld "github.com/qbeon/webwire-go/payload" 7 | ) 8 | 9 | // parseReply parses MsgReplyBinary and MsgReplyUtf8 messages 10 | func (msg *Message) parseReply() error { 11 | if msg.MsgBuffer.len < MinLenReply { 12 | return errors.New("invalid reply message, too short") 13 | } 14 | 15 | dat := msg.MsgBuffer.Data() 16 | 17 | // Read identifier 18 | msg.MsgIdentifierBytes = dat[1:9] 19 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 20 | 21 | // Skip payload if there's none 22 | if msg.MsgBuffer.len == MinLenReply { 23 | return nil 24 | } 25 | 26 | // Read payload 27 | msg.MsgPayload = pld.Payload{ 28 | Data: dat[9:], 29 | } 30 | return nil 31 | } 32 | -------------------------------------------------------------------------------- /message/parseReplyUtf16.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | 6 | pld "github.com/qbeon/webwire-go/payload" 7 | ) 8 | 9 | func (msg *Message) parseReplyUtf16() error { 10 | if msg.MsgBuffer.len < MinLenReplyUtf16 { 11 | return errors.New("invalid UTF16 reply message, too short") 12 | } 13 | 14 | if msg.MsgBuffer.len%2 != 0 { 15 | return errors.New( 16 | "unaligned UTF16 encoded reply message " + 17 | "(probably missing header padding)", 18 | ) 19 | } 20 | 21 | dat := msg.MsgBuffer.Data() 22 | 23 | // Read identifier 24 | msg.MsgIdentifierBytes = dat[1:9] 25 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 26 | 27 | // Skip payload if there's none 28 | if msg.MsgBuffer.len == MinLenReplyUtf16 { 29 | msg.MsgPayload = pld.Payload{ 30 | Encoding: pld.Utf16, 31 | } 32 | return nil 33 | } 34 | 35 | // Read payload 36 | msg.MsgPayload = pld.Payload{ 37 | // Take header padding byte into account 38 | Data: dat[10:], 39 | } 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /message/parseRequest.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // parseRequest parses MsgRequestBinary and MsgRequestUtf8 messages 11 | func (msg *Message) parseRequest() error { 12 | if msg.MsgBuffer.len < MinLenRequest { 13 | return errors.New("invalid request message, too short") 14 | } 15 | 16 | dat := msg.MsgBuffer.Data() 17 | 18 | // Read identifier 19 | msg.MsgIdentifierBytes = dat[1:9] 20 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 21 | 22 | // Read name length 23 | nameLen := int(dat[9]) 24 | payloadOffset := 10 + nameLen 25 | 26 | // Verify total message size to prevent segmentation faults caused 27 | // by inconsistent flags. This could happen if the specified name length 28 | // doesn't correspond to the actual name length 29 | if nameLen > 0 { 30 | // Subtract one to not require the payload but at least the name 31 | if msg.MsgBuffer.len < MinLenRequest+nameLen-1 { 32 | return fmt.Errorf( 33 | "invalid request message, too short for full name (%d)", 34 | nameLen, 35 | ) 36 | } 37 | 38 | // Take name into account 39 | msg.MsgName = dat[10 : 10+nameLen] 40 | 41 | // Read payload if any 42 | if msg.MsgBuffer.len > MinLenRequest+nameLen-1 { 43 | msg.MsgPayload = pld.Payload{ 44 | Data: dat[payloadOffset:], 45 | } 46 | } 47 | } else { 48 | // No name present, expect just the payload to be in place 49 | msg.MsgPayload = pld.Payload{ 50 | Data: dat[10:], 51 | } 52 | } 53 | 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /message/parseRequestUtf16.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // parseRequestUtf16 parses MsgRequestUtf16 messages 11 | func (msg *Message) parseRequestUtf16() error { 12 | if msg.MsgBuffer.len < MinLenRequestUtf16 { 13 | return errors.New("invalid request message, too short") 14 | } 15 | 16 | if msg.MsgBuffer.len%2 != 0 { 17 | return errors.New( 18 | "unaligned UTF16 encoded request message " + 19 | "(probably missing header padding)", 20 | ) 21 | } 22 | 23 | dat := msg.MsgBuffer.Data() 24 | 25 | // Read identifier 26 | msg.MsgIdentifierBytes = dat[1:9] 27 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 28 | 29 | // Read name length 30 | nameLen := int(dat[9]) 31 | 32 | // Determine minimum required message length. 33 | // There's at least a 10 byte header and a 2 byte payload expected 34 | minRequiredMsgSize := 12 35 | if nameLen > 0 { 36 | // ...unless a name is given, in which case the payload isn't required 37 | minRequiredMsgSize = 10 + nameLen 38 | } 39 | 40 | // A header padding byte is only expected, when there's a payload 41 | // beyond the name. It's not required if there's just the header and a name 42 | payloadOffset := 10 + nameLen 43 | if msg.MsgBuffer.len > payloadOffset && nameLen%2 != 0 { 44 | minRequiredMsgSize++ 45 | payloadOffset++ 46 | } 47 | 48 | // Verify total message size to prevent segmentation faults caused 49 | // by inconsistent flags. This could happen if the specified name length 50 | // doesn't correspond to the actual name length 51 | if nameLen > 0 { 52 | if msg.MsgBuffer.len < minRequiredMsgSize { 53 | return fmt.Errorf( 54 | "invalid request message, too short for full name (%d)", 55 | nameLen, 56 | ) 57 | } 58 | 59 | // Take name into account 60 | msg.MsgName = dat[10 : 10+nameLen] 61 | 62 | // Read payload if any 63 | if msg.MsgBuffer.len > minRequiredMsgSize { 64 | msg.MsgPayload = pld.Payload{ 65 | Data: dat[payloadOffset:], 66 | } 67 | } 68 | } else { 69 | // No name present, just payload 70 | msg.MsgPayload = pld.Payload{ 71 | Data: dat[10:], 72 | } 73 | } 74 | 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /message/parseRestoreSession.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | 6 | pld "github.com/qbeon/webwire-go/payload" 7 | ) 8 | 9 | // parseRestoreSession parses MsgRequestRestoreSession messages 10 | func (msg *Message) parseRestoreSession() error { 11 | if msg.MsgBuffer.len < MinLenRequestRestoreSession { 12 | return errors.New( 13 | "invalid session restoration request message, too short", 14 | ) 15 | } 16 | 17 | dat := msg.MsgBuffer.Data() 18 | 19 | // Read identifier 20 | msg.MsgIdentifierBytes = dat[1:9] 21 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 22 | 23 | // Read payload 24 | msg.MsgPayload = pld.Payload{ 25 | Data: dat[9:], 26 | } 27 | return nil 28 | } 29 | -------------------------------------------------------------------------------- /message/parseSessionClosed.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "errors" 4 | 5 | // parseSessionClosed parses MsgNotifySessionClosed messages 6 | func (msg *Message) parseSessionClosed() error { 7 | if msg.MsgBuffer.len != MinLenNotifySessionClosed { 8 | return errors.New( 9 | "invalid session closure notification message, too short", 10 | ) 11 | } 12 | return nil 13 | } 14 | -------------------------------------------------------------------------------- /message/parseSessionCreated.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | 6 | pld "github.com/qbeon/webwire-go/payload" 7 | ) 8 | 9 | // parseSessionCreated parses MsgNotifySessionCreated messages 10 | func (msg *Message) parseSessionCreated() error { 11 | if msg.MsgBuffer.len < MinLenNotifySessionCreated { 12 | return errors.New( 13 | "invalid session creation notification message, too short", 14 | ) 15 | } 16 | 17 | msg.MsgPayload = pld.Payload{ 18 | Data: msg.MsgBuffer.Data()[1:], 19 | } 20 | 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /message/parseSignal.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // parseSignal parses MsgSignalBinary and MsgSignalUtf8 messages 11 | func (msg *Message) parseSignal() error { 12 | if msg.MsgBuffer.len < MinLenSignal { 13 | return errors.New("invalid signal message, too short") 14 | } 15 | 16 | dat := msg.MsgBuffer.Data() 17 | 18 | // Read name length 19 | nameLen := int(dat[1]) 20 | minMsgLen := 2 + nameLen 21 | 22 | // Verify total message size to prevent segmentation faults 23 | // caused by inconsistent flags. This could happen if the specified 24 | // name length doesn't correspond to the actual name length 25 | if msg.MsgBuffer.len < minMsgLen { 26 | return fmt.Errorf( 27 | "invalid signal message, too short for full name (%d) "+ 28 | "and the minimum payload (1)", 29 | nameLen, 30 | ) 31 | } 32 | 33 | if nameLen > 0 { 34 | // Take name into account 35 | msg.MsgName = dat[2:minMsgLen] 36 | msg.MsgPayload = pld.Payload{ 37 | Data: dat[minMsgLen:], 38 | } 39 | } else { 40 | // No name present, just payload 41 | msg.MsgPayload = pld.Payload{ 42 | Data: dat[2:], 43 | } 44 | } 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /message/parseSignalUtf16.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // parseSignalUtf16 parses MsgSignalUtf16 messages 11 | func (msg *Message) parseSignalUtf16() error { 12 | if msg.MsgBuffer.len < MinLenSignalUtf16 { 13 | return errors.New("invalid signal message, too short") 14 | } 15 | 16 | if msg.MsgBuffer.len%2 != 0 { 17 | return errors.New( 18 | "Unaligned UTF16 encoded signal message " + 19 | "(probably missing header padding)", 20 | ) 21 | } 22 | 23 | dat := msg.MsgBuffer.Data() 24 | 25 | // Read name length 26 | nameLen := int(dat[1]) 27 | 28 | // Determine minimum required message length 29 | minMsgSize := MinLenSignalUtf16 + nameLen 30 | payloadOffset := 2 + nameLen 31 | 32 | // Check whether a name padding byte is to be expected 33 | if nameLen%2 != 0 { 34 | minMsgSize++ 35 | payloadOffset++ 36 | } 37 | 38 | // Verify total message size to prevent segmentation faults 39 | // caused by inconsistent flags. This could happen if the specified 40 | // name length doesn't correspond to the actual name length 41 | if msg.MsgBuffer.len < minMsgSize { 42 | return fmt.Errorf( 43 | "invalid signal message, too short for full name (%d) "+ 44 | "and the minimum payload (2)", 45 | nameLen, 46 | ) 47 | } 48 | 49 | if nameLen > 0 { 50 | // Take name into account 51 | msg.MsgName = dat[2 : 2+nameLen] 52 | msg.MsgPayload = pld.Payload{ 53 | Data: dat[payloadOffset:], 54 | } 55 | } else { 56 | // No name present, just payload 57 | msg.MsgPayload = pld.Payload{ 58 | Data: dat[2:], 59 | } 60 | } 61 | return nil 62 | } 63 | -------------------------------------------------------------------------------- /message/parseSpecialReplyMessage.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "errors" 4 | 5 | // parseSpecialReplyMessage parses the following message types: 6 | // MsgReplyShutdown, MsgReplyInternalError, MsgReplySessionNotFound, 7 | // MsgReplyMaxSessConnsReached, MsgReplySessionsDisabled 8 | func (msg *Message) parseSpecialReplyMessage() error { 9 | if msg.MsgBuffer.len < 9 { 10 | return errors.New("invalid special reply message, too short") 11 | } 12 | 13 | // Read identifier 14 | msg.MsgIdentifierBytes = msg.MsgBuffer.Data()[1:9] 15 | copy(msg.MsgIdentifier[:], msg.MsgIdentifierBytes) 16 | 17 | return nil 18 | } 19 | -------------------------------------------------------------------------------- /message/parser_corruptpayload_test.go: -------------------------------------------------------------------------------- 1 | package message_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | pld "github.com/qbeon/webwire-go/payload" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | /****************************************************************\ 12 | Parser - invalid input (corrupt payload) 13 | \****************************************************************/ 14 | 15 | // TestMsgParseReplyUtf16CorruptInput tests parsing of 16 | // UTF16 encoded reply message with a corrupted input stream 17 | // (length not divisible by 2) 18 | func TestMsgParseReplyUtf16CorruptInput(t *testing.T) { 19 | id := genRndMsgIdentifier() 20 | payload := pld.Payload{ 21 | Encoding: pld.Utf16, 22 | Data: []byte("invalid"), 23 | } 24 | 25 | // Compose encoded message 26 | // Add type flag 27 | encoded := []byte{message.MsgReplyUtf16} 28 | // Add identifier 29 | encoded = append(encoded, id[:]...) 30 | // Add header padding byte due to UTF16 encoding 31 | encoded = append(encoded, byte(0)) 32 | // Add payload 33 | encoded = append(encoded, payload.Data...) 34 | 35 | // Parse 36 | _, err := tryParse(t, encoded) 37 | require.Error(t, 38 | err, 39 | "Expected Parse to return an error due to corrupt input stream", 40 | ) 41 | } 42 | 43 | // TestMsgParseRequestUtf16CorruptInput tests parsing of a named 44 | // UTF16 encoded request with a corrupted input stream 45 | // (length not divisible by 2) 46 | func TestMsgParseRequestUtf16CorruptInput(t *testing.T) { 47 | id := genRndMsgIdentifier() 48 | name := genRndName(1, 255) 49 | payload := pld.Payload{ 50 | Encoding: pld.Utf16, 51 | Data: []byte("invalid"), 52 | } 53 | 54 | // Compose encoded message 55 | // Add type flag 56 | encoded := []byte{message.MsgRequestUtf16} 57 | // Add identifier 58 | encoded = append(encoded, id[:]...) 59 | // Add name length flag 60 | encoded = append(encoded, byte(len(name))) 61 | // Add name 62 | encoded = append(encoded, []byte(name)...) 63 | // Add header padding if necessary 64 | if len(name)%2 != 0 { 65 | encoded = append(encoded, byte(0)) 66 | } 67 | // Add payload 68 | encoded = append(encoded, payload.Data...) 69 | 70 | // Parse 71 | _, err := tryParse(t, encoded) 72 | require.Error(t, 73 | err, 74 | "Expected Parse to return an error due to corrupt input stream", 75 | ) 76 | } 77 | 78 | // TestMsgParseSignalUtf16CorruptInput tests parsing of a named 79 | // UTF16 encoded signal with a corrupt unaligned input stream 80 | // (length not divisible by 2) 81 | func TestMsgParseSignalUtf16CorruptInput(t *testing.T) { 82 | name := genRndName(1, 255) 83 | payload := pld.Payload{ 84 | Encoding: pld.Utf16, 85 | Data: []byte("invalid"), 86 | } 87 | 88 | // Compose encoded message 89 | // Add type flag 90 | encoded := []byte{message.MsgSignalUtf16} 91 | // Add name length flag 92 | encoded = append(encoded, byte(len(name))) 93 | // Add name 94 | encoded = append(encoded, []byte(name)...) 95 | // Add header padding if necessary 96 | if len(name)%2 != 0 { 97 | encoded = append(encoded, byte(0)) 98 | } 99 | // Add payload 100 | encoded = append(encoded, payload.Data...) 101 | 102 | // Parse 103 | _, err := tryParse(t, encoded) 104 | require.Error(t, 105 | err, 106 | "Expected Parse to return an error due to corrupt input stream", 107 | ) 108 | } 109 | -------------------------------------------------------------------------------- /message/parser_invtoolong_test.go: -------------------------------------------------------------------------------- 1 | package message_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | /****************************************************************\ 11 | Parser - invalid messages (too long) 12 | \****************************************************************/ 13 | 14 | // TestMsgParseInvalidSessionClosedTooLong tests parsing of an invalid 15 | // session closed notification message which is too long to be considered valid 16 | func TestMsgParseInvalidSessionClosedTooLong(t *testing.T) { 17 | lenTooLong := message.MinLenNotifySessionClosed + 1 18 | invalidMessage := make([]byte, lenTooLong) 19 | 20 | invalidMessage[0] = message.MsgNotifySessionClosed 21 | 22 | _, err := tryParse(t, invalidMessage) 23 | require.Error(t, 24 | err, 25 | "Expected error while parsing invalid session closed message "+ 26 | "(too long: %d)", 27 | lenTooLong, 28 | ) 29 | } 30 | 31 | // TestMsgParseInvalidHeartbeatTooLong tests parsing of an invalid heartbeat 32 | // message which is too long to be considered valid 33 | func TestMsgParseInvalidHeartbeatTooLong(t *testing.T) { 34 | lenTooLong := 2 35 | invalidMessage := make([]byte, lenTooLong) 36 | 37 | invalidMessage[0] = message.MsgHeartbeat 38 | 39 | _, err := tryParse(t, invalidMessage) 40 | require.Error(t, 41 | err, 42 | "Expected error while parsing invalid heartbeat message "+ 43 | "(too long: %d)", 44 | lenTooLong, 45 | ) 46 | } 47 | -------------------------------------------------------------------------------- /message/pool.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | // Pool defines the message buffer pool interface 4 | type Pool interface { 5 | // Get fetches a message buffer from the pool which must be put back when 6 | // it's no longer needed 7 | Get() *Message 8 | } 9 | -------------------------------------------------------------------------------- /message/read.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | ) 7 | 8 | // ReadBytes reads and parses the message from the given byte slice 9 | func (msg *Message) ReadBytes(bytes []byte) (typeParsed bool, err error) { 10 | if len(bytes) < 1 { 11 | return false, nil 12 | } 13 | if len(bytes) > len(msg.MsgBuffer.buf) { 14 | return false, errors.New("message buffer overflow") 15 | } 16 | if !msg.MsgBuffer.IsEmpty() { 17 | msg.MsgBuffer.Close() 18 | } 19 | copy(msg.MsgBuffer.buf[:len(bytes)], bytes) 20 | msg.MsgBuffer.len = len(bytes) 21 | return msg.parse() 22 | } 23 | 24 | // Read reads and parses the message from the given reader 25 | func (msg *Message) Read(reader io.Reader) (typeParsed bool, err error) { 26 | if err := msg.MsgBuffer.Read(reader); err != nil { 27 | return false, err 28 | } 29 | return msg.parse() 30 | } 31 | -------------------------------------------------------------------------------- /message/syncPool.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "sync" 4 | 5 | // SyncPool represents a thread-safe messageBuffer pool 6 | type SyncPool struct { 7 | bufferSize uint32 8 | pool *sync.Pool 9 | } 10 | 11 | // NewSyncPool initializes a new sync.Pool based message buffer pool instance 12 | func NewSyncPool(bufferSize, prealloc uint32) *SyncPool { 13 | pool := &sync.Pool{} 14 | pool.New = func() interface{} { 15 | msg := NewMessage(bufferSize) 16 | msg.onClose = func() { 17 | pool.Put(msg) 18 | } 19 | return msg 20 | } 21 | return &SyncPool{ 22 | bufferSize: bufferSize, 23 | pool: pool, 24 | } 25 | } 26 | 27 | // Get implements the Pool interface 28 | func (mbp *SyncPool) Get() *Message { 29 | return mbp.pool.Get().(*Message) 30 | } 31 | -------------------------------------------------------------------------------- /message/writeMsgErrorReply.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgReplyError writes an error reply message to the given writer 9 | // closing it eventually 10 | func WriteMsgReplyError( 11 | writer io.WriteCloser, 12 | requestIdent []byte, 13 | code, 14 | message []byte, 15 | safeMode bool, 16 | ) error { 17 | if safeMode { 18 | // Validate input 19 | if len(code) < 1 { 20 | initialErr := fmt.Errorf( 21 | "missing error code while creating a new error reply message", 22 | ) 23 | if err := writer.Close(); err != nil { 24 | return fmt.Errorf("%s: %s", initialErr, err) 25 | } 26 | return initialErr 27 | } else if len(code) > 255 { 28 | initialErr := fmt.Errorf( 29 | "invalid error code while creating a new error reply message,"+ 30 | "too long (%d)", 31 | len(code), 32 | ) 33 | if err := writer.Close(); err != nil { 34 | return fmt.Errorf("%s: %s", initialErr, err) 35 | } 36 | return initialErr 37 | } 38 | // Determine total message length 39 | // messageSize := 10 + len(code) + len(message) 40 | // if len(buf) < messageSize { 41 | // if closeErr := writer.Close(); closeErr != nil { 42 | // return fmt.Errorf("%s: %s", err, closeErr) 43 | // } 44 | // return 0, errors.New( 45 | // "message buffer too small to fit an error reply message", 46 | // ) 47 | // } 48 | for _, char := range code { 49 | if char < 32 || char > 126 { 50 | initialErr := fmt.Errorf( 51 | "unsupported character in reply error - error code: %s", 52 | string(char), 53 | ) 54 | if closeErr := writer.Close(); closeErr != nil { 55 | return fmt.Errorf("%s: %s", initialErr, closeErr) 56 | } 57 | return initialErr 58 | } 59 | } 60 | } 61 | 62 | // Write message type flag 63 | if _, err := writer.Write(msgTypeReplyError); err != nil { 64 | if closeErr := writer.Close(); closeErr != nil { 65 | return fmt.Errorf("%s: %s", err, closeErr) 66 | } 67 | return err 68 | } 69 | 70 | // Write request identifier 71 | if _, err := writer.Write(requestIdent); err != nil { 72 | if closeErr := writer.Close(); closeErr != nil { 73 | return fmt.Errorf("%s: %s", err, closeErr) 74 | } 75 | return err 76 | } 77 | 78 | // Write code length flag 79 | if _, err := writer.Write( 80 | msgNameLenBytes[len(code) : len(code)+1], 81 | ); err != nil { 82 | if closeErr := writer.Close(); closeErr != nil { 83 | return fmt.Errorf("%s: %s", err, closeErr) 84 | } 85 | return err 86 | } 87 | 88 | // Write error code 89 | if _, err := writer.Write(code); err != nil { 90 | if closeErr := writer.Close(); closeErr != nil { 91 | return fmt.Errorf("%s: %s", err, closeErr) 92 | } 93 | return err 94 | } 95 | 96 | // Write error message 97 | if _, err := writer.Write(message); err != nil { 98 | if closeErr := writer.Close(); closeErr != nil { 99 | return fmt.Errorf("%s: %s", err, closeErr) 100 | } 101 | return err 102 | } 103 | 104 | return writer.Close() 105 | } 106 | -------------------------------------------------------------------------------- /message/writeMsgHeartbeat.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgHeartbeat writes a session closure notification message to the 9 | // given writer closing it eventually 10 | func WriteMsgHeartbeat(writer io.WriteCloser) error { 11 | // Write message type flag 12 | if _, err := writer.Write(msgTypeHeartbeat); err != nil { 13 | if closeErr := writer.Close(); closeErr != nil { 14 | return fmt.Errorf("%s: %s", err, closeErr) 15 | } 16 | return err 17 | } 18 | 19 | return writer.Close() 20 | } 21 | -------------------------------------------------------------------------------- /message/writeMsgNamelessRequest.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgNamelessRequest writes a nameless (initially without a name) 9 | // request message to the given writer closing it eventually 10 | func WriteMsgNamelessRequest( 11 | writer io.WriteCloser, 12 | reqType byte, 13 | identifier []byte, 14 | binaryPayload []byte, 15 | ) error { 16 | msgType := msgTypeRequestCloseSession 17 | if reqType == MsgRequestRestoreSession { 18 | msgType = msgTypeRequestRestoreSession 19 | } else if reqType != MsgRequestCloseSession { 20 | panic(fmt.Errorf("unexpected nameless request type: %d", reqType)) 21 | } 22 | 23 | // Write message type flag 24 | if _, err := writer.Write(msgType); err != nil { 25 | if closeErr := writer.Close(); closeErr != nil { 26 | return fmt.Errorf("%s: %s", err, closeErr) 27 | } 28 | return err 29 | } 30 | 31 | // Write request identifier 32 | if _, err := writer.Write(identifier); err != nil { 33 | if closeErr := writer.Close(); closeErr != nil { 34 | return fmt.Errorf("%s: %s", err, closeErr) 35 | } 36 | return err 37 | } 38 | 39 | // Write payload 40 | if len(binaryPayload) > 0 { 41 | if _, err := writer.Write(binaryPayload); err != nil { 42 | if closeErr := writer.Close(); closeErr != nil { 43 | return fmt.Errorf("%s: %s", err, closeErr) 44 | } 45 | return err 46 | } 47 | } 48 | 49 | return writer.Close() 50 | } 51 | -------------------------------------------------------------------------------- /message/writeMsgReply.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // WriteMsgReply writes a reply message to the given writer closing it 11 | // eventually 12 | func WriteMsgReply( 13 | writer io.WriteCloser, 14 | requestIdentifier []byte, 15 | payloadEncoding pld.Encoding, 16 | payloadData []byte, 17 | ) error { 18 | // Verify payload data validity in case of UTF16 encoding 19 | if payloadEncoding == pld.Utf16 && len(payloadData)%2 != 0 { 20 | initialErr := fmt.Errorf( 21 | "invalid UTF16 reply payload data length: %d", 22 | len(payloadData), 23 | ) 24 | if err := writer.Close(); err != nil { 25 | return fmt.Errorf("%s: %s", initialErr, err) 26 | } 27 | return initialErr 28 | } 29 | 30 | // Determine message type from payload encoding type 31 | msgType := msgTypeReplyBinary 32 | if payloadEncoding == pld.Utf8 { 33 | msgType = msgTypeReplyUtf8 34 | } else if payloadEncoding == pld.Utf16 { 35 | msgType = msgTypeReplyUtf16 36 | } 37 | 38 | // Write message type flag 39 | if _, err := writer.Write(msgType); err != nil { 40 | if closeErr := writer.Close(); closeErr != nil { 41 | return fmt.Errorf("%s: %s", err, closeErr) 42 | } 43 | return err 44 | } 45 | 46 | // Write request identifier 47 | if _, err := writer.Write(requestIdentifier); err != nil { 48 | if closeErr := writer.Close(); closeErr != nil { 49 | return fmt.Errorf("%s: %s", err, closeErr) 50 | } 51 | return err 52 | } 53 | 54 | // Write header padding byte if the payload requires proper alignment 55 | if payloadEncoding == pld.Utf16 { 56 | if _, err := writer.Write(msgHeaderPadding); err != nil { 57 | if closeErr := writer.Close(); closeErr != nil { 58 | return fmt.Errorf("%s: %s", err, closeErr) 59 | } 60 | return err 61 | } 62 | } 63 | 64 | // Write payload 65 | if _, err := writer.Write(payloadData); err != nil { 66 | if closeErr := writer.Close(); closeErr != nil { 67 | return fmt.Errorf("%s: %s", err, closeErr) 68 | } 69 | return err 70 | } 71 | 72 | return writer.Close() 73 | } 74 | -------------------------------------------------------------------------------- /message/writeMsgRequest.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | 8 | pld "github.com/qbeon/webwire-go/payload" 9 | ) 10 | 11 | // WriteMsgRequest writes a named request message to the given writer 12 | // closing it eventually 13 | func WriteMsgRequest( 14 | writer io.WriteCloser, 15 | identifier []byte, 16 | name []byte, 17 | payloadEncoding pld.Encoding, 18 | payloadData []byte, 19 | safeMode bool, 20 | ) error { 21 | // Require either a name, or a payload or both, but don't allow none 22 | if len(name) < 1 && len(payloadData) < 1 { 23 | initialErr := errors.New( 24 | "request message requires either a name, or a payload, or both", 25 | ) 26 | if err := writer.Close(); err != nil { 27 | return fmt.Errorf("%s: %s", initialErr, err) 28 | } 29 | return initialErr 30 | } 31 | 32 | // Cap name length at 255 bytes 33 | if len(name) > 255 { 34 | initialErr := fmt.Errorf( 35 | "unsupported request message name length: %d", 36 | len(name), 37 | ) 38 | if err := writer.Close(); err != nil { 39 | return fmt.Errorf("%s: %s", initialErr, err) 40 | } 41 | return initialErr 42 | } 43 | 44 | // Verify payload data validity in case of UTF16 encoding 45 | if payloadEncoding == pld.Utf16 && len(payloadData)%2 != 0 { 46 | initialErr := fmt.Errorf( 47 | "invalid UTF16 request payload data length: %d", 48 | len(payloadData), 49 | ) 50 | if err := writer.Close(); err != nil { 51 | return fmt.Errorf("%s: %s", initialErr, err) 52 | } 53 | return initialErr 54 | } 55 | 56 | // Validate name 57 | if safeMode { 58 | for i := 0; i < len(name); i++ { 59 | char := name[i] 60 | if char < 32 || char > 126 { 61 | initialErr := fmt.Errorf( 62 | "unsupported character in request name: %s", 63 | string(char), 64 | ) 65 | if err := writer.Close(); err != nil { 66 | return fmt.Errorf("%s: %s", initialErr, err) 67 | } 68 | return initialErr 69 | } 70 | } 71 | } 72 | 73 | // Determine message type from payload encoding type 74 | msgType := msgTypeRequestBinary 75 | if payloadEncoding == pld.Utf8 { 76 | msgType = msgTypeRequestUtf8 77 | } else if payloadEncoding == pld.Utf16 { 78 | msgType = msgTypeRequestUtf16 79 | } 80 | 81 | // Write message type flag 82 | if _, err := writer.Write(msgType); err != nil { 83 | if closeErr := writer.Close(); closeErr != nil { 84 | return fmt.Errorf("%s: %s", err, closeErr) 85 | } 86 | return err 87 | } 88 | 89 | // Write request identifier 90 | if _, err := writer.Write(identifier); err != nil { 91 | if closeErr := writer.Close(); closeErr != nil { 92 | return fmt.Errorf("%s: %s", err, closeErr) 93 | } 94 | return err 95 | } 96 | 97 | // Write name length flag 98 | if _, err := writer.Write( 99 | msgNameLenBytes[len(name) : len(name)+1], 100 | ); err != nil { 101 | if closeErr := writer.Close(); closeErr != nil { 102 | return fmt.Errorf("%s: %s", err, closeErr) 103 | } 104 | return err 105 | } 106 | // Write name 107 | if _, err := writer.Write(name); err != nil { 108 | if closeErr := writer.Close(); closeErr != nil { 109 | return fmt.Errorf("%s: %s", err, closeErr) 110 | } 111 | return err 112 | } 113 | 114 | // Write header padding byte if the payload requires proper alignment 115 | if payloadEncoding == pld.Utf16 && len(name)%2 != 0 { 116 | if _, err := writer.Write(msgHeaderPadding); err != nil { 117 | if closeErr := writer.Close(); closeErr != nil { 118 | return fmt.Errorf("%s: %s", err, closeErr) 119 | } 120 | return err 121 | } 122 | } 123 | 124 | // Write payload 125 | if _, err := writer.Write(payloadData); err != nil { 126 | if closeErr := writer.Close(); closeErr != nil { 127 | return fmt.Errorf("%s: %s", err, closeErr) 128 | } 129 | return err 130 | } 131 | 132 | return writer.Close() 133 | } 134 | -------------------------------------------------------------------------------- /message/writeMsgSessionClosed.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgNotifySessionClosed writes a session closure notification message to the 9 | // given writer closing it eventually 10 | func WriteMsgNotifySessionClosed(writer io.WriteCloser) error { 11 | // Write message type flag 12 | if _, err := writer.Write(msgTypeSessionClosed); err != nil { 13 | if closeErr := writer.Close(); closeErr != nil { 14 | return fmt.Errorf("%s: %s", err, closeErr) 15 | } 16 | return err 17 | } 18 | 19 | return writer.Close() 20 | } 21 | -------------------------------------------------------------------------------- /message/writeMsgSessionCreated.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgNotifySessionCreated writes a session creation notification message to 9 | // the given writer closing it eventually 10 | func WriteMsgNotifySessionCreated( 11 | writer io.WriteCloser, 12 | sessionInfo []byte, 13 | ) error { 14 | // Write message type flag 15 | if _, err := writer.Write(msgTypeSessionCreated); err != nil { 16 | if closeErr := writer.Close(); closeErr != nil { 17 | return fmt.Errorf("%s: %s", err, closeErr) 18 | } 19 | return err 20 | } 21 | 22 | // Write the session info payload 23 | if _, err := writer.Write(sessionInfo); err != nil { 24 | if closeErr := writer.Close(); closeErr != nil { 25 | return fmt.Errorf("%s: %s", err, closeErr) 26 | } 27 | return err 28 | } 29 | 30 | return writer.Close() 31 | } 32 | -------------------------------------------------------------------------------- /message/writeMsgSignal.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | pld "github.com/qbeon/webwire-go/payload" 8 | ) 9 | 10 | // WriteMsgSignal writes a named signal message to the given writer closing 11 | // it eventually 12 | func WriteMsgSignal( 13 | writer io.WriteCloser, 14 | name []byte, 15 | payloadEncoding pld.Encoding, 16 | payloadData []byte, 17 | safeMode bool, 18 | ) error { 19 | if len(name) > 255 { 20 | initialErr := fmt.Errorf( 21 | "unsupported request message name length: %d", 22 | len(name), 23 | ) 24 | if err := writer.Close(); err != nil { 25 | return fmt.Errorf("%s: %s", initialErr, err) 26 | } 27 | return initialErr 28 | } 29 | 30 | // Verify payload data validity in case of UTF16 encoding 31 | if payloadEncoding == pld.Utf16 && len(payloadData)%2 != 0 { 32 | initialErr := fmt.Errorf( 33 | "invalid UTF16 signal payload data length: %d", 34 | len(payloadData), 35 | ) 36 | if err := writer.Close(); err != nil { 37 | return fmt.Errorf("%s: %s", initialErr, err) 38 | } 39 | return initialErr 40 | } 41 | 42 | if safeMode { 43 | for i := range name { 44 | char := name[i] 45 | if char < 32 || char > 126 { 46 | initialErr := fmt.Errorf( 47 | "unsupported character in request name: %s", 48 | string(char), 49 | ) 50 | if err := writer.Close(); err != nil { 51 | return fmt.Errorf("%s: %s", initialErr, err) 52 | } 53 | return initialErr 54 | } 55 | } 56 | } 57 | 58 | // Determine the message type from the payload encoding type 59 | msgType := msgTypeSignalBinary 60 | if payloadEncoding == pld.Utf8 { 61 | msgType = msgTypeSignalUtf8 62 | } else if payloadEncoding == pld.Utf16 { 63 | msgType = msgTypeSignalUtf16 64 | } 65 | 66 | // Write message type flag 67 | if _, err := writer.Write(msgType); err != nil { 68 | if closeErr := writer.Close(); closeErr != nil { 69 | return fmt.Errorf("%s: %s", err, closeErr) 70 | } 71 | return err 72 | } 73 | 74 | // Write name length flag 75 | if _, err := writer.Write( 76 | msgNameLenBytes[len(name) : len(name)+1], 77 | ); err != nil { 78 | if closeErr := writer.Close(); closeErr != nil { 79 | return fmt.Errorf("%s: %s", err, closeErr) 80 | } 81 | return err 82 | } 83 | 84 | // Write name (if any) 85 | if len(name) > 0 { 86 | if _, err := writer.Write(name); err != nil { 87 | if closeErr := writer.Close(); closeErr != nil { 88 | return fmt.Errorf("%s: %s", err, closeErr) 89 | } 90 | return err 91 | } 92 | } 93 | 94 | // Write header padding byte if the payload requires proper alignment 95 | if payloadEncoding == pld.Utf16 && len(name)%2 != 0 { 96 | if _, err := writer.Write(msgHeaderPadding); err != nil { 97 | if closeErr := writer.Close(); closeErr != nil { 98 | return fmt.Errorf("%s: %s", err, closeErr) 99 | } 100 | return err 101 | } 102 | } 103 | 104 | // Write payload 105 | if _, err := writer.Write(payloadData); err != nil { 106 | if closeErr := writer.Close(); closeErr != nil { 107 | return fmt.Errorf("%s: %s", err, closeErr) 108 | } 109 | return err 110 | } 111 | 112 | return writer.Close() 113 | } 114 | -------------------------------------------------------------------------------- /message/writeMsgSpecialRequestReply.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // WriteMsgSpecialRequestReply writes a special request reply message to the 9 | // given writer closing it eventually 10 | func WriteMsgSpecialRequestReply( 11 | writer io.WriteCloser, 12 | reqType byte, 13 | reqIdent []byte, 14 | ) error { 15 | msgType := msgTypeReplyInternalError 16 | switch reqType { 17 | case MsgReplyInternalError: 18 | case MsgReplyMaxSessConnsReached: 19 | msgType = msgTypeReplyMaxSessConnsReached 20 | case MsgReplySessionNotFound: 21 | msgType = msgTypeReplySessionNotFound 22 | case MsgReplySessionsDisabled: 23 | msgType = msgTypeSessionsDisabled 24 | case MsgReplyShutdown: 25 | msgType = msgTypeReplyShutdown 26 | default: 27 | initialErr := fmt.Errorf( 28 | "message type (%d) doesn't represent a special reply message", 29 | reqType, 30 | ) 31 | if err := writer.Close(); err != nil { 32 | return fmt.Errorf("%s: %s", initialErr, err) 33 | } 34 | return initialErr 35 | } 36 | 37 | // Write message type flag 38 | if _, err := writer.Write(msgType); err != nil { 39 | if closeErr := writer.Close(); closeErr != nil { 40 | return fmt.Errorf("%s: %s", err, closeErr) 41 | } 42 | return err 43 | } 44 | 45 | // Write request identifier 46 | if _, err := writer.Write(reqIdent); err != nil { 47 | if closeErr := writer.Close(); closeErr != nil { 48 | return fmt.Errorf("%s: %s", err, closeErr) 49 | } 50 | return err 51 | } 52 | 53 | return writer.Close() 54 | } 55 | -------------------------------------------------------------------------------- /message/write_corruptUtf16Payload_test.go: -------------------------------------------------------------------------------- 1 | package message_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | pld "github.com/qbeon/webwire-go/payload" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | /****************************************************************\ 12 | Constructors - invalid input (corrupt name length flags) 13 | \****************************************************************/ 14 | 15 | // TestWriteMsgReplyUtf16CorruptPayload tests WriteMsgReply using UTF16 payload 16 | // encoding passing corrupt data (length not divisible by 2 thus not UTF16 17 | // encoded) 18 | func TestWriteMsgReplyUtf16CorruptPayload(t *testing.T) { 19 | writer := &testWriter{} 20 | require.Error(t, message.WriteMsgReply( 21 | writer, 22 | genRndMsgIdentifier(), 23 | pld.Utf16, 24 | // Payload is corrupt, only 7 bytes long, not power 2 25 | []byte("invalid"), 26 | )) 27 | require.True(t, writer.closed) 28 | } 29 | 30 | // TestWriteMsgReqUtf16CorruptPayload tests WriteMsgRequest using UTF16 payload 31 | // encoding passing corrupt data (length not divisible by 2 thus not UTF16 32 | // encoded) 33 | func TestWriteMsgReqUtf16CorruptPayload(t *testing.T) { 34 | writer := &testWriter{} 35 | require.Error(t, message.WriteMsgRequest( 36 | writer, 37 | genRndMsgIdentifier(), 38 | genRndName(1, 255), 39 | pld.Utf16, 40 | // Payload is corrupt, only 7 bytes long, not power 2 41 | []byte("invalid"), 42 | true, 43 | )) 44 | require.True(t, writer.closed) 45 | } 46 | 47 | // TestWriteMsgSigUtf16CorruptPayload tests WriteMsgSignal using UTF16 48 | // payload encoding passing corrupt data (length not divisible by 2 thus not 49 | // UTF16 encoded) 50 | func TestWriteMsgSigUtf16CorruptPayload(t *testing.T) { 51 | writer := &testWriter{} 52 | require.Error(t, message.WriteMsgSignal( 53 | writer, 54 | genRndName(1, 255), 55 | pld.Utf16, 56 | // Payload is corrupt, only 7 bytes long, not power 2 57 | []byte("invalid"), 58 | true, 59 | )) 60 | require.True(t, writer.closed) 61 | } 62 | -------------------------------------------------------------------------------- /newServer.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/url" 7 | "sync" 8 | 9 | "github.com/qbeon/webwire-go/message" 10 | ) 11 | 12 | // NewServer creates a new webwire server instance 13 | func NewServer( 14 | implementation ServerImplementation, 15 | opts ServerOptions, 16 | transport Transport, 17 | ) (instance Server, err error) { 18 | if implementation == nil { 19 | return nil, errors.New("missing server implementation") 20 | } 21 | 22 | if transport == nil { 23 | return nil, errors.New("missing transport layer implementation") 24 | } 25 | 26 | if err := opts.Prepare(); err != nil { 27 | return nil, err 28 | } 29 | 30 | sessionsEnabled := false 31 | if opts.Sessions == Enabled { 32 | sessionsEnabled = true 33 | } 34 | 35 | // Prepare the configuration push message for the webwire accept handshake 36 | configMsg, err := message.NewAcceptConfMessage( 37 | message.ServerConfiguration{ 38 | MajorProtocolVersion: 2, 39 | MinorProtocolVersion: 0, 40 | ReadTimeout: opts.ReadTimeout, 41 | MessageBufferSize: opts.MessageBufferSize, 42 | SubProtocolName: opts.SubProtocolName, 43 | }, 44 | ) 45 | if err != nil { 46 | return nil, fmt.Errorf( 47 | "couldn't initialize server configuration-push message: %s", 48 | err, 49 | ) 50 | } 51 | 52 | // Initialize the webwire server 53 | srv := &server{ 54 | transport: transport, 55 | impl: implementation, 56 | sessionManager: opts.SessionManager, 57 | sessionKeyGen: opts.SessionKeyGenerator, 58 | sessionInfoParser: opts.SessionInfoParser, 59 | addr: url.URL{}, 60 | options: opts, 61 | configMsg: configMsg, 62 | shutdown: false, 63 | shutdownRdy: make(chan bool), 64 | currentOps: 0, 65 | opsLock: &sync.Mutex{}, 66 | connections: make([]*connection, 0), 67 | connectionsLock: &sync.Mutex{}, 68 | sessionsEnabled: sessionsEnabled, 69 | messagePool: message.NewSyncPool(opts.MessageBufferSize, 1024), 70 | warnLog: opts.WarnLog, 71 | errorLog: opts.ErrorLog, 72 | } 73 | 74 | srv.sessionRegistry = newSessionRegistry( 75 | opts.MaxSessionConnections, 76 | func(sessionKey string) { 77 | if err := srv.sessionManager.OnSessionClosed( 78 | sessionKey, 79 | ); err != nil { 80 | srv.errorLog.Printf("session registry ") 81 | } 82 | }, 83 | ) 84 | 85 | // Initialize the transport layer 86 | if err := transport.Initialize( 87 | opts, 88 | srv.isShuttingDown, 89 | srv.handleConnection, 90 | ); err != nil { 91 | return nil, fmt.Errorf("couldn't initialize transport layer: %s", err) 92 | } 93 | 94 | return srv, nil 95 | } 96 | -------------------------------------------------------------------------------- /payload.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import "github.com/qbeon/webwire-go/payload" 4 | 5 | // PayloadEncoding represents the type of encoding of the message payload 6 | type PayloadEncoding = payload.Encoding 7 | 8 | const ( 9 | // EncodingBinary represents unencoded binary data 10 | EncodingBinary PayloadEncoding = payload.Binary 11 | 12 | // EncodingUtf8 represents UTF8 encoding 13 | EncodingUtf8 PayloadEncoding = payload.Utf8 14 | 15 | // EncodingUtf16 represents UTF16 encoding 16 | EncodingUtf16 PayloadEncoding = payload.Utf16 17 | ) 18 | 19 | // Payload represents an encoded payload 20 | type Payload struct { 21 | // Encoding represents the encoding type of the payload which is 22 | // EncodingBinary by default 23 | Encoding PayloadEncoding 24 | 25 | // Data represents the payload data 26 | Data []byte 27 | } 28 | -------------------------------------------------------------------------------- /payload/encoding.go: -------------------------------------------------------------------------------- 1 | package payload 2 | 3 | // Encoding represents the type of encoding of the message payload 4 | type Encoding int 5 | 6 | const ( 7 | // Binary represents unencoded binary data 8 | Binary Encoding = iota 9 | 10 | // Utf8 represents UTF8 encoding 11 | Utf8 12 | 13 | // Utf16 represents UTF16 encoding 14 | Utf16 15 | ) 16 | 17 | // String stringifies the encoding type 18 | func (enc Encoding) String() string { 19 | switch enc { 20 | case Binary: 21 | return "binary" 22 | case Utf8: 23 | return "utf8" 24 | case Utf16: 25 | return "utf16" 26 | } 27 | return "" 28 | } 29 | -------------------------------------------------------------------------------- /payload/encoding_test.go: -------------------------------------------------------------------------------- 1 | package payload 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | // TestEncodingStringification tests the stringification method 10 | // of the Encoding enumeration type 11 | func TestEncodingStringification(t *testing.T) { 12 | binaryEncoding := Binary 13 | require.Equal(t, "binary", binaryEncoding.String()) 14 | 15 | utf8Encoding := Utf8 16 | require.Equal(t, "utf8", utf8Encoding.String()) 17 | 18 | utf16Encoding := Utf16 19 | require.Equal(t, "utf16", utf16Encoding.String()) 20 | } 21 | 22 | // TestConvertUtf8ToUtf8 tests the Utf8() payload conversion method 23 | // with a payload already encoded in UTF8 24 | func TestConvertUtf8ToUtf8(t *testing.T) { 25 | payload := Payload{ 26 | Encoding: Utf8, 27 | Data: []byte{65, 66, 67}, // "ABC" 28 | } 29 | 30 | result, err := payload.Utf8() 31 | require.NoError(t, err) 32 | require.Equal(t, []byte("ABC"), result) 33 | } 34 | 35 | // TestConvertBinaryToUtf8 tests the Utf8() payload conversion method 36 | // with a binary payload 37 | func TestConvertBinaryToUtf8(t *testing.T) { 38 | payload := Payload{ 39 | Encoding: Binary, 40 | Data: []byte("ABC ёжз φπμλβωϘ"), 41 | } 42 | 43 | result, err := payload.Utf8() 44 | require.NoError(t, err) 45 | require.Equal(t, "ABC ёжз φπμλβωϘ", string(result)) 46 | require.Len(t, result, 25) 47 | } 48 | 49 | // TestConvertUtf16ToUtf8 tests the Utf8() payload conversion method 50 | // with a UTF16 encoded payload 51 | func TestConvertUtf16ToUtf8(t *testing.T) { 52 | payload := Payload{ 53 | Encoding: Utf16, 54 | Data: []byte{ 55 | /* 0xFF 0xFE */ // byte order mark 56 | 0x41, 0x00, 57 | 0x42, 0x00, 58 | 0x43, 0x00, 59 | 0x20, 0x00, 60 | 0x51, 0x04, 61 | 0x36, 0x04, 62 | 0x37, 0x04, 63 | 0x20, 0x00, 64 | 0xC6, 0x03, 65 | 0xC0, 0x03, 66 | 0xBC, 0x03, 67 | 0xBB, 0x03, 68 | 0xB2, 0x03, 69 | 0xC9, 0x03, 70 | 0xD8, 0x03, 71 | }, 72 | } 73 | 74 | result, err := payload.Utf8() 75 | require.NoError(t, err) 76 | require.Equal(t, "ABC ёжз φπμλβωϘ", string(result)) 77 | require.Len(t, result, 25) 78 | } 79 | 80 | // TestConvertCorruptUtf16 tests the Utf8() payload conversion method 81 | // with a corrupted UTF16 payload 82 | func TestConvertCorruptUtf16(t *testing.T) { 83 | payload := Payload{ 84 | Encoding: Utf16, 85 | Data: []byte{65, 66, 67}, // Odd number of bytes 86 | } 87 | 88 | result, err := payload.Utf8() 89 | require.Error(t, err) 90 | require.Len(t, result, 0) 91 | } 92 | -------------------------------------------------------------------------------- /payload/payload.go: -------------------------------------------------------------------------------- 1 | package payload 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "unicode/utf16" 7 | "unicode/utf8" 8 | ) 9 | 10 | // Payload represents an encoded message payload 11 | type Payload struct { 12 | Encoding Encoding 13 | Data []byte 14 | } 15 | 16 | // Utf8 returns a UTF8 representation of the payload data 17 | func (pld *Payload) Utf8() ([]byte, error) { 18 | if pld.Encoding == Utf16 { 19 | if len(pld.Data)%2 != 0 { 20 | return nil, fmt.Errorf( 21 | "Cannot convert invalid UTF16 payload data to UTF8", 22 | ) 23 | } 24 | u16str := make([]uint16, 1) 25 | utf8str := &bytes.Buffer{} 26 | utf8buf := make([]byte, 4) 27 | for i := 0; i < len(pld.Data); i += 2 { 28 | u16str[0] = uint16(pld.Data[i]) + (uint16(pld.Data[i+1]) << 8) 29 | rn := utf16.Decode(u16str) 30 | rnSize := utf8.EncodeRune(utf8buf, rn[0]) 31 | utf8str.Write(utf8buf[:rnSize]) 32 | } 33 | return utf8str.Bytes(), nil 34 | } 35 | 36 | // Binary and UTF8 encoded payloads should pass through untouched 37 | return pld.Data, nil 38 | } 39 | -------------------------------------------------------------------------------- /registerHandler.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/qbeon/webwire-go/message" 7 | ) 8 | 9 | // registerHandler increments the number of currently executed handlers for this 10 | // particular client and returns true if a handler was registered, otherwise 11 | // returns false. It blocks if the current number of max concurrent handlers was 12 | // reached until a handler slot is available 13 | func (srv *server) registerHandler( 14 | con *connection, 15 | msg *message.Message, 16 | ) bool { 17 | failMsg := false 18 | 19 | if !con.IsActive() { 20 | return false 21 | } 22 | 23 | // Acquire handler slot if the number of concurrent handlers is limited 24 | if con.options.ConcurrencyLimit > 1 { 25 | con.handlerSlots.Acquire(context.Background(), 1) 26 | } 27 | 28 | srv.opsLock.Lock() 29 | if srv.shutdown { 30 | // defer failure due to shutdown of either the server or the connection 31 | failMsg = true 32 | } else { 33 | srv.currentOps++ 34 | } 35 | srv.opsLock.Unlock() 36 | 37 | if failMsg && msg.RequiresReply() { 38 | // Don't process the message, fail it 39 | srv.failMsgShutdown(con, msg) 40 | return false 41 | } 42 | 43 | con.registerTask() 44 | return true 45 | } 46 | -------------------------------------------------------------------------------- /requestManager/reply.go: -------------------------------------------------------------------------------- 1 | package requestmanager 2 | 3 | import ( 4 | "github.com/qbeon/webwire-go" 5 | "github.com/qbeon/webwire-go/message" 6 | ) 7 | 8 | // reply represents an implementation of the Reply interface 9 | type reply struct { 10 | msg *message.Message 11 | } 12 | 13 | // PayloadEncoding implements the Reply interface 14 | func (rp *reply) PayloadEncoding() webwire.PayloadEncoding { 15 | return rp.msg.MsgPayload.Encoding 16 | } 17 | 18 | // Payload implements the Reply interface 19 | func (rp *reply) Payload() []byte { 20 | return rp.msg.MsgPayload.Data 21 | } 22 | 23 | // PayloadUtf8 implements the Reply interface 24 | func (rp *reply) PayloadUtf8() ([]byte, error) { 25 | return rp.msg.MsgPayload.Utf8() 26 | } 27 | 28 | // Close implements the Reply interface 29 | func (rp *reply) Close() { 30 | rp.msg.Close() 31 | } 32 | -------------------------------------------------------------------------------- /requestManager/request.go: -------------------------------------------------------------------------------- 1 | package requestmanager 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/qbeon/webwire-go" 7 | "github.com/qbeon/webwire-go/message" 8 | ) 9 | 10 | // TODO: The request identifier should remain a uint64 until it's converted into 11 | // the byte array for transmission, this would slightly increase performance 12 | 13 | // genericReply is used by the request manager to represent the results of a 14 | // request that either failed or succeeded 15 | type genericReply struct { 16 | ReplyMsg *message.Message 17 | Error error 18 | } 19 | 20 | // Request represents a request created and tracked by the request manager 21 | type Request struct { 22 | // manager references the RequestManager instance managing this request 23 | manager *RequestManager 24 | 25 | // identifier represents the unique identifier of this request 26 | Identifier [8]byte 27 | IdentifierBytes []byte 28 | 29 | // Reply represents a channel for asynchronous reply handling 30 | Reply chan genericReply 31 | } 32 | 33 | // AwaitReply blocks the calling goroutine 34 | // until either the reply is fulfilled or failed, the request timed out 35 | // a user-defined deadline was exceeded or the request was prematurely canceled. 36 | // The timer is started when AwaitReply is called. 37 | func (req *Request) AwaitReply(ctx context.Context) (webwire.Reply, error) { 38 | // Block until either context canceled (including timeout) or reply received 39 | select { 40 | case <-ctx.Done(): 41 | req.manager.deregister(req.Identifier) 42 | return nil, webwire.TranslateContextError(ctx.Err()) 43 | 44 | case rp := <-req.Reply: 45 | if rp.Error != nil { 46 | return nil, rp.Error 47 | } 48 | return &reply{msg: rp.ReplyMsg}, nil 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /requestManager/requestManager.go: -------------------------------------------------------------------------------- 1 | package requestmanager 2 | 3 | import ( 4 | "encoding/binary" 5 | "sync" 6 | "sync/atomic" 7 | 8 | "github.com/qbeon/webwire-go/message" 9 | ) 10 | 11 | // RequestManager manages and keeps track of outgoing pending requests 12 | type RequestManager struct { 13 | lastID uint64 14 | lock *sync.RWMutex 15 | 16 | // pending represents an indexed list of all pending requests 17 | pending map[[8]byte]*Request 18 | } 19 | 20 | // NewRequestManager constructs and returns a new instance of a RequestManager 21 | func NewRequestManager() RequestManager { 22 | return RequestManager{ 23 | lastID: 0, 24 | lock: &sync.RWMutex{}, 25 | pending: make(map[[8]byte]*Request), 26 | } 27 | } 28 | 29 | // Create creates and registers a new request. 30 | // Create doesn't start the timeout timer, 31 | // this is done in the subsequent request.AwaitReply 32 | func (manager *RequestManager) Create() *Request { 33 | // Generate unique request identifier by incrementing the last assigned id 34 | ident := atomic.AddUint64(&manager.lastID, 1) 35 | 36 | identBytes := make([]byte, 8) 37 | binary.LittleEndian.PutUint64(identBytes, ident) 38 | newRequest := &Request{ 39 | manager: manager, 40 | IdentifierBytes: identBytes, 41 | Reply: make(chan genericReply, 1), 42 | } 43 | copy(newRequest.Identifier[:], identBytes) 44 | 45 | // Register the newly created request 46 | manager.lock.Lock() 47 | manager.pending[newRequest.Identifier] = newRequest 48 | manager.lock.Unlock() 49 | 50 | return newRequest 51 | } 52 | 53 | // deregister deregisters the given clients session from the list 54 | // of currently pending requests 55 | func (manager *RequestManager) deregister(identifier [8]byte) { 56 | manager.lock.Lock() 57 | delete(manager.pending, identifier) 58 | manager.lock.Unlock() 59 | } 60 | 61 | // Fulfill fulfills the request associated with the given request identifier 62 | // with the provided reply payload. 63 | // Returns true if a pending request was fulfilled and deregistered, 64 | // otherwise returns false 65 | func (manager *RequestManager) Fulfill(msg *message.Message) bool { 66 | manager.lock.RLock() 67 | req, exists := manager.pending[msg.MsgIdentifier] 68 | manager.lock.RUnlock() 69 | 70 | if !exists { 71 | return false 72 | } 73 | 74 | manager.deregister(msg.MsgIdentifier) 75 | req.Reply <- genericReply{ 76 | ReplyMsg: msg, 77 | } 78 | return true 79 | } 80 | 81 | // Fail fails the request associated with the given request identifier 82 | // with the provided error. Returns true if a pending request 83 | // was failed and deregistered, otherwise returns false 84 | func (manager *RequestManager) Fail( 85 | identifier [8]byte, 86 | err error, 87 | ) bool { 88 | manager.lock.RLock() 89 | req, exists := manager.pending[identifier] 90 | manager.lock.RUnlock() 91 | 92 | if !exists { 93 | return false 94 | } 95 | 96 | manager.deregister(identifier) 97 | req.Reply <- genericReply{ 98 | Error: err, 99 | } 100 | return true 101 | } 102 | 103 | // PendingRequests returns the number of currently pending requests 104 | func (manager *RequestManager) PendingRequests() int { 105 | manager.lock.RLock() 106 | len := len(manager.pending) 107 | manager.lock.RUnlock() 108 | return len 109 | } 110 | 111 | // IsPending returns true if the request associated 112 | // with the given identifier is pending 113 | func (manager *RequestManager) IsPending(identifier [8]byte) bool { 114 | manager.lock.RLock() 115 | _, exists := manager.pending[identifier] 116 | manager.lock.RUnlock() 117 | return exists 118 | } 119 | -------------------------------------------------------------------------------- /requestManager/requestManager_test.go: -------------------------------------------------------------------------------- 1 | package requestmanager_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/qbeon/webwire-go/payload" 10 | reqman "github.com/qbeon/webwire-go/requestManager" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestFulfillRequest tests RequestManager.Create, RequestManager.Fulfill, 15 | // RequestManager.IsPending and Request.AwaitReply 16 | func TestFulfillRequest(t *testing.T) { 17 | manager := reqman.NewRequestManager() 18 | 19 | // Create request 20 | request := manager.Create() 21 | require.NotNil(t, request) 22 | 23 | require.True(t, manager.IsPending(request.Identifier)) 24 | 25 | // Fulfill the request 26 | pld := payload.Payload{ 27 | Encoding: payload.Binary, 28 | Data: []byte("test payload"), 29 | } 30 | require.True(t, manager.Fulfill( 31 | &message.Message{ 32 | MsgIdentifier: request.Identifier, 33 | MsgPayload: pld, 34 | }, 35 | )) 36 | 37 | require.False(t, manager.IsPending(request.Identifier)) 38 | 39 | reply, err := request.AwaitReply(context.Background()) 40 | require.NoError(t, err) 41 | require.NotNil(t, reply) 42 | require.Equal(t, pld.Encoding, reply.PayloadEncoding()) 43 | require.Equal(t, pld.Data, reply.Payload()) 44 | } 45 | 46 | // TestFailRequest tests RequestManager.Create, RequestManager.Fail, 47 | // RequestManager.IsPending and Request.AwaitReply 48 | func TestFailRequest(t *testing.T) { 49 | manager := reqman.NewRequestManager() 50 | 51 | // Create request 52 | request := manager.Create() 53 | require.NotNil(t, request) 54 | 55 | require.True(t, manager.IsPending(request.Identifier)) 56 | 57 | // Fail the request 58 | manager.Fail(request.Identifier, errors.New("test error")) 59 | 60 | require.False(t, manager.IsPending(request.Identifier)) 61 | 62 | reply, err := request.AwaitReply(context.Background()) 63 | require.Nil(t, reply) 64 | require.Error(t, err) 65 | } 66 | 67 | // TestPendingRequests tests RequestManager.PendingRequests 68 | func TestPendingRequests(t *testing.T) { 69 | manager := reqman.NewRequestManager() 70 | require.Equal(t, 0, manager.PendingRequests()) 71 | 72 | // Create first request 73 | request1 := manager.Create() 74 | require.Equal(t, 1, manager.PendingRequests()) 75 | 76 | // Create second request 77 | request2 := manager.Create() 78 | require.Equal(t, 2, manager.PendingRequests()) 79 | 80 | // Fail the first request 81 | manager.Fail(request1.Identifier, errors.New("test error")) 82 | require.Equal(t, 1, manager.PendingRequests()) 83 | 84 | // Fulfill the second request 85 | require.True(t, manager.Fulfill( 86 | &message.Message{ 87 | MsgIdentifier: request2.Identifier, 88 | MsgPayload: payload.Payload{ 89 | Encoding: payload.Binary, 90 | Data: []byte("test payload"), 91 | }, 92 | }, 93 | )) 94 | require.Equal(t, 0, manager.PendingRequests()) 95 | } 96 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/url" 7 | "sync" 8 | 9 | "github.com/qbeon/webwire-go/message" 10 | ) 11 | 12 | // server represents a headless WebWire server instance, 13 | // where headless means there's no HTTP server that's hosting it 14 | type server struct { 15 | transport Transport 16 | impl ServerImplementation 17 | sessionManager SessionManager 18 | sessionKeyGen SessionKeyGenerator 19 | sessionInfoParser SessionInfoParser 20 | addr url.URL 21 | options ServerOptions 22 | configMsg []byte 23 | shutdown bool 24 | shutdownRdy chan bool 25 | currentOps uint32 26 | opsLock *sync.Mutex 27 | connectionsLock *sync.Mutex 28 | connections []*connection 29 | sessionsEnabled bool 30 | sessionRegistry *sessionRegistry 31 | messagePool message.Pool 32 | 33 | // Internals 34 | warnLog *log.Logger 35 | errorLog *log.Logger 36 | } 37 | 38 | // shutdownServer initiates the shutdown of the underlying transport layer 39 | func (srv *server) shutdownServer() error { 40 | if err := srv.transport.Shutdown(); err != nil { 41 | return fmt.Errorf("couldn't properly shutdown HTTP server: %s", err) 42 | } 43 | return nil 44 | } 45 | 46 | // Run implements the Server interface 47 | func (srv *server) Run() error { 48 | return srv.transport.Serve() 49 | } 50 | 51 | // Address implements the Server interface 52 | func (srv *server) Address() url.URL { 53 | return srv.transport.Address() 54 | } 55 | 56 | // Shutdown implements the Server interface 57 | func (srv *server) Shutdown() error { 58 | srv.opsLock.Lock() 59 | srv.shutdown = true 60 | // Don't block if there's no currently processed operations 61 | if srv.currentOps < 1 { 62 | srv.opsLock.Unlock() 63 | return srv.shutdownServer() 64 | } 65 | srv.opsLock.Unlock() 66 | 67 | // Wait until the server is ready for shutdown 68 | <-srv.shutdownRdy 69 | 70 | return srv.shutdownServer() 71 | } 72 | 73 | // ActiveSessionsNum implements the Server interface 74 | func (srv *server) ActiveSessionsNum() int { 75 | return srv.sessionRegistry.activeSessionsNum() 76 | } 77 | 78 | // SessionConnectionsNum implements the Server interface 79 | func (srv *server) SessionConnectionsNum(sessionKey string) int { 80 | return srv.sessionRegistry.sessionConnectionsNum(sessionKey) 81 | } 82 | 83 | // SessionConnections implements the Server interface 84 | func (srv *server) SessionConnections(sessionKey string) []Connection { 85 | return srv.sessionRegistry.sessionConnections(sessionKey) 86 | } 87 | 88 | // CloseSession implements the Server interface 89 | func (srv *server) CloseSession(sessionKey string) ( 90 | affectedConnections []Connection, 91 | errors []error, 92 | generalError error, 93 | ) { 94 | connections := srv.sessionRegistry.sessionConnections(sessionKey) 95 | 96 | errors = make([]error, len(connections)) 97 | if connections == nil { 98 | return nil, nil, nil 99 | } 100 | affectedConnections = make([]Connection, len(connections)) 101 | i := 0 102 | errNum := 0 103 | for _, connection := range connections { 104 | affectedConnections[i] = connection 105 | err := connection.CloseSession() 106 | if err != nil { 107 | errors[i] = err 108 | errNum++ 109 | } else { 110 | errors[i] = nil 111 | } 112 | i++ 113 | } 114 | 115 | if errNum > 0 { 116 | generalError = fmt.Errorf( 117 | "%d errors during the closure of a session", 118 | errNum, 119 | ) 120 | } 121 | 122 | return affectedConnections, errors, generalError 123 | } 124 | -------------------------------------------------------------------------------- /serverOptions.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // OptionValue represents the setting value of an option 11 | type OptionValue = int32 12 | 13 | const ( 14 | // OptionUnset represents the default unset value 15 | OptionUnset OptionValue = iota 16 | 17 | // Disabled disables an option 18 | Disabled 19 | 20 | // Enabled enables an option 21 | Enabled 22 | ) 23 | 24 | // ServerOptions represents the options 25 | // used during the creation of a new WebWire server instance 26 | type ServerOptions struct { 27 | Sessions OptionValue 28 | SessionManager SessionManager 29 | SessionKeyGenerator SessionKeyGenerator 30 | SessionInfoParser SessionInfoParser 31 | MaxSessionConnections uint 32 | WarnLog *log.Logger 33 | ErrorLog *log.Logger 34 | ReadTimeout time.Duration 35 | 36 | // SubProtocolName defines the optional name of the hosted webwire 37 | // sub-protocol 38 | SubProtocolName []byte 39 | 40 | // MessageBufferSize defines the size of the message buffer 41 | MessageBufferSize uint32 42 | } 43 | 44 | // Prepare verifies the specified options and sets the default values to 45 | // unspecified options 46 | func (op *ServerOptions) Prepare() error { 47 | // Enable sessions by default 48 | if op.Sessions == OptionUnset { 49 | op.Sessions = Enabled 50 | } 51 | 52 | if op.Sessions == Enabled && op.SessionManager == nil { 53 | // Force the default session manager 54 | // to use the default session directory 55 | op.SessionManager = NewDefaultSessionManager("") 56 | } 57 | 58 | if op.Sessions == Enabled && op.SessionKeyGenerator == nil { 59 | op.SessionKeyGenerator = NewDefaultSessionKeyGenerator() 60 | } 61 | 62 | if op.SessionInfoParser == nil { 63 | op.SessionInfoParser = GenericSessionInfoParser 64 | } 65 | 66 | if op.ReadTimeout < 1*time.Second { 67 | op.ReadTimeout = 60 * time.Second 68 | } 69 | 70 | // Create default loggers to std-out/err when no loggers are specified 71 | if op.WarnLog == nil { 72 | op.WarnLog = log.New( 73 | os.Stdout, 74 | "WWR_WARN: ", 75 | log.Ldate|log.Ltime|log.Lshortfile, 76 | ) 77 | } 78 | if op.ErrorLog == nil { 79 | op.ErrorLog = log.New( 80 | os.Stderr, 81 | "WWR_ERR: ", 82 | log.Ldate|log.Ltime|log.Lshortfile, 83 | ) 84 | } 85 | 86 | const minMsgBufferSize = 32 87 | 88 | // Verify the message buffer size 89 | if op.MessageBufferSize == 0 { 90 | op.MessageBufferSize = 8192 // Default buffer size: 8K 91 | } else if op.MessageBufferSize < minMsgBufferSize { 92 | return fmt.Errorf( 93 | "message buffer size too small: %d bytes (min: %d bytes)", 94 | op.MessageBufferSize, 95 | minMsgBufferSize, 96 | ) 97 | } 98 | 99 | return nil 100 | } 101 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | cryptoRand "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | "time" 8 | ) 9 | 10 | // generateRandomBytes returns securely generated random bytes. 11 | // It will return an error if the system's secure random 12 | // number generator fails to function correctly, in which 13 | // case the caller should not continue. 14 | func generateRandomBytes(length uint32) (bytes []byte, err error) { 15 | bytes = make([]byte, length) 16 | _, err = cryptoRand.Read(bytes) 17 | // Note that err == nil only if we read len(b) bytes. 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | return bytes, nil 23 | } 24 | 25 | // generateSessionKey returns a URL-safe, base64 encoded 26 | // securely generated random string. 27 | // It will return an error if the system's secure random 28 | // number generator fails to function correctly, in which 29 | // case the caller should not continue. 30 | func generateSessionKey() string { 31 | bytes, err := generateRandomBytes(48) 32 | if err != nil { 33 | panic(fmt.Errorf("Could not generate a session key")) 34 | } 35 | return base64.URLEncoding.EncodeToString(bytes) 36 | } 37 | 38 | // JSONEncodedSession represents a JSON encoded session object. 39 | // This structure is used during session restoration for unmarshalling 40 | // TODO: move to internal shared package 41 | type JSONEncodedSession struct { 42 | Key string `json:"k"` 43 | Creation time.Time `json:"c"` 44 | LastLookup time.Time `json:"l"` 45 | Info map[string]interface{} `json:"i,omitempty"` 46 | } 47 | 48 | // Session represents a session object. 49 | // If the key is empty the session is invalid. 50 | // Info can contain arbitrary attached data 51 | type Session struct { 52 | Key string 53 | Creation time.Time 54 | LastLookup time.Time 55 | Info SessionInfo 56 | } 57 | 58 | // Clone returns an exact copy of the session object 59 | func (s *Session) Clone() *Session { 60 | if s == nil { 61 | return nil 62 | } 63 | 64 | var info SessionInfo 65 | if s.Info != nil { 66 | info = s.Info.Copy() 67 | } 68 | 69 | return &Session{ 70 | Key: s.Key, 71 | Creation: s.Creation, 72 | LastLookup: s.LastLookup, 73 | Info: info, 74 | } 75 | } 76 | 77 | // NewSession generates a new session object 78 | // generating a cryptographically random secure key 79 | func NewSession(info SessionInfo, generator func() string) Session { 80 | key := generator() 81 | if len(key) < 1 { 82 | panic(fmt.Errorf( 83 | "Invalid session key returned by the session key generator (empty)", 84 | )) 85 | } 86 | timeNow := time.Now() 87 | return Session{ 88 | key, 89 | timeNow, 90 | timeNow, 91 | info, 92 | } 93 | } 94 | 95 | // DefaultSessionKeyGenerator implements 96 | // the webwire.SessionKeyGenerator interface 97 | type DefaultSessionKeyGenerator struct{} 98 | 99 | // NewDefaultSessionKeyGenerator constructs a new default 100 | // session key generator implementation 101 | func NewDefaultSessionKeyGenerator() SessionKeyGenerator { 102 | return &DefaultSessionKeyGenerator{} 103 | } 104 | 105 | // Generate implements the webwire.Sessio 106 | func (gen *DefaultSessionKeyGenerator) Generate() string { 107 | return generateSessionKey() 108 | } 109 | -------------------------------------------------------------------------------- /sessionInfoToVarMap.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | // SessionInfoToVarMap is a utility function that turns a 4 | // session info compliant object into a map of variants. 5 | // This is helpful for serialization of session info objects. 6 | func SessionInfoToVarMap(info SessionInfo) map[string]interface{} { 7 | if info == nil { 8 | return nil 9 | } 10 | varMap := make(map[string]interface{}) 11 | for _, field := range info.Fields() { 12 | varMap[field] = info.Value(field) 13 | } 14 | return varMap 15 | } 16 | -------------------------------------------------------------------------------- /sessionInfoToVarMap_test.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | // TestSessionInfoToVarMap tests the SessionInfoToVarMap function 10 | // using the generic session info implementation 11 | func TestSessionInfoToVarMap(t *testing.T) { 12 | check := func(varMap map[string]interface{}) { 13 | expectedStruct := struct { 14 | Name string 15 | Weight float64 16 | }{ 17 | Name: "samplename", 18 | Weight: 20.5, 19 | } 20 | 21 | require.Len(t, varMap, 4) 22 | require.Equal(t, "value1", varMap["field1"]) 23 | require.Equal(t, int(42), varMap["field2"]) 24 | require.Equal(t, expectedStruct, varMap["field3"]) 25 | require.IsType(t, []string{}, varMap["field4"]) 26 | require.ElementsMatch(t, []string{"item1", "item2"}, varMap["field4"]) 27 | } 28 | 29 | info := &GenericSessionInfo{ 30 | data: map[string]interface{}{ 31 | "field1": "value1", 32 | "field2": int(42), 33 | "field3": struct { 34 | Name string 35 | Weight float64 36 | }{ 37 | Name: "samplename", 38 | Weight: 20.5, 39 | }, 40 | "field4": []string{"item1", "item2"}, 41 | }, 42 | } 43 | 44 | varMap := SessionInfoToVarMap(SessionInfo(info)) 45 | check(varMap) 46 | 47 | // Test immutability, ensure fields won't mutate 48 | // even if the original session info object was changed 49 | info.data["field1"] = "mutated" 50 | info.data["field2"] = int(84) 51 | info.data["field3"] = struct { 52 | Name string 53 | Weight float64 54 | }{ 55 | Name: "another name", 56 | Weight: 0.75, 57 | } 58 | info.data["field4"] = []string{"item3"} 59 | 60 | check(varMap) 61 | } 62 | 63 | // TestSessionInfoToVarMapNil tests the SessionInfoToVarMap function 64 | // with a nil session info 65 | func TestSessionInfoToVarMapNil(t *testing.T) { 66 | varMap := SessionInfoToVarMap(nil) 67 | require.Nil(t, varMap) 68 | } 69 | -------------------------------------------------------------------------------- /sessionLookupResult.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import "time" 4 | 5 | // NewSessionLookupResult creates a new result of a session lookup operation 6 | func NewSessionLookupResult( 7 | creation time.Time, 8 | lastLookup time.Time, 9 | info map[string]interface{}, 10 | ) SessionLookupResult { 11 | return &sessionLookupResult{ 12 | creation: creation, 13 | lastLookup: lastLookup, 14 | info: info, 15 | } 16 | } 17 | 18 | // sessionLookupResult represents an implementation 19 | // of the SessionLookupResult interface 20 | type sessionLookupResult struct { 21 | creation time.Time 22 | lastLookup time.Time 23 | info map[string]interface{} 24 | } 25 | 26 | // Creation implements the SessionLookupResult interface 27 | func (slr *sessionLookupResult) Creation() time.Time { 28 | return slr.creation 29 | } 30 | 31 | // LastLookup implements the SessionLookupResult interface 32 | func (slr *sessionLookupResult) LastLookup() time.Time { 33 | return slr.lastLookup 34 | } 35 | 36 | // Info implements the SessionLookupResult interface 37 | func (slr *sessionLookupResult) Info() map[string]interface{} { 38 | return slr.info 39 | } 40 | -------------------------------------------------------------------------------- /socket.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "time" 7 | 8 | "github.com/qbeon/webwire-go/message" 9 | ) 10 | 11 | // Socket defines the abstract socket implementation interface 12 | type Socket interface { 13 | // GetWriter returns a writer for the next message to send. The writer's 14 | // Close method flushes the written message to the network. In case of 15 | // concurrent use GetWriter will block until the previous writer is closed 16 | // and a new one is available 17 | GetWriter() (io.WriteCloser, error) 18 | 19 | // Read blocks the calling goroutine and awaits an incoming message. If 20 | // deadline is 0 then Read will never timeout. In case of concurrent use 21 | // Read will block until the previous call finished 22 | Read(into *message.Message, deadline time.Time) ErrSockRead 23 | 24 | // IsConnected returns true if the given socket maintains an open connection 25 | // or otherwise return false 26 | IsConnected() bool 27 | 28 | // RemoteAddr returns the address of the remote client or nil if the client 29 | // is not connected 30 | RemoteAddr() net.Addr 31 | 32 | // Close closes the socket 33 | Close() error 34 | } 35 | 36 | // ClientSocket defines the abstract client socket implementation interface 37 | type ClientSocket interface { 38 | // Dial connects the socket to the server 39 | Dial(deadline time.Time) error 40 | 41 | Socket 42 | } 43 | -------------------------------------------------------------------------------- /test/activeSessionRegistry_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestActiveSessionRegistry verifies that the session registry 15 | // of currently active sessions is properly updated 16 | func TestActiveSessionRegistry(t *testing.T) { 17 | sessionCreated := sync.WaitGroup{} 18 | sessionCreated.Add(1) 19 | sessionClosed := sync.WaitGroup{} 20 | sessionClosed.Add(1) 21 | 22 | // Initialize webwire server 23 | setup := SetupTestServer( 24 | t, 25 | &ServerImpl{ 26 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 27 | // Try to create a new session 28 | assert.NoError(t, c.CreateSession(nil)) 29 | sessionCreated.Done() 30 | }, 31 | Signal: func( 32 | _ context.Context, 33 | c wwr.Connection, 34 | msg wwr.Message, 35 | ) { 36 | // Close session on logout 37 | assert.NoError(t, c.CloseSession()) 38 | 39 | sessionClosed.Done() 40 | }, 41 | }, 42 | wwr.ServerOptions{ 43 | SessionManager: &SessionManager{ 44 | SessionCreated: func(c wwr.Connection) error { 45 | return nil 46 | }, 47 | }, 48 | }, 49 | nil, // Use the default transport implementation 50 | ) 51 | 52 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 53 | 54 | // Initialize client 55 | sock, _ := setup.NewClientSocket() 56 | 57 | readSessionCreated(t, sock) 58 | 59 | sessionCreated.Wait() 60 | 61 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 62 | 63 | // Close session 64 | signal(t, sock, []byte("s"), payload.Payload{}) 65 | 66 | readSessionClosed(t, sock) 67 | 68 | sessionClosed.Wait() 69 | 70 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 71 | } 72 | -------------------------------------------------------------------------------- /test/clientInitiatedSessionDestruction_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestClientInitiatedSessionDestruction tests client-initiated session 15 | // destruction 16 | func TestClientInitiatedSessionDestruction(t *testing.T) { 17 | sessionCreated := sync.WaitGroup{} 18 | sessionCreated.Add(1) 19 | sessionDestructionCallbackCalled := sync.WaitGroup{} 20 | sessionDestructionCallbackCalled.Add(1) 21 | 22 | var sessionKey string 23 | 24 | // Initialize webwire server 25 | setup := SetupTestServer( 26 | t, 27 | &ServerImpl{ 28 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 29 | // Create a new session 30 | assert.NoError(t, c.CreateSession(nil)) 31 | }, 32 | Request: func( 33 | _ context.Context, 34 | conn wwr.Connection, 35 | msg wwr.Message, 36 | ) (wwr.Payload, error) { 37 | // Verify session destruction 38 | assert.Nil(t, 39 | conn.Session(), 40 | "Expected the session to be destroyed", 41 | ) 42 | return wwr.Payload{}, nil 43 | }, 44 | }, 45 | wwr.ServerOptions{ 46 | SessionManager: &SessionManager{ 47 | SessionCreated: func(conn wwr.Connection) error { 48 | defer sessionCreated.Done() 49 | 50 | sessionKey = conn.SessionKey() 51 | 52 | return nil 53 | }, 54 | SessionClosed: func(closedSessionKey string) error { 55 | defer sessionDestructionCallbackCalled.Done() 56 | 57 | // Ensure that the correct session was closed 58 | assert.Equal(t, sessionKey, closedSessionKey) 59 | 60 | return nil 61 | }, 62 | }, 63 | }, 64 | nil, // Use the default transport implementation 65 | ) 66 | 67 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 68 | 69 | sock, _ := setup.NewClientSocket() 70 | 71 | // Expect session creation notification message 72 | readSessionCreated(t, sock) 73 | 74 | sessionCreated.Wait() 75 | 76 | assert.NotEqual(t, "", sessionKey) 77 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 78 | require.Equal(t, 1, setup.Server.SessionConnectionsNum(sessionKey)) 79 | require.Equal(t, 1, len(setup.Server.SessionConnections(sessionKey))) 80 | 81 | // Initiate session destruction 82 | requestCloseSessionSuccess(t, sock) 83 | 84 | // Wait for the server to finally destroy the session 85 | sessionDestructionCallbackCalled.Wait() 86 | 87 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 88 | require.Equal(t, -1, setup.Server.SessionConnectionsNum(sessionKey)) 89 | require.Equal(t, 0, len(setup.Server.SessionConnections(sessionKey))) 90 | 91 | // Verify session destruction 92 | requestSuccess(t, sock, 32, []byte("verify"), payload.Payload{}) 93 | } 94 | -------------------------------------------------------------------------------- /test/connIsConnected_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestConnIsConnected tests the Connection.IsActive method as well as the 12 | // OnClientConnected and OnClientDisconnected server hooks 13 | func TestConnIsConnected(t *testing.T) { 14 | ready := sync.WaitGroup{} 15 | clientDisconnected := sync.WaitGroup{} 16 | finished := sync.WaitGroup{} 17 | ready.Add(1) 18 | clientDisconnected.Add(1) 19 | finished.Add(1) 20 | 21 | // Initialize webwire server 22 | setup := SetupTestServer( 23 | t, 24 | &ServerImpl{ 25 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 26 | assert.True(t, c.IsActive()) 27 | 28 | go func() { 29 | ready.Done() 30 | clientDisconnected.Wait() 31 | 32 | assert.False(t, c.IsActive()) 33 | 34 | finished.Done() 35 | }() 36 | }, 37 | ClientDisconnected: func(c wwr.Connection, _ error) { 38 | assert.False(t, c.IsActive()) 39 | clientDisconnected.Done() 40 | }, 41 | }, 42 | wwr.ServerOptions{}, 43 | nil, // Use the default transport implementation 44 | ) 45 | 46 | // Initialize client 47 | sock, _ := setup.NewClientSocket() 48 | 49 | // Wait for the connection to be set by the OnClientConnected handler 50 | ready.Wait() 51 | 52 | // Close the client connection and continue in the tester goroutine 53 | // spawned in the OnClientConnected handler of the server 54 | sock.Close() 55 | 56 | // Wait for the tester goroutine to finish 57 | finished.Wait() 58 | } 59 | -------------------------------------------------------------------------------- /test/connSessionNoOverride_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestSessionNoOverride tests overriding of a connection session 12 | // expecting it to fail 13 | func TestSessionNoOverride(t *testing.T) { 14 | finished := sync.WaitGroup{} 15 | finished.Add(1) 16 | 17 | // Initialize server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 22 | defer finished.Done() 23 | 24 | assert.NoError(t, c.CreateSession(nil)) 25 | sessionKey := c.SessionKey() 26 | 27 | // Try to override the previous session 28 | assert.Error(t, c.CreateSession(nil)) 29 | 30 | // Ensure the session didn't change 31 | assert.Equal(t, sessionKey, c.SessionKey()) 32 | }, 33 | }, 34 | wwr.ServerOptions{}, 35 | nil, // Use the default transport implementation 36 | ) 37 | 38 | // Initialize client 39 | sock, _ := setup.NewClientSocket() 40 | 41 | readSessionCreated(t, sock) 42 | 43 | finished.Wait() 44 | } 45 | -------------------------------------------------------------------------------- /test/connSignalBufferOverflow_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestConnSignalBufferOverflow tests Connection.Signal with a name and payload 12 | // that would overflow the buffer 13 | func TestConnSignalBufferOverflow(t *testing.T) { 14 | finished := sync.WaitGroup{} 15 | finished.Add(1) 16 | 17 | // Initialize server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 22 | defer finished.Done() 23 | 24 | payload := make([]byte, 2048) 25 | err := c.Signal( 26 | []byte(nil), // No name 27 | wwr.Payload{Data: payload}, // Payload too big 28 | ) 29 | 30 | assert.Error(t, err) 31 | assert.IsType(t, wwr.ErrBufferOverflow{}, err) 32 | }, 33 | }, 34 | wwr.ServerOptions{ 35 | MessageBufferSize: 1024, 36 | }, 37 | nil, // Use the default transport implementation 38 | ) 39 | 40 | // Initialize client 41 | setup.NewClientSocket() 42 | 43 | finished.Wait() 44 | } 45 | -------------------------------------------------------------------------------- /test/connSignalNoNameNoPayload_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestConnSignalNoNameNoPayload tests Connection.Signal providing both a nil 12 | // name and a nil payload 13 | func TestConnSignalNoNameNoPayload(t *testing.T) { 14 | finished := sync.WaitGroup{} 15 | finished.Add(1) 16 | 17 | // Initialize server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 22 | defer finished.Done() 23 | 24 | assert.Error(t, c.Signal( 25 | []byte(nil), // No name 26 | wwr.Payload{Data: []byte(nil)}, // No payload 27 | )) 28 | }, 29 | }, 30 | wwr.ServerOptions{}, 31 | nil, // Use the default transport implementation 32 | ) 33 | 34 | // Initialize client 35 | setup.NewClientSocket() 36 | 37 | finished.Wait() 38 | } 39 | -------------------------------------------------------------------------------- /test/connectionInfo_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/transport/memchan" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | // TestConnectionInfo tests the connection.Info method 14 | func TestConnectionInfo(t *testing.T) { 15 | handlerFinished := sync.WaitGroup{} 16 | handlerFinished.Add(1) 17 | 18 | // Initialize server 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{ 22 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 23 | defer handlerFinished.Done() 24 | assert.Equal(t, "samplestring", c.Info(1).(string)) 25 | assert.Equal(t, uint64(42), c.Info(2).(uint64)) 26 | assert.Nil(t, c.Info(3)) 27 | assert.WithinDuration( 28 | t, 29 | time.Now(), 30 | c.Creation(), 31 | 1*time.Second, 32 | ) 33 | }, 34 | }, 35 | wwr.ServerOptions{}, 36 | &memchan.Transport{ 37 | OnBeforeCreation: func() wwr.ConnectionOptions { 38 | return wwr.ConnectionOptions{ 39 | Connection: wwr.Accept, 40 | Info: map[int]interface{}{ 41 | 1: "samplestring", 42 | 2: uint64(42), 43 | }, 44 | } 45 | }, 46 | }, 47 | ) 48 | 49 | // Initialize client 50 | setup.NewClientSocket() 51 | 52 | handlerFinished.Wait() 53 | } 54 | -------------------------------------------------------------------------------- /test/connectionSessionGetters_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | wwr "github.com/qbeon/webwire-go" 10 | "github.com/qbeon/webwire-go/payload" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | type testSessInfo struct { 15 | UserIdent string 16 | SomeNumber int 17 | } 18 | 19 | // Copy implements the webwire.SessionInfo interface 20 | func (sinf *testSessInfo) Copy() wwr.SessionInfo { 21 | return &testSessInfo{ 22 | UserIdent: sinf.UserIdent, 23 | SomeNumber: sinf.SomeNumber, 24 | } 25 | } 26 | 27 | // Fields implements the webwire.SessionInfo interface 28 | func (sinf *testSessInfo) Fields() []string { 29 | return []string{"uid", "some-number"} 30 | } 31 | 32 | // Copy implements the webwire.SessionInfo interface 33 | func (sinf *testSessInfo) Value(fieldName string) interface{} { 34 | switch fieldName { 35 | case "uid": 36 | return sinf.UserIdent 37 | case "some-number": 38 | return sinf.SomeNumber 39 | } 40 | return nil 41 | } 42 | 43 | // TestConnectionSessionGetters tests the connection session information getters 44 | func TestConnectionSessionGetters(t *testing.T) { 45 | signalDone := sync.WaitGroup{} 46 | signalDone.Add(1) 47 | 48 | compareSession := func(conn wwr.Connection) { 49 | timeNow := time.Now() 50 | 51 | sess := conn.Session() 52 | assert.NotNil(t, sess) 53 | 54 | assert.Equal(t, "testsessionkey", sess.Key) 55 | assert.Equal(t, &testSessInfo{ 56 | UserIdent: "clientidentifiergoeshere", // uid 57 | SomeNumber: 12345, // some-number 58 | }, sess.Info) 59 | assert.WithinDuration(t, timeNow, sess.Creation, 1*time.Second) 60 | assert.WithinDuration(t, timeNow, sess.LastLookup, 1*time.Second) 61 | 62 | assert.WithinDuration( 63 | t, 64 | timeNow, 65 | conn.SessionCreation(), 66 | 1*time.Second, 67 | ) 68 | assert.Equal(t, "testsessionkey", conn.SessionKey()) 69 | uid := conn.SessionInfo("uid") 70 | assert.NotNil(t, uid) 71 | assert.IsType(t, string(""), uid) 72 | 73 | someNumber := conn.SessionInfo("some-number") 74 | assert.NotNil(t, someNumber) 75 | assert.IsType(t, int(0), someNumber) 76 | } 77 | 78 | // Initialize server 79 | setup := SetupTestServer( 80 | t, 81 | &ServerImpl{ 82 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 83 | // Before session creation 84 | assert.Nil(t, c.Session()) 85 | assert.True(t, c.SessionCreation().IsZero()) 86 | assert.Equal(t, "", c.SessionKey()) 87 | assert.Nil(t, c.SessionInfo("uid")) 88 | assert.Nil(t, c.SessionInfo("some-number")) 89 | 90 | assert.NoError(t, c.CreateSession( 91 | &testSessInfo{ 92 | UserIdent: "clientidentifiergoeshere", // uid 93 | SomeNumber: 12345, // some-number 94 | }, 95 | )) 96 | 97 | // After session creation 98 | compareSession(c) 99 | }, 100 | Request: func( 101 | _ context.Context, 102 | c wwr.Connection, 103 | _ wwr.Message, 104 | ) (wwr.Payload, error) { 105 | compareSession(c) 106 | return wwr.Payload{}, nil 107 | }, 108 | Signal: func(_ context.Context, c wwr.Connection, _ wwr.Message) { 109 | defer signalDone.Done() 110 | compareSession(c) 111 | }, 112 | }, 113 | wwr.ServerOptions{ 114 | SessionKeyGenerator: &SessionKeyGen{ 115 | OnGenerate: func() string { 116 | return "testsessionkey" 117 | }, 118 | }, 119 | }, 120 | nil, // Use the default transport implementation 121 | ) 122 | 123 | // Connect new client 124 | sock, _ := setup.NewClientSocket() 125 | 126 | readSessionCreated(t, sock) 127 | 128 | requestSuccess(t, sock, 32, []byte("r"), payload.Payload{}) 129 | 130 | signal(t, sock, []byte("s"), payload.Payload{}) 131 | 132 | signalDone.Wait() 133 | } 134 | -------------------------------------------------------------------------------- /test/customSessKeyGenInvalid_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // TestCustomSessKeyGenInvalid tests custom session key generators returning 13 | // invalid keys 14 | func TestCustomSessKeyGenInvalid(t *testing.T) { 15 | finished := sync.WaitGroup{} 16 | finished.Add(1) 17 | 18 | // Initialize webwire server 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{ 22 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 23 | defer func() { 24 | recoveredErr := recover() 25 | assert.NotNil(t, recoveredErr) 26 | assert.IsType(t, errors.New(""), recoveredErr) 27 | 28 | finished.Done() 29 | }() 30 | 31 | // Try to create a new session 32 | err := c.CreateSession(nil) 33 | assert.NoError(t, err) 34 | }, 35 | }, 36 | wwr.ServerOptions{ 37 | SessionKeyGenerator: &SessionKeyGen{ 38 | OnGenerate: func() string { 39 | // Return invalid session key 40 | return "" 41 | }, 42 | }, 43 | }, 44 | nil, // Use the default transport implementation 45 | ) 46 | 47 | // Initialize client 48 | setup.NewClientSocket() 49 | 50 | finished.Wait() 51 | } 52 | -------------------------------------------------------------------------------- /test/customSessKeyGen_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestCustomSessKeyGen tests custom session key generators 12 | func TestCustomSessKeyGen(t *testing.T) { 13 | finished := sync.WaitGroup{} 14 | finished.Add(1) 15 | expectedSessionKey := "customkey123" 16 | 17 | // Initialize webwire server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 22 | defer finished.Done() 23 | 24 | // Try to create a new session 25 | assert.NoError(t, c.CreateSession(nil)) 26 | assert.Equal(t, expectedSessionKey, c.SessionKey()) 27 | }, 28 | }, 29 | wwr.ServerOptions{ 30 | SessionKeyGenerator: &SessionKeyGen{ 31 | OnGenerate: func() string { 32 | return expectedSessionKey 33 | }, 34 | }, 35 | }, 36 | nil, // Use the default transport implementation 37 | ) 38 | 39 | // Initialize client 40 | sock, _ := setup.NewClientSocket() 41 | 42 | readSessionCreated(t, sock) 43 | 44 | finished.Wait() 45 | } 46 | -------------------------------------------------------------------------------- /test/disabledSessions_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestDisabledSessions tests errors returned by CreateSession, CloseSession and 14 | // client.RestoreSession when sessions are disabled 15 | func TestDisabledSessions(t *testing.T) { 16 | finished := sync.WaitGroup{} 17 | finished.Add(1) 18 | 19 | // Initialize webwire server 20 | setup := SetupTestServer( 21 | t, 22 | &ServerImpl{ 23 | ClientConnected: func(_ wwr.ConnectionOptions, c wwr.Connection) { 24 | assert.Nil(t, c.Session()) 25 | 26 | // Try to create a new session 27 | createErr := c.CreateSession(nil) 28 | assert.IsType(t, wwr.ErrSessionsDisabled{}, createErr) 29 | 30 | // Try to close a session 31 | closeErr := c.CloseSession() 32 | assert.IsType(t, wwr.ErrSessionsDisabled{}, closeErr) 33 | 34 | finished.Done() 35 | }, 36 | }, 37 | wwr.ServerOptions{ 38 | Sessions: wwr.Disabled, 39 | SessionManager: &SessionManager{ 40 | SessionCreated: func(c wwr.Connection) error { 41 | t.Fatal("unexpected hook call") 42 | return nil 43 | }, 44 | SessionLookup: func( 45 | sessionKey string, 46 | ) (wwr.SessionLookupResult, error) { 47 | t.Fatal("unexpected hook call") 48 | return nil, nil 49 | }, 50 | SessionClosed: func(sessionKey string) error { 51 | t.Fatal("unexpected hook call") 52 | return nil 53 | }, 54 | }, 55 | }, 56 | nil, // Use the default transport implementation 57 | ) 58 | 59 | // Initialize client 60 | sock, _ := setup.NewClientSocket() 61 | 62 | finished.Wait() 63 | 64 | // Try to restore a session 65 | reply := requestRestoreSession(t, sock, []byte("testsessionkey")) 66 | require.Equal(t, message.MsgReplySessionsDisabled, reply.MsgType) 67 | } 68 | -------------------------------------------------------------------------------- /test/emptyReplyUtf16_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestEmptyReplyUtf16 tests returning empty UTF16 replies 14 | func TestEmptyReplyUtf16(t *testing.T) { 15 | // Initialize webwire server given only the request 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | _ wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Return empty reply 25 | return wwr.Payload{Encoding: wwr.EncodingUtf16}, nil 26 | }, 27 | }, 28 | wwr.ServerOptions{}, 29 | nil, // Use the default transport implementation 30 | ) 31 | 32 | // Initialize client 33 | sock, _ := setup.NewClientSocket() 34 | 35 | // Send request and await an empty binary reply 36 | reply := request(t, sock, 64, []byte("r"), payload.Payload{}) 37 | require.Equal(t, message.MsgReplyUtf16, reply.MsgType) 38 | require.Equal(t, payload.Utf16, reply.MsgPayload.Encoding) 39 | require.Equal(t, []byte(nil), reply.MsgPayload.Data) 40 | } 41 | -------------------------------------------------------------------------------- /test/emptyReplyUtf8_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestEmptyReplyUtf8 tests returning empty UTF8 replies 14 | func TestEmptyReplyUtf8(t *testing.T) { 15 | // Initialize webwire server given only the request 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | _ wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Return empty reply 25 | return wwr.Payload{Encoding: wwr.EncodingUtf8}, nil 26 | }, 27 | }, 28 | wwr.ServerOptions{}, 29 | nil, // Use the default transport implementation 30 | ) 31 | 32 | // Initialize client 33 | sock, _ := setup.NewClientSocket() 34 | 35 | // Send request and await an empty binary reply 36 | reply := request(t, sock, 64, []byte("r"), payload.Payload{}) 37 | require.Equal(t, message.MsgReplyUtf8, reply.MsgType) 38 | require.Equal(t, payload.Utf8, reply.MsgPayload.Encoding) 39 | require.Equal(t, []byte(nil), reply.MsgPayload.Data) 40 | } 41 | -------------------------------------------------------------------------------- /test/emptyReply_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestEmptyReply tests returning empty binary replies 14 | func TestEmptyReply(t *testing.T) { 15 | // Initialize webwire server given only the request 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | _ wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Return empty reply 25 | return wwr.Payload{}, nil 26 | }, 27 | }, 28 | wwr.ServerOptions{}, 29 | nil, // Use the default transport implementation 30 | ) 31 | 32 | // Initialize client 33 | sock, _ := setup.NewClientSocket() 34 | 35 | // Send request and await an empty binary reply 36 | reply := request(t, sock, 64, []byte("r"), payload.Payload{}) 37 | require.Equal(t, message.MsgReplyBinary, reply.MsgType) 38 | require.Equal(t, payload.Binary, reply.MsgPayload.Encoding) 39 | require.Equal(t, []byte(nil), reply.MsgPayload.Data) 40 | } 41 | -------------------------------------------------------------------------------- /test/handshake_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // TestHandshake tests the connection establishment handshake testing the server 13 | // configuration push message 14 | func TestHandshake(t *testing.T) { 15 | serverReadTimeout := 3 * time.Second 16 | messageBufferSize := uint32(1024 * 8) 17 | 18 | // Initialize webwire server 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{}, 22 | wwr.ServerOptions{ 23 | ReadTimeout: serverReadTimeout, 24 | MessageBufferSize: messageBufferSize, 25 | }, 26 | nil, // Use the default transport implementation 27 | ) 28 | 29 | readTimeout := 5 * time.Second 30 | 31 | socket, err := setup.NewDisconnectedClientSocket() 32 | require.NoError(t, err) 33 | 34 | require.NoError(t, socket.Dial(time.Time{})) 35 | 36 | // Await the server configuration push message 37 | msg := message.NewMessage(messageBufferSize) 38 | require.NoError(t, socket.Read(msg, time.Now().Add(readTimeout))) 39 | 40 | require.Equal(t, [8]byte{}, msg.MsgIdentifier) 41 | require.Equal(t, []byte{0, 0, 0, 0, 0, 0, 0, 0}, msg.MsgIdentifierBytes) 42 | require.Nil(t, msg.MsgName) 43 | require.Equal(t, message.ServerConfiguration{ 44 | MajorProtocolVersion: 2, 45 | MinorProtocolVersion: 0, 46 | ReadTimeout: serverReadTimeout, 47 | MessageBufferSize: messageBufferSize, 48 | }, msg.ServerConfiguration) 49 | } 50 | -------------------------------------------------------------------------------- /test/maxConcSessConn_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/qbeon/webwire-go/message" 8 | 9 | "github.com/qbeon/webwire-go" 10 | wwr "github.com/qbeon/webwire-go" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestMaxConcSessConn tests 4 maximum concurrent connections of a session 15 | func TestMaxConcSessConn(t *testing.T) { 16 | concurrentConns := uint(4) 17 | 18 | var sessionKey = "testsessionkey" 19 | sessionCreation := time.Now() 20 | 21 | // Initialize server 22 | setup := SetupTestServer( 23 | t, 24 | &ServerImpl{}, 25 | wwr.ServerOptions{ 26 | MaxSessionConnections: concurrentConns, 27 | SessionManager: &SessionManager{ 28 | SessionLookup: func(key string) ( 29 | webwire.SessionLookupResult, 30 | error, 31 | ) { 32 | if key != sessionKey { 33 | // Session not found 34 | return nil, nil 35 | } 36 | return webwire.NewSessionLookupResult( 37 | sessionCreation, // Creation 38 | time.Now(), // LastLookup 39 | nil, // Info 40 | ), nil 41 | }, 42 | }, 43 | }, 44 | nil, // Use the default transport implementation 45 | ) 46 | 47 | // Initialize clients 48 | clients := make([]wwr.Socket, concurrentConns) 49 | for i := uint(0); i < concurrentConns; i++ { 50 | sock, _ := setup.NewClientSocket() 51 | clients[i] = sock 52 | 53 | requestRestoreSessionSuccess(t, sock, []byte(sessionKey)) 54 | } 55 | 56 | // Ensure that the last superfluous client is rejected 57 | superfluousClient, _ := setup.NewClientSocket() 58 | 59 | // Try to restore the session one more time and expect this request to fail 60 | // due to reached limit 61 | reply := requestRestoreSession(t, superfluousClient, []byte(sessionKey)) 62 | require.Equal(t, message.MsgReplyMaxSessConnsReached, reply.MsgType) 63 | } 64 | -------------------------------------------------------------------------------- /test/namedRequest_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // TestNamedRequest tests correct handling of named requests 13 | func TestNamedRequest(t *testing.T) { 14 | currentStep := 1 15 | 16 | shortestPossibleName := []byte("s") 17 | longestPossibleName := make([]byte, 255) 18 | for i := range longestPossibleName { 19 | longestPossibleName[i] = 'x' 20 | } 21 | 22 | // Initialize server 23 | setup := SetupTestServer( 24 | t, 25 | &ServerImpl{ 26 | Request: func( 27 | _ context.Context, 28 | _ wwr.Connection, 29 | msg wwr.Message, 30 | ) (wwr.Payload, error) { 31 | msgName := msg.Name() 32 | switch currentStep { 33 | case 1: 34 | assert.Equal(t, shortestPossibleName, msgName) 35 | case 2: 36 | assert.Equal(t, longestPossibleName, msgName) 37 | } 38 | return wwr.Payload{}, nil 39 | }, 40 | }, 41 | wwr.ServerOptions{}, 42 | nil, // Use the default transport implementation 43 | ) 44 | 45 | // Initialize client 46 | sock, _ := setup.NewClientSocket() 47 | 48 | // Send request with the shortest possible name 49 | currentStep = 1 50 | requestSuccess(t, sock, 32, shortestPossibleName, payload.Payload{}) 51 | 52 | // Send request with the longest possible name 53 | currentStep = 2 54 | requestSuccess(t, sock, 32, longestPossibleName, payload.Payload{}) 55 | } 56 | -------------------------------------------------------------------------------- /test/namedSignal_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | // TestNamedSignal tests correct handling of named signals 14 | func TestNamedSignal(t *testing.T) { 15 | shortestNameSignalArrived := sync.WaitGroup{} 16 | shortestNameSignalArrived.Add(1) 17 | longestNameSignalArrived := sync.WaitGroup{} 18 | longestNameSignalArrived.Add(1) 19 | currentStep := 1 20 | 21 | shortestPossibleName := []byte("s") 22 | longestPossibleName := make([]byte, 255) 23 | for i := range longestPossibleName { 24 | longestPossibleName[i] = 'x' 25 | } 26 | 27 | // Initialize server 28 | setup := SetupTestServer( 29 | t, 30 | &ServerImpl{ 31 | Signal: func( 32 | _ context.Context, 33 | _ wwr.Connection, 34 | msg wwr.Message, 35 | ) { 36 | msgName := msg.Name() 37 | switch currentStep { 38 | case 1: 39 | assert.Equal(t, shortestPossibleName, msgName) 40 | shortestNameSignalArrived.Done() 41 | case 2: 42 | assert.Equal(t, longestPossibleName, msgName) 43 | longestNameSignalArrived.Done() 44 | } 45 | }, 46 | }, 47 | wwr.ServerOptions{}, 48 | nil, // Use the default transport implementation 49 | ) 50 | 51 | // Initialize client 52 | sock, _ := setup.NewClientSocket() 53 | 54 | // Send request with the shortest possible name 55 | currentStep = 1 56 | signal(t, sock, shortestPossibleName, payload.Payload{}) 57 | shortestNameSignalArrived.Wait() 58 | 59 | // Send request with the longest possible name 60 | currentStep = 2 61 | signal(t, sock, longestPossibleName, payload.Payload{}) 62 | longestNameSignalArrived.Wait() 63 | } 64 | -------------------------------------------------------------------------------- /test/protocolViolation_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // TestProtocolViolation tests sending messages that violate the protocol 13 | func TestProtocolViolation(t *testing.T) { 14 | // Initialize webwire server 15 | setup := SetupTestServer( 16 | t, 17 | &ServerImpl{}, 18 | wwr.ServerOptions{}, 19 | nil, // Use the default transport implementation 20 | ) 21 | 22 | defaultReadTimeout := 2 * time.Second 23 | 24 | // Setup a regular websocket connection 25 | try := func(m []byte) { 26 | socket, _ := setup.NewClientSocket() 27 | 28 | // Get writer 29 | writer, err := socket.GetWriter() 30 | require.NoError(t, err) 31 | 32 | // Write the message 33 | bytesWritten, writeErr := writer.Write(m) 34 | require.NoError(t, writeErr) 35 | require.Equal(t, len(m), bytesWritten) 36 | require.NoError(t, writer.Close()) 37 | 38 | emptyMsg := message.NewMessage(256) 39 | readErr := socket.Read(emptyMsg, time.Now().Add(defaultReadTimeout)) 40 | require.Error(t, readErr) 41 | 42 | require.False(t, socket.IsConnected()) 43 | } 44 | 45 | // Test a message with an invalid type identifier (200, which is undefined) 46 | // and expect the server to ignore it returning no answer 47 | try([]byte{byte(200)}) 48 | 49 | // Test a message with an invalid name length flag (bigger than name) 50 | // and expect the server to return a protocol violation error response 51 | try([]byte{ 52 | message.MsgRequestBinary, // Message type identifier 53 | 0, 0, 0, 0, 0, 0, 0, 0, // Request identifier 54 | 3, // Name length flag 55 | 0x041, // Name 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /test/refuseConnections_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/transport/memchan" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // TestRefuseConnections tests refusing connections on the transport level 13 | func TestRefuseConnections(t *testing.T) { 14 | // Initialize server 15 | setup := SetupTestServer( 16 | t, 17 | &ServerImpl{}, 18 | wwr.ServerOptions{}, 19 | &memchan.Transport{ 20 | OnBeforeCreation: func() wwr.ConnectionOptions { 21 | // Refuse all incoming connections 22 | return wwr.ConnectionOptions{ 23 | Connection: wwr.Refuse, 24 | } 25 | }, 26 | }, 27 | ) 28 | 29 | // Initialize client 30 | sock, err := setup.NewDisconnectedClientSocket() 31 | require.NoError(t, err) 32 | 33 | // Try connect 34 | require.Error(t, sock.Dial(time.Time{})) 35 | } 36 | -------------------------------------------------------------------------------- /test/requestError_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestErrRequestor tests server-side request errors properly failing the 14 | // client-side requests 15 | func TestErrRequestor(t *testing.T) { 16 | // Initialize server 17 | setup := SetupTestServer( 18 | t, 19 | &ServerImpl{ 20 | Request: func( 21 | _ context.Context, 22 | _ wwr.Connection, 23 | _ wwr.Message, 24 | ) (wwr.Payload, error) { 25 | // Fail the request by returning an error 26 | return wwr.Payload{Data: []byte("garbage")}, wwr.ErrRequest{ 27 | Code: "SAMPLE_ERROR", 28 | Message: "Sample error message", 29 | } 30 | }, 31 | }, 32 | wwr.ServerOptions{}, 33 | nil, // Use the default transport implementation 34 | ) 35 | 36 | // Initialize client 37 | sock, _ := setup.NewClientSocket() 38 | 39 | rep := request(t, sock, 192, []byte("r"), payload.Payload{}) 40 | require.Equal(t, message.MsgReplyError, rep.MsgType) 41 | require.Equal(t, []byte("SAMPLE_ERROR"), rep.MsgName) 42 | require.Equal(t, payload.Binary, rep.MsgPayload.Encoding) 43 | require.Equal(t, []byte("Sample error message"), rep.MsgPayload.Data) 44 | } 45 | -------------------------------------------------------------------------------- /test/requestInternalError_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/message" 10 | "github.com/qbeon/webwire-go/payload" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestRequestInternalError tests returning non-ReqErr errors from the request 15 | // handler 16 | func TestRequestInternalError(t *testing.T) { 17 | // Initialize server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | Request: func( 22 | _ context.Context, 23 | _ wwr.Connection, 24 | _ wwr.Message, 25 | ) (wwr.Payload, error) { 26 | // Fail the request by returning a non-ReqErr error 27 | return wwr.Payload{Data: []byte("garbage")}, fmt.Errorf( 28 | "don't worry, this internal error is expected", 29 | ) 30 | }, 31 | }, 32 | wwr.ServerOptions{}, 33 | nil, // Use the default transport implementation 34 | ) 35 | 36 | // Initialize client 37 | sock, _ := setup.NewClientSocket() 38 | 39 | rep := request(t, sock, 192, []byte("r"), payload.Payload{}) 40 | require.Equal(t, message.MsgReplyInternalError, rep.MsgType) 41 | require.Nil(t, rep.MsgName) 42 | require.Equal(t, payload.Binary, rep.MsgPayload.Encoding) 43 | require.Nil(t, rep.MsgPayload.Data) 44 | } 45 | -------------------------------------------------------------------------------- /test/requestNameOnly_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // TestRequestNameOnly tests named requests without a payload 13 | func TestRequestNameOnly(t *testing.T) { 14 | // Initialize server 15 | setup := SetupTestServer( 16 | t, 17 | &ServerImpl{ 18 | Request: func( 19 | _ context.Context, 20 | _ wwr.Connection, 21 | msg wwr.Message, 22 | ) (wwr.Payload, error) { 23 | // Expect a named request 24 | assert.Equal(t, []byte("name"), msg.Name()) 25 | 26 | // Expect no payload to arrive 27 | assert.Equal(t, 0, len(msg.Payload())) 28 | 29 | switch msg.PayloadEncoding() { 30 | case wwr.EncodingUtf8: 31 | return wwr.Payload{Encoding: wwr.EncodingUtf8}, nil 32 | case wwr.EncodingUtf16: 33 | return wwr.Payload{Encoding: wwr.EncodingUtf16}, nil 34 | } 35 | return wwr.Payload{}, nil 36 | }, 37 | }, 38 | wwr.ServerOptions{}, 39 | nil, // Use the default transport implementation 40 | ) 41 | 42 | // Initialize client 43 | sock, _ := setup.NewClientSocket() 44 | 45 | requestSuccess(t, sock, 32, []byte("name"), payload.Payload{}) 46 | 47 | // Send a named UTF8 encoded request without a payload and await reply 48 | requestSuccess(t, sock, 32, []byte("name"), payload.Payload{ 49 | Encoding: payload.Utf8, 50 | }) 51 | 52 | // Send a UTF16 encoded named binary request without a payload 53 | requestSuccess(t, sock, 32, []byte("name"), payload.Payload{ 54 | Encoding: payload.Utf16, 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /test/requestNoNameNoPayload_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/qbeon/webwire-go/message" 9 | 10 | wwr "github.com/qbeon/webwire-go" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestRequestNoNameNoPayload tests sending requests without both a name and a 15 | // payload expecting the server to reject the message 16 | func TestRequestNoNameNoPayload(t *testing.T) { 17 | // Initialize server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{ 21 | Request: func( 22 | _ context.Context, 23 | _ wwr.Connection, 24 | msg wwr.Message, 25 | ) (wwr.Payload, error) { 26 | // Expect the following request to not even arrive 27 | t.Error("Not expected but reached") 28 | return wwr.Payload{}, nil 29 | }, 30 | }, 31 | wwr.ServerOptions{}, 32 | nil, // Use the default transport implementation 33 | ) 34 | 35 | // Initialize client 36 | sock, _ := setup.NewClientSocket() 37 | 38 | // TODO: improve test by avoiding the use of the client but performing the 39 | // request over a raw socket to ensure the client doesn't filter the request 40 | // out preemtively so it never even reaches the server 41 | 42 | // Send request without a name and without a payload. 43 | // Expect a protocol error in return not sending the invalid request off 44 | writer, err := sock.GetWriter() 45 | require.NoError(t, err) 46 | require.NotNil(t, writer) 47 | 48 | bytesWritten, err := writer.Write([]byte{ 49 | message.MsgRequestBinary, // Type 50 | 0, 0, 0, 0, 0, 0, 0, 0, // Identifier 51 | 0, // Name length 52 | }) 53 | require.NoError(t, err) 54 | require.Equal(t, 10, bytesWritten) 55 | 56 | require.NoError(t, writer.Close()) 57 | 58 | // Expect the socket to be closed by the server due to protocol violation 59 | msg := message.NewMessage(1024) 60 | readErr := sock.Read(msg, time.Time{}) 61 | require.NotNil(t, readErr) 62 | require.True(t, readErr.IsCloseErr()) 63 | require.False(t, sock.IsConnected()) 64 | } 65 | -------------------------------------------------------------------------------- /test/requestPayloadOnly_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestRequestPayloadOnly tests requests without a name but only a payload 14 | func TestRequestPayloadOnly(t *testing.T) { 15 | // Initialize server 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | msg wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Expect a named request 25 | msgName := msg.Name() 26 | assert.Nil(t, msgName) 27 | 28 | switch msg.PayloadEncoding() { 29 | case wwr.EncodingBinary: 30 | require.Equal(t, []byte("d"), msg.Payload()) 31 | return wwr.Payload{Encoding: wwr.EncodingBinary}, nil 32 | case wwr.EncodingUtf8: 33 | require.Equal(t, []byte("d"), msg.Payload()) 34 | return wwr.Payload{Encoding: wwr.EncodingUtf8}, nil 35 | case wwr.EncodingUtf16: 36 | require.Equal(t, []byte{32, 32}, msg.Payload()) 37 | return wwr.Payload{Encoding: wwr.EncodingUtf16}, nil 38 | default: 39 | panic("unexpected message payload encoding type") 40 | } 41 | }, 42 | }, 43 | wwr.ServerOptions{}, 44 | nil, // Use the default transport implementation 45 | ) 46 | 47 | // Initialize client 48 | sock, _ := setup.NewClientSocket() 49 | 50 | // Send an unnamed binary request with a payload and await reply 51 | requestSuccess(t, sock, 32, nil, payload.Payload{Data: []byte("d")}) 52 | 53 | // Send an unnamed UTF8 encoded request with a payload 54 | requestSuccess(t, sock, 32, nil, payload.Payload{ 55 | Encoding: payload.Utf8, 56 | Data: []byte("d"), 57 | }) 58 | 59 | // Send an unnamed UTF16 encoded request with a payload 60 | requestSuccess(t, sock, 32, nil, payload.Payload{ 61 | Encoding: payload.Utf16, 62 | Data: []byte{32, 32}, 63 | }) 64 | } 65 | -------------------------------------------------------------------------------- /test/requestUtf16_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestRequestUtf16 tests requests with UTF16 encoded payloads 14 | func TestRequestUtf16(t *testing.T) { 15 | // Initialize webwire server given only the request 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | msg wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Verify request payload 25 | assert.Equal(t, wwr.EncodingUtf16, msg.PayloadEncoding()) 26 | assert.Equal(t, []byte{11, 20, 31, 40, 51, 60}, msg.Payload()) 27 | return wwr.Payload{ 28 | Encoding: wwr.EncodingUtf16, 29 | Data: []byte{80, 91, 100, 111, 120, 131}, 30 | }, nil 31 | }, 32 | }, 33 | wwr.ServerOptions{}, 34 | nil, // Use the default transport implementation 35 | ) 36 | 37 | // Initialize client 38 | sock, _ := setup.NewClientSocket() 39 | 40 | // Send request and await reply 41 | reply := requestSuccess(t, sock, 32, nil, payload.Payload{ 42 | Encoding: wwr.EncodingUtf16, 43 | Data: []byte{11, 20, 31, 40, 51, 60}, 44 | }) 45 | 46 | // Verify reply 47 | require.Equal(t, wwr.EncodingUtf16, reply.MsgPayload.Encoding) 48 | require.Equal(t, []byte{80, 91, 100, 111, 120, 131}, reply.Payload()) 49 | } 50 | -------------------------------------------------------------------------------- /test/requestUtf8_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/payload" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestRequestUtf8 tests requests with UTF8 encoded payloads 14 | func TestRequestUtf8(t *testing.T) { 15 | // Initialize webwire server given only the request 16 | setup := SetupTestServer( 17 | t, 18 | &ServerImpl{ 19 | Request: func( 20 | _ context.Context, 21 | _ wwr.Connection, 22 | msg wwr.Message, 23 | ) (wwr.Payload, error) { 24 | // Verify request payload 25 | assert.Equal(t, wwr.EncodingUtf8, msg.PayloadEncoding()) 26 | assert.Equal(t, []byte("sample data"), msg.Payload()) 27 | return wwr.Payload{ 28 | Encoding: wwr.EncodingUtf8, 29 | Data: []byte("sample reply"), 30 | }, nil 31 | }, 32 | }, 33 | wwr.ServerOptions{}, 34 | nil, // Use the default transport implementation 35 | ) 36 | 37 | // Initialize client 38 | sock, _ := setup.NewClientSocket() 39 | 40 | // Send request and await reply 41 | reply := requestSuccess(t, sock, 32, nil, payload.Payload{ 42 | Encoding: wwr.EncodingUtf8, 43 | Data: []byte("sample data"), 44 | }) 45 | 46 | // Verify reply 47 | require.Equal(t, wwr.EncodingUtf8, reply.MsgPayload.Encoding) 48 | require.Equal(t, []byte("sample reply"), reply.Payload()) 49 | } 50 | -------------------------------------------------------------------------------- /test/serverImpl.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | 6 | wwr "github.com/qbeon/webwire-go" 7 | ) 8 | 9 | // ServerImpl implements the webwire.ServerImplementation interface 10 | type ServerImpl struct { 11 | ClientConnected func( 12 | connectionOptions wwr.ConnectionOptions, 13 | connection wwr.Connection, 14 | ) 15 | ClientDisconnected func(connection wwr.Connection, reason error) 16 | Signal func( 17 | ctx context.Context, 18 | connection wwr.Connection, 19 | message wwr.Message, 20 | ) 21 | Request func( 22 | ctx context.Context, 23 | connection wwr.Connection, 24 | message wwr.Message, 25 | ) (response wwr.Payload, err error) 26 | } 27 | 28 | // OnClientConnected implements the webwire.ServerImplementation interface 29 | func (srv *ServerImpl) OnClientConnected( 30 | opts wwr.ConnectionOptions, 31 | conn wwr.Connection, 32 | ) { 33 | if srv.ClientConnected != nil { 34 | srv.ClientConnected(opts, conn) 35 | } 36 | } 37 | 38 | // OnClientDisconnected implements the webwire.ServerImplementation interface 39 | func (srv *ServerImpl) OnClientDisconnected(conn wwr.Connection, reason error) { 40 | if srv.ClientDisconnected != nil { 41 | srv.ClientDisconnected(conn, reason) 42 | } 43 | } 44 | 45 | // OnSignal implements the webwire.ServerImplementation interface 46 | func (srv *ServerImpl) OnSignal( 47 | ctx context.Context, 48 | clt wwr.Connection, 49 | msg wwr.Message, 50 | ) { 51 | if srv.Signal != nil { 52 | srv.Signal(ctx, clt, msg) 53 | } 54 | } 55 | 56 | // OnRequest implements the webwire.ServerImplementation interface 57 | func (srv *ServerImpl) OnRequest( 58 | ctx context.Context, 59 | clt wwr.Connection, 60 | msg wwr.Message, 61 | ) (response wwr.Payload, err error) { 62 | if srv.Request != nil { 63 | return srv.Request(ctx, clt, msg) 64 | } 65 | return wwr.Payload{}, nil 66 | } 67 | -------------------------------------------------------------------------------- /test/serverInitiatedSessionDestruction_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | wwr "github.com/qbeon/webwire-go" 10 | "github.com/qbeon/webwire-go/payload" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // TestServerInitiatedSessionDestruction tests client-initiated session 16 | // destruction 17 | func TestServerInitiatedSessionDestruction(t *testing.T) { 18 | sessionDestructionCallbackCalled := sync.WaitGroup{} 19 | sessionDestructionCallbackCalled.Add(1) 20 | signalReceived := sync.WaitGroup{} 21 | signalReceived.Add(1) 22 | 23 | sessionKey := "testsessionkey" 24 | sessionCreation := time.Now() 25 | 26 | // Initialize webwire server 27 | setup := SetupTestServer( 28 | t, 29 | &ServerImpl{ 30 | Request: func( 31 | _ context.Context, 32 | c wwr.Connection, 33 | msg wwr.Message, 34 | ) (wwr.Payload, error) { 35 | // Verify session destruction 36 | assert.Nil(t, c.Session()) 37 | return wwr.Payload{}, nil 38 | }, 39 | Signal: func( 40 | _ context.Context, 41 | c wwr.Connection, 42 | _ wwr.Message, 43 | ) { 44 | c.CloseSession() 45 | signalReceived.Done() 46 | }, 47 | }, 48 | wwr.ServerOptions{ 49 | SessionManager: &SessionManager{ 50 | SessionLookup: func( 51 | key string, 52 | ) (wwr.SessionLookupResult, error) { 53 | if key != sessionKey { 54 | return nil, nil 55 | } 56 | return wwr.NewSessionLookupResult( 57 | sessionCreation, // Creation 58 | time.Now(), // LastLookup 59 | nil, // Info 60 | ), nil 61 | }, 62 | SessionClosed: func(closedSessionKey string) error { 63 | defer sessionDestructionCallbackCalled.Done() 64 | 65 | // Ensure that the correct session was closed 66 | assert.Equal(t, sessionKey, closedSessionKey) 67 | 68 | return nil 69 | }, 70 | }, 71 | }, 72 | nil, // Use the default transport implementation 73 | ) 74 | 75 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 76 | require.Equal(t, -1, setup.Server.SessionConnectionsNum(sessionKey)) 77 | require.Nil(t, setup.Server.SessionConnections(sessionKey)) 78 | 79 | sock, _ := setup.NewClientSocket() 80 | 81 | requestRestoreSessionSuccess(t, sock, []byte(sessionKey)) 82 | 83 | assert.NotEqual(t, "", sessionKey) 84 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 85 | require.Equal(t, 1, setup.Server.SessionConnectionsNum(sessionKey)) 86 | require.Equal(t, 1, len(setup.Server.SessionConnections(sessionKey))) 87 | 88 | // Initiate session destruction 89 | signal(t, sock, []byte("close_session"), payload.Payload{}) 90 | 91 | readSessionClosed(t, sock) 92 | 93 | signalReceived.Wait() 94 | 95 | // Wait for the server to finally destroy the session 96 | sessionDestructionCallbackCalled.Wait() 97 | 98 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 99 | require.Equal(t, -1, setup.Server.SessionConnectionsNum(sessionKey)) 100 | require.Equal(t, 0, len(setup.Server.SessionConnections(sessionKey))) 101 | 102 | // Verify session destruction 103 | requestSuccess(t, sock, 32, []byte("verify"), payload.Payload{}) 104 | } 105 | -------------------------------------------------------------------------------- /test/sessionCreationOnClosedConn_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // TestSessionCreationOnClosedConn tests the creation of a session on a 12 | // disconnected connection 13 | func TestSessionCreationOnClosedConn(t *testing.T) { 14 | onConnectedFinished := sync.WaitGroup{} 15 | onConnectedFinished.Add(1) 16 | onDisconnectedFinished := sync.WaitGroup{} 17 | onDisconnectedFinished.Add(1) 18 | 19 | // Initialize server 20 | setup := SetupTestServer( 21 | t, 22 | &ServerImpl{ 23 | ClientConnected: func( 24 | _ wwr.ConnectionOptions, 25 | conn wwr.Connection, 26 | ) { 27 | defer onConnectedFinished.Done() 28 | conn.Close() 29 | 30 | err := conn.CreateSession(nil) 31 | assert.Error(t, err) 32 | assert.IsType(t, wwr.ErrDisconnected{}, err) 33 | }, 34 | ClientDisconnected: func(conn wwr.Connection, _ error) { 35 | defer onDisconnectedFinished.Done() 36 | err := conn.CreateSession(nil) 37 | assert.Error(t, err) 38 | assert.IsType(t, wwr.ErrDisconnected{}, err) 39 | }, 40 | }, 41 | wwr.ServerOptions{}, 42 | nil, // Use the default transport implementation 43 | ) 44 | 45 | // Initialize client 46 | setup.NewClientSocket() 47 | 48 | onConnectedFinished.Wait() 49 | onDisconnectedFinished.Wait() 50 | } 51 | -------------------------------------------------------------------------------- /test/sessionKeyGen.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | // SessionKeyGen implements the webwire.SessionKeyGenerator interface 4 | type SessionKeyGen struct { 5 | OnGenerate func() string 6 | } 7 | 8 | // Generate implements the webwire.SessionKeyGenerator interface 9 | func (gen *SessionKeyGen) Generate() string { 10 | if gen.OnGenerate != nil { 11 | return gen.OnGenerate() 12 | } 13 | return "" 14 | } 15 | -------------------------------------------------------------------------------- /test/sessionManagers.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | ) 9 | 10 | type session struct { 11 | Key string 12 | Creation time.Time 13 | LastLookup time.Time 14 | Info wwr.SessionInfo 15 | } 16 | 17 | // inMemSessManager is a default in-memory session manager for testing purposes 18 | type inMemSessManager struct { 19 | sessions map[string]session 20 | lock *sync.Mutex 21 | } 22 | 23 | // newInMemSessManager constructs a new default session manager instance 24 | // for testing purposes. 25 | func newInMemSessManager() *inMemSessManager { 26 | return &inMemSessManager{ 27 | sessions: make(map[string]session), 28 | lock: &sync.Mutex{}, 29 | } 30 | } 31 | 32 | // OnSessionCreated implements the session manager interface. 33 | // It writes the created session into a file using the session key as file name 34 | func (mng *inMemSessManager) OnSessionCreated(conn wwr.Connection) error { 35 | mng.lock.Lock() 36 | defer mng.lock.Unlock() 37 | sess := conn.Session() 38 | var sessInfo wwr.SessionInfo 39 | if sess.Info != nil { 40 | sessInfo = sess.Info.Copy() 41 | } 42 | mng.sessions[sess.Key] = session{ 43 | Key: sess.Key, 44 | Creation: sess.Creation, 45 | Info: sessInfo, 46 | } 47 | return nil 48 | } 49 | 50 | // OnSessionLookup implements the session manager interface. 51 | // It searches the session file directory for the session file and loads it 52 | func (mng *inMemSessManager) OnSessionLookup(key string) ( 53 | wwr.SessionLookupResult, 54 | error, 55 | ) { 56 | mng.lock.Lock() 57 | defer mng.lock.Unlock() 58 | if session, exists := mng.sessions[key]; exists { 59 | // Update last lookup field 60 | session.LastLookup = time.Now().UTC() 61 | mng.sessions[key] = session 62 | 63 | // Session found 64 | return wwr.NewSessionLookupResult( 65 | session.Creation, // Creation 66 | session.LastLookup, // LastLookup 67 | wwr.SessionInfoToVarMap(session.Info), // Info 68 | ), nil 69 | } 70 | 71 | // Session not found 72 | return nil, nil 73 | } 74 | 75 | // OnSessionClosed implements the session manager interface. 76 | // It closes the session by deleting the according session file 77 | func (mng *inMemSessManager) OnSessionClosed(sessionKey string) error { 78 | mng.lock.Lock() 79 | defer mng.lock.Unlock() 80 | delete(mng.sessions, sessionKey) 81 | return nil 82 | } 83 | 84 | // SessionManager represents a callback-powered session manager 85 | // for testing purposes 86 | type SessionManager struct { 87 | SessionCreated func(client wwr.Connection) error 88 | SessionLookup func(key string) ( 89 | wwr.SessionLookupResult, 90 | error, 91 | ) 92 | SessionClosed func(sessionKey string) error 93 | } 94 | 95 | // OnSessionCreated implements the session manager interface 96 | // calling the configured callback 97 | func (mng *SessionManager) OnSessionCreated( 98 | client wwr.Connection, 99 | ) error { 100 | if mng.SessionCreated == nil { 101 | return nil 102 | } 103 | return mng.SessionCreated(client) 104 | } 105 | 106 | // OnSessionLookup implements the session manager interface 107 | // calling the configured callback 108 | func (mng *SessionManager) OnSessionLookup( 109 | key string, 110 | ) (wwr.SessionLookupResult, error) { 111 | if mng.SessionLookup == nil { 112 | return nil, nil 113 | } 114 | return mng.SessionLookup(key) 115 | } 116 | 117 | // OnSessionClosed implements the session manager interface 118 | // calling the configured callback 119 | func (mng *SessionManager) OnSessionClosed( 120 | sessionKey string, 121 | ) error { 122 | if mng.SessionClosed == nil { 123 | return nil 124 | } 125 | return mng.SessionClosed(sessionKey) 126 | } 127 | -------------------------------------------------------------------------------- /test/sessionNotFound_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/qbeon/webwire-go/message" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // TestSessionNotFound tests restoration requests for inexistent sessions 13 | // and expect them to fail returning the according error 14 | func TestSessionNotFound(t *testing.T) { 15 | lookupTriggered := sync.WaitGroup{} 16 | lookupTriggered.Add(1) 17 | 18 | // Initialize webwire server 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{}, 22 | wwr.ServerOptions{ 23 | SessionManager: &SessionManager{ 24 | SessionLookup: func( 25 | sessionKey string, 26 | ) (wwr.SessionLookupResult, error) { 27 | lookupTriggered.Done() 28 | return nil, nil 29 | }, 30 | }, 31 | }, 32 | nil, // Use the default transport implementation 33 | ) 34 | 35 | // Initialize client 36 | sock, _ := setup.NewClientSocket() 37 | 38 | // Skip manual connection establishment and rely on autoconnect instead 39 | reply := requestRestoreSession(t, sock, []byte("inexistentkey")) 40 | require.Equal(t, message.MsgReplySessionNotFound, reply.MsgType) 41 | 42 | lookupTriggered.Wait() 43 | } 44 | -------------------------------------------------------------------------------- /test/sessionRestoration_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | ) 10 | 11 | // TestSessionRestoration tests manual session restoration by key 12 | func TestSessionRestoration(t *testing.T) { 13 | lookupTriggered := sync.WaitGroup{} 14 | lookupTriggered.Add(1) 15 | var sessionKey = "testsessionkey" 16 | sessionCreation := time.Now() 17 | 18 | // Initialize server 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{}, 22 | wwr.ServerOptions{ 23 | SessionManager: &SessionManager{ 24 | SessionLookup: func(key string) ( 25 | wwr.SessionLookupResult, 26 | error, 27 | ) { 28 | defer lookupTriggered.Done() 29 | if key != sessionKey { 30 | // Session not found 31 | return nil, nil 32 | } 33 | return wwr.NewSessionLookupResult( 34 | sessionCreation, // Creation 35 | time.Now(), // LastLookup 36 | nil, // Info 37 | ), nil 38 | }, 39 | }, 40 | }, 41 | nil, // Use the default transport implementation 42 | ) 43 | 44 | // Initialize clients 45 | sock, _ := setup.NewClientSocket() 46 | 47 | requestRestoreSessionSuccess(t, sock, []byte(sessionKey)) 48 | 49 | lookupTriggered.Wait() 50 | } 51 | -------------------------------------------------------------------------------- /test/sessionStatus_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // TestSessionStatus tests session monitoring methods 12 | func TestSessionStatus(t *testing.T) { 13 | sessionKey := "testsessionkey" 14 | 15 | sessionCreation := time.Now() 16 | 17 | // Initialize webwire server 18 | setup := SetupTestServer( 19 | t, 20 | &ServerImpl{}, 21 | wwr.ServerOptions{ 22 | SessionManager: &SessionManager{ 23 | SessionLookup: func(key string) ( 24 | wwr.SessionLookupResult, 25 | error, 26 | ) { 27 | if key != string(sessionKey) { 28 | // Session not found 29 | return nil, nil 30 | } 31 | return wwr.NewSessionLookupResult( 32 | sessionCreation, // Creation 33 | time.Now(), // LastLookup 34 | nil, // Info 35 | ), nil 36 | }, 37 | }, 38 | }, 39 | nil, // Use the default transport implementation 40 | ) 41 | 42 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 43 | 44 | // Initialize client A 45 | clientA, _ := setup.NewClientSocket() 46 | 47 | requestRestoreSessionSuccess(t, clientA, []byte(sessionKey)) 48 | 49 | // Check status, expect 1 session with 1 connection 50 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 51 | require.Equal(t, 1, setup.Server.SessionConnectionsNum(sessionKey)) 52 | require.Equal(t, 1, len(setup.Server.SessionConnections(sessionKey))) 53 | 54 | // Initialize client B 55 | clientB, _ := setup.NewClientSocket() 56 | 57 | requestRestoreSessionSuccess(t, clientB, []byte(sessionKey)) 58 | 59 | // Check status, expect 1 session with 2 connections 60 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 61 | require.Equal(t, 2, setup.Server.SessionConnectionsNum(sessionKey)) 62 | require.Equal(t, 2, len(setup.Server.SessionConnections(sessionKey))) 63 | 64 | // Close first connection 65 | require.NoError(t, clientA.Close()) 66 | 67 | // Wait for the server to close client A 68 | time.Sleep(50 * time.Millisecond) 69 | 70 | // Check status, expect 1 session with 1 connection 71 | require.Equal(t, 1, setup.Server.ActiveSessionsNum()) 72 | require.Equal(t, 1, setup.Server.SessionConnectionsNum(sessionKey)) 73 | require.Equal(t, 1, len(setup.Server.SessionConnections(sessionKey))) 74 | 75 | // Close session 76 | requestCloseSessionSuccess(t, clientB) 77 | 78 | // Wait for the server to close client B 79 | time.Sleep(50 * time.Millisecond) 80 | 81 | // Check status, expect 0 sessions 82 | require.Equal(t, 0, setup.Server.ActiveSessionsNum()) 83 | require.Equal(t, -1, setup.Server.SessionConnectionsNum(sessionKey)) 84 | require.Nil(t, setup.Server.SessionConnections(sessionKey)) 85 | } 86 | -------------------------------------------------------------------------------- /test/signalUtf16_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | // TestSignalUtf16 tests client-side signals with UTF16 encoded payloads 14 | func TestSignalUtf16(t *testing.T) { 15 | signalArrived := sync.WaitGroup{} 16 | signalArrived.Add(1) 17 | 18 | // Initialize webwire server given only the signal handler 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{ 22 | Signal: func( 23 | _ context.Context, 24 | _ wwr.Connection, 25 | msg wwr.Message, 26 | ) { 27 | assert.Equal(t, wwr.EncodingUtf16, msg.PayloadEncoding()) 28 | assert.Equal(t, []byte{ 29 | 00, 115, 00, 97, 00, 109, 30 | 00, 112, 00, 108, 00, 101, 31 | }, msg.Payload()) 32 | 33 | // Synchronize, notify signal arrival 34 | signalArrived.Done() 35 | }, 36 | }, 37 | wwr.ServerOptions{}, 38 | nil, // Use the default transport implementation 39 | ) 40 | 41 | // Initialize client 42 | sock, _ := setup.NewClientSocket() 43 | 44 | signal(t, sock, []byte("utf16_sig"), payload.Payload{ 45 | Encoding: wwr.EncodingUtf16, 46 | Data: []byte{ 47 | 00, 115, 00, 97, 00, 109, 48 | 00, 112, 00, 108, 00, 101, 49 | }, 50 | }) 51 | 52 | // Synchronize, await signal arrival 53 | signalArrived.Wait() 54 | } 55 | -------------------------------------------------------------------------------- /test/signalUtf8_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | 8 | wwr "github.com/qbeon/webwire-go" 9 | "github.com/qbeon/webwire-go/payload" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestSignalUtf8 tests client-side signals with UTF8 encoded payloads 14 | func TestSignalUtf8(t *testing.T) { 15 | signalArrived := sync.WaitGroup{} 16 | signalArrived.Add(1) 17 | 18 | // Initialize webwire server given only the signal handler 19 | setup := SetupTestServer( 20 | t, 21 | &ServerImpl{ 22 | Signal: func( 23 | _ context.Context, 24 | _ wwr.Connection, 25 | msg wwr.Message, 26 | ) { 27 | // Verify signal payload 28 | require.Equal(t, wwr.EncodingUtf8, msg.PayloadEncoding()) 29 | require.Equal(t, []byte("üникод"), msg.Payload()) 30 | 31 | // Synchronize, notify signal arrival 32 | signalArrived.Done() 33 | }, 34 | }, 35 | wwr.ServerOptions{}, 36 | nil, // Use the default transport implementation 37 | ) 38 | 39 | // Initialize client 40 | sock, _ := setup.NewClientSocket() 41 | 42 | signal(t, sock, []byte("sig_utf8"), payload.Payload{ 43 | Encoding: wwr.EncodingUtf8, 44 | Data: []byte("üникод"), 45 | }) 46 | 47 | // Synchronize, await signal arrival 48 | signalArrived.Wait() 49 | } 50 | -------------------------------------------------------------------------------- /test/simpleShutdown_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | 6 | wwr "github.com/qbeon/webwire-go" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | // TestSimpleShutdown tests simple shutdown without any pending tasks 11 | func TestSimpleShutdown(t *testing.T) { 12 | connectedClientsNum := 5 13 | 14 | // Initialize webwire server 15 | setup := SetupTestServer( 16 | t, 17 | &ServerImpl{}, 18 | wwr.ServerOptions{}, 19 | nil, // Use the default transport implementation 20 | ) 21 | 22 | clients := make([]wwr.Socket, connectedClientsNum) 23 | for i := 0; i < connectedClientsNum; i++ { 24 | sock, _ := setup.NewClientSocket() 25 | clients[i] = sock 26 | } 27 | 28 | require.NoError(t, setup.Server.Shutdown()) 29 | } 30 | -------------------------------------------------------------------------------- /translateContextError.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // TranslateContextError translates context errors to webwire error types 9 | func TranslateContextError(err error) error { 10 | if err == context.DeadlineExceeded { 11 | return ErrDeadlineExceeded{Cause: err} 12 | } else if err == context.Canceled { 13 | return ErrCanceled{Cause: err} 14 | } 15 | return fmt.Errorf("unexpected context error: %s", err) 16 | } 17 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | package webwire 2 | 3 | import ( 4 | "net/url" 5 | "time" 6 | ) 7 | 8 | // IsShuttingDown must be called when the server is accepting a new connection 9 | // and refuse the connection if true is returned 10 | type IsShuttingDown func() bool 11 | 12 | // OnNewConnection must be called when the connection is ready to be used by the 13 | // webwire server 14 | type OnNewConnection func( 15 | connectionOptions ConnectionOptions, 16 | socket Socket, 17 | ) 18 | 19 | // Transport defines the interface of a webwire transport 20 | type Transport interface { 21 | // Initialize initializes the server 22 | Initialize( 23 | options ServerOptions, 24 | isShuttingdown IsShuttingDown, 25 | onNewConnection OnNewConnection, 26 | ) error 27 | 28 | // Serve starts serving blocking the calling goroutine 29 | Serve() error 30 | 31 | // Shutdown shuts the server down 32 | Shutdown() error 33 | 34 | // Address returns the URL address the server is listening on 35 | Address() url.URL 36 | } 37 | 38 | // ClientTransport defines the interface of a webwire client transport 39 | type ClientTransport interface { 40 | // NewSocket initializes a new client socket 41 | NewSocket(dialTimeout time.Duration) (ClientSocket, error) 42 | } 43 | -------------------------------------------------------------------------------- /transport/memchan/buffer.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | ) 9 | 10 | // Buffer represents a reactive outbound buffer implementation 11 | type Buffer struct { 12 | buf []byte 13 | len int 14 | onFlush func([]byte) error 15 | lock *sync.Mutex 16 | } 17 | 18 | // NewBuffer allocates a new buffer 19 | func NewBuffer( 20 | buf []byte, 21 | onFlush func([]byte) error, 22 | ) Buffer { 23 | if len(buf) < 1 { 24 | panic("empty buffer") 25 | } 26 | return Buffer{ 27 | buf: buf, 28 | len: 0, 29 | onFlush: onFlush, 30 | lock: &sync.Mutex{}, 31 | } 32 | } 33 | 34 | // reset clears the buffer 35 | func (buf *Buffer) reset() { 36 | buf.len = 0 37 | } 38 | 39 | // Write writes a portion of data to the buffer 40 | func (buf *Buffer) Write(p []byte) (int, error) { 41 | buf.lock.Lock() 42 | if len(p) > len(buf.buf)-buf.len { 43 | // Buffer overflow 44 | buf.reset() 45 | buf.lock.Unlock() 46 | return 0, wwr.ErrBufferOverflow{} 47 | } 48 | copy(buf.buf[buf.len:], p) 49 | buf.len += len(p) 50 | buf.lock.Unlock() 51 | return len(p), nil 52 | } 53 | 54 | // Close flushes the buffer to the reader 55 | func (buf *Buffer) Close() (err error) { 56 | buf.lock.Lock() 57 | if buf.len < 1 { 58 | return errors.New("no data written") 59 | } 60 | err = buf.onFlush(buf.buf[:buf.len]) 61 | buf.reset() 62 | buf.lock.Unlock() 63 | return err 64 | } 65 | -------------------------------------------------------------------------------- /transport/memchan/buffer_test.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | wwr "github.com/qbeon/webwire-go" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // TestBufferWrite tests Buffer.Write 12 | func TestBufferWrite(t *testing.T) { 13 | buf := make([]byte, 6) 14 | buffer := NewBuffer(buf, nil) 15 | bytesWritten, err := buffer.Write([]byte{1}) 16 | require.NoError(t, err) 17 | require.Equal(t, 1, bytesWritten) 18 | 19 | bytesWritten, err = buffer.Write([]byte{2, 3, 4}) 20 | require.NoError(t, err) 21 | require.Equal(t, 3, bytesWritten) 22 | 23 | require.Equal(t, []byte{1, 2, 3, 4, 0, 0}, buf) 24 | } 25 | 26 | // TestBufferWriteOverflow tests Buffer.Write overflowing the buffer 27 | func TestBufferWriteOverflow(t *testing.T) { 28 | buf := make([]byte, 4) 29 | buffer := NewBuffer(buf, nil) 30 | 31 | bytesWritten, err := buffer.Write([]byte{1, 2, 3, 4, 5}) 32 | require.Equal(t, 0, bytesWritten) 33 | require.Error(t, err) 34 | require.IsType(t, wwr.ErrBufferOverflow{}, err) 35 | 36 | require.Equal(t, []byte{0, 0, 0, 0}, buf) 37 | } 38 | 39 | // TestBufferClose tests Buffer.Close 40 | func TestBufferClose(t *testing.T) { 41 | flushed := false 42 | 43 | buf := make([]byte, 3) 44 | buffer := NewBuffer(buf, func(data []byte) error { 45 | flushed = true 46 | require.Equal(t, []byte{1, 1}, data) 47 | return nil 48 | }) 49 | 50 | bytesWritten, err := buffer.Write([]byte{1, 1}) 51 | require.Equal(t, 2, bytesWritten) 52 | require.NoError(t, err) 53 | 54 | require.NoError(t, buffer.Close()) 55 | require.Equal(t, true, flushed) 56 | } 57 | 58 | // TestBufferCloseError tests Buffer.Close with error 59 | func TestBufferCloseError(t *testing.T) { 60 | buf := make([]byte, 3) 61 | buffer := NewBuffer(buf, func(data []byte) error { 62 | return errors.New("test error") 63 | }) 64 | 65 | require.Error(t, buffer.Close()) 66 | } 67 | -------------------------------------------------------------------------------- /transport/memchan/clientTransport.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "time" 5 | 6 | wwr "github.com/qbeon/webwire-go" 7 | ) 8 | 9 | // ClientTransport implements the ClientTransport interface 10 | type ClientTransport struct { 11 | Server *Transport 12 | } 13 | 14 | // NewSocket implements the ClientTransport interface 15 | func (cltTrans *ClientTransport) NewSocket( 16 | dialTimeout time.Duration, 17 | ) (wwr.ClientSocket, error) { 18 | if cltTrans.Server == nil { 19 | // Create a disconnected socket instance 20 | return newDisconnectedSocket(), nil 21 | } 22 | 23 | // Create a new entangled socket pair 24 | _, clt := NewEntangledSockets(cltTrans.Server) 25 | 26 | return clt, nil 27 | } 28 | -------------------------------------------------------------------------------- /transport/memchan/entangleSockets.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | // entangleSockets connects two sockets 4 | func entangleSockets(server, client *Socket) { 5 | if server.status != nil { 6 | panic("the server socket is already entangled") 7 | } 8 | if client.status != nil { 9 | panic("the server socket is already entangled") 10 | } 11 | 12 | // Set the socket types 13 | server.sockType = SocketServer 14 | client.sockType = SocketClient 15 | 16 | // Entangle references 17 | server.remote = client 18 | client.remote = server 19 | 20 | // Initialize shared status 21 | status := statusDisconnected 22 | server.status = &status 23 | client.status = &status 24 | } 25 | -------------------------------------------------------------------------------- /transport/memchan/newEntangledSockets.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | // NewEntangledSockets creates a new socket pair 4 | func NewEntangledSockets(server *Transport) (srv, clt *Socket) { 5 | srv = newSocket(server, server.bufferSize) 6 | clt = newSocket(server, server.bufferSize) 7 | entangleSockets(srv, clt) 8 | return 9 | } 10 | -------------------------------------------------------------------------------- /transport/memchan/newSocket.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // newSocket creates a new socket instance that must be entangled with another 9 | // socket before it can be used 10 | func newSocket(server *Transport, bufferSize uint32) *Socket { 11 | // Setup a new inactive timer 12 | readTimer := time.NewTimer(0) 13 | <-readTimer.C 14 | 15 | socket := &Socket{ 16 | sockType: SocketUninitialized, 17 | server: server, 18 | readLock: &sync.Mutex{}, 19 | writerLock: &sync.Mutex{}, 20 | reader: make(chan []byte, 1), 21 | readerLock: &sync.Mutex{}, 22 | readerErr: make(chan error), 23 | readTimer: readTimer, 24 | remote: nil, 25 | status: nil, 26 | } 27 | 28 | // Allocate the outbound buffer 29 | socket.outboundBuffer = NewBuffer( 30 | make([]byte, bufferSize), 31 | // Connect the onFlush callback to the corresponding slot method 32 | socket.onBufferFlush, 33 | ) 34 | 35 | return socket 36 | } 37 | 38 | // newDisconnectedSocket creates a new disconnected socket instance 39 | func newDisconnectedSocket() *Socket { 40 | // Setup a new inactive timer 41 | readTimer := time.NewTimer(0) 42 | <-readTimer.C 43 | 44 | status := statusDisconnected 45 | 46 | socket := &Socket{ 47 | sockType: SocketClient, 48 | readLock: &sync.Mutex{}, 49 | writerLock: &sync.Mutex{}, 50 | reader: make(chan []byte, 1), 51 | readerLock: &sync.Mutex{}, 52 | readerErr: make(chan error), 53 | readTimer: readTimer, 54 | remote: nil, 55 | status: &status, 56 | } 57 | 58 | // Allocate the outbound buffer 59 | socket.outboundBuffer = NewBuffer( 60 | make([]byte, 1), 61 | // Connect the onFlush callback to the corresponding slot method 62 | socket.onBufferFlush, 63 | ) 64 | 65 | return socket 66 | } 67 | -------------------------------------------------------------------------------- /transport/memchan/remoteAddress.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // RemoteAddress represents a net.Addr interface implementation 8 | type RemoteAddress struct { 9 | serverSocket *Socket 10 | } 11 | 12 | // Network implements the net.Addr interface 13 | func (addr RemoteAddress) Network() string { 14 | return "memchan" 15 | } 16 | 17 | // String implements the net.Addr interface 18 | func (addr RemoteAddress) String() string { 19 | return fmt.Sprintf("%p", addr.serverSocket) 20 | } 21 | -------------------------------------------------------------------------------- /transport/memchan/sockReadErr.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import "fmt" 4 | 5 | // ErrSockRead implements the ErrSockRead interface 6 | type ErrSockRead struct { 7 | // closed is true when the error was caused by a graceful socket closure 8 | closed bool 9 | 10 | err error 11 | } 12 | 13 | // Error implements the Go error interface 14 | func (err ErrSockRead) Error() string { 15 | if err.closed { 16 | return "socket closed" 17 | } 18 | return fmt.Sprintf("reading socket failed: %s", err.err) 19 | } 20 | 21 | // IsCloseErr implements the ErrSockRead interface 22 | func (err ErrSockRead) IsCloseErr() bool { 23 | return err.closed 24 | } 25 | -------------------------------------------------------------------------------- /transport/memchan/transport.go: -------------------------------------------------------------------------------- 1 | package memchan 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/url" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | 11 | wwr "github.com/qbeon/webwire-go" 12 | ) 13 | 14 | const serverClosed = 0 15 | const serverActive = 1 16 | 17 | // Transport implements the Transport 18 | type Transport struct { 19 | // OnBeforeCreation is called before the creation of a new connection and 20 | // must return the options to be assigned to the new connection 21 | OnBeforeCreation func() wwr.ConnectionOptions 22 | 23 | onNewConnection wwr.OnNewConnection 24 | isShuttingdown wwr.IsShuttingDown 25 | 26 | bufferSize uint32 27 | readTimeout time.Duration 28 | connections map[*Socket]*Socket 29 | connectionsLock *sync.Mutex 30 | status uint32 31 | shutdown chan struct{} 32 | } 33 | 34 | // Initialize implements the Transport interface 35 | func (srv *Transport) Initialize( 36 | options wwr.ServerOptions, 37 | isShuttingdown wwr.IsShuttingDown, 38 | onNewConnection wwr.OnNewConnection, 39 | ) error { 40 | srv.readTimeout = options.ReadTimeout 41 | srv.bufferSize = options.MessageBufferSize 42 | srv.isShuttingdown = isShuttingdown 43 | srv.onNewConnection = onNewConnection 44 | srv.connections = make(map[*Socket]*Socket) 45 | srv.connectionsLock = &sync.Mutex{} 46 | srv.shutdown = make(chan struct{}) 47 | srv.status = serverActive 48 | 49 | if srv.OnBeforeCreation == nil { 50 | srv.OnBeforeCreation = func() wwr.ConnectionOptions { 51 | return wwr.ConnectionOptions{} 52 | } 53 | } 54 | 55 | return nil 56 | } 57 | 58 | // Serve implements the Transport interface 59 | func (srv *Transport) Serve() error { 60 | if atomic.LoadUint32(&srv.status) != serverActive { 61 | return errors.New("server is closed") 62 | } 63 | <-srv.shutdown 64 | return nil 65 | } 66 | 67 | // Shutdown implements the Transport interface 68 | func (srv *Transport) Shutdown() error { 69 | if atomic.CompareAndSwapUint32(&srv.status, serverActive, serverClosed) { 70 | close(srv.shutdown) 71 | srv.connectionsLock.Lock() 72 | conns := make([]*Socket, len(srv.connections)) 73 | index := 0 74 | for sock := range srv.connections { 75 | conns[index] = sock 76 | index++ 77 | } 78 | srv.connectionsLock.Unlock() 79 | 80 | // Close all connections 81 | for _, sock := range conns { 82 | if err := sock.Close(); err != nil { 83 | srv.connectionsLock.Unlock() 84 | return fmt.Errorf("couldn't close socket %p: %s", sock, err) 85 | } 86 | } 87 | } 88 | return nil 89 | } 90 | 91 | // Address implements the Transport interface 92 | func (srv *Transport) Address() url.URL { 93 | return url.URL{ 94 | Scheme: "memchan", 95 | } 96 | } 97 | 98 | // onConnect is called in Socket.Dial by a client-type socket on connection 99 | func (srv *Transport) onConnect( 100 | clientSocket *Socket, 101 | connectionOptions wwr.ConnectionOptions, 102 | ) error { 103 | // Reject incoming connections during server shutdown 104 | if srv.isShuttingdown() { 105 | return errors.New("server is shutting down") 106 | } 107 | 108 | if atomic.LoadUint32(&srv.status) != serverActive { 109 | return errors.New("server is closed") 110 | } 111 | 112 | if clientSocket.remote == nil || clientSocket.status == nil { 113 | return errors.New("uninitialized socket") 114 | } 115 | 116 | srv.connectionsLock.Lock() 117 | srv.connections[clientSocket.remote] = clientSocket 118 | srv.connectionsLock.Unlock() 119 | 120 | go srv.onNewConnection( 121 | connectionOptions, 122 | clientSocket.remote, 123 | ) 124 | 125 | return nil 126 | } 127 | 128 | // onDisconnect is called in Socket.Close by a server-type socket on closure 129 | func (srv *Transport) onDisconnect(serverSocket *Socket) { 130 | srv.connectionsLock.Lock() 131 | delete(srv.connections, serverSocket) 132 | srv.connectionsLock.Unlock() 133 | } 134 | -------------------------------------------------------------------------------- /vendor/github.com/davecgh/go-spew/spew/bypasssafe.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2015-2016 Dave Collins 2 | // 3 | // Permission to use, copy, modify, and distribute this software for any 4 | // purpose with or without fee is hereby granted, provided that the above 5 | // copyright notice and this permission notice appear in all copies. 6 | // 7 | // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | // NOTE: Due to the following build constraints, this file will only be compiled 16 | // when the code is running on Google App Engine, compiled by GopherJS, or 17 | // "-tags safe" is added to the go build command line. The "disableunsafe" 18 | // tag is deprecated and thus should not be used. 19 | // +build js appengine safe disableunsafe !go1.4 20 | 21 | package spew 22 | 23 | import "reflect" 24 | 25 | const ( 26 | // UnsafeDisabled is a build-time constant which specifies whether or 27 | // not access to the unsafe package is available. 28 | UnsafeDisabled = true 29 | ) 30 | 31 | // unsafeReflectValue typically converts the passed reflect.Value into a one 32 | // that bypasses the typical safety restrictions preventing access to 33 | // unaddressable and unexported data. However, doing this relies on access to 34 | // the unsafe package. This is a stub version which simply returns the passed 35 | // reflect.Value when the unsafe package is not available. 36 | func unsafeReflectValue(v reflect.Value) reflect.Value { 37 | return v 38 | } 39 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/assert/assertion_format.go.tmpl: -------------------------------------------------------------------------------- 1 | {{.CommentFormat}} 2 | func {{.DocInfo.Name}}f(t TestingT, {{.ParamsFormat}}) bool { 3 | if h, ok := t.(tHelper); ok { h.Helper() } 4 | return {{.DocInfo.Name}}(t, {{.ForwardedParamsFormat}}) 5 | } 6 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/assert/assertion_forward.go.tmpl: -------------------------------------------------------------------------------- 1 | {{.CommentWithoutT "a"}} 2 | func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool { 3 | if h, ok := a.t.(tHelper); ok { h.Helper() } 4 | return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) 5 | } 6 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/assert/doc.go: -------------------------------------------------------------------------------- 1 | // Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. 2 | // 3 | // Example Usage 4 | // 5 | // The following is a complete example using assert in a standard test function: 6 | // import ( 7 | // "testing" 8 | // "github.com/stretchr/testify/assert" 9 | // ) 10 | // 11 | // func TestSomething(t *testing.T) { 12 | // 13 | // var a string = "Hello" 14 | // var b string = "Hello" 15 | // 16 | // assert.Equal(t, a, b, "The two words should be the same.") 17 | // 18 | // } 19 | // 20 | // if you assert many times, use the format below: 21 | // 22 | // import ( 23 | // "testing" 24 | // "github.com/stretchr/testify/assert" 25 | // ) 26 | // 27 | // func TestSomething(t *testing.T) { 28 | // assert := assert.New(t) 29 | // 30 | // var a string = "Hello" 31 | // var b string = "Hello" 32 | // 33 | // assert.Equal(a, b, "The two words should be the same.") 34 | // } 35 | // 36 | // Assertions 37 | // 38 | // Assertions allow you to easily write test code, and are global funcs in the `assert` package. 39 | // All assertion functions take, as the first argument, the `*testing.T` object provided by the 40 | // testing framework. This allows the assertion funcs to write the failings and other details to 41 | // the correct place. 42 | // 43 | // Every assertion function also takes an optional string message as the final argument, 44 | // allowing custom error messages to be appended to the message the assertion method outputs. 45 | package assert 46 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/assert/errors.go: -------------------------------------------------------------------------------- 1 | package assert 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // AnError is an error instance useful for testing. If the code does not care 8 | // about error specifics, and only needs to return the error for example, this 9 | // error should be used to make the test code more readable. 10 | var AnError = errors.New("assert.AnError general error for testing") 11 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/assert/forward_assertions.go: -------------------------------------------------------------------------------- 1 | package assert 2 | 3 | // Assertions provides assertion methods around the 4 | // TestingT interface. 5 | type Assertions struct { 6 | t TestingT 7 | } 8 | 9 | // New makes a new Assertions object for the specified TestingT. 10 | func New(t TestingT) *Assertions { 11 | return &Assertions{ 12 | t: t, 13 | } 14 | } 15 | 16 | //go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs 17 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/require/doc.go: -------------------------------------------------------------------------------- 1 | // Package require implements the same assertions as the `assert` package but 2 | // stops test execution when a test fails. 3 | // 4 | // Example Usage 5 | // 6 | // The following is a complete example using require in a standard test function: 7 | // import ( 8 | // "testing" 9 | // "github.com/stretchr/testify/require" 10 | // ) 11 | // 12 | // func TestSomething(t *testing.T) { 13 | // 14 | // var a string = "Hello" 15 | // var b string = "Hello" 16 | // 17 | // require.Equal(t, a, b, "The two words should be the same.") 18 | // 19 | // } 20 | // 21 | // Assertions 22 | // 23 | // The `require` package have same global functions as in the `assert` package, 24 | // but instead of returning a boolean result they call `t.FailNow()`. 25 | // 26 | // Every assertion function also takes an optional string message as the final argument, 27 | // allowing custom error messages to be appended to the message the assertion method outputs. 28 | package require 29 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/require/forward_requirements.go: -------------------------------------------------------------------------------- 1 | package require 2 | 3 | // Assertions provides assertion methods around the 4 | // TestingT interface. 5 | type Assertions struct { 6 | t TestingT 7 | } 8 | 9 | // New makes a new Assertions object for the specified TestingT. 10 | func New(t TestingT) *Assertions { 11 | return &Assertions{ 12 | t: t, 13 | } 14 | } 15 | 16 | //go:generate go run ../_codegen/main.go -output-package=require -template=require_forward.go.tmpl -include-format-funcs 17 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/require/require.go.tmpl: -------------------------------------------------------------------------------- 1 | {{.Comment}} 2 | func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { 3 | if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return } 4 | if h, ok := t.(tHelper); ok { h.Helper() } 5 | t.FailNow() 6 | } 7 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/require/require_forward.go.tmpl: -------------------------------------------------------------------------------- 1 | {{.CommentWithoutT "a"}} 2 | func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) { 3 | if h, ok := a.t.(tHelper); ok { h.Helper() } 4 | {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) 5 | } 6 | -------------------------------------------------------------------------------- /vendor/github.com/stretchr/testify/require/requirements.go: -------------------------------------------------------------------------------- 1 | package require 2 | 3 | // TestingT is an interface wrapper around *testing.T 4 | type TestingT interface { 5 | Errorf(format string, args ...interface{}) 6 | FailNow() 7 | } 8 | 9 | type tHelper interface { 10 | Helper() 11 | } 12 | 13 | // ComparisonAssertionFunc is a common function prototype when comparing two values. Can be useful 14 | // for table driven tests. 15 | type ComparisonAssertionFunc func(TestingT, interface{}, interface{}, ...interface{}) 16 | 17 | // ValueAssertionFunc is a common function prototype when validating a single value. Can be useful 18 | // for table driven tests. 19 | type ValueAssertionFunc func(TestingT, interface{}, ...interface{}) 20 | 21 | // BoolAssertionFunc is a common function prototype when validating a bool value. Can be useful 22 | // for table driven tests. 23 | type BoolAssertionFunc func(TestingT, bool, ...interface{}) 24 | 25 | // ValuesAssertionFunc is a common function prototype when validating an error value. Can be useful 26 | // for table driven tests. 27 | type ErrorAssertionFunc func(TestingT, error, ...interface{}) 28 | 29 | //go:generate go run ../_codegen/main.go -output-package=require -template=require.go.tmpl -include-format-funcs 30 | -------------------------------------------------------------------------------- /vendor/golang.org/x/net/context/context.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package context defines the Context type, which carries deadlines, 6 | // cancelation signals, and other request-scoped values across API boundaries 7 | // and between processes. 8 | // As of Go 1.7 this package is available in the standard library under the 9 | // name context. https://golang.org/pkg/context. 10 | // 11 | // Incoming requests to a server should create a Context, and outgoing calls to 12 | // servers should accept a Context. The chain of function calls between must 13 | // propagate the Context, optionally replacing it with a modified copy created 14 | // using WithDeadline, WithTimeout, WithCancel, or WithValue. 15 | // 16 | // Programs that use Contexts should follow these rules to keep interfaces 17 | // consistent across packages and enable static analysis tools to check context 18 | // propagation: 19 | // 20 | // Do not store Contexts inside a struct type; instead, pass a Context 21 | // explicitly to each function that needs it. The Context should be the first 22 | // parameter, typically named ctx: 23 | // 24 | // func DoSomething(ctx context.Context, arg Arg) error { 25 | // // ... use ctx ... 26 | // } 27 | // 28 | // Do not pass a nil Context, even if a function permits it. Pass context.TODO 29 | // if you are unsure about which Context to use. 30 | // 31 | // Use context Values only for request-scoped data that transits processes and 32 | // APIs, not for passing optional parameters to functions. 33 | // 34 | // The same Context may be passed to functions running in different goroutines; 35 | // Contexts are safe for simultaneous use by multiple goroutines. 36 | // 37 | // See http://blog.golang.org/context for example code for a server that uses 38 | // Contexts. 39 | package context // import "golang.org/x/net/context" 40 | 41 | // Background returns a non-nil, empty Context. It is never canceled, has no 42 | // values, and has no deadline. It is typically used by the main function, 43 | // initialization, and tests, and as the top-level Context for incoming 44 | // requests. 45 | func Background() Context { 46 | return background 47 | } 48 | 49 | // TODO returns a non-nil, empty Context. Code should use context.TODO when 50 | // it's unclear which Context to use or it is not yet available (because the 51 | // surrounding function has not yet been extended to accept a Context 52 | // parameter). TODO is recognized by static analysis tools that determine 53 | // whether Contexts are propagated correctly in a program. 54 | func TODO() Context { 55 | return todo 56 | } 57 | -------------------------------------------------------------------------------- /vendor/golang.org/x/net/context/go17.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // +build go1.7 6 | 7 | package context 8 | 9 | import ( 10 | "context" // standard library's context, as of Go 1.7 11 | "time" 12 | ) 13 | 14 | var ( 15 | todo = context.TODO() 16 | background = context.Background() 17 | ) 18 | 19 | // Canceled is the error returned by Context.Err when the context is canceled. 20 | var Canceled = context.Canceled 21 | 22 | // DeadlineExceeded is the error returned by Context.Err when the context's 23 | // deadline passes. 24 | var DeadlineExceeded = context.DeadlineExceeded 25 | 26 | // WithCancel returns a copy of parent with a new Done channel. The returned 27 | // context's Done channel is closed when the returned cancel function is called 28 | // or when the parent context's Done channel is closed, whichever happens first. 29 | // 30 | // Canceling this context releases resources associated with it, so code should 31 | // call cancel as soon as the operations running in this Context complete. 32 | func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { 33 | ctx, f := context.WithCancel(parent) 34 | return ctx, CancelFunc(f) 35 | } 36 | 37 | // WithDeadline returns a copy of the parent context with the deadline adjusted 38 | // to be no later than d. If the parent's deadline is already earlier than d, 39 | // WithDeadline(parent, d) is semantically equivalent to parent. The returned 40 | // context's Done channel is closed when the deadline expires, when the returned 41 | // cancel function is called, or when the parent context's Done channel is 42 | // closed, whichever happens first. 43 | // 44 | // Canceling this context releases resources associated with it, so code should 45 | // call cancel as soon as the operations running in this Context complete. 46 | func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { 47 | ctx, f := context.WithDeadline(parent, deadline) 48 | return ctx, CancelFunc(f) 49 | } 50 | 51 | // WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). 52 | // 53 | // Canceling this context releases resources associated with it, so code should 54 | // call cancel as soon as the operations running in this Context complete: 55 | // 56 | // func slowOperationWithTimeout(ctx context.Context) (Result, error) { 57 | // ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) 58 | // defer cancel() // releases resources if slowOperation completes before timeout elapses 59 | // return slowOperation(ctx) 60 | // } 61 | func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { 62 | return WithDeadline(parent, time.Now().Add(timeout)) 63 | } 64 | 65 | // WithValue returns a copy of parent in which the value associated with key is 66 | // val. 67 | // 68 | // Use context Values only for request-scoped data that transits processes and 69 | // APIs, not for passing optional parameters to functions. 70 | func WithValue(parent Context, key interface{}, val interface{}) Context { 71 | return context.WithValue(parent, key, val) 72 | } 73 | -------------------------------------------------------------------------------- /vendor/golang.org/x/net/context/go19.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // +build go1.9 6 | 7 | package context 8 | 9 | import "context" // standard library's context, as of Go 1.7 10 | 11 | // A Context carries a deadline, a cancelation signal, and other values across 12 | // API boundaries. 13 | // 14 | // Context's methods may be called by multiple goroutines simultaneously. 15 | type Context = context.Context 16 | 17 | // A CancelFunc tells an operation to abandon its work. 18 | // A CancelFunc does not wait for the work to stop. 19 | // After the first call, subsequent calls to a CancelFunc do nothing. 20 | type CancelFunc = context.CancelFunc 21 | --------------------------------------------------------------------------------