├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── examples ├── webchat │ ├── README.md │ ├── web │ │ ├── app.js │ │ └── index.html │ └── webchat.go └── webecho │ ├── README.md │ ├── web │ ├── app.js │ └── index.html │ └── webecho.go ├── sockjs ├── .gitignore ├── README.md ├── benchmarks_test.go ├── buffer.go ├── doc.go ├── eventsource.go ├── eventsource_test.go ├── example_handler_test.go ├── frame.go ├── frame_test.go ├── handler.go ├── handler_test.go ├── htmlfile.go ├── htmlfile_test.go ├── httpreceiver.go ├── httpreceiver_test.go ├── iframe.go ├── iframe_test.go ├── jsonp.go ├── jsonp_test.go ├── mapping.go ├── mapping_test.go ├── options.go ├── options_test.go ├── rawwebsocket.go ├── rawwebsocket_test.go ├── session.go ├── session_test.go ├── sockjs.go ├── sockjs_test.go ├── utils.go ├── utils_test.go ├── web.go ├── web_test.go ├── websocket.go ├── websocket_test.go ├── xhr.go └── xhr_test.go ├── testserver └── server.go └── v3 ├── go.mod ├── go.sum └── sockjs ├── benchmarks_test.go ├── buffer.go ├── doc.go ├── eventsource.go ├── eventsource_integration_stage_test.go ├── eventsource_intergration_test.go ├── eventsource_test.go ├── example_handler_test.go ├── frame.go ├── frame_test.go ├── handler.go ├── handler_test.go ├── htmlfile.go ├── htmlfile_integration_stage_test.go ├── htmlfile_intergration_test.go ├── htmlfile_test.go ├── httpreceiver.go ├── httpreceiver_test.go ├── iframe.go ├── iframe_test.go ├── jsonp.go ├── jsonp_test.go ├── mapping.go ├── mapping_test.go ├── options.go ├── options_test.go ├── rawwebsocket.go ├── rawwebsocket_test.go ├── receiver.go ├── session.go ├── session_test.go ├── sockjs.go ├── sockjs_test.go ├── utils.go ├── utils_test.go ├── web.go ├── web_test.go ├── websocket.go ├── websocket_test.go ├── xhr.go └── xhr_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | *.swp 6 | 7 | # Folders 8 | _obj 9 | _test 10 | .idea 11 | .DS_Store 12 | 13 | # Architecture specific extensions/prefixes 14 | *.[568vq] 15 | [568vq].out 16 | 17 | *.cgo1.go 18 | *.cgo2.c 19 | _cgo_defun.c 20 | _cgo_gotypes.go 21 | _cgo_export.* 22 | 23 | _testmain.go 24 | 25 | *.exe 26 | *.coverprofile 27 | 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.14.x" 5 | 6 | before_install: 7 | - cd v3 8 | - go get golang.org/x/tools/cmd/cover 9 | - go get github.com/mattn/goveralls 10 | - go get github.com/golangci/golangci-lint/cmd/golangci-lint 11 | 12 | after_success: 13 | - go test ./... -coverprofile=profile.out -covermode=count 14 | - PATH=$HOME/gopath/bin:$PATH goveralls -coverprofile=profile.out -service=travis-ci 15 | 16 | script: 17 | - golangci-lint run 18 | - go test ./... -race 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2020, sockjs-go authors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of nor the names of its contributors may be used to 13 | endorse or promote products derived from this software without specific 14 | prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 20 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 | POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://api.travis-ci.org/igm/sockjs-go.svg?branch=master)](https://travis-ci.org/igm/sockjs-go) 2 | [![GoDoc](https://godoc.org/github.com/igm/sockjs-go/v3/sockjs?status.svg)](https://pkg.go.dev/github.com/igm/sockjs-go/v3/sockjs?tab=doc) 3 | [![Coverage Status](https://coveralls.io/repos/github/igm/sockjs-go/badge.svg?branch=master)](https://coveralls.io/github/igm/sockjs-go?branch=master) 4 | 5 | What is SockJS? 6 | = 7 | 8 | SockJS is a JavaScript library (for browsers) that provides a WebSocket-like 9 | object. SockJS gives you a coherent, cross-browser, Javascript API 10 | which creates a low latency, full duplex, cross-domain communication 11 | channel between the browser and the web server, with WebSockets or without. 12 | This necessitates the use of a server, which this is one version of, for GO. 13 | 14 | 15 | SockJS-Go server library 16 | = 17 | 18 | SockJS-Go is a [SockJS](https://github.com/sockjs/sockjs-client) server library written in Go. 19 | 20 | For latest **v3** version of `sockjs-go` use: 21 | 22 | github.com/igm/sockjs-go/v3/sockjs 23 | 24 | For **v2** version of `sockjs-go` use: 25 | 26 | gopkg.in/igm/sockjs-go.v2/sockjs 27 | 28 | Using version **v1** is not recommended (DEPRECATED) 29 | 30 | gopkg.in/igm/sockjs-go.v1/sockjs 31 | 32 | Note: using `github.com/igm/sockjs-go/sockjs` is not recommended. It exists for backwards compatibility reasons and is not maintained. 33 | 34 | Versioning 35 | - 36 | 37 | SockJS-Go project adopted [gopkg.in](http://gopkg.in) approach for versioning. SockJS-Go library details can be found [here](https://gopkg.in/igm/sockjs-go.v2/sockjs) 38 | 39 | With the introduction of go modules a new version `v3` is developed and maintained in the `master` and has new import part `github.com/igm/sockjs-go/v3/sockjs`. 40 | 41 | Example 42 | - 43 | 44 | A simple echo sockjs server: 45 | 46 | 47 | ```go 48 | package main 49 | 50 | import ( 51 | "log" 52 | "net/http" 53 | 54 | "github.com/igm/sockjs-go/v3/sockjs" 55 | ) 56 | 57 | func main() { 58 | handler := sockjs.NewHandler("/echo", sockjs.DefaultOptions, echoHandler) 59 | log.Fatal(http.ListenAndServe(":8081", handler)) 60 | } 61 | 62 | func echoHandler(session sockjs.Session) { 63 | for { 64 | if msg, err := session.Recv(); err == nil { 65 | session.Send(msg) 66 | continue 67 | } 68 | break 69 | } 70 | } 71 | ``` 72 | 73 | 74 | SockJS Protocol Tests Status 75 | - 76 | SockJS defines a set of [protocol tests](https://github.com/sockjs/sockjs-protocol) to quarantee a server compatibility with sockjs client library and various browsers. SockJS-Go server library aims to provide full compatibility, however there are couple of tests that don't and probably will never pass due to reasons explained in table below: 77 | 78 | 79 | | Failing Test | Explanation | 80 | | -------------| ------------| 81 | | **XhrPolling.test_transport** | does not pass due to a feature in net/http that does not send content-type header in case of StatusNoContent response code (even if explicitly set in the code), [details](https://code.google.com/p/go/source/detail?r=902dc062bff8) | 82 | | **WebSocket.** | Sockjs Go version supports RFC 6455, draft protocols hixie-76, hybi-10 are not supported | 83 | | **JSONEncoding** | As menioned in [browser quirks](https://github.com/sockjs/sockjs-client#browser-quirks) section: "it's advisable to use only valid characters. Using invalid characters is a bit slower, and may not work with SockJS servers that have a proper Unicode support." Go lang has a proper Unicode support | 84 | | **RawWebsocket.** | The sockjs protocol tests use old WebSocket client library (hybi-10) that does not support RFC 6455 properly | 85 | 86 | WebSocket 87 | - 88 | As mentioned above sockjs-go library is compatible with RFC 6455. That means the browsers not supporting RFC 6455 are not supported properly. There are no plans to support draft versions of WebSocket protocol. The WebSocket support is based on [Gorilla web toolkit](http://www.gorillatoolkit.org/pkg/websocket) implementation of WebSocket. 89 | 90 | For detailed information about browser versions supporting RFC 6455 see this [wiki page](http://en.wikipedia.org/wiki/WebSocket#Browser_support). 91 | -------------------------------------------------------------------------------- /examples/webchat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | Simple sockjs chat example. 4 | 5 | ## Run 6 | ```shell 7 | $ go run webchat.go 8 | ``` 9 | Navigate using web browser: http://127.0.0.1:8080 10 | Open multiple windows with the same URL and see how chat works. 11 | 12 | -------------------------------------------------------------------------------- /examples/webchat/web/app.js: -------------------------------------------------------------------------------- 1 | if (!window.location.origin) { // Some browsers (mainly IE) do not have this property, so we need to build it manually... 2 | window.location.origin = window.location.protocol + '//' + window.location.hostname + (window.location.port ? (':' + window.location.port) : ''); 3 | } 4 | 5 | 6 | var sock = new SockJS(window.location.origin+'/echo') 7 | 8 | sock.onopen = function() { 9 | // console.log('connection open'); 10 | document.getElementById("status").innerHTML = "connected"; 11 | document.getElementById("send").disabled=false; 12 | }; 13 | 14 | sock.onmessage = function(e) { 15 | document.getElementById("output").value += e.data +"\n"; 16 | }; 17 | 18 | sock.onclose = function() { 19 | // console.log('connection closed'); 20 | document.getElementById("status").innerHTML = "disconnected"; 21 | document.getElementById("send").disabled=true; 22 | }; 23 | -------------------------------------------------------------------------------- /examples/webchat/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Chat Web Example 8 | 9 | 10 | 11 |

Chat - Web Example

12 |
13 | Input text: 14 | 15 |
16 |
17 | Messages from server:
18 | 20 |
21 | status: connecting... 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /examples/webchat/webchat.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | 7 | "github.com/igm/pubsub" 8 | "github.com/igm/sockjs-go/v3/sockjs" 9 | ) 10 | 11 | var chat pubsub.Publisher 12 | 13 | func main() { 14 | http.Handle("/echo/", sockjs.NewHandler("/echo", sockjs.DefaultOptions, echoHandler)) 15 | http.Handle("/", http.FileServer(http.Dir("web/"))) 16 | log.Println("Server started on port: 8080") 17 | log.Fatal(http.ListenAndServe(":8080", nil)) 18 | } 19 | 20 | func echoHandler(session sockjs.Session) { 21 | log.Println("new sockjs session established") 22 | var closedSession = make(chan struct{}) 23 | chat.Publish("[info] new participant joined chat") 24 | defer chat.Publish("[info] participant left chat") 25 | go func() { 26 | reader, _ := chat.SubChannel(nil) 27 | for { 28 | select { 29 | case <-closedSession: 30 | return 31 | case msg := <-reader: 32 | if err := session.Send(msg.(string)); err != nil { 33 | return 34 | } 35 | } 36 | 37 | } 38 | }() 39 | for { 40 | if msg, err := session.Recv(); err == nil { 41 | chat.Publish(msg) 42 | continue 43 | } 44 | break 45 | } 46 | close(closedSession) 47 | log.Println("sockjs session closed") 48 | } 49 | -------------------------------------------------------------------------------- /examples/webecho/README.md: -------------------------------------------------------------------------------- 1 | # Echo Example 2 | 3 | Simple echo sockjs example. 4 | 5 | ## Run 6 | ```shell 7 | $ go run webecho.go 8 | ``` 9 | Navigate using web browser: http://127.0.0.1:8080 10 | -------------------------------------------------------------------------------- /examples/webecho/web/app.js: -------------------------------------------------------------------------------- 1 | if (!window.location.origin) { // Some browsers (mainly IE) do not have this property, so we need to build it manually... 2 | window.location.origin = window.location.protocol + '//' + window.location.hostname + (window.location.port ? (':' + window.location.port) : ''); 3 | } 4 | 5 | var origin = window.location.origin; 6 | 7 | // options usage example 8 | var options = { 9 | debug: true, 10 | devel: true, 11 | protocols_whitelist: ['websocket', 'xdr-streaming', 'xhr-streaming', 'iframe-eventsource', 'iframe-htmlfile', 'xdr-polling', 'xhr-polling', 'iframe-xhr-polling', 'jsonp-polling'] 12 | }; 13 | 14 | var sock = new SockJS(origin+'/echo', undefined, options); 15 | 16 | sock.onopen = function() { 17 | //console.log('connection open'); 18 | document.getElementById("status").innerHTML = "connected"; 19 | document.getElementById("send").disabled=false; 20 | }; 21 | 22 | sock.onmessage = function(e) { 23 | document.getElementById("output").value += e.data +"\n"; 24 | }; 25 | 26 | sock.onclose = function() { 27 | document.getElementById("status").innerHTML = "connection closed"; 28 | //console.log('connection closed'); 29 | }; 30 | 31 | function send() { 32 | text = document.getElementById("input").value; 33 | sock.send(document.getElementById("input").value); return false; 34 | } 35 | -------------------------------------------------------------------------------- /examples/webecho/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Echo Web Example 8 | 9 | 10 | 11 |

Echo - Web Example

12 |
13 | Input text: 14 | 15 |
16 |
17 | Messages from server:
18 | 20 |
21 | status: connecting... 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /examples/webecho/webecho.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "net/http" 7 | 8 | "github.com/igm/sockjs-go/v3/sockjs" 9 | ) 10 | 11 | var ( 12 | websocket = flag.Bool("websocket", true, "enable/disable websocket protocol") 13 | ) 14 | 15 | func init() { 16 | flag.Parse() 17 | } 18 | 19 | func main() { 20 | opts := sockjs.DefaultOptions 21 | opts.Websocket = *websocket 22 | handler := sockjs.NewHandler("/echo", opts, echoHandler) 23 | http.Handle("/echo/", handler) 24 | http.Handle("/", http.FileServer(http.Dir("web/"))) 25 | log.Println("Server started on port: 8080") 26 | log.Fatal(http.ListenAndServe(":8080", nil)) 27 | } 28 | 29 | func echoHandler(session sockjs.Session) { 30 | log.Println("new sockjs session established") 31 | for { 32 | if msg, err := session.Recv(); err == nil { 33 | if err := session.Send(msg); err != nil { 34 | break 35 | } 36 | continue 37 | } 38 | break 39 | } 40 | log.Println("sockjs session closed") 41 | } 42 | -------------------------------------------------------------------------------- /sockjs/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | profile.out 3 | *.coverprofile 4 | -------------------------------------------------------------------------------- /sockjs/README.md: -------------------------------------------------------------------------------- 1 | see [README](../README.md) for proper import paths 2 | -------------------------------------------------------------------------------- /sockjs/benchmarks_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "bufio" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strings" 12 | "sync" 13 | "testing" 14 | "time" 15 | 16 | "github.com/gorilla/websocket" 17 | ) 18 | 19 | func BenchmarkSimple(b *testing.B) { 20 | var messages = make(chan string, 10) 21 | h := NewHandler("/echo", DefaultOptions, func(session Session) { 22 | for m := range messages { 23 | session.Send(m) 24 | } 25 | session.Close(1024, "Close") 26 | }) 27 | server := httptest.NewServer(h) 28 | defer server.Close() 29 | 30 | req, _ := http.NewRequest("POST", server.URL+fmt.Sprintf("/echo/server/%d/xhr_streaming", 1000), nil) 31 | resp, err := http.DefaultClient.Do(req) 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | for n := 0; n < b.N; n++ { 36 | messages <- "some message" 37 | } 38 | fmt.Println(b.N) 39 | close(messages) 40 | resp.Body.Close() 41 | } 42 | 43 | func BenchmarkMessages(b *testing.B) { 44 | msg := strings.Repeat("m", 10) 45 | h := NewHandler("/echo", DefaultOptions, func(session Session) { 46 | for n := 0; n < b.N; n++ { 47 | session.Send(msg) 48 | } 49 | session.Close(1024, "Close") 50 | }) 51 | server := httptest.NewServer(h) 52 | 53 | var wg sync.WaitGroup 54 | 55 | for i := 0; i < 100; i++ { 56 | wg.Add(1) 57 | go func(session int) { 58 | reqc := 0 59 | req, _ := http.NewRequest("POST", server.URL+fmt.Sprintf("/echo/server/%d/xhr_streaming", session), nil) 60 | for { 61 | reqc++ 62 | resp, err := http.DefaultClient.Do(req) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | reader := bufio.NewReader(resp.Body) 67 | for { 68 | line, err := reader.ReadString('\n') 69 | if err != nil { 70 | goto AGAIN 71 | } 72 | if strings.HasPrefix(line, "data: c[1024") { 73 | resp.Body.Close() 74 | goto DONE 75 | } 76 | } 77 | AGAIN: 78 | resp.Body.Close() 79 | } 80 | DONE: 81 | wg.Done() 82 | }(i) 83 | } 84 | wg.Wait() 85 | server.Close() 86 | } 87 | 88 | var size = flag.Int("size", 4*1024, "Size of one message.") 89 | 90 | func BenchmarkMessageWebsocket(b *testing.B) { 91 | flag.Parse() 92 | 93 | msg := strings.Repeat("x", *size) 94 | wsFrame := []byte(fmt.Sprintf("[%q]", msg)) 95 | 96 | opts := Options{ 97 | Websocket: true, 98 | SockJSURL: "//cdnjs.cloudflare.com/ajax/libs/sockjs-client/0.3.4/sockjs.min.js", 99 | HeartbeatDelay: time.Hour, 100 | DisconnectDelay: time.Hour, 101 | ResponseLimit: uint32(*size), 102 | } 103 | 104 | h := NewHandler("/echo", opts, func(session Session) { 105 | for { 106 | msg, err := session.Recv() 107 | if err != nil { 108 | if session.GetSessionState() != SessionActive { 109 | break 110 | } 111 | b.Fatalf("Recv()=%s", err) 112 | } 113 | 114 | if err := session.Send(msg); err != nil { 115 | b.Fatalf("Send()=%s", err) 116 | } 117 | } 118 | }) 119 | 120 | server := httptest.NewServer(h) 121 | defer server.Close() 122 | 123 | url := "ws" + server.URL[4:] + "/echo/server/0/websocket" 124 | 125 | client, _, err := websocket.DefaultDialer.Dial(url, nil) 126 | if err != nil { 127 | b.Fatalf("Dial()=%s", err) 128 | } 129 | 130 | _, p, err := client.ReadMessage() 131 | if err != nil || string(p) != "o" { 132 | b.Fatalf("failed to start new session: frame=%v, err=%v", p, err) 133 | } 134 | 135 | b.ReportAllocs() 136 | b.ResetTimer() 137 | 138 | for i := 0; i < b.N; i++ { 139 | if err := client.WriteMessage(websocket.TextMessage, wsFrame); err != nil { 140 | b.Fatalf("WriteMessage()=%s", err) 141 | } 142 | 143 | if _, _, err := client.ReadMessage(); err != nil { 144 | b.Fatalf("ReadMessage()=%s", err) 145 | } 146 | } 147 | 148 | if err := client.Close(); err != nil { 149 | b.Fatalf("Close()=%s", err) 150 | } 151 | } 152 | 153 | func BenchmarkHandler_ParseSessionID(b *testing.B) { 154 | h := handler{prefix: "/prefix"} 155 | url, _ := url.Parse("http://server:80/prefix/server/session/whatever") 156 | 157 | b.ReportAllocs() 158 | b.ResetTimer() 159 | for i := 0; i < b.N; i++ { 160 | h.parseSessionID(url) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /sockjs/buffer.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "sync" 4 | 5 | // messageBuffer is an unbounded buffer that blocks on 6 | // pop if it's empty until the new element is enqueued. 7 | type messageBuffer struct { 8 | popCh chan string 9 | closeCh chan struct{} 10 | once sync.Once // for b.close() 11 | } 12 | 13 | func newMessageBuffer() *messageBuffer { 14 | return &messageBuffer{ 15 | popCh: make(chan string), 16 | closeCh: make(chan struct{}), 17 | } 18 | } 19 | 20 | func (b *messageBuffer) push(messages ...string) error { 21 | for _, message := range messages { 22 | select { 23 | case b.popCh <- message: 24 | case <-b.closeCh: 25 | return ErrSessionNotOpen 26 | } 27 | } 28 | 29 | return nil 30 | } 31 | 32 | func (b *messageBuffer) pop() (string, error) { 33 | select { 34 | case msg := <-b.popCh: 35 | return msg, nil 36 | case <-b.closeCh: 37 | return "", ErrSessionNotOpen 38 | } 39 | } 40 | 41 | func (b *messageBuffer) close() { b.once.Do(func() { close(b.closeCh) }) } 42 | -------------------------------------------------------------------------------- /sockjs/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package sockjs is a server side implementation of sockjs protocol. 3 | */ 4 | 5 | package sockjs 6 | -------------------------------------------------------------------------------- /sockjs/eventsource.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | func (h *handler) eventSource(rw http.ResponseWriter, req *http.Request) { 10 | rw.Header().Set("content-type", "text/event-stream; charset=UTF-8") 11 | fmt.Fprintf(rw, "\r\n") 12 | rw.(http.Flusher).Flush() 13 | 14 | recv := newHTTPReceiver(rw, h.options.ResponseLimit, new(eventSourceFrameWriter)) 15 | sess, _ := h.sessionByRequest(req) 16 | if err := sess.attachReceiver(recv); err != nil { 17 | recv.sendFrame(cFrame) 18 | recv.close() 19 | return 20 | } 21 | 22 | select { 23 | case <-recv.doneNotify(): 24 | case <-recv.interruptedNotify(): 25 | } 26 | } 27 | 28 | type eventSourceFrameWriter struct{} 29 | 30 | func (*eventSourceFrameWriter) write(w io.Writer, frame string) (int, error) { 31 | return fmt.Fprintf(w, "data: %s\r\n\r\n", frame) 32 | } 33 | -------------------------------------------------------------------------------- /sockjs/eventsource_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "runtime" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestHandler_EventSource(t *testing.T) { 12 | rw := httptest.NewRecorder() 13 | req, _ := http.NewRequest("POST", "/server/session/eventsource", nil) 14 | h := newTestHandler() 15 | h.options.ResponseLimit = 1024 16 | go func() { 17 | var sess *session 18 | for exists := false; !exists; { 19 | runtime.Gosched() 20 | h.sessionsMux.Lock() 21 | sess, exists = h.sessions["session"] 22 | h.sessionsMux.Unlock() 23 | } 24 | for exists := false; !exists; { 25 | runtime.Gosched() 26 | sess.RLock() 27 | exists = sess.recv != nil 28 | sess.RUnlock() 29 | } 30 | sess.RLock() 31 | sess.recv.close() 32 | sess.RUnlock() 33 | }() 34 | h.eventSource(rw, req) 35 | contentType := rw.Header().Get("content-type") 36 | expected := "text/event-stream; charset=UTF-8" 37 | if contentType != expected { 38 | t.Errorf("Unexpected content type, got '%s', extected '%s'", contentType, expected) 39 | } 40 | if rw.Code != http.StatusOK { 41 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusOK) 42 | } 43 | 44 | if rw.Body.String() != "\r\ndata: o\r\n\r\n" { 45 | t.Errorf("Event stream prelude, got '%s'", rw.Body) 46 | } 47 | } 48 | 49 | func TestHandler_EventSourceMultipleConnections(t *testing.T) { 50 | h := newTestHandler() 51 | h.options.ResponseLimit = 1024 52 | rw := httptest.NewRecorder() 53 | req, _ := http.NewRequest("POST", "/server/sess/eventsource", nil) 54 | go func() { 55 | rw := &ClosableRecorder{httptest.NewRecorder(), nil} 56 | h.eventSource(rw, req) 57 | if rw.Body.String() != "\r\ndata: c[2010,\"Another connection still open\"]\r\n\r\n" { 58 | t.Errorf("wrong, got '%v'", rw.Body) 59 | } 60 | h.sessionsMux.Lock() 61 | sess := h.sessions["sess"] 62 | sess.close() 63 | h.sessionsMux.Unlock() 64 | }() 65 | h.eventSource(rw, req) 66 | } 67 | 68 | func TestHandler_EventSourceConnectionInterrupted(t *testing.T) { 69 | h := newTestHandler() 70 | sess := newTestSession() 71 | sess.state = SessionActive 72 | h.sessions["session"] = sess 73 | req, _ := http.NewRequest("POST", "/server/session/eventsource", nil) 74 | rw := newClosableRecorder() 75 | close(rw.closeNotifCh) 76 | h.eventSource(rw, req) 77 | select { 78 | case <-sess.closeCh: 79 | case <-time.After(1 * time.Second): 80 | t.Errorf("session close channel should be closed") 81 | } 82 | sess.Lock() 83 | if sess.state != SessionClosed { 84 | t.Errorf("Session should be closed") 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /sockjs/example_handler_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/igm/sockjs-go/sockjs" 7 | ) 8 | 9 | func ExampleNewHandler_simple() { 10 | handler := sockjs.NewHandler("/echo", sockjs.DefaultOptions, func(session sockjs.Session) { 11 | for { 12 | if msg, err := session.Recv(); err == nil { 13 | if session.Send(msg) != nil { 14 | break 15 | } 16 | } else { 17 | break 18 | } 19 | } 20 | }) 21 | http.ListenAndServe(":8080", handler) 22 | } 23 | 24 | func ExampleNewHandler_defaultMux() { 25 | handler := sockjs.NewHandler("/echo", sockjs.DefaultOptions, func(session sockjs.Session) { 26 | for { 27 | if msg, err := session.Recv(); err == nil { 28 | if session.Send(msg) != nil { 29 | break 30 | } 31 | } else { 32 | break 33 | } 34 | } 35 | }) 36 | // need to provide path prefix for http.Mux 37 | http.Handle("/echo/", handler) 38 | http.ListenAndServe(":8080", nil) 39 | } 40 | -------------------------------------------------------------------------------- /sockjs/frame.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | func closeFrame(status uint32, reason string) string { 9 | bytes, _ := json.Marshal([]interface{}{status, reason}) 10 | return fmt.Sprintf("c%s", string(bytes)) 11 | } 12 | -------------------------------------------------------------------------------- /sockjs/frame_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "testing" 4 | 5 | func TestCloseFrame(t *testing.T) { 6 | cf := closeFrame(1024, "some close text") 7 | if cf != "c[1024,\"some close text\"]" { 8 | t.Errorf("Wrong close frame generated '%s'", cf) 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /sockjs/handler.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/url" 7 | "regexp" 8 | "strings" 9 | "sync" 10 | ) 11 | 12 | var ( 13 | prefixRegexp = make(map[string]*regexp.Regexp) 14 | prefixRegexpMu sync.Mutex // protects prefixRegexp 15 | ) 16 | 17 | type handler struct { 18 | prefix string 19 | options Options 20 | handlerFunc func(Session) 21 | mappings []*mapping 22 | 23 | sessionsMux sync.Mutex 24 | sessions map[string]*session 25 | } 26 | 27 | // NewHandler creates new HTTP handler that conforms to the basic net/http.Handler interface. 28 | // It takes path prefix, options and sockjs handler function as parameters 29 | func NewHandler(prefix string, opts Options, handleFunc func(Session)) http.Handler { 30 | return newHandler(prefix, opts, handleFunc) 31 | } 32 | 33 | func newHandler(prefix string, opts Options, handlerFunc func(Session)) *handler { 34 | h := &handler{ 35 | prefix: prefix, 36 | options: opts, 37 | handlerFunc: handlerFunc, 38 | sessions: make(map[string]*session), 39 | } 40 | xhrCors := xhrCorsFactory(opts) 41 | matchPrefix := prefix 42 | if matchPrefix == "" { 43 | matchPrefix = "^" 44 | } 45 | sessionPrefix := matchPrefix + "/[^/.]+/[^/.]+" 46 | h.mappings = []*mapping{ 47 | newMapping("GET", matchPrefix+"[/]?$", welcomeHandler), 48 | newMapping("OPTIONS", matchPrefix+"/info$", opts.cookie, xhrCors, cacheFor, opts.info), 49 | newMapping("GET", matchPrefix+"/info$", xhrCors, noCache, opts.info), 50 | // XHR 51 | newMapping("POST", sessionPrefix+"/xhr_send$", opts.cookie, xhrCors, noCache, h.xhrSend), 52 | newMapping("OPTIONS", sessionPrefix+"/xhr_send$", opts.cookie, xhrCors, cacheFor, xhrOptions), 53 | newMapping("POST", sessionPrefix+"/xhr$", opts.cookie, xhrCors, noCache, h.xhrPoll), 54 | newMapping("OPTIONS", sessionPrefix+"/xhr$", opts.cookie, xhrCors, cacheFor, xhrOptions), 55 | newMapping("POST", sessionPrefix+"/xhr_streaming$", opts.cookie, xhrCors, noCache, h.xhrStreaming), 56 | newMapping("OPTIONS", sessionPrefix+"/xhr_streaming$", opts.cookie, xhrCors, cacheFor, xhrOptions), 57 | // EventStream 58 | newMapping("GET", sessionPrefix+"/eventsource$", opts.cookie, xhrCors, noCache, h.eventSource), 59 | // Htmlfile 60 | newMapping("GET", sessionPrefix+"/htmlfile$", opts.cookie, xhrCors, noCache, h.htmlFile), 61 | // JsonP 62 | newMapping("GET", sessionPrefix+"/jsonp$", opts.cookie, xhrCors, noCache, h.jsonp), 63 | newMapping("OPTIONS", sessionPrefix+"/jsonp$", opts.cookie, xhrCors, cacheFor, xhrOptions), 64 | newMapping("POST", sessionPrefix+"/jsonp_send$", opts.cookie, xhrCors, noCache, h.jsonpSend), 65 | // IFrame 66 | newMapping("GET", matchPrefix+"/iframe[0-9-.a-z_]*.html$", cacheFor, h.iframe), 67 | } 68 | if opts.Websocket { 69 | h.mappings = append(h.mappings, newMapping("GET", sessionPrefix+"/websocket$", h.sockjsWebsocket)) 70 | } 71 | if opts.RawWebsocket { 72 | h.mappings = append(h.mappings, newMapping("GET", matchPrefix+"/websocket$", h.rawWebsocket)) 73 | } 74 | return h 75 | } 76 | 77 | func (h *handler) Prefix() string { return h.prefix } 78 | 79 | func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 80 | // iterate over mappings 81 | allowedMethods := []string{} 82 | for _, mapping := range h.mappings { 83 | if match, method := mapping.matches(req); match == fullMatch { 84 | for _, hf := range mapping.chain { 85 | hf(rw, req) 86 | } 87 | return 88 | } else if match == pathMatch { 89 | allowedMethods = append(allowedMethods, method) 90 | } 91 | } 92 | if len(allowedMethods) > 0 { 93 | rw.Header().Set("allow", strings.Join(allowedMethods, ", ")) 94 | rw.Header().Set("Content-Type", "") 95 | rw.WriteHeader(http.StatusMethodNotAllowed) 96 | return 97 | } 98 | http.NotFound(rw, req) 99 | } 100 | 101 | func (h *handler) parseSessionID(url *url.URL) (string, error) { 102 | // cache compiled regexp objects for most used prefixes 103 | prefixRegexpMu.Lock() 104 | session, ok := prefixRegexp[h.prefix] 105 | if !ok { 106 | session = regexp.MustCompile(h.prefix + "/(?P[^/.]+)/(?P[^/.]+)/.*") 107 | prefixRegexp[h.prefix] = session 108 | } 109 | prefixRegexpMu.Unlock() 110 | 111 | matches := session.FindStringSubmatch(url.Path) 112 | if len(matches) == 3 { 113 | return matches[2], nil 114 | } 115 | return "", errors.New("unable to parse URL for session") 116 | } 117 | 118 | func (h *handler) sessionByRequest(req *http.Request) (*session, error) { 119 | h.sessionsMux.Lock() 120 | defer h.sessionsMux.Unlock() 121 | sessionID, err := h.parseSessionID(req.URL) 122 | if err != nil { 123 | return nil, err 124 | } 125 | sess, exists := h.sessions[sessionID] 126 | if !exists { 127 | sess = newSession(req, sessionID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 128 | h.sessions[sessionID] = sess 129 | if h.handlerFunc != nil { 130 | go h.handlerFunc(sess) 131 | } 132 | go func() { 133 | <-sess.closedNotify() 134 | h.sessionsMux.Lock() 135 | delete(h.sessions, sessionID) 136 | h.sessionsMux.Unlock() 137 | }() 138 | } 139 | return sess, nil 140 | } 141 | -------------------------------------------------------------------------------- /sockjs/handler_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "net/http" 7 | "net/http/httptest" 8 | "net/url" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | var testOptions = DefaultOptions 14 | 15 | func init() { 16 | testOptions.RawWebsocket = true 17 | } 18 | 19 | func TestHandler_Create(t *testing.T) { 20 | handler := newHandler("/echo", testOptions, nil) 21 | if handler.Prefix() != "/echo" { 22 | t.Errorf("Prefix not properly set, got '%s' expected '%s'", handler.Prefix(), "/echo") 23 | } 24 | if handler.sessions == nil { 25 | t.Errorf("Handler session map not made") 26 | } 27 | server := httptest.NewServer(handler) 28 | defer server.Close() 29 | 30 | resp, err := http.Get(server.URL + "/echo") 31 | if err != nil { 32 | t.Errorf("There should not be any error, got '%s'", err) 33 | t.FailNow() 34 | } 35 | if resp == nil { 36 | t.Errorf("Response should not be nil") 37 | t.FailNow() 38 | } 39 | if resp.StatusCode != http.StatusOK { 40 | t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK) 41 | } 42 | } 43 | 44 | func TestHandler_RootPrefixInfoHandler(t *testing.T) { 45 | handler := newHandler("", testOptions, nil) 46 | if handler.Prefix() != "" { 47 | t.Errorf("Prefix not properly set, got '%s' expected '%s'", handler.Prefix(), "") 48 | } 49 | server := httptest.NewServer(handler) 50 | defer server.Close() 51 | 52 | resp, err := http.Get(server.URL + "/info") 53 | if err != nil { 54 | t.Errorf("There should not be any error, got '%s'", err) 55 | t.FailNow() 56 | } 57 | if resp == nil { 58 | t.Errorf("Response should not be nil") 59 | t.FailNow() 60 | } 61 | 62 | if resp.StatusCode != http.StatusOK { 63 | t.Errorf("Unexpected status code receiver, got '%d' expected '%d'", resp.StatusCode, http.StatusOK) 64 | } 65 | infoData, err := ioutil.ReadAll(resp.Body) 66 | if err != nil { 67 | t.Errorf("Error reading body: '%v'", err) 68 | } 69 | var i info 70 | err = json.Unmarshal(infoData, &i) 71 | if err != nil { 72 | t.Fatalf("Error unmarshaling info: '%v', data was: '%s'", err, string(infoData)) 73 | } 74 | if i.Websocket != true { 75 | t.Fatalf("Expected websocket to be true") 76 | } 77 | } 78 | 79 | func TestHandler_ParseSessionId(t *testing.T) { 80 | h := handler{prefix: "/prefix"} 81 | url, _ := url.Parse("http://server:80/prefix/server/session/whatever") 82 | if session, err := h.parseSessionID(url); session != "session" || err != nil { 83 | t.Errorf("Wrong session parsed, got '%s' expected '%s' with error = '%v'", session, "session", err) 84 | } 85 | url, _ = url.Parse("http://server:80/asdasd/server/session/whatever") 86 | if _, err := h.parseSessionID(url); err == nil { 87 | t.Errorf("Should return error") 88 | } 89 | } 90 | 91 | func TestHandler_SessionByRequest(t *testing.T) { 92 | h := newHandler("", testOptions, nil) 93 | h.options.DisconnectDelay = 10 * time.Millisecond 94 | var handlerFuncCalled = make(chan Session) 95 | h.handlerFunc = func(conn Session) { handlerFuncCalled <- conn } 96 | req, _ := http.NewRequest("POST", "/server/sessionid/whatever/follows", nil) 97 | sess, err := h.sessionByRequest(req) 98 | if sess == nil || err != nil { 99 | t.Errorf("Session should be returned") 100 | // test handlerFunc was called 101 | select { 102 | case conn := <-handlerFuncCalled: // ok 103 | if conn != sess { 104 | t.Errorf("Handler was not passed correct session") 105 | } 106 | case <-time.After(100 * time.Millisecond): 107 | t.Errorf("HandlerFunc was not called") 108 | } 109 | } 110 | // test session is reused for multiple requests with same sessionID 111 | req2, _ := http.NewRequest("POST", "/server/sessionid/whatever", nil) 112 | if sess2, err := h.sessionByRequest(req2); sess2 != sess || err != nil { 113 | t.Errorf("Expected error, got session: '%v'", sess) 114 | } 115 | // test session expires after timeout 116 | time.Sleep(15 * time.Millisecond) 117 | h.sessionsMux.Lock() 118 | if _, exists := h.sessions["sessionid"]; exists { 119 | t.Errorf("Session should not exist in handler after timeout") 120 | } 121 | h.sessionsMux.Unlock() 122 | // test proper behaviour in case URL is not correct 123 | req, _ = http.NewRequest("POST", "", nil) 124 | if _, err := h.sessionByRequest(req); err == nil { 125 | t.Errorf("Expected parser sessionID from URL error, got 'nil'") 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /sockjs/htmlfile.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "regexp" 8 | "strings" 9 | ) 10 | 11 | var iframeTemplate = ` 12 | 13 | 14 | 15 |

Don't panic!

16 | 23 | ` 24 | 25 | var invalidCallback = regexp.MustCompile("[^a-zA-Z0-9\\_\\.]") 26 | 27 | func init() { 28 | iframeTemplate += strings.Repeat(" ", 1024-len(iframeTemplate)+14) 29 | iframeTemplate += "\r\n\r\n" 30 | } 31 | 32 | func (h *handler) htmlFile(rw http.ResponseWriter, req *http.Request) { 33 | rw.Header().Set("content-type", "text/html; charset=UTF-8") 34 | 35 | req.ParseForm() 36 | callback := req.Form.Get("c") 37 | if callback == "" { 38 | http.Error(rw, `"callback" parameter required`, http.StatusInternalServerError) 39 | return 40 | } else if invalidCallback.MatchString(callback) { 41 | http.Error(rw, `invalid character in "callback" parameter`, http.StatusBadRequest) 42 | return 43 | } 44 | rw.WriteHeader(http.StatusOK) 45 | fmt.Fprintf(rw, iframeTemplate, callback) 46 | rw.(http.Flusher).Flush() 47 | sess, _ := h.sessionByRequest(req) 48 | recv := newHTTPReceiver(rw, h.options.ResponseLimit, new(htmlfileFrameWriter)) 49 | if err := sess.attachReceiver(recv); err != nil { 50 | recv.sendFrame(cFrame) 51 | recv.close() 52 | return 53 | } 54 | select { 55 | case <-recv.doneNotify(): 56 | case <-recv.interruptedNotify(): 57 | } 58 | } 59 | 60 | type htmlfileFrameWriter struct{} 61 | 62 | func (*htmlfileFrameWriter) write(w io.Writer, frame string) (int, error) { 63 | return fmt.Fprintf(w, "\r\n", quote(frame)) 64 | } 65 | -------------------------------------------------------------------------------- /sockjs/htmlfile_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestHandler_htmlFileNoCallback(t *testing.T) { 11 | h := newTestHandler() 12 | rw := httptest.NewRecorder() 13 | req, _ := http.NewRequest("GET", "/server/session/htmlfile", nil) 14 | h.htmlFile(rw, req) 15 | if rw.Code != http.StatusInternalServerError { 16 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 17 | } 18 | expectedContentType := "text/plain; charset=utf-8" 19 | if rw.Header().Get("content-type") != expectedContentType { 20 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 21 | } 22 | } 23 | 24 | func TestHandler_htmlFile(t *testing.T) { 25 | h := newTestHandler() 26 | rw := httptest.NewRecorder() 27 | req, _ := http.NewRequest("GET", "/server/session/htmlfile?c=testCallback", nil) 28 | h.htmlFile(rw, req) 29 | if rw.Code != http.StatusOK { 30 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusOK) 31 | } 32 | expectedContentType := "text/html; charset=UTF-8" 33 | if rw.Header().Get("content-type") != expectedContentType { 34 | t.Errorf("Unexpected content-type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 35 | } 36 | if rw.Body.String() != expectedIFrame { 37 | t.Errorf("Unexpected response body, got '%s', expected '%s'", rw.Body, expectedIFrame) 38 | } 39 | 40 | } 41 | 42 | func TestHandler_cannotIntoXSS(t *testing.T) { 43 | h := newTestHandler() 44 | rw := httptest.NewRecorder() 45 | // test simple injection 46 | req, _ := http.NewRequest("GET", "/server/session/htmlfile?c=fake%3Balert(1337)", nil) 47 | h.htmlFile(rw, req) 48 | if rw.Code != http.StatusBadRequest { 49 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusBadRequest) 50 | } 51 | 52 | h = newTestHandler() 53 | rw = httptest.NewRecorder() 54 | // test simple injection 55 | req, _ = http.NewRequest("GET", "/server/session/htmlfile?c=fake%2Dalert", nil) 56 | h.htmlFile(rw, req) 57 | if rw.Code != http.StatusBadRequest { 58 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusBadRequest) 59 | } 60 | } 61 | 62 | func init() { 63 | expectedIFrame += strings.Repeat(" ", 1024-len(expectedIFrame)+len("testCallack")+13) 64 | expectedIFrame += "\r\n\r\n" 65 | expectedIFrame += "\r\n" 66 | } 67 | 68 | var expectedIFrame = ` 69 | 70 | 71 | 72 |

Don't panic!

73 | 80 | ` 81 | -------------------------------------------------------------------------------- /sockjs/httpreceiver.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | type frameWriter interface { 12 | write(writer io.Writer, frame string) (int, error) 13 | } 14 | 15 | type httpReceiverState int 16 | 17 | const ( 18 | stateHTTPReceiverActive httpReceiverState = iota 19 | stateHTTPReceiverClosed 20 | ) 21 | 22 | type httpReceiver struct { 23 | sync.Mutex 24 | state httpReceiverState 25 | 26 | frameWriter frameWriter 27 | rw http.ResponseWriter 28 | maxResponseSize uint32 29 | currentResponseSize uint32 30 | doneCh chan struct{} 31 | interruptCh chan struct{} 32 | } 33 | 34 | func newHTTPReceiver(rw http.ResponseWriter, maxResponse uint32, frameWriter frameWriter) *httpReceiver { 35 | recv := &httpReceiver{ 36 | rw: rw, 37 | frameWriter: frameWriter, 38 | maxResponseSize: maxResponse, 39 | doneCh: make(chan struct{}), 40 | interruptCh: make(chan struct{}), 41 | } 42 | if closeNotifier, ok := rw.(http.CloseNotifier); ok { 43 | // if supported check for close notifications from http.RW 44 | closeNotifyCh := closeNotifier.CloseNotify() 45 | go func() { 46 | select { 47 | case <-closeNotifyCh: 48 | recv.Lock() 49 | defer recv.Unlock() 50 | if recv.state < stateHTTPReceiverClosed { 51 | recv.state = stateHTTPReceiverClosed 52 | close(recv.interruptCh) 53 | } 54 | case <-recv.doneCh: 55 | // ok, no action needed here, receiver closed in correct way 56 | // just finish the routine 57 | } 58 | }() 59 | } 60 | return recv 61 | } 62 | 63 | func (recv *httpReceiver) sendBulk(messages ...string) { 64 | if len(messages) > 0 { 65 | recv.sendFrame(fmt.Sprintf("a[%s]", 66 | strings.Join( 67 | transform(messages, quote), 68 | ",", 69 | ), 70 | )) 71 | } 72 | } 73 | 74 | func (recv *httpReceiver) sendFrame(value string) { 75 | recv.Lock() 76 | defer recv.Unlock() 77 | 78 | if recv.state == stateHTTPReceiverActive { 79 | // TODO(igm) check err, possibly act as if interrupted 80 | n, _ := recv.frameWriter.write(recv.rw, value) 81 | recv.currentResponseSize += uint32(n) 82 | if recv.currentResponseSize >= recv.maxResponseSize { 83 | recv.state = stateHTTPReceiverClosed 84 | close(recv.doneCh) 85 | } else { 86 | recv.rw.(http.Flusher).Flush() 87 | } 88 | } 89 | } 90 | 91 | func (recv *httpReceiver) doneNotify() <-chan struct{} { return recv.doneCh } 92 | func (recv *httpReceiver) interruptedNotify() <-chan struct{} { return recv.interruptCh } 93 | func (recv *httpReceiver) close() { 94 | recv.Lock() 95 | defer recv.Unlock() 96 | if recv.state < stateHTTPReceiverClosed { 97 | recv.state = stateHTTPReceiverClosed 98 | close(recv.doneCh) 99 | } 100 | } 101 | func (recv *httpReceiver) canSend() bool { 102 | recv.Lock() 103 | defer recv.Unlock() 104 | return recv.state != stateHTTPReceiverClosed 105 | } 106 | -------------------------------------------------------------------------------- /sockjs/httpreceiver_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "io" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | type testFrameWriter struct { 11 | frames []string 12 | } 13 | 14 | func (t *testFrameWriter) write(w io.Writer, frame string) (int, error) { 15 | t.frames = append(t.frames, frame) 16 | return len(frame), nil 17 | } 18 | 19 | func TestHttpReceiver_Create(t *testing.T) { 20 | rec := httptest.NewRecorder() 21 | recv := newHTTPReceiver(rec, 1024, new(testFrameWriter)) 22 | if recv.doneCh != recv.doneNotify() { 23 | t.Errorf("Calling done() must return close channel, but it does not") 24 | } 25 | if recv.rw != rec { 26 | t.Errorf("Http.ResponseWriter not properly initialized") 27 | } 28 | if recv.maxResponseSize != 1024 { 29 | t.Errorf("MaxResponseSize not properly initialized") 30 | } 31 | } 32 | 33 | func TestHttpReceiver_SendEmptyFrames(t *testing.T) { 34 | rec := httptest.NewRecorder() 35 | recv := newHTTPReceiver(rec, 1024, new(testFrameWriter)) 36 | recv.sendBulk() 37 | if rec.Body.String() != "" { 38 | t.Errorf("Incorrect body content received from receiver '%s'", rec.Body.String()) 39 | } 40 | } 41 | 42 | func TestHttpReceiver_SendFrame(t *testing.T) { 43 | rec := httptest.NewRecorder() 44 | fw := new(testFrameWriter) 45 | recv := newHTTPReceiver(rec, 1024, fw) 46 | var frame = "some frame content" 47 | recv.sendFrame(frame) 48 | if len(fw.frames) != 1 || fw.frames[0] != frame { 49 | t.Errorf("Incorrect body content received, got '%s', expected '%s'", fw.frames, frame) 50 | } 51 | 52 | } 53 | 54 | func TestHttpReceiver_SendBulk(t *testing.T) { 55 | rec := httptest.NewRecorder() 56 | fw := new(testFrameWriter) 57 | recv := newHTTPReceiver(rec, 1024, fw) 58 | recv.sendBulk("message 1", "message 2", "message 3") 59 | expected := "a[\"message 1\",\"message 2\",\"message 3\"]" 60 | if len(fw.frames) != 1 || fw.frames[0] != expected { 61 | t.Errorf("Incorrect body content received from receiver, got '%s' expected '%s'", fw.frames, expected) 62 | } 63 | } 64 | 65 | func TestHttpReceiver_MaximumResponseSize(t *testing.T) { 66 | rec := httptest.NewRecorder() 67 | recv := newHTTPReceiver(rec, 52, new(testFrameWriter)) 68 | recv.sendBulk("message 1", "message 2") // produces 26 bytes of response in 1 frame 69 | if recv.currentResponseSize != 26 { 70 | t.Errorf("Incorrect response size calcualated, got '%d' expected '%d'", recv.currentResponseSize, 26) 71 | } 72 | select { 73 | case <-recv.doneNotify(): 74 | t.Errorf("Receiver should not be done yet") 75 | default: // ok 76 | } 77 | recv.sendBulk("message 1", "message 2") // produces another 26 bytes of response in 1 frame to go over max resposne size 78 | select { 79 | case <-recv.doneNotify(): // ok 80 | default: 81 | t.Errorf("Receiver closed channel did not close") 82 | } 83 | } 84 | 85 | func TestHttpReceiver_Close(t *testing.T) { 86 | rec := httptest.NewRecorder() 87 | recv := newHTTPReceiver(rec, 1024, nil) 88 | recv.close() 89 | if recv.state != stateHTTPReceiverClosed { 90 | t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed) 91 | } 92 | } 93 | 94 | func TestHttpReceiver_ConnectionInterrupt(t *testing.T) { 95 | rw := newClosableRecorder() 96 | recv := newHTTPReceiver(rw, 1024, nil) 97 | rw.closeNotifCh <- true 98 | select { 99 | case <-recv.interruptCh: 100 | case <-time.After(1 * time.Second): 101 | t.Errorf("should interrupt") 102 | } 103 | if recv.state != stateHTTPReceiverClosed { 104 | t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /sockjs/iframe.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "crypto/md5" 5 | "fmt" 6 | "net/http" 7 | "text/template" 8 | ) 9 | 10 | var tmpl = template.Must(template.New("iframe").Parse(iframeBody)) 11 | 12 | func (h *handler) iframe(rw http.ResponseWriter, req *http.Request) { 13 | etagReq := req.Header.Get("If-None-Match") 14 | hash := md5.New() 15 | hash.Write([]byte(iframeBody)) 16 | etag := fmt.Sprintf("%x", hash.Sum(nil)) 17 | if etag == etagReq { 18 | rw.WriteHeader(http.StatusNotModified) 19 | return 20 | } 21 | 22 | rw.Header().Set("Content-Type", "text/html; charset=UTF-8") 23 | rw.Header().Add("ETag", etag) 24 | tmpl.Execute(rw, h.options.SockJSURL) 25 | } 26 | 27 | var iframeBody = ` 28 | 29 | 30 | 31 | 32 | 36 | 37 | 38 | 39 |

Don't panic!

40 |

This is a SockJS hidden iframe. It's used for cross domain magic.

41 | 42 | ` 43 | -------------------------------------------------------------------------------- /sockjs/iframe_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestHandler_iframe(t *testing.T) { 10 | h := newTestHandler() 11 | h.options.SockJSURL = "http://sockjs.com/sockjs.js" 12 | rw := httptest.NewRecorder() 13 | req, _ := http.NewRequest("GET", "/server/sess/iframe", nil) 14 | h.iframe(rw, req) 15 | if rw.Body.String() != expected { 16 | t.Errorf("Unexpected html content,\ngot:\n'%s'\n\nexpected\n'%s'", rw.Body, expected) 17 | } 18 | eTag := rw.Header().Get("etag") 19 | req.Header.Set("if-none-match", eTag) 20 | rw = httptest.NewRecorder() 21 | h.iframe(rw, req) 22 | if rw.Code != http.StatusNotModified { 23 | t.Errorf("Unexpected response, got '%d', expected '%d'", rw.Code, http.StatusNotModified) 24 | } 25 | } 26 | 27 | var expected = ` 28 | 29 | 30 | 31 | 32 | 36 | 37 | 38 | 39 |

Don't panic!

40 |

This is a SockJS hidden iframe. It's used for cross domain magic.

41 | 42 | ` 43 | -------------------------------------------------------------------------------- /sockjs/jsonp.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | ) 10 | 11 | func (h *handler) jsonp(rw http.ResponseWriter, req *http.Request) { 12 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 13 | 14 | req.ParseForm() 15 | callback := req.Form.Get("c") 16 | if callback == "" { 17 | http.Error(rw, `"callback" parameter required`, http.StatusInternalServerError) 18 | return 19 | } else if invalidCallback.MatchString(callback) { 20 | http.Error(rw, `invalid character in "callback" parameter`, http.StatusBadRequest) 21 | return 22 | } 23 | rw.WriteHeader(http.StatusOK) 24 | rw.(http.Flusher).Flush() 25 | 26 | sess, _ := h.sessionByRequest(req) 27 | recv := newHTTPReceiver(rw, 1, &jsonpFrameWriter{callback}) 28 | if err := sess.attachReceiver(recv); err != nil { 29 | recv.sendFrame(cFrame) 30 | recv.close() 31 | return 32 | } 33 | select { 34 | case <-recv.doneNotify(): 35 | case <-recv.interruptedNotify(): 36 | } 37 | } 38 | 39 | func (h *handler) jsonpSend(rw http.ResponseWriter, req *http.Request) { 40 | req.ParseForm() 41 | var data io.Reader 42 | data = req.Body 43 | 44 | formReader := strings.NewReader(req.PostFormValue("d")) 45 | if formReader.Len() != 0 { 46 | data = formReader 47 | } 48 | if data == nil { 49 | http.Error(rw, "Payload expected.", http.StatusInternalServerError) 50 | return 51 | } 52 | var messages []string 53 | err := json.NewDecoder(data).Decode(&messages) 54 | if err == io.EOF { 55 | http.Error(rw, "Payload expected.", http.StatusInternalServerError) 56 | return 57 | } 58 | if err != nil { 59 | http.Error(rw, "Broken JSON encoding.", http.StatusInternalServerError) 60 | return 61 | } 62 | sessionID, _ := h.parseSessionID(req.URL) 63 | h.sessionsMux.Lock() 64 | defer h.sessionsMux.Unlock() 65 | if sess, ok := h.sessions[sessionID]; !ok { 66 | http.NotFound(rw, req) 67 | } else { 68 | _ = sess.accept(messages...) // TODO(igm) reponse with http.StatusInternalServerError in case of err? 69 | rw.Header().Set("content-type", "text/plain; charset=UTF-8") 70 | rw.Write([]byte("ok")) 71 | } 72 | } 73 | 74 | type jsonpFrameWriter struct { 75 | callback string 76 | } 77 | 78 | func (j *jsonpFrameWriter) write(w io.Writer, frame string) (int, error) { 79 | return fmt.Fprintf(w, "%s(%s);\r\n", j.callback, quote(frame)) 80 | } 81 | -------------------------------------------------------------------------------- /sockjs/jsonp_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestHandler_jsonpNoCallback(t *testing.T) { 12 | h := newTestHandler() 13 | rw := httptest.NewRecorder() 14 | req, _ := http.NewRequest("GET", "/server/session/jsonp", nil) 15 | h.jsonp(rw, req) 16 | if rw.Code != http.StatusInternalServerError { 17 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 18 | } 19 | expectedContentType := "text/plain; charset=utf-8" 20 | if rw.Header().Get("content-type") != expectedContentType { 21 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 22 | } 23 | } 24 | 25 | func TestHandler_jsonp(t *testing.T) { 26 | h := newTestHandler() 27 | rw := httptest.NewRecorder() 28 | req, _ := http.NewRequest("GET", "/server/session/jsonp?c=testCallback", nil) 29 | h.jsonp(rw, req) 30 | expectedContentType := "application/javascript; charset=UTF-8" 31 | if rw.Header().Get("content-type") != expectedContentType { 32 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 33 | } 34 | expectedBody := "testCallback(\"o\");\r\n" 35 | if rw.Body.String() != expectedBody { 36 | t.Errorf("Unexpected body, got '%s', expected '%s'", rw.Body, expectedBody) 37 | } 38 | } 39 | 40 | func TestHandler_jsonpSendNoPayload(t *testing.T) { 41 | h := newTestHandler() 42 | rw := httptest.NewRecorder() 43 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", nil) 44 | h.jsonpSend(rw, req) 45 | if rw.Code != http.StatusInternalServerError { 46 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 47 | } 48 | } 49 | 50 | func TestHandler_jsonpSendWrongPayload(t *testing.T) { 51 | h := newTestHandler() 52 | rw := httptest.NewRecorder() 53 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("wrong payload")) 54 | h.jsonpSend(rw, req) 55 | if rw.Code != http.StatusInternalServerError { 56 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 57 | } 58 | } 59 | 60 | func TestHandler_jsonpSendNoSession(t *testing.T) { 61 | h := newTestHandler() 62 | rw := httptest.NewRecorder() 63 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("[\"message\"]")) 64 | h.jsonpSend(rw, req) 65 | if rw.Code != http.StatusNotFound { 66 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusNotFound) 67 | } 68 | } 69 | 70 | func TestHandler_jsonpSend(t *testing.T) { 71 | h := newTestHandler() 72 | 73 | rw := httptest.NewRecorder() 74 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("[\"message\"]")) 75 | 76 | sess := newSession(req, "session", time.Second, time.Second) 77 | h.sessions["session"] = sess 78 | 79 | var done = make(chan struct{}) 80 | go func() { 81 | h.jsonpSend(rw, req) 82 | close(done) 83 | }() 84 | msg, _ := sess.Recv() 85 | if msg != "message" { 86 | t.Errorf("Incorrect message in the channel, should be '%s', was '%s'", "some message", msg) 87 | } 88 | <-done 89 | if rw.Code != http.StatusOK { 90 | t.Errorf("Wrong response status received %d, should be %d", rw.Code, http.StatusOK) 91 | } 92 | if rw.Header().Get("content-type") != "text/plain; charset=UTF-8" { 93 | t.Errorf("Wrong content type received '%s'", rw.Header().Get("content-type")) 94 | } 95 | if rw.Body.String() != "ok" { 96 | t.Errorf("Unexpected body, got '%s', expected 'ok'", rw.Body) 97 | } 98 | } 99 | 100 | func TestHandler_jsonpCannotIntoXSS(t *testing.T) { 101 | h := newTestHandler() 102 | rw := httptest.NewRecorder() 103 | req, _ := http.NewRequest("GET", "/server/session/jsonp?c=%3Chtml%3E%3Chead%3E%3Cscript%3Ealert(5520)%3C%2Fscript%3E", nil) 104 | h.jsonp(rw, req) 105 | if rw.Code != http.StatusBadRequest { 106 | t.Errorf("JsonP forwarded an exploitable response.") 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /sockjs/mapping.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | ) 7 | 8 | type mapping struct { 9 | method string 10 | path *regexp.Regexp 11 | chain []http.HandlerFunc 12 | } 13 | 14 | func newMapping(method string, re string, handlers ...http.HandlerFunc) *mapping { 15 | return &mapping{method, regexp.MustCompile(re), handlers} 16 | } 17 | 18 | type matchType uint32 19 | 20 | const ( 21 | fullMatch matchType = iota 22 | pathMatch 23 | noMatch 24 | ) 25 | 26 | // matches checks if given req.URL is a match with a mapping. Match can be either full, partial (http method mismatch) or no match. 27 | func (m *mapping) matches(req *http.Request) (match matchType, method string) { 28 | if !m.path.MatchString(req.URL.Path) { 29 | match, method = noMatch, "" 30 | } else if m.method != req.Method { 31 | match, method = pathMatch, m.method 32 | } else { 33 | match, method = fullMatch, m.method 34 | } 35 | return 36 | } 37 | -------------------------------------------------------------------------------- /sockjs/mapping_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestMappingMatcher(t *testing.T) { 10 | mappingPrefix := mapping{"GET", regexp.MustCompile("prefix/$"), nil} 11 | mappingPrefixRegExp := mapping{"GET", regexp.MustCompile(".*x/$"), nil} 12 | 13 | var testRequests = []struct { 14 | mapping mapping 15 | method string 16 | url string 17 | expectedMatch matchType 18 | }{ 19 | {mappingPrefix, "GET", "http://foo/prefix/", fullMatch}, 20 | {mappingPrefix, "POST", "http://foo/prefix/", pathMatch}, 21 | {mappingPrefix, "GET", "http://foo/prefix_not_mapped", noMatch}, 22 | {mappingPrefixRegExp, "GET", "http://foo/prefix/", fullMatch}, 23 | } 24 | 25 | for _, request := range testRequests { 26 | req, _ := http.NewRequest(request.method, request.url, nil) 27 | m := request.mapping 28 | match, method := m.matches(req) 29 | if match != request.expectedMatch { 30 | t.Errorf("mapping %s should match url=%s", m.path, request.url) 31 | } 32 | if request.expectedMatch == pathMatch { 33 | if method != m.method { 34 | t.Errorf("Matcher method should be %s, but got %s", m.method, method) 35 | } 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /sockjs/options.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "math/rand" 7 | "net/http" 8 | "sync" 9 | "time" 10 | 11 | "github.com/gorilla/websocket" 12 | ) 13 | 14 | var ( 15 | entropy *rand.Rand 16 | entropyMutex sync.Mutex 17 | ) 18 | 19 | func init() { 20 | entropy = rand.New(rand.NewSource(time.Now().UnixNano())) 21 | } 22 | 23 | // Options type is used for defining various sockjs options 24 | type Options struct { 25 | // Transports which don't support cross-domain communication natively ('eventsource' to name one) use an iframe trick. 26 | // A simple page is served from the SockJS server (using its foreign domain) and is placed in an invisible iframe. 27 | // Code run from this iframe doesn't need to worry about cross-domain issues, as it's being run from domain local to the SockJS server. 28 | // This iframe also does need to load SockJS javascript client library, and this option lets you specify its url (if you're unsure, 29 | // point it to the latest minified SockJS client release, this is the default). You must explicitly specify this url on the server 30 | // side for security reasons - we don't want the possibility of running any foreign javascript within the SockJS domain (aka cross site scripting attack). 31 | // Also, sockjs javascript library is probably already cached by the browser - it makes sense to reuse the sockjs url you're using in normally. 32 | SockJSURL string 33 | // Most streaming transports save responses on the client side and don't free memory used by delivered messages. 34 | // Such transports need to be garbage-collected once in a while. `response_limit` sets a minimum number of bytes that can be send 35 | // over a single http streaming request before it will be closed. After that client needs to open new request. 36 | // Setting this value to one effectively disables streaming and will make streaming transports to behave like polling transports. 37 | // The default value is 128K. 38 | ResponseLimit uint32 39 | // Some load balancers don't support websockets. This option can be used to disable websockets support by the server. By default websockets are enabled. 40 | Websocket bool 41 | // This option can be used to enable raw websockets support by the server. By default raw websockets are disabled. 42 | RawWebsocket bool 43 | // Provide a custom Upgrader for Websocket connections to enable features like compression. 44 | // See https://godoc.org/github.com/gorilla/websocket#Upgrader for more details. 45 | WebsocketUpgrader *websocket.Upgrader 46 | // WebsocketWriteTimeout is a custom write timeout for Websocket underlying network connection. 47 | // A zero value means writes will not time out. 48 | WebsocketWriteTimeout time.Duration 49 | // In order to keep proxies and load balancers from closing long running http requests we need to pretend that the connection is active 50 | // and send a heartbeat packet once in a while. This setting controls how often this is done. 51 | // By default a heartbeat packet is sent every 25 seconds. 52 | HeartbeatDelay time.Duration 53 | // The server closes a session when a client receiving connection have not been seen for a while. 54 | // This delay is configured by this setting. 55 | // By default the session is closed when a receiving connection wasn't seen for 5 seconds. 56 | DisconnectDelay time.Duration 57 | // Some hosting providers enable sticky sessions only to requests that have JSessionID cookie set. 58 | // This setting controls if the server should set this cookie to a dummy value. 59 | // By default setting JSessionID cookie is disabled. More sophisticated behaviour can be achieved by supplying a function. 60 | JSessionID func(http.ResponseWriter, *http.Request) 61 | // CORS origin to be set on outgoing responses. If set to the empty string, it will default to the 62 | // incoming `Origin` header, or "*" if the Origin header isn't set. 63 | Origin string 64 | // CheckOrigin allows to dynamically decide whether server should set CORS 65 | // headers or not in case of XHR requests. When true returned CORS will be 66 | // configured with allowed origin equal to incoming `Origin` header, or "*" 67 | // if the request Origin header isn't set. When false returned CORS headers 68 | // won't be set at all. If this function is nil then Origin option above will 69 | // be taken into account. 70 | CheckOrigin func(*http.Request) bool 71 | } 72 | 73 | // DefaultOptions is a convenient set of options to be used for sockjs 74 | var DefaultOptions = Options{ 75 | Websocket: true, 76 | RawWebsocket: false, 77 | JSessionID: nil, 78 | SockJSURL: "//cdnjs.cloudflare.com/ajax/libs/sockjs-client/0.3.4/sockjs.min.js", 79 | HeartbeatDelay: 25 * time.Second, 80 | DisconnectDelay: 5 * time.Second, 81 | ResponseLimit: 128 * 1024, 82 | WebsocketUpgrader: nil, 83 | } 84 | 85 | type info struct { 86 | Websocket bool `json:"websocket"` 87 | CookieNeeded bool `json:"cookie_needed"` 88 | Origins []string `json:"origins"` 89 | Entropy int32 `json:"entropy"` 90 | } 91 | 92 | func (options *Options) info(rw http.ResponseWriter, req *http.Request) { 93 | switch req.Method { 94 | case "GET": 95 | rw.Header().Set("Content-Type", "application/json; charset=UTF-8") 96 | json.NewEncoder(rw).Encode(info{ 97 | Websocket: options.Websocket, 98 | CookieNeeded: options.JSessionID != nil, 99 | Origins: []string{"*:*"}, 100 | Entropy: generateEntropy(), 101 | }) 102 | case "OPTIONS": 103 | rw.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") 104 | rw.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", 365*24*60*60)) 105 | rw.WriteHeader(http.StatusNoContent) // 204 106 | default: 107 | http.NotFound(rw, req) 108 | } 109 | } 110 | 111 | // DefaultJSessionID is a default behaviour function to be used in options for JSessionID if JSESSIONID is needed 112 | func DefaultJSessionID(rw http.ResponseWriter, req *http.Request) { 113 | cookie, err := req.Cookie("JSESSIONID") 114 | if err == http.ErrNoCookie { 115 | cookie = &http.Cookie{ 116 | Name: "JSESSIONID", 117 | Value: "dummy", 118 | } 119 | } 120 | cookie.Path = "/" 121 | header := rw.Header() 122 | header.Add("Set-Cookie", cookie.String()) 123 | } 124 | 125 | func (options *Options) cookie(rw http.ResponseWriter, req *http.Request) { 126 | if options.JSessionID != nil { // cookie is needed 127 | options.JSessionID(rw, req) 128 | } 129 | } 130 | 131 | func generateEntropy() int32 { 132 | entropyMutex.Lock() 133 | entropy := entropy.Int31() 134 | entropyMutex.Unlock() 135 | return entropy 136 | } 137 | -------------------------------------------------------------------------------- /sockjs/options_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | ) 8 | import "testing" 9 | 10 | func TestInfoGet(t *testing.T) { 11 | recorder := httptest.NewRecorder() 12 | request, _ := http.NewRequest("GET", "", nil) 13 | DefaultOptions.info(recorder, request) 14 | 15 | if recorder.Code != http.StatusOK { 16 | t.Errorf("Wrong status code, got '%d' expected '%d'", recorder.Code, http.StatusOK) 17 | } 18 | 19 | decoder := json.NewDecoder(recorder.Body) 20 | var a info 21 | decoder.Decode(&a) 22 | if !a.Websocket { 23 | t.Errorf("Websocket field should be set true") 24 | } 25 | if a.CookieNeeded { 26 | t.Errorf("CookieNeeded should be set to false") 27 | } 28 | } 29 | 30 | func TestInfoOptions(t *testing.T) { 31 | recorder := httptest.NewRecorder() 32 | request, _ := http.NewRequest("OPTIONS", "", nil) 33 | DefaultOptions.info(recorder, request) 34 | if recorder.Code != http.StatusNoContent { 35 | t.Errorf("Incorrect status code received, got '%d' expected '%d'", recorder.Code, http.StatusNoContent) 36 | } 37 | } 38 | 39 | func TestInfoUnknown(t *testing.T) { 40 | req, _ := http.NewRequest("PUT", "", nil) 41 | rec := httptest.NewRecorder() 42 | DefaultOptions.info(rec, req) 43 | if rec.Code != http.StatusNotFound { 44 | t.Errorf("Incorrec response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 45 | } 46 | } 47 | 48 | func TestCookies(t *testing.T) { 49 | rec := httptest.NewRecorder() 50 | req, _ := http.NewRequest("GET", "", nil) 51 | optionsWithCookies := DefaultOptions 52 | optionsWithCookies.JSessionID = DefaultJSessionID 53 | optionsWithCookies.cookie(rec, req) 54 | if rec.Header().Get("set-cookie") != "JSESSIONID=dummy; Path=/" { 55 | t.Errorf("Cookie not properly set in response") 56 | } 57 | // cookie value set in request 58 | req.AddCookie(&http.Cookie{Name: "JSESSIONID", Value: "some_jsession_id", Path: "/"}) 59 | rec = httptest.NewRecorder() 60 | optionsWithCookies.cookie(rec, req) 61 | if rec.Header().Get("set-cookie") != "JSESSIONID=some_jsession_id; Path=/" { 62 | t.Errorf("Cookie not properly set in response") 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /sockjs/rawwebsocket.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | func (h *handler) rawWebsocket(rw http.ResponseWriter, req *http.Request) { 12 | var conn *websocket.Conn 13 | var err error 14 | if h.options.WebsocketUpgrader != nil { 15 | conn, err = h.options.WebsocketUpgrader.Upgrade(rw, req, nil) 16 | } else { 17 | // use default as before, so that those 2 buffer size variables are used as before 18 | conn, err = websocket.Upgrade(rw, req, nil, WebSocketReadBufSize, WebSocketWriteBufSize) 19 | } 20 | 21 | if _, ok := err.(websocket.HandshakeError); ok { 22 | http.Error(rw, `Can "Upgrade" only to "WebSocket".`, http.StatusBadRequest) 23 | return 24 | } else if err != nil { 25 | rw.WriteHeader(http.StatusInternalServerError) 26 | return 27 | } 28 | 29 | sessID := "" 30 | sess := newSession(req, sessID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 31 | sess.raw = true 32 | 33 | receiver := newRawWsReceiver(conn, h.options.WebsocketWriteTimeout) 34 | sess.attachReceiver(receiver) 35 | if h.handlerFunc != nil { 36 | go h.handlerFunc(sess) 37 | } 38 | readCloseCh := make(chan struct{}) 39 | go func() { 40 | for { 41 | frameType, p, err := conn.ReadMessage() 42 | if err != nil { 43 | close(readCloseCh) 44 | return 45 | } 46 | if frameType == websocket.TextMessage || frameType == websocket.BinaryMessage { 47 | sess.accept(string(p)) 48 | } 49 | } 50 | }() 51 | 52 | select { 53 | case <-readCloseCh: 54 | case <-receiver.doneNotify(): 55 | } 56 | sess.close() 57 | conn.Close() 58 | } 59 | 60 | type rawWsReceiver struct { 61 | conn *websocket.Conn 62 | closeCh chan struct{} 63 | writeTimeout time.Duration 64 | } 65 | 66 | func newRawWsReceiver(conn *websocket.Conn, writeTimeout time.Duration) *rawWsReceiver { 67 | return &rawWsReceiver{ 68 | conn: conn, 69 | closeCh: make(chan struct{}), 70 | writeTimeout: writeTimeout, 71 | } 72 | } 73 | 74 | func (w *rawWsReceiver) sendBulk(messages ...string) { 75 | if len(messages) > 0 { 76 | for _, m := range messages { 77 | if w.writeTimeout != 0 { 78 | w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) 79 | } 80 | err := w.conn.WriteMessage(websocket.TextMessage, []byte(m)) 81 | if err != nil { 82 | w.close() 83 | break 84 | } 85 | 86 | } 87 | } 88 | } 89 | 90 | func (w *rawWsReceiver) sendFrame(frame string) { 91 | if w.writeTimeout != 0 { 92 | w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) 93 | } 94 | var err error 95 | if frame == "h" { 96 | err = w.conn.WriteMessage(websocket.PingMessage, []byte{}) 97 | } else if len(frame) > 0 && frame[0] == 'c' { 98 | status, reason := parseCloseFrame(frame) 99 | msg := websocket.FormatCloseMessage(int(status), reason) 100 | err = w.conn.WriteMessage(websocket.CloseMessage, msg) 101 | } else { 102 | err = w.conn.WriteMessage(websocket.TextMessage, []byte(frame)) 103 | } 104 | if err != nil { 105 | w.close() 106 | } 107 | } 108 | 109 | func parseCloseFrame(frame string) (status uint32, reason string) { 110 | var items [2]interface{} 111 | json.Unmarshal([]byte(frame)[1:], &items) 112 | statusF, _ := items[0].(float64) 113 | status = uint32(statusF) 114 | reason, _ = items[1].(string) 115 | return 116 | } 117 | 118 | func (w *rawWsReceiver) close() { 119 | select { 120 | case <-w.closeCh: // already closed 121 | default: 122 | close(w.closeCh) 123 | } 124 | } 125 | func (w *rawWsReceiver) canSend() bool { 126 | select { 127 | case <-w.closeCh: // already closed 128 | return false 129 | default: 130 | return true 131 | } 132 | } 133 | func (w *rawWsReceiver) doneNotify() <-chan struct{} { return w.closeCh } 134 | func (w *rawWsReceiver) interruptedNotify() <-chan struct{} { return nil } 135 | -------------------------------------------------------------------------------- /sockjs/rawwebsocket_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | 9 | "github.com/gorilla/websocket" 10 | ) 11 | 12 | func TestHandler_RawWebSocketHandshakeError(t *testing.T) { 13 | h := newTestHandler() 14 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 15 | defer server.Close() 16 | req, _ := http.NewRequest("GET", server.URL, nil) 17 | req.Header.Set("origin", "https"+server.URL[4:]) 18 | resp, _ := http.DefaultClient.Do(req) 19 | if resp.StatusCode != http.StatusBadRequest { 20 | t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest) 21 | } 22 | } 23 | 24 | func TestHandler_RawWebSocket(t *testing.T) { 25 | h := newTestHandler() 26 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 27 | defer server.CloseClientConnections() 28 | url := "ws" + server.URL[4:] 29 | var connCh = make(chan Session) 30 | h.handlerFunc = func(conn Session) { connCh <- conn } 31 | conn, resp, err := websocket.DefaultDialer.Dial(url, nil) 32 | if conn == nil { 33 | t.Errorf("Connection should not be nil") 34 | } 35 | if err != nil { 36 | t.Errorf("Unexpected error '%v'", err) 37 | } 38 | if resp.StatusCode != http.StatusSwitchingProtocols { 39 | t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols) 40 | } 41 | select { 42 | case <-connCh: //ok 43 | case <-time.After(10 * time.Millisecond): 44 | t.Errorf("Sockjs Handler not invoked") 45 | } 46 | } 47 | 48 | func TestHandler_RawWebSocketTerminationByServer(t *testing.T) { 49 | h := newTestHandler() 50 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 51 | defer server.Close() 52 | url := "ws" + server.URL[4:] 53 | h.handlerFunc = func(conn Session) { 54 | // close the session without sending any message 55 | conn.Close(3000, "some close message") 56 | conn.Close(0, "this should be ignored") 57 | } 58 | conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 59 | if err != nil { 60 | t.Fatalf("websocket dial failed: %v", err) 61 | } 62 | for i := 0; i < 2; i++ { 63 | _, _, err := conn.ReadMessage() 64 | closeError, ok := err.(*websocket.CloseError) 65 | if !ok { 66 | t.Fatalf("expected close error but got: %v", err) 67 | } 68 | if closeError.Code != 3000 { 69 | t.Errorf("unexpected close status: %v", closeError.Code) 70 | } 71 | if closeError.Text != "some close message" { 72 | t.Errorf("unexpected close reason: '%v'", closeError.Text) 73 | } 74 | } 75 | } 76 | 77 | func TestHandler_RawWebSocketTerminationByClient(t *testing.T) { 78 | h := newTestHandler() 79 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 80 | defer server.Close() 81 | url := "ws" + server.URL[4:] 82 | var done = make(chan struct{}) 83 | h.handlerFunc = func(conn Session) { 84 | if _, err := conn.Recv(); err != ErrSessionNotOpen { 85 | t.Errorf("Recv should fail") 86 | } 87 | close(done) 88 | } 89 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 90 | conn.Close() 91 | <-done 92 | } 93 | 94 | func TestHandler_RawWebSocketCommunication(t *testing.T) { 95 | h := newTestHandler() 96 | h.options.WebsocketWriteTimeout = time.Second 97 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 98 | // defer server.CloseClientConnections() 99 | url := "ws" + server.URL[4:] 100 | var done = make(chan struct{}) 101 | h.handlerFunc = func(conn Session) { 102 | conn.Send("message 1") 103 | conn.Send("message 2") 104 | expected := "[\"message 3\"]\n" 105 | msg, err := conn.Recv() 106 | if msg != expected || err != nil { 107 | t.Errorf("Got '%s', expected '%s'", msg, expected) 108 | } 109 | conn.Close(123, "close") 110 | close(done) 111 | } 112 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 113 | conn.WriteJSON([]string{"message 3"}) 114 | var expected = []string{"message 1", "message 2"} 115 | for _, exp := range expected { 116 | _, msg, err := conn.ReadMessage() 117 | if string(msg) != exp || err != nil { 118 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 119 | } 120 | } 121 | <-done 122 | } 123 | 124 | func TestHandler_RawCustomWebSocketCommunication(t *testing.T) { 125 | h := newTestHandler() 126 | h.options.WebsocketWriteTimeout = time.Second 127 | h.options.WebsocketUpgrader = &websocket.Upgrader{ 128 | ReadBufferSize: 0, 129 | WriteBufferSize: 0, 130 | CheckOrigin: func(_ *http.Request) bool { return true }, 131 | Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, 132 | } 133 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 134 | url := "ws" + server.URL[4:] 135 | var done = make(chan struct{}) 136 | h.handlerFunc = func(conn Session) { 137 | conn.Send("message 1") 138 | conn.Send("message 2") 139 | expected := "[\"message 3\"]\n" 140 | msg, err := conn.Recv() 141 | if msg != expected || err != nil { 142 | t.Errorf("Got '%s', expected '%s'", msg, expected) 143 | } 144 | conn.Close(123, "close") 145 | close(done) 146 | } 147 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 148 | conn.WriteJSON([]string{"message 3"}) 149 | var expected = []string{"message 1", "message 2"} 150 | for _, exp := range expected { 151 | _, msg, err := conn.ReadMessage() 152 | if string(msg) != exp || err != nil { 153 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 154 | } 155 | } 156 | <-done 157 | } 158 | -------------------------------------------------------------------------------- /sockjs/session.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // SessionState defines the current state of the session 11 | type SessionState uint32 12 | 13 | const ( 14 | // brand new session, need to send "h" to receiver 15 | SessionOpening SessionState = iota 16 | // active session 17 | SessionActive 18 | // session being closed, sending "closeFrame" to receivers 19 | SessionClosing 20 | // closed session, no activity at all, should be removed from handler completely and not reused 21 | SessionClosed 22 | ) 23 | 24 | var ( 25 | // ErrSessionNotOpen error is used to denote session not in open state. 26 | // Recv() and Send() operations are not suppored if session is closed. 27 | ErrSessionNotOpen = errors.New("sockjs: session not in open state") 28 | errSessionReceiverAttached = errors.New("sockjs: another receiver already attached") 29 | ) 30 | 31 | type session struct { 32 | sync.RWMutex 33 | id string 34 | req *http.Request 35 | state SessionState 36 | 37 | recv receiver // protocol dependent receiver (xhr, eventsource, ...) 38 | sendBuffer []string // messages to be sent to client 39 | recvBuffer *messageBuffer // messages received from client to be consumed by application 40 | closeFrame string // closeFrame to send after session is closed 41 | 42 | // do not use SockJS framing for raw websocket connections 43 | raw bool 44 | 45 | // internal timer used to handle session expiration if no receiver is attached, or heartbeats if recevier is attached 46 | sessionTimeoutInterval time.Duration 47 | heartbeatInterval time.Duration 48 | timer *time.Timer 49 | // once the session timeouts this channel also closes 50 | closeCh chan struct{} 51 | } 52 | 53 | type receiver interface { 54 | // sendBulk send multiple data messages in frame frame in format: a["msg 1", "msg 2", ....] 55 | sendBulk(...string) 56 | // sendFrame sends given frame over the wire (with possible chunking depending on receiver) 57 | sendFrame(string) 58 | // close closes the receiver in a "done" way (idempotent) 59 | close() 60 | canSend() bool 61 | // done notification channel gets closed whenever receiver ends 62 | doneNotify() <-chan struct{} 63 | // interrupted channel gets closed whenever receiver is interrupted (i.e. http connection drops,...) 64 | interruptedNotify() <-chan struct{} 65 | } 66 | 67 | // Session is a central component that handles receiving and sending frames. It maintains internal state 68 | func newSession(req *http.Request, sessionID string, sessionTimeoutInterval, heartbeatInterval time.Duration) *session { 69 | 70 | s := &session{ 71 | id: sessionID, 72 | req: req, 73 | sessionTimeoutInterval: sessionTimeoutInterval, 74 | heartbeatInterval: heartbeatInterval, 75 | recvBuffer: newMessageBuffer(), 76 | closeCh: make(chan struct{}), 77 | } 78 | 79 | s.Lock() // "go test -race" complains if ommited, not sure why as no race can happen here 80 | s.timer = time.AfterFunc(sessionTimeoutInterval, s.close) 81 | s.Unlock() 82 | return s 83 | } 84 | 85 | func (s *session) sendMessage(msg string) error { 86 | s.Lock() 87 | defer s.Unlock() 88 | if s.state > SessionActive { 89 | return ErrSessionNotOpen 90 | } 91 | s.sendBuffer = append(s.sendBuffer, msg) 92 | if s.recv != nil && s.recv.canSend() { 93 | s.recv.sendBulk(s.sendBuffer...) 94 | s.sendBuffer = nil 95 | } 96 | return nil 97 | } 98 | 99 | func (s *session) attachReceiver(recv receiver) error { 100 | s.Lock() 101 | defer s.Unlock() 102 | if s.recv != nil { 103 | return errSessionReceiverAttached 104 | } 105 | s.recv = recv 106 | go func(r receiver) { 107 | select { 108 | case <-r.doneNotify(): 109 | s.detachReceiver() 110 | case <-r.interruptedNotify(): 111 | s.detachReceiver() 112 | s.close() 113 | } 114 | }(recv) 115 | 116 | if s.state == SessionClosing { 117 | if !s.raw { 118 | s.recv.sendFrame(s.closeFrame) 119 | } 120 | s.recv.close() 121 | return nil 122 | } 123 | if s.state == SessionOpening { 124 | if !s.raw { 125 | s.recv.sendFrame("o") 126 | } 127 | s.state = SessionActive 128 | } 129 | s.recv.sendBulk(s.sendBuffer...) 130 | s.sendBuffer = nil 131 | s.timer.Stop() 132 | if s.heartbeatInterval > 0 { 133 | s.timer = time.AfterFunc(s.heartbeatInterval, s.heartbeat) 134 | } 135 | return nil 136 | } 137 | 138 | func (s *session) detachReceiver() { 139 | s.Lock() 140 | s.timer.Stop() 141 | s.timer = time.AfterFunc(s.sessionTimeoutInterval, s.close) 142 | s.recv = nil 143 | s.Unlock() 144 | } 145 | 146 | func (s *session) heartbeat() { 147 | s.Lock() 148 | if s.recv != nil { // timer could have fired between Lock and timer.Stop in detachReceiver 149 | s.recv.sendFrame("h") 150 | s.timer = time.AfterFunc(s.heartbeatInterval, s.heartbeat) 151 | } 152 | s.Unlock() 153 | } 154 | 155 | func (s *session) accept(messages ...string) error { 156 | return s.recvBuffer.push(messages...) 157 | } 158 | 159 | // idempotent operation 160 | func (s *session) closing() { 161 | s.Lock() 162 | defer s.Unlock() 163 | if s.state < SessionClosing { 164 | s.state = SessionClosing 165 | s.recvBuffer.close() 166 | if s.recv != nil { 167 | s.recv.sendFrame(s.closeFrame) 168 | s.recv.close() 169 | } 170 | } 171 | } 172 | 173 | // idempotent operation 174 | func (s *session) close() { 175 | s.closing() 176 | s.Lock() 177 | defer s.Unlock() 178 | if s.state < SessionClosed { 179 | s.state = SessionClosed 180 | s.timer.Stop() 181 | close(s.closeCh) 182 | } 183 | } 184 | 185 | func (s *session) closedNotify() <-chan struct{} { return s.closeCh } 186 | 187 | // Conn interface implementation 188 | func (s *session) Close(status uint32, reason string) error { 189 | s.Lock() 190 | if s.state < SessionClosing { 191 | s.closeFrame = closeFrame(status, reason) 192 | s.Unlock() 193 | s.closing() 194 | return nil 195 | } 196 | s.Unlock() 197 | return ErrSessionNotOpen 198 | } 199 | 200 | func (s *session) Recv() (string, error) { 201 | return s.recvBuffer.pop() 202 | } 203 | 204 | func (s *session) Send(msg string) error { 205 | return s.sendMessage(msg) 206 | } 207 | 208 | func (s *session) ID() string { return s.id } 209 | 210 | func (s *session) GetSessionState() SessionState { 211 | s.RLock() 212 | defer s.RUnlock() 213 | return s.state 214 | } 215 | 216 | func (s *session) Request() *http.Request { 217 | return s.req 218 | } 219 | -------------------------------------------------------------------------------- /sockjs/sockjs.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "net/http" 4 | 5 | // Session represents a connection between server and client. 6 | type Session interface { 7 | // Id returns a session id 8 | ID() string 9 | // Request returns the first http request 10 | Request() *http.Request 11 | // Recv reads one text frame from session 12 | Recv() (string, error) 13 | // Send sends one text frame to session 14 | Send(string) error 15 | // Close closes the session with provided code and reason. 16 | Close(status uint32, reason string) error 17 | //Gets the state of the session. SessionOpening/SessionActive/SessionClosing/SessionClosed; 18 | GetSessionState() SessionState 19 | } 20 | -------------------------------------------------------------------------------- /sockjs/sockjs_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "regexp" 7 | "testing" 8 | ) 9 | 10 | func TestSockJS_ServeHTTP(t *testing.T) { 11 | m := handler{mappings: make([]*mapping, 0)} 12 | m.mappings = []*mapping{ 13 | &mapping{"POST", regexp.MustCompile("/foo/.*"), []http.HandlerFunc{func(http.ResponseWriter, *http.Request) {}}}, 14 | } 15 | req, _ := http.NewRequest("GET", "/foo/bar", nil) 16 | rec := httptest.NewRecorder() 17 | m.ServeHTTP(rec, req) 18 | if rec.Code != http.StatusMethodNotAllowed { 19 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusMethodNotAllowed) 20 | } 21 | req, _ = http.NewRequest("GET", "/bar", nil) 22 | rec = httptest.NewRecorder() 23 | m.ServeHTTP(rec, req) 24 | if rec.Code != http.StatusNotFound { 25 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /sockjs/utils.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "encoding/json" 4 | 5 | func quote(in string) string { 6 | quoted, _ := json.Marshal(in) 7 | return string(quoted) 8 | } 9 | 10 | func transform(values []string, transformFn func(string) string) []string { 11 | ret := make([]string, len(values)) 12 | for i, msg := range values { 13 | ret[i] = transformFn(msg) 14 | } 15 | return ret 16 | } 17 | -------------------------------------------------------------------------------- /sockjs/utils_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "testing" 4 | 5 | func TestQuote(t *testing.T) { 6 | var quotationTests = []struct { 7 | input string 8 | output string 9 | }{ 10 | {"simple", "\"simple\""}, 11 | {"more complex \"", "\"more complex \\\"\""}, 12 | } 13 | 14 | for _, testCase := range quotationTests { 15 | if quote(testCase.input) != testCase.output { 16 | t.Errorf("Expected '%s', got '%s'", testCase.output, quote(testCase.input)) 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /sockjs/web.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | func xhrCorsFactory(opts Options) func(rw http.ResponseWriter, req *http.Request) { 10 | return func(rw http.ResponseWriter, req *http.Request) { 11 | header := rw.Header() 12 | var corsEnabled bool 13 | var corsOrigin string 14 | 15 | if opts.CheckOrigin != nil { 16 | corsEnabled = opts.CheckOrigin(req) 17 | if corsEnabled { 18 | corsOrigin = req.Header.Get("origin") 19 | if corsOrigin == "" { 20 | corsOrigin = "*" 21 | } 22 | } 23 | } else { 24 | corsEnabled = true 25 | corsOrigin = opts.Origin 26 | if corsOrigin == "" { 27 | corsOrigin = req.Header.Get("origin") 28 | } 29 | if corsOrigin == "" || corsOrigin == "null" { 30 | corsOrigin = "*" 31 | } 32 | } 33 | 34 | if corsEnabled { 35 | header.Set("Access-Control-Allow-Origin", corsOrigin) 36 | if allowHeaders := req.Header.Get("Access-Control-Request-Headers"); allowHeaders != "" && allowHeaders != "null" { 37 | header.Add("Access-Control-Allow-Headers", allowHeaders) 38 | } 39 | header.Set("Access-Control-Allow-Credentials", "true") 40 | } 41 | } 42 | } 43 | 44 | func xhrOptions(rw http.ResponseWriter, req *http.Request) { 45 | rw.Header().Set("Access-Control-Allow-Methods", "OPTIONS, POST") 46 | rw.WriteHeader(http.StatusNoContent) // 204 47 | } 48 | 49 | func cacheFor(rw http.ResponseWriter, req *http.Request) { 50 | rw.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", 365*24*60*60)) 51 | rw.Header().Set("Expires", time.Now().AddDate(1, 0, 0).Format(time.RFC1123)) 52 | rw.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", 365*24*60*60)) 53 | } 54 | 55 | func noCache(rw http.ResponseWriter, req *http.Request) { 56 | rw.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") 57 | } 58 | 59 | func welcomeHandler(rw http.ResponseWriter, req *http.Request) { 60 | rw.Header().Set("content-type", "text/plain;charset=UTF-8") 61 | fmt.Fprintf(rw, "Welcome to SockJS!\n") 62 | } 63 | 64 | func httpError(w http.ResponseWriter, error string, code int) { 65 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 66 | w.WriteHeader(code) 67 | fmt.Fprintf(w, error) 68 | } 69 | -------------------------------------------------------------------------------- /sockjs/web_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestXhrCors(t *testing.T) { 11 | recorder := httptest.NewRecorder() 12 | req, _ := http.NewRequest("GET", "/", nil) 13 | xhrCors := xhrCorsFactory(Options{}) 14 | xhrCors(recorder, req) 15 | acao := recorder.Header().Get("access-control-allow-origin") 16 | if acao != "*" { 17 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "*") 18 | } 19 | req.Header.Set("origin", "localhost") 20 | xhrCors(recorder, req) 21 | acao = recorder.Header().Get("access-control-allow-origin") 22 | if acao != "localhost" { 23 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "localhost") 24 | } 25 | req.Header.Set("access-control-request-headers", "some value") 26 | rec := httptest.NewRecorder() 27 | xhrCors(rec, req) 28 | if rec.Header().Get("access-control-allow-headers") != "some value" { 29 | t.Errorf("Incorent value for ACAH, got %s", rec.Header().Get("access-control-allow-headers")) 30 | } 31 | 32 | rec = httptest.NewRecorder() 33 | xhrCors(rec, req) 34 | if rec.Header().Get("access-control-allow-credentials") != "true" { 35 | t.Errorf("Incorent value for ACAC, got %s", rec.Header().Get("access-control-allow-credentials")) 36 | } 37 | 38 | // verify that if Access-Control-Allow-Credentials was previously set that xhrCors() does not duplicate the value 39 | rec = httptest.NewRecorder() 40 | rec.Header().Set("Access-Control-Allow-Credentials", "true") 41 | xhrCors(rec, req) 42 | acac := rec.Header()["Access-Control-Allow-Credentials"] 43 | if len(acac) != 1 || acac[0] != "true" { 44 | t.Errorf("Incorent value for ACAC, got %s", strings.Join(acac, ",")) 45 | } 46 | } 47 | 48 | func TestCheckOriginCORSAllowedNullOrigin(t *testing.T) { 49 | recorder := httptest.NewRecorder() 50 | req, _ := http.NewRequest("GET", "/", nil) 51 | xhrCors := xhrCorsFactory(Options{ 52 | CheckOrigin: func(req *http.Request) bool { 53 | return true 54 | }, 55 | }) 56 | req.Header.Set("origin", "null") 57 | xhrCors(recorder, req) 58 | acao := recorder.Header().Get("access-control-allow-origin") 59 | if acao != "null" { 60 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "null") 61 | } 62 | } 63 | 64 | func TestCheckOriginCORSAllowedEmptyOrigin(t *testing.T) { 65 | recorder := httptest.NewRecorder() 66 | req, _ := http.NewRequest("GET", "/", nil) 67 | xhrCors := xhrCorsFactory(Options{ 68 | CheckOrigin: func(req *http.Request) bool { 69 | return true 70 | }, 71 | }) 72 | xhrCors(recorder, req) 73 | acao := recorder.Header().Get("access-control-allow-origin") 74 | if acao != "*" { 75 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "*") 76 | } 77 | } 78 | 79 | func TestCheckOriginCORSNotAllowed(t *testing.T) { 80 | recorder := httptest.NewRecorder() 81 | req, _ := http.NewRequest("GET", "/", nil) 82 | xhrCors := xhrCorsFactory(Options{ 83 | CheckOrigin: func(req *http.Request) bool { 84 | return false 85 | }, 86 | }) 87 | req.Header.Set("origin", "localhost") 88 | xhrCors(recorder, req) 89 | acao := recorder.Header().Get("access-control-allow-origin") 90 | if acao != "" { 91 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "") 92 | } 93 | } 94 | 95 | func TestXhrOptions(t *testing.T) { 96 | rec := httptest.NewRecorder() 97 | req, _ := http.NewRequest("GET", "/", nil) 98 | xhrOptions(rec, req) 99 | if rec.Code != http.StatusNoContent { 100 | t.Errorf("Wrong response status code, expected %d, got %d", http.StatusNoContent, rec.Code) 101 | } 102 | } 103 | 104 | func TestCacheFor(t *testing.T) { 105 | rec := httptest.NewRecorder() 106 | cacheFor(rec, nil) 107 | cacheControl := rec.Header().Get("cache-control") 108 | if cacheControl != "public, max-age=31536000" { 109 | t.Errorf("Incorrect cache-control header value, got '%s'", cacheControl) 110 | } 111 | expires := rec.Header().Get("expires") 112 | if expires == "" { 113 | t.Errorf("Expires header should not be empty") // TODO(igm) check proper formating of string 114 | } 115 | maxAge := rec.Header().Get("access-control-max-age") 116 | if maxAge != "31536000" { 117 | t.Errorf("Incorrect value for access-control-max-age, got '%s'", maxAge) 118 | } 119 | } 120 | 121 | func TestNoCache(t *testing.T) { 122 | rec := httptest.NewRecorder() 123 | noCache(rec, nil) 124 | } 125 | 126 | func TestWelcomeHandler(t *testing.T) { 127 | rec := httptest.NewRecorder() 128 | welcomeHandler(rec, nil) 129 | if rec.Body.String() != "Welcome to SockJS!\n" { 130 | t.Errorf("Incorrect welcome message received, got '%s'", rec.Body.String()) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /sockjs/websocket.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | "time" 8 | 9 | "github.com/gorilla/websocket" 10 | ) 11 | 12 | // WebSocketReadBufSize is a parameter that is used for WebSocket Upgrader. 13 | // https://github.com/gorilla/websocket/blob/master/server.go#L230 14 | var WebSocketReadBufSize = 4096 15 | 16 | // WebSocketWriteBufSize is a parameter that is used for WebSocket Upgrader 17 | // https://github.com/gorilla/websocket/blob/master/server.go#L230 18 | var WebSocketWriteBufSize = 4096 19 | 20 | func (h *handler) sockjsWebsocket(rw http.ResponseWriter, req *http.Request) { 21 | var conn *websocket.Conn 22 | var err error 23 | if h.options.WebsocketUpgrader != nil { 24 | conn, err = h.options.WebsocketUpgrader.Upgrade(rw, req, nil) 25 | } else { 26 | // use default as before, so that those 2 buffer size variables are used as before 27 | conn, err = websocket.Upgrade(rw, req, nil, WebSocketReadBufSize, WebSocketWriteBufSize) 28 | } 29 | if _, ok := err.(websocket.HandshakeError); ok { 30 | http.Error(rw, `Can "Upgrade" only to "WebSocket".`, http.StatusBadRequest) 31 | return 32 | } else if err != nil { 33 | rw.WriteHeader(http.StatusInternalServerError) 34 | return 35 | } 36 | sessID, _ := h.parseSessionID(req.URL) 37 | sess := newSession(req, sessID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 38 | receiver := newWsReceiver(conn, h.options.WebsocketWriteTimeout) 39 | sess.attachReceiver(receiver) 40 | if h.handlerFunc != nil { 41 | go h.handlerFunc(sess) 42 | } 43 | readCloseCh := make(chan struct{}) 44 | go func() { 45 | var d []string 46 | for { 47 | err := conn.ReadJSON(&d) 48 | if err != nil { 49 | close(readCloseCh) 50 | return 51 | } 52 | sess.accept(d...) 53 | } 54 | }() 55 | 56 | select { 57 | case <-readCloseCh: 58 | case <-receiver.doneNotify(): 59 | } 60 | sess.close() 61 | conn.Close() 62 | } 63 | 64 | type wsReceiver struct { 65 | conn *websocket.Conn 66 | closeCh chan struct{} 67 | writeTimeout time.Duration 68 | } 69 | 70 | func newWsReceiver(conn *websocket.Conn, writeTimeout time.Duration) *wsReceiver { 71 | return &wsReceiver{ 72 | conn: conn, 73 | closeCh: make(chan struct{}), 74 | writeTimeout: writeTimeout, 75 | } 76 | } 77 | 78 | func (w *wsReceiver) sendBulk(messages ...string) { 79 | if len(messages) > 0 { 80 | w.sendFrame(fmt.Sprintf("a[%s]", strings.Join(transform(messages, quote), ","))) 81 | } 82 | } 83 | 84 | func (w *wsReceiver) sendFrame(frame string) { 85 | if w.writeTimeout != 0 { 86 | w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) 87 | } 88 | if err := w.conn.WriteMessage(websocket.TextMessage, []byte(frame)); err != nil { 89 | w.close() 90 | } 91 | } 92 | 93 | func (w *wsReceiver) close() { 94 | select { 95 | case <-w.closeCh: // already closed 96 | default: 97 | close(w.closeCh) 98 | } 99 | } 100 | func (w *wsReceiver) canSend() bool { 101 | select { 102 | case <-w.closeCh: // already closed 103 | return false 104 | default: 105 | return true 106 | } 107 | } 108 | func (w *wsReceiver) doneNotify() <-chan struct{} { return w.closeCh } 109 | func (w *wsReceiver) interruptedNotify() <-chan struct{} { return nil } 110 | -------------------------------------------------------------------------------- /sockjs/websocket_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/gorilla/websocket" 11 | ) 12 | 13 | func TestHandler_WebSocketHandshakeError(t *testing.T) { 14 | h := newTestHandler() 15 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 16 | defer server.Close() 17 | req, _ := http.NewRequest("GET", server.URL, nil) 18 | req.Header.Set("origin", "https"+server.URL[4:]) 19 | resp, err := http.DefaultClient.Do(req) 20 | if err != nil { 21 | t.Errorf("There should not be any error, got '%s'", err) 22 | t.FailNow() 23 | } 24 | if resp == nil { 25 | t.Errorf("Response should not be nil") 26 | t.FailNow() 27 | } 28 | if resp.StatusCode != http.StatusBadRequest { 29 | t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest) 30 | } 31 | } 32 | 33 | func TestHandler_WebSocket(t *testing.T) { 34 | h := newTestHandler() 35 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 36 | defer server.CloseClientConnections() 37 | url := "ws" + server.URL[4:] 38 | var connCh = make(chan Session) 39 | h.handlerFunc = func(conn Session) { connCh <- conn } 40 | conn, resp, err := websocket.DefaultDialer.Dial(url, nil) 41 | if err != nil { 42 | t.Errorf("Unexpected error '%v'", err) 43 | t.FailNow() 44 | } 45 | if conn == nil { 46 | t.Errorf("Connection should not be nil") 47 | t.FailNow() 48 | } 49 | if resp == nil { 50 | t.Errorf("Response should not be nil") 51 | t.FailNow() 52 | } 53 | if resp.StatusCode != http.StatusSwitchingProtocols { 54 | t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols) 55 | } 56 | select { 57 | case <-connCh: //ok 58 | case <-time.After(10 * time.Millisecond): 59 | t.Errorf("Sockjs Handler not invoked") 60 | } 61 | } 62 | 63 | func TestHandler_WebSocketTerminationByServer(t *testing.T) { 64 | h := newTestHandler() 65 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 66 | defer server.Close() 67 | url := "ws" + server.URL[4:] 68 | h.handlerFunc = func(conn Session) { 69 | conn.Close(1024, "some close message") 70 | conn.Close(0, "this should be ignored") 71 | } 72 | conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 73 | if err != nil { 74 | t.Fatalf("websocket dial failed: %v", err) 75 | t.FailNow() 76 | } 77 | if conn == nil { 78 | t.Errorf("Connection should not be nil") 79 | t.FailNow() 80 | } 81 | _, msg, err := conn.ReadMessage() 82 | if string(msg) != "o" || err != nil { 83 | t.Errorf("Open frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, "o") 84 | } 85 | _, msg, err = conn.ReadMessage() 86 | if string(msg) != `c[1024,"some close message"]` || err != nil { 87 | t.Errorf("Close frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, `c[1024,"some close message"]`) 88 | } 89 | _, msg, err = conn.ReadMessage() 90 | // gorilla websocket keeps `errUnexpectedEOF` private so we need to introspect the error message 91 | if err != nil { 92 | if !strings.Contains(err.Error(), "unexpected EOF") { 93 | t.Errorf("Expected 'unexpected EOF' error or similar, got '%v'", err) 94 | } 95 | } 96 | } 97 | 98 | func TestHandler_WebSocketTerminationByClient(t *testing.T) { 99 | h := newTestHandler() 100 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 101 | defer server.Close() 102 | url := "ws" + server.URL[4:] 103 | var done = make(chan struct{}) 104 | h.handlerFunc = func(conn Session) { 105 | if _, err := conn.Recv(); err != ErrSessionNotOpen { 106 | t.Errorf("Recv should fail") 107 | } 108 | close(done) 109 | } 110 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 111 | if conn == nil { 112 | t.Errorf("Connection should not be nil") 113 | t.FailNow() 114 | } 115 | conn.Close() 116 | <-done 117 | } 118 | 119 | func TestHandler_WebSocketCommunication(t *testing.T) { 120 | h := newTestHandler() 121 | h.options.WebsocketWriteTimeout = time.Second 122 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 123 | // defer server.CloseClientConnections() 124 | url := "ws" + server.URL[4:] 125 | var done = make(chan struct{}) 126 | h.handlerFunc = func(conn Session) { 127 | conn.Send("message 1") 128 | conn.Send("message 2") 129 | msg, err := conn.Recv() 130 | if msg != "message 3" || err != nil { 131 | t.Errorf("Got '%s', expected '%s'", msg, "message 3") 132 | } 133 | conn.Close(123, "close") 134 | close(done) 135 | } 136 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 137 | conn.WriteJSON([]string{"message 3"}) 138 | var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`} 139 | for _, exp := range expected { 140 | _, msg, err := conn.ReadMessage() 141 | if string(msg) != exp || err != nil { 142 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 143 | } 144 | } 145 | <-done 146 | } 147 | 148 | func TestHandler_CustomWebSocketCommunication(t *testing.T) { 149 | h := newTestHandler() 150 | h.options.WebsocketUpgrader = &websocket.Upgrader{ 151 | ReadBufferSize: 0, 152 | WriteBufferSize: 0, 153 | CheckOrigin: func(_ *http.Request) bool { return true }, 154 | Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, 155 | } 156 | h.options.WebsocketWriteTimeout = time.Second 157 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 158 | url := "ws" + server.URL[4:] 159 | var done = make(chan struct{}) 160 | h.handlerFunc = func(conn Session) { 161 | conn.Send("message 1") 162 | conn.Send("message 2") 163 | msg, err := conn.Recv() 164 | if msg != "message 3" || err != nil { 165 | t.Errorf("Got '%s', expected '%s'", msg, "message 3") 166 | } 167 | conn.Close(123, "close") 168 | close(done) 169 | } 170 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 171 | conn.WriteJSON([]string{"message 3"}) 172 | var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`} 173 | for _, exp := range expected { 174 | _, msg, err := conn.ReadMessage() 175 | if string(msg) != exp || err != nil { 176 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 177 | } 178 | } 179 | <-done 180 | } 181 | -------------------------------------------------------------------------------- /sockjs/xhr.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | ) 10 | 11 | var ( 12 | cFrame = closeFrame(2010, "Another connection still open") 13 | xhrStreamingPrelude = strings.Repeat("h", 2048) 14 | ) 15 | 16 | func (h *handler) xhrSend(rw http.ResponseWriter, req *http.Request) { 17 | if req.Body == nil { 18 | httpError(rw, "Payload expected.", http.StatusInternalServerError) 19 | return 20 | } 21 | var messages []string 22 | err := json.NewDecoder(req.Body).Decode(&messages) 23 | if err == io.EOF { 24 | httpError(rw, "Payload expected.", http.StatusInternalServerError) 25 | return 26 | } 27 | if _, ok := err.(*json.SyntaxError); ok || err == io.ErrUnexpectedEOF { 28 | httpError(rw, "Broken JSON encoding.", http.StatusInternalServerError) 29 | return 30 | } 31 | sessionID, err := h.parseSessionID(req.URL) 32 | if err != nil { 33 | http.Error(rw, err.Error(), http.StatusInternalServerError) 34 | return 35 | } 36 | 37 | h.sessionsMux.Lock() 38 | defer h.sessionsMux.Unlock() 39 | if sess, ok := h.sessions[sessionID]; !ok { 40 | http.NotFound(rw, req) 41 | } else { 42 | _ = sess.accept(messages...) // TODO(igm) reponse with SISE in case of err? 43 | rw.Header().Set("content-type", "text/plain; charset=UTF-8") // Ignored by net/http (but protocol test complains), see https://code.google.com/p/go/source/detail?r=902dc062bff8 44 | rw.WriteHeader(http.StatusNoContent) 45 | } 46 | } 47 | 48 | type xhrFrameWriter struct{} 49 | 50 | func (*xhrFrameWriter) write(w io.Writer, frame string) (int, error) { 51 | return fmt.Fprintf(w, "%s\n", frame) 52 | } 53 | 54 | func (h *handler) xhrPoll(rw http.ResponseWriter, req *http.Request) { 55 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 56 | sess, _ := h.sessionByRequest(req) // TODO(igm) add err handling, although err should not happen as handler should not pass req in that case 57 | receiver := newHTTPReceiver(rw, 1, new(xhrFrameWriter)) 58 | if err := sess.attachReceiver(receiver); err != nil { 59 | receiver.sendFrame(cFrame) 60 | receiver.close() 61 | return 62 | } 63 | 64 | select { 65 | case <-receiver.doneNotify(): 66 | case <-receiver.interruptedNotify(): 67 | } 68 | } 69 | 70 | func (h *handler) xhrStreaming(rw http.ResponseWriter, req *http.Request) { 71 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 72 | fmt.Fprintf(rw, "%s\n", xhrStreamingPrelude) 73 | rw.(http.Flusher).Flush() 74 | 75 | sess, _ := h.sessionByRequest(req) 76 | receiver := newHTTPReceiver(rw, h.options.ResponseLimit, new(xhrFrameWriter)) 77 | 78 | if err := sess.attachReceiver(receiver); err != nil { 79 | receiver.sendFrame(cFrame) 80 | receiver.close() 81 | return 82 | } 83 | 84 | select { 85 | case <-receiver.doneNotify(): 86 | case <-receiver.interruptedNotify(): 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /sockjs/xhr_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestHandler_XhrSendNilBody(t *testing.T) { 12 | h := newTestHandler() 13 | rec := httptest.NewRecorder() 14 | req, _ := http.NewRequest("POST", "/server/non_existing_session/xhr_send", nil) 15 | h.xhrSend(rec, req) 16 | if rec.Code != http.StatusInternalServerError { 17 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusInternalServerError) 18 | } 19 | if rec.Body.String() != "Payload expected." { 20 | t.Errorf("Unexcpected body received: '%s'", rec.Body.String()) 21 | } 22 | } 23 | 24 | func TestHandler_XhrSendEmptyBody(t *testing.T) { 25 | h := newTestHandler() 26 | rec := httptest.NewRecorder() 27 | req, _ := http.NewRequest("POST", "/server/non_existing_session/xhr_send", strings.NewReader("")) 28 | h.xhrSend(rec, req) 29 | if rec.Code != http.StatusInternalServerError { 30 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusInternalServerError) 31 | } 32 | if rec.Body.String() != "Payload expected." { 33 | t.Errorf("Unexcpected body received: '%s'", rec.Body.String()) 34 | } 35 | } 36 | 37 | func TestHandler_XhrSendWrongUrlPath(t *testing.T) { 38 | h := newTestHandler() 39 | rec := httptest.NewRecorder() 40 | req, _ := http.NewRequest("POST", "incorrect", strings.NewReader("[\"a\"]")) 41 | h.xhrSend(rec, req) 42 | if rec.Code != http.StatusInternalServerError { 43 | t.Errorf("Unexcpected response status, got '%d', expected '%d'", rec.Code, http.StatusInternalServerError) 44 | } 45 | } 46 | 47 | func TestHandler_XhrSendToExistingSession(t *testing.T) { 48 | h := newTestHandler() 49 | rec := httptest.NewRecorder() 50 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 51 | sess := newSession(req, "session", time.Second, time.Second) 52 | h.sessions["session"] = sess 53 | 54 | req, _ = http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 55 | var done = make(chan bool) 56 | go func() { 57 | h.xhrSend(rec, req) 58 | done <- true 59 | }() 60 | msg, _ := sess.Recv() 61 | if msg != "some message" { 62 | t.Errorf("Incorrect message in the channel, should be '%s', was '%s'", "some message", msg) 63 | } 64 | <-done 65 | if rec.Code != http.StatusNoContent { 66 | t.Errorf("Wrong response status received %d, should be %d", rec.Code, http.StatusNoContent) 67 | } 68 | if rec.Header().Get("content-type") != "text/plain; charset=UTF-8" { 69 | t.Errorf("Wrong content type received '%s'", rec.Header().Get("content-type")) 70 | } 71 | } 72 | 73 | func TestHandler_XhrSendInvalidInput(t *testing.T) { 74 | h := newTestHandler() 75 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("some invalid message frame")) 76 | rec := httptest.NewRecorder() 77 | h.xhrSend(rec, req) 78 | if rec.Code != http.StatusInternalServerError || rec.Body.String() != "Broken JSON encoding." { 79 | t.Errorf("Unexpected response, got '%d,%s' expected '%d,Broken JSON encoding.'", rec.Code, rec.Body.String(), http.StatusInternalServerError) 80 | } 81 | 82 | // unexpected EOF 83 | req, _ = http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"x")) 84 | rec = httptest.NewRecorder() 85 | h.xhrSend(rec, req) 86 | if rec.Code != http.StatusInternalServerError || rec.Body.String() != "Broken JSON encoding." { 87 | t.Errorf("Unexpected response, got '%d,%s' expected '%d,Broken JSON encoding.'", rec.Code, rec.Body.String(), http.StatusInternalServerError) 88 | } 89 | } 90 | 91 | func TestHandler_XhrSendSessionNotFound(t *testing.T) { 92 | h := handler{} 93 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 94 | rec := httptest.NewRecorder() 95 | h.xhrSend(rec, req) 96 | if rec.Code != http.StatusNotFound { 97 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 98 | } 99 | } 100 | 101 | func TestHandler_XhrPoll(t *testing.T) { 102 | h := newTestHandler() 103 | rw := httptest.NewRecorder() 104 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 105 | h.xhrPoll(rw, req) 106 | if rw.Header().Get("content-type") != "application/javascript; charset=UTF-8" { 107 | t.Errorf("Wrong content type received, got '%s'", rw.Header().Get("content-type")) 108 | } 109 | } 110 | 111 | func TestHandler_XhrPollConnectionInterrupted(t *testing.T) { 112 | h := newTestHandler() 113 | sess := newTestSession() 114 | sess.state = SessionActive 115 | h.sessions["session"] = sess 116 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 117 | rw := newClosableRecorder() 118 | close(rw.closeNotifCh) 119 | h.xhrPoll(rw, req) 120 | time.Sleep(1 * time.Millisecond) 121 | sess.Lock() 122 | if sess.state != SessionClosed { 123 | t.Errorf("Session should be closed") 124 | } 125 | } 126 | 127 | func TestHandler_XhrPollAnotherConnectionExists(t *testing.T) { 128 | h := newTestHandler() 129 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 130 | // turn of timeoutes and heartbeats 131 | sess := newSession(req, "session", time.Hour, time.Hour) 132 | h.sessions["session"] = sess 133 | sess.attachReceiver(newTestReceiver()) 134 | req, _ = http.NewRequest("POST", "/server/session/xhr", nil) 135 | rw2 := httptest.NewRecorder() 136 | h.xhrPoll(rw2, req) 137 | if rw2.Body.String() != "c[2010,\"Another connection still open\"]\n" { 138 | t.Errorf("Unexpected body, got '%s'", rw2.Body) 139 | } 140 | } 141 | 142 | func TestHandler_XhrStreaming(t *testing.T) { 143 | h := newTestHandler() 144 | rw := newClosableRecorder() 145 | req, _ := http.NewRequest("POST", "/server/session/xhr_streaming", nil) 146 | h.xhrStreaming(rw, req) 147 | expectedBody := strings.Repeat("h", 2048) + "\no\n" 148 | if rw.Body.String() != expectedBody { 149 | t.Errorf("Unexpected body, got '%s' expected '%s'", rw.Body, expectedBody) 150 | } 151 | } 152 | 153 | func TestHandler_XhrStreamingAnotherReceiver(t *testing.T) { 154 | h := newTestHandler() 155 | h.options.ResponseLimit = 4096 156 | rw1 := newClosableRecorder() 157 | req, _ := http.NewRequest("POST", "/server/session/xhr_streaming", nil) 158 | go func() { 159 | rec := httptest.NewRecorder() 160 | h.xhrStreaming(rec, req) 161 | expectedBody := strings.Repeat("h", 2048) + "\n" + "c[2010,\"Another connection still open\"]\n" 162 | if rec.Body.String() != expectedBody { 163 | t.Errorf("Unexpected body got '%s', expected '%s', ", rec.Body, expectedBody) 164 | } 165 | close(rw1.closeNotifCh) 166 | }() 167 | h.xhrStreaming(rw1, req) 168 | } 169 | 170 | // various test only structs 171 | func newTestHandler() *handler { 172 | h := &handler{sessions: make(map[string]*session)} 173 | h.options.HeartbeatDelay = time.Hour 174 | h.options.DisconnectDelay = time.Hour 175 | return h 176 | } 177 | 178 | type ClosableRecorder struct { 179 | *httptest.ResponseRecorder 180 | closeNotifCh chan bool 181 | } 182 | 183 | func newClosableRecorder() *ClosableRecorder { 184 | return &ClosableRecorder{httptest.NewRecorder(), make(chan bool)} 185 | } 186 | 187 | func (cr *ClosableRecorder) CloseNotify() <-chan bool { return cr.closeNotifCh } 188 | -------------------------------------------------------------------------------- /testserver/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/igm/sockjs-go/v3/sockjs" 9 | ) 10 | 11 | type testHandler struct { 12 | prefix string 13 | handler http.Handler 14 | } 15 | 16 | func newSockjsHandler(prefix string, options sockjs.Options, fn func(sockjs.Session)) *testHandler { 17 | return &testHandler{prefix, sockjs.NewHandler(prefix, options, fn)} 18 | } 19 | 20 | type testHandlers []*testHandler 21 | 22 | func main() { 23 | // prepare various options for tests 24 | echoOptions := sockjs.DefaultOptions 25 | echoOptions.ResponseLimit = 4096 26 | echoOptions.RawWebsocket = true 27 | 28 | disabledWebsocketOptions := sockjs.DefaultOptions 29 | disabledWebsocketOptions.Websocket = false 30 | 31 | cookieNeededOptions := sockjs.DefaultOptions 32 | cookieNeededOptions.JSessionID = sockjs.DefaultJSessionID 33 | 34 | closeOptions := sockjs.DefaultOptions 35 | closeOptions.RawWebsocket = true 36 | // register various test handlers 37 | var handlers = []*testHandler{ 38 | newSockjsHandler("/echo", echoOptions, echoHandler), 39 | newSockjsHandler("/cookie_needed_echo", cookieNeededOptions, echoHandler), 40 | newSockjsHandler("/close", closeOptions, closeHandler), 41 | newSockjsHandler("/disabled_websocket_echo", disabledWebsocketOptions, echoHandler), 42 | } 43 | log.Fatal(http.ListenAndServe("localhost:8081", testHandlers(handlers))) 44 | } 45 | 46 | func (t testHandlers) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 47 | for _, handler := range t { 48 | if strings.HasPrefix(req.URL.Path, handler.prefix) { 49 | handler.handler.ServeHTTP(rw, req) 50 | return 51 | } 52 | } 53 | http.NotFound(rw, req) 54 | } 55 | 56 | func closeHandler(conn sockjs.Session) { conn.Close(3000, "Go away!") } 57 | func echoHandler(conn sockjs.Session) { 58 | log.Println("New connection created") 59 | for { 60 | if msg, err := conn.Recv(); err != nil { 61 | break 62 | } else { 63 | if err := conn.Send(msg); err != nil { 64 | break 65 | } 66 | } 67 | } 68 | log.Println("Sessionection closed") 69 | } 70 | -------------------------------------------------------------------------------- /v3/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/igm/sockjs-go/v3 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/gorilla/websocket v1.4.2 7 | github.com/stretchr/testify v1.5.1 8 | ) 9 | -------------------------------------------------------------------------------- /v3/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 4 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 8 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 9 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 13 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 14 | -------------------------------------------------------------------------------- /v3/sockjs/benchmarks_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "bufio" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strings" 12 | "sync" 13 | "testing" 14 | "time" 15 | 16 | "github.com/gorilla/websocket" 17 | ) 18 | 19 | func BenchmarkSimple(b *testing.B) { 20 | var messages = make(chan string, 10) 21 | h := NewHandler("/echo", DefaultOptions, func(session Session) { 22 | for m := range messages { 23 | _ = session.Send(m) 24 | } 25 | _ = session.Close(1024, "Close") 26 | }) 27 | server := httptest.NewServer(h) 28 | defer server.Close() 29 | 30 | req, _ := http.NewRequest("POST", server.URL+fmt.Sprintf("/echo/server/%d/xhr_streaming", 1000), nil) 31 | resp, err := http.DefaultClient.Do(req) 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | for n := 0; n < b.N; n++ { 36 | messages <- "some message" 37 | } 38 | fmt.Println(b.N) 39 | close(messages) 40 | resp.Body.Close() 41 | } 42 | 43 | func BenchmarkMessages(b *testing.B) { 44 | msg := strings.Repeat("m", 10) 45 | h := NewHandler("/echo", DefaultOptions, func(session Session) { 46 | for n := 0; n < b.N; n++ { 47 | _ = session.Send(msg) 48 | } 49 | _ = session.Close(1024, "Close") 50 | }) 51 | server := httptest.NewServer(h) 52 | 53 | var wg sync.WaitGroup 54 | 55 | for i := 0; i < 100; i++ { 56 | wg.Add(1) 57 | go func(session int) { 58 | reqc := 0 59 | req, _ := http.NewRequest("POST", server.URL+fmt.Sprintf("/echo/server/%d/xhr_streaming", session), nil) 60 | for { 61 | reqc++ 62 | resp, err := http.DefaultClient.Do(req) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | reader := bufio.NewReader(resp.Body) 67 | for { 68 | line, err := reader.ReadString('\n') 69 | if err != nil { 70 | goto AGAIN 71 | } 72 | if strings.HasPrefix(line, "data: c[1024") { 73 | resp.Body.Close() 74 | goto DONE 75 | } 76 | } 77 | AGAIN: 78 | resp.Body.Close() 79 | } 80 | DONE: 81 | wg.Done() 82 | }(i) 83 | } 84 | wg.Wait() 85 | server.Close() 86 | } 87 | 88 | var size = flag.Int("size", 4*1024, "Size of one message.") 89 | 90 | func BenchmarkMessageWebsocket(b *testing.B) { 91 | flag.Parse() 92 | 93 | msg := strings.Repeat("x", *size) 94 | wsFrame := []byte(fmt.Sprintf("[%q]", msg)) 95 | 96 | opts := Options{ 97 | Websocket: true, 98 | SockJSURL: "//cdnjs.cloudflare.com/ajax/libs/sockjs-client/0.3.4/sockjs.min.js", 99 | HeartbeatDelay: time.Hour, 100 | DisconnectDelay: time.Hour, 101 | ResponseLimit: uint32(*size), 102 | } 103 | 104 | h := NewHandler("/echo", opts, func(session Session) { 105 | for { 106 | msg, err := session.Recv() 107 | if err != nil { 108 | if session.GetSessionState() != SessionActive { 109 | break 110 | } 111 | b.Fatalf("Recv()=%s", err) 112 | } 113 | 114 | if err := session.Send(msg); err != nil { 115 | b.Fatalf("Send()=%s", err) 116 | } 117 | } 118 | }) 119 | 120 | server := httptest.NewServer(h) 121 | defer server.Close() 122 | 123 | url := "ws" + server.URL[4:] + "/echo/server/0/websocket" 124 | 125 | client, _, err := websocket.DefaultDialer.Dial(url, nil) 126 | if err != nil { 127 | b.Fatalf("Dial()=%s", err) 128 | } 129 | 130 | _, p, err := client.ReadMessage() 131 | if err != nil || string(p) != "o" { 132 | b.Fatalf("failed to start new session: frame=%v, err=%v", p, err) 133 | } 134 | 135 | b.ReportAllocs() 136 | b.ResetTimer() 137 | 138 | for i := 0; i < b.N; i++ { 139 | if err := client.WriteMessage(websocket.TextMessage, wsFrame); err != nil { 140 | b.Fatalf("WriteMessage()=%s", err) 141 | } 142 | 143 | if _, _, err := client.ReadMessage(); err != nil { 144 | b.Fatalf("ReadMessage()=%s", err) 145 | } 146 | } 147 | 148 | if err := client.Close(); err != nil { 149 | b.Fatalf("Close()=%s", err) 150 | } 151 | } 152 | 153 | func BenchmarkHandler_ParseSessionID(b *testing.B) { 154 | h := Handler{prefix: "/prefix"} 155 | url, _ := url.Parse("http://server:80/prefix/server/session/whatever") 156 | 157 | b.ReportAllocs() 158 | b.ResetTimer() 159 | for i := 0; i < b.N; i++ { 160 | _, _ = h.parseSessionID(url) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /v3/sockjs/buffer.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // messageBuffer is an unbounded buffer that blocks on 9 | // pop if it's empty until the new element is enqueued. 10 | type messageBuffer struct { 11 | popCh chan string 12 | closeCh chan struct{} 13 | once sync.Once // for b.close() 14 | } 15 | 16 | func newMessageBuffer() *messageBuffer { 17 | return &messageBuffer{ 18 | popCh: make(chan string), 19 | closeCh: make(chan struct{}), 20 | } 21 | } 22 | 23 | func (b *messageBuffer) push(messages ...string) error { 24 | for _, message := range messages { 25 | select { 26 | case b.popCh <- message: 27 | case <-b.closeCh: 28 | return ErrSessionNotOpen 29 | } 30 | } 31 | 32 | return nil 33 | } 34 | 35 | func (b *messageBuffer) pop(ctx context.Context) (string, error) { 36 | select { 37 | case msg := <-b.popCh: 38 | return msg, nil 39 | case <-b.closeCh: 40 | return "", ErrSessionNotOpen 41 | case <-ctx.Done(): 42 | return "", ctx.Err() 43 | } 44 | } 45 | 46 | func (b *messageBuffer) close() { b.once.Do(func() { close(b.closeCh) }) } 47 | -------------------------------------------------------------------------------- /v3/sockjs/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package sockjs is a server side implementation of sockjs protocol. 3 | */ 4 | 5 | package sockjs 6 | -------------------------------------------------------------------------------- /v3/sockjs/eventsource.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "net/url" 8 | "strings" 9 | ) 10 | 11 | func (h *Handler) eventSource(rw http.ResponseWriter, req *http.Request) { 12 | rw.Header().Set("content-type", "text/event-stream; charset=UTF-8") 13 | _, _ = fmt.Fprint(rw, "\r\n") 14 | rw.(http.Flusher).Flush() 15 | 16 | recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(eventSourceFrameWriter), ReceiverTypeEventSource) 17 | sess, err := h.sessionByRequest(req) 18 | if err != nil { 19 | http.Error(rw, err.Error(), http.StatusInternalServerError) 20 | return 21 | } 22 | if err := sess.attachReceiver(recv); err != nil { 23 | if err := recv.sendFrame(cFrame); err != nil { 24 | http.Error(rw, err.Error(), http.StatusInternalServerError) 25 | return 26 | } 27 | recv.close() 28 | return 29 | } 30 | sess.startHandlerOnce.Do(func() { go h.handlerFunc(Session{sess}) }) 31 | select { 32 | case <-recv.doneNotify(): 33 | case <-recv.interruptedNotify(): 34 | } 35 | } 36 | 37 | type eventSourceFrameWriter struct{} 38 | 39 | var escaper *strings.Replacer = strings.NewReplacer( 40 | "%", url.QueryEscape("%"), 41 | "\n", url.QueryEscape("\n"), 42 | "\r", url.QueryEscape("\r"), 43 | "\x00", url.QueryEscape("\x00"), 44 | ) 45 | 46 | func (*eventSourceFrameWriter) write(w io.Writer, frame string) (int, error) { 47 | return fmt.Fprintf(w, "data: %s\r\n\r\n", escaper.Replace(frame)) 48 | } 49 | -------------------------------------------------------------------------------- /v3/sockjs/eventsource_integration_stage_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/igm/sockjs-go/v3/sockjs" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | type eventSourceStage struct { 18 | t *testing.T 19 | handler *sockjs.Handler 20 | server *httptest.Server 21 | resp *http.Response 22 | err error 23 | session sockjs.Session 24 | haveSession chan struct{} 25 | receivedMessages chan string 26 | } 27 | 28 | func newEventSourceStage(t *testing.T) (*eventSourceStage, *eventSourceStage, *eventSourceStage) { 29 | stage := &eventSourceStage{ 30 | t: t, 31 | haveSession: make(chan struct{}), 32 | receivedMessages: make(chan string, 1024), 33 | } 34 | return stage, stage, stage 35 | } 36 | 37 | func (s *eventSourceStage) a_new_sockjs_handler_is_created() *eventSourceStage { 38 | s.handler = sockjs.NewHandler("/prefix", sockjs.DefaultOptions, func(sess sockjs.Session) { 39 | s.session = sess 40 | close(s.haveSession) 41 | for { 42 | msg, err := sess.Recv() 43 | if err == sockjs.ErrSessionNotOpen { 44 | return 45 | } 46 | require.NoError(s.t, err) 47 | s.receivedMessages <- msg 48 | } 49 | }) 50 | return s 51 | } 52 | 53 | func (s *eventSourceStage) a_server_is_started() *eventSourceStage { 54 | s.server = httptest.NewServer(s.handler) 55 | return s 56 | } 57 | 58 | func (s *eventSourceStage) a_sockjs_eventsource_connection_is_received() *eventSourceStage { 59 | s.resp, s.err = http.Get(s.server.URL + "/prefix/123/456/eventsource") 60 | return s 61 | } 62 | 63 | func (s *eventSourceStage) handler_is_invoked_with_session() *eventSourceStage { 64 | select { 65 | case <-s.haveSession: 66 | case <-time.After(1 * time.Second): 67 | s.t.Fatal("no session was created") 68 | } 69 | assert.Equal(s.t, sockjs.ReceiverTypeEventSource, s.session.ReceiverType()) 70 | return s 71 | } 72 | 73 | func (s *eventSourceStage) session_is_closed() *eventSourceStage { 74 | s.session.Close(1024, "Close") 75 | assert.Error(s.t, s.session.Context().Err()) 76 | select { 77 | case <-s.session.Context().Done(): 78 | case <-time.After(1 * time.Second): 79 | s.t.Fatal("context should have been done") 80 | } 81 | return s 82 | } 83 | 84 | func (s *eventSourceStage) valid_eventsource_frames_should_be_received() *eventSourceStage { 85 | require.NoError(s.t, s.err) 86 | assert.Equal(s.t, "text/event-stream; charset=UTF-8", s.resp.Header.Get("content-type")) 87 | assert.Equal(s.t, "true", s.resp.Header.Get("access-control-allow-credentials")) 88 | assert.Equal(s.t, "*", s.resp.Header.Get("access-control-allow-origin")) 89 | 90 | all, err := ioutil.ReadAll(s.resp.Body) 91 | require.NoError(s.t, err) 92 | expectedBody := "\r\ndata: o\r\n\r\ndata: c[1024,\"Close\"]\r\n\r\n" 93 | assert.Equal(s.t, expectedBody, string(all)) 94 | return s 95 | } 96 | 97 | func (s *eventSourceStage) a_message_is_sent_from_client(msg string) *eventSourceStage { 98 | out, err := json.Marshal([]string{msg}) 99 | require.NoError(s.t, err) 100 | r, err := http.Post(s.server.URL+"/prefix/123/456/xhr_send", "application/json", bytes.NewReader(out)) 101 | require.NoError(s.t, err) 102 | require.Equal(s.t, http.StatusNoContent, r.StatusCode) 103 | return s 104 | } 105 | 106 | func (s *eventSourceStage) same_message_should_be_received_from_session(expectredMsg string) *eventSourceStage { 107 | select { 108 | case msg := <-s.receivedMessages: 109 | assert.Equal(s.t, expectredMsg, msg) 110 | case <-time.After(1 * time.Second): 111 | s.t.Fatal("no message was received") 112 | } 113 | return s 114 | } 115 | 116 | func (s *eventSourceStage) and() *eventSourceStage { return s } 117 | 118 | func (s *eventSourceStage) a_server_is_started_with_handler() *eventSourceStage { 119 | s.a_new_sockjs_handler_is_created() 120 | s.a_server_is_started() 121 | return s 122 | } 123 | -------------------------------------------------------------------------------- /v3/sockjs/eventsource_intergration_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestEventSource(t *testing.T) { 8 | given, when, then := newEventSourceStage(t) 9 | 10 | given. 11 | a_new_sockjs_handler_is_created().and(). 12 | a_server_is_started() 13 | 14 | when. 15 | a_sockjs_eventsource_connection_is_received().and(). 16 | handler_is_invoked_with_session().and(). 17 | session_is_closed() 18 | 19 | then. 20 | valid_eventsource_frames_should_be_received() 21 | } 22 | 23 | func TestEventSourceMessageInteraction(t *testing.T) { 24 | given, when, then := newEventSourceStage(t) 25 | 26 | given. 27 | a_server_is_started_with_handler(). 28 | a_sockjs_eventsource_connection_is_received(). 29 | handler_is_invoked_with_session() 30 | 31 | when. 32 | a_message_is_sent_from_client("Hello World!").and(). 33 | session_is_closed() 34 | 35 | then. 36 | same_message_should_be_received_from_session("Hello World!") 37 | } 38 | -------------------------------------------------------------------------------- /v3/sockjs/eventsource_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "net/http" 7 | "net/http/httptest" 8 | "runtime" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestHandler_EventSource(t *testing.T) { 14 | rw := httptest.NewRecorder() 15 | req, _ := http.NewRequest("POST", "/server/session/eventsource", nil) 16 | h := newTestHandler() 17 | h.options.ResponseLimit = 1024 18 | go func() { 19 | var sess *session 20 | for exists := false; !exists; { 21 | runtime.Gosched() 22 | h.sessionsMux.Lock() 23 | sess, exists = h.sessions["session"] 24 | h.sessionsMux.Unlock() 25 | } 26 | for exists := false; !exists; { 27 | runtime.Gosched() 28 | sess.mux.RLock() 29 | exists = sess.recv != nil 30 | sess.mux.RUnlock() 31 | } 32 | if rt := sess.ReceiverType(); rt != ReceiverTypeEventSource { 33 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeEventSource) 34 | } 35 | sess.mux.RLock() 36 | sess.recv.close() 37 | sess.mux.RUnlock() 38 | }() 39 | h.eventSource(rw, req) 40 | 41 | contentType := rw.Header().Get("content-type") 42 | expected := "text/event-stream; charset=UTF-8" 43 | if contentType != expected { 44 | t.Errorf("Unexpected content type, got '%s', extected '%s'", contentType, expected) 45 | } 46 | if rw.Code != http.StatusOK { 47 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusOK) 48 | } 49 | 50 | if rw.Body.String() != "\r\ndata: o\r\n\r\n" { 51 | t.Errorf("Event stream prelude, got '%s'", rw.Body) 52 | } 53 | } 54 | 55 | func TestHandler_EventSourceMultipleConnections(t *testing.T) { 56 | h := newTestHandler() 57 | h.options.ResponseLimit = 1024 58 | rw := httptest.NewRecorder() 59 | req, _ := http.NewRequest("POST", "/server/sess/eventsource", nil) 60 | go func() { 61 | rw := httptest.NewRecorder() 62 | h.eventSource(rw, req) 63 | if rw.Body.String() != "\r\ndata: c[2010,\"Another connection still open\"]\r\n\r\n" { 64 | t.Errorf("wrong, got '%v'", rw.Body) 65 | } 66 | h.sessionsMux.Lock() 67 | sess := h.sessions["sess"] 68 | sess.close() 69 | h.sessionsMux.Unlock() 70 | }() 71 | h.eventSource(rw, req) 72 | } 73 | 74 | func TestHandler_EventSourceConnectionInterrupted(t *testing.T) { 75 | h := newTestHandler() 76 | sess := newTestSession() 77 | sess.state = SessionActive 78 | h.sessions["session"] = sess 79 | req, _ := http.NewRequest("POST", "/server/session/eventsource", nil) 80 | ctx, cancel := context.WithCancel(req.Context()) 81 | req = req.WithContext(ctx) 82 | rw := httptest.NewRecorder() 83 | cancel() 84 | h.eventSource(rw, req) 85 | select { 86 | case <-sess.closeCh: 87 | case <-time.After(1 * time.Second): 88 | t.Errorf("session close channel should be closed") 89 | } 90 | sess.mux.Lock() 91 | if sess.state != SessionClosed { 92 | t.Errorf("session should be closed") 93 | } 94 | } 95 | 96 | func TestEventSourceFrameWriter(t *testing.T) { 97 | writer := new(eventSourceFrameWriter) 98 | out := new(bytes.Buffer) 99 | 100 | // Confirm that "important" characters are escaped, but others pass 101 | // through unmodified. 102 | _, err := writer.write(out, "escaped: %\r\n;unescaped: +&#") 103 | if err != nil { 104 | t.Errorf("unexpected write error: %s", err) 105 | } 106 | if out.String() != "data: escaped: %25%0D%0A;unescaped: +&#\r\n\r\n" { 107 | t.Errorf("wrong, got '%v'", out.String()) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /v3/sockjs/example_handler_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/igm/sockjs-go/v3/sockjs" 7 | ) 8 | 9 | func ExampleNewHandler_simple() { 10 | handler := sockjs.NewHandler("/echo", sockjs.DefaultOptions, func(session sockjs.Session) { 11 | for { 12 | if msg, err := session.Recv(); err == nil { 13 | if session.Send(msg) != nil { 14 | break 15 | } 16 | } else { 17 | break 18 | } 19 | } 20 | }) 21 | _ = http.ListenAndServe(":8080", handler) 22 | } 23 | 24 | func ExampleNewHandler_defaultMux() { 25 | handler := sockjs.NewHandler("/echo", sockjs.DefaultOptions, func(session sockjs.Session) { 26 | for { 27 | if msg, err := session.Recv(); err == nil { 28 | if session.Send(msg) != nil { 29 | break 30 | } 31 | } else { 32 | break 33 | } 34 | } 35 | }) 36 | // need to provide path prefix for http.Mux 37 | http.Handle("/echo/", handler) 38 | _ = http.ListenAndServe(":8080", nil) 39 | } 40 | -------------------------------------------------------------------------------- /v3/sockjs/frame.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | func closeFrame(status uint32, reason string) string { 9 | bytes, _ := json.Marshal([]interface{}{status, reason}) 10 | return fmt.Sprintf("c%s", string(bytes)) 11 | } 12 | -------------------------------------------------------------------------------- /v3/sockjs/frame_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "testing" 4 | 5 | func TestCloseFrame(t *testing.T) { 6 | cf := closeFrame(1024, "some close text") 7 | if cf != "c[1024,\"some close text\"]" { 8 | t.Errorf("Wrong close frame generated '%s'", cf) 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /v3/sockjs/handler.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "regexp" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | type Handler struct { 12 | prefix string 13 | options Options 14 | handlerFunc func(Session) 15 | mappings []*mapping 16 | 17 | sessionsMux sync.Mutex 18 | sessions map[string]*session 19 | } 20 | 21 | const sessionPrefix = "^/([^/.]+)/([^/.]+)" 22 | 23 | var sessionRegExp = regexp.MustCompile(sessionPrefix) 24 | 25 | // NewHandler creates new HTTP handler that conforms to the basic net/http.Handler interface. 26 | // It takes path prefix, options and sockjs handler function as parameters 27 | func NewHandler(prefix string, opts Options, handlerFunc func(Session)) *Handler { 28 | if handlerFunc == nil { 29 | handlerFunc = func(s Session) {} 30 | } 31 | h := &Handler{ 32 | prefix: prefix, 33 | options: opts, 34 | handlerFunc: handlerFunc, 35 | sessions: make(map[string]*session), 36 | } 37 | 38 | h.fillMappingsWithAllowedMethods() 39 | 40 | if opts.Websocket { 41 | h.mappings = append(h.mappings, newMapping("GET", sessionPrefix+"/websocket$", h.sockjsWebsocket)) 42 | } 43 | if opts.RawWebsocket { 44 | h.mappings = append(h.mappings, newMapping("GET", "^/websocket$", h.rawWebsocket)) 45 | } 46 | return h 47 | } 48 | 49 | func (h *Handler) Prefix() string { return h.prefix } 50 | 51 | func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 52 | // iterate over mappings 53 | http.StripPrefix(h.prefix, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 54 | var allowedMethods []string 55 | for _, mapping := range h.mappings { 56 | if match, method := mapping.matches(req); match == fullMatch { 57 | for _, hf := range mapping.chain { 58 | hf(rw, req) 59 | } 60 | return 61 | } else if match == pathMatch { 62 | allowedMethods = append(allowedMethods, method) 63 | } 64 | } 65 | if len(allowedMethods) > 0 { 66 | rw.Header().Set("allow", strings.Join(allowedMethods, ", ")) 67 | rw.Header().Set("Content-Type", "") 68 | rw.WriteHeader(http.StatusMethodNotAllowed) 69 | return 70 | } 71 | http.NotFound(rw, req) 72 | })).ServeHTTP(rw, req) 73 | } 74 | 75 | func (h *Handler) parseSessionID(url *url.URL) (string, error) { 76 | matches := sessionRegExp.FindStringSubmatch(url.Path) 77 | if len(matches) == 3 { 78 | return matches[2], nil 79 | } 80 | return "", errSessionParse 81 | } 82 | 83 | func (h *Handler) sessionByRequest(req *http.Request) (*session, error) { 84 | h.sessionsMux.Lock() 85 | defer h.sessionsMux.Unlock() 86 | sessionID, err := h.parseSessionID(req.URL) 87 | if err != nil { 88 | return nil, err 89 | } 90 | sess, exists := h.sessions[sessionID] 91 | if !exists { 92 | sess = newSession(req, sessionID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 93 | h.sessions[sessionID] = sess 94 | go func() { 95 | <-sess.closeCh 96 | h.sessionsMux.Lock() 97 | delete(h.sessions, sessionID) 98 | h.sessionsMux.Unlock() 99 | }() 100 | } 101 | sess.setCurrentRequest(req) 102 | return sess, nil 103 | } 104 | 105 | // fillMappingsWithAllowedMethods adds only allowed methods to handler.mappings, by if method is not disabled 106 | func (h *Handler) fillMappingsWithAllowedMethods() { 107 | 108 | xhrCors := xhrCorsFactory(h.options) 109 | 110 | // Default Methods 111 | h.mappings = []*mapping{ 112 | newMapping("GET", "^[/]?$", welcomeHandler), 113 | newMapping("OPTIONS", "^/info$", h.options.cookie, xhrCors, cacheFor, h.options.info), 114 | newMapping("GET", "^/info$", h.options.cookie, xhrCors, noCache, h.options.info), 115 | // IFrame 116 | newMapping("GET", "^/iframe[0-9-.a-z_]*.html$", cacheFor, h.iframe), 117 | } 118 | 119 | // Adding XHR to mapping 120 | if !h.options.DisableXHR { 121 | h.mappings = append(h.mappings, 122 | newMapping("POST", sessionPrefix+"/xhr$", h.options.cookie, xhrCors, noCache, h.xhrPoll), 123 | newMapping("OPTIONS", sessionPrefix+"/xhr$", h.options.cookie, xhrCors, cacheFor, xhrOptions), 124 | ) 125 | } 126 | 127 | // Adding XHRStreaming to mapping 128 | if !h.options.DisableXHRStreaming { 129 | h.mappings = append(h.mappings, 130 | newMapping("POST", sessionPrefix+"/xhr_streaming$", h.options.cookie, xhrCors, noCache, h.xhrStreaming), 131 | newMapping("OPTIONS", sessionPrefix+"/xhr_streaming$", h.options.cookie, xhrCors, cacheFor, xhrOptions), 132 | ) 133 | } 134 | 135 | // Adding EventSource to mapping 136 | if !h.options.DisableEventSource { 137 | h.mappings = append(h.mappings, 138 | newMapping("GET", sessionPrefix+"/eventsource$", h.options.cookie, xhrCors, noCache, h.eventSource), 139 | ) 140 | } 141 | 142 | // Adding HtmlFile to mapping 143 | if !h.options.DisableHtmlFile { 144 | h.mappings = append(h.mappings, 145 | newMapping("GET", sessionPrefix+"/htmlfile$", h.options.cookie, xhrCors, noCache, h.htmlFile), 146 | ) 147 | } 148 | 149 | // Adding JSONP to mapping 150 | if !h.options.DisableJSONP { 151 | h.mappings = append(h.mappings, 152 | newMapping("GET", sessionPrefix+"/jsonp$", h.options.cookie, xhrCors, noCache, h.jsonp), 153 | newMapping("OPTIONS", sessionPrefix+"/jsonp$", h.options.cookie, xhrCors, cacheFor, xhrOptions), 154 | newMapping("POST", sessionPrefix+"/jsonp_send$", h.options.cookie, xhrCors, noCache, h.jsonpSend), 155 | ) 156 | } 157 | 158 | // when adding XHRPoll or/and XHRStreaming xhr_send must be added too (only once) 159 | if !h.options.DisableXHR || !h.options.DisableXHRStreaming { 160 | h.mappings = append(h.mappings, 161 | newMapping("POST", sessionPrefix+"/xhr_send$", h.options.cookie, xhrCors, noCache, h.xhrSend), 162 | newMapping("OPTIONS", sessionPrefix+"/xhr_send$", h.options.cookie, xhrCors, cacheFor, xhrOptions), 163 | ) 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /v3/sockjs/htmlfile.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "regexp" 8 | "strings" 9 | ) 10 | 11 | var iframeTemplate = ` 12 | 13 | 14 | 15 |

Don't panic!

16 | 23 | ` 24 | 25 | var invalidCallback = regexp.MustCompile(`[^a-zA-Z0-9_.]`) 26 | 27 | func init() { 28 | iframeTemplate += strings.Repeat(" ", 1024-len(iframeTemplate)+14) 29 | iframeTemplate += "\r\n\r\n" 30 | } 31 | 32 | func (h *Handler) htmlFile(rw http.ResponseWriter, req *http.Request) { 33 | rw.Header().Set("content-type", "text/html; charset=UTF-8") 34 | 35 | if err := req.ParseForm(); err != nil { 36 | http.Error(rw, err.Error(), http.StatusBadRequest) 37 | return 38 | } 39 | callback := req.Form.Get("c") 40 | if callback == "" { 41 | http.Error(rw, `"callback" parameter required`, http.StatusBadRequest) 42 | return 43 | } else if invalidCallback.MatchString(callback) { 44 | http.Error(rw, `invalid character in "callback" parameter`, http.StatusBadRequest) 45 | return 46 | } 47 | rw.WriteHeader(http.StatusOK) 48 | fmt.Fprintf(rw, iframeTemplate, callback) 49 | rw.(http.Flusher).Flush() 50 | sess, err := h.sessionByRequest(req) 51 | if err != nil { 52 | http.Error(rw, err.Error(), http.StatusInternalServerError) 53 | return 54 | } 55 | recv := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(htmlfileFrameWriter), ReceiverTypeHtmlFile) 56 | if err := sess.attachReceiver(recv); err != nil { 57 | if err := recv.sendFrame(cFrame); err != nil { 58 | http.Error(rw, err.Error(), http.StatusInternalServerError) 59 | return 60 | } 61 | recv.close() 62 | return 63 | } 64 | sess.startHandlerOnce.Do(func() { go h.handlerFunc(Session{sess}) }) 65 | select { 66 | case <-recv.doneNotify(): 67 | case <-recv.interruptedNotify(): 68 | } 69 | } 70 | 71 | type htmlfileFrameWriter struct{} 72 | 73 | func (*htmlfileFrameWriter) write(w io.Writer, frame string) (int, error) { 74 | return fmt.Fprintf(w, "\r\n", quote(frame)) 75 | } 76 | -------------------------------------------------------------------------------- /v3/sockjs/htmlfile_integration_stage_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | 12 | "github.com/igm/sockjs-go/v3/sockjs" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | type htmlFileStage struct { 18 | t *testing.T 19 | handler *sockjs.Handler 20 | server *httptest.Server 21 | resp *http.Response 22 | err error 23 | session sockjs.Session 24 | haveSession chan struct{} 25 | receivedMessages chan string 26 | } 27 | 28 | func newHtmlFileStage(t *testing.T) (*htmlFileStage, *htmlFileStage, *htmlFileStage) { 29 | stage := &htmlFileStage{ 30 | t: t, 31 | haveSession: make(chan struct{}), 32 | receivedMessages: make(chan string, 1024), 33 | } 34 | return stage, stage, stage 35 | } 36 | 37 | func (s *htmlFileStage) a_new_sockjs_handler_is_created() *htmlFileStage { 38 | s.handler = sockjs.NewHandler("/prefix", sockjs.DefaultOptions, func(sess sockjs.Session) { 39 | s.session = sess 40 | close(s.haveSession) 41 | for { 42 | msg, err := sess.Recv() 43 | if err == sockjs.ErrSessionNotOpen { 44 | return 45 | } 46 | require.NoError(s.t, err) 47 | s.receivedMessages <- msg 48 | } 49 | }) 50 | return s 51 | } 52 | 53 | func (s *htmlFileStage) a_server_is_started() *htmlFileStage { 54 | s.server = httptest.NewServer(s.handler) 55 | return s 56 | } 57 | 58 | func (s *htmlFileStage) a_sockjs_htmlfile_connection_is_received() *htmlFileStage { 59 | s.resp, s.err = http.Get(s.server.URL + "/prefix/123/123/htmlfile?c=testCallback") 60 | return s 61 | } 62 | 63 | func (s *htmlFileStage) correct_http_response_should_be_received() *htmlFileStage { 64 | require.NoError(s.t, s.err) 65 | assert.Equal(s.t, http.StatusOK, s.resp.StatusCode) 66 | assert.Equal(s.t, "text/html; charset=UTF-8", s.resp.Header.Get("content-type")) 67 | assert.Equal(s.t, "true", s.resp.Header.Get("access-control-allow-credentials")) 68 | assert.Equal(s.t, "*", s.resp.Header.Get("access-control-allow-origin")) 69 | return s 70 | } 71 | 72 | func (s *htmlFileStage) handler_should_be_started_with_session() *htmlFileStage { 73 | select { 74 | case <-s.haveSession: 75 | case <-time.After(1 * time.Second): 76 | s.t.Fatal("no session was created") 77 | } 78 | assert.Equal(s.t, sockjs.ReceiverTypeHtmlFile, s.session.ReceiverType()) 79 | return s 80 | } 81 | 82 | func (s *htmlFileStage) session_is_closed() *htmlFileStage { 83 | require.NoError(s.t, s.session.Close(1024, "Close")) 84 | assert.Error(s.t, s.session.Context().Err()) 85 | select { 86 | case <-s.session.Context().Done(): 87 | case <-time.After(1 * time.Second): 88 | s.t.Fatal("context should have been done") 89 | } 90 | return s 91 | } 92 | 93 | func (s *htmlFileStage) valid_htmlfile_response_should_be_received() *htmlFileStage { 94 | all, err := ioutil.ReadAll(s.resp.Body) 95 | require.NoError(s.t, err) 96 | assert.Contains(s.t, string(all), `p("o");`, string(all)) 97 | assert.Contains(s.t, string(all), `p("c[1024,\"Close\"]")`, string(all)) 98 | assert.Contains(s.t, string(all), `var c = parent.testCallback;`, string(all)) 99 | return s 100 | } 101 | 102 | func (s *htmlFileStage) and() *htmlFileStage { return s } 103 | 104 | func (s *htmlFileStage) a_server_is_started_with_handler() *htmlFileStage { 105 | s.a_new_sockjs_handler_is_created() 106 | s.a_server_is_started() 107 | return s 108 | } 109 | 110 | func (s *htmlFileStage) active_session_is_closed() *htmlFileStage { 111 | s.session_is_active() 112 | s.session_is_closed() 113 | return s 114 | } 115 | 116 | func (s *htmlFileStage) session_is_active() *htmlFileStage { 117 | s.a_sockjs_htmlfile_connection_is_received() 118 | s.handler_should_be_started_with_session() 119 | return s 120 | } 121 | 122 | func (s *htmlFileStage) a_message_is_sent_from_client(msg string) *htmlFileStage { 123 | out, err := json.Marshal([]string{msg}) 124 | require.NoError(s.t, err) 125 | r, err := http.Post(s.server.URL+"/prefix/123/123/xhr_send", "application/json", bytes.NewReader(out)) 126 | require.NoError(s.t, err) 127 | require.Equal(s.t, http.StatusNoContent, r.StatusCode) 128 | return s 129 | } 130 | 131 | func (s *htmlFileStage) same_message_should_be_received_from_session(expectredMsg string) *htmlFileStage { 132 | select { 133 | case msg := <-s.receivedMessages: 134 | assert.Equal(s.t, expectredMsg, msg) 135 | case <-time.After(1 * time.Second): 136 | s.t.Fatal("no message was received") 137 | } 138 | return s 139 | } 140 | -------------------------------------------------------------------------------- /v3/sockjs/htmlfile_intergration_test.go: -------------------------------------------------------------------------------- 1 | package sockjs_test 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestHtmlFile_StartHandler(t *testing.T) { 8 | given, when, then := newHtmlFileStage(t) 9 | 10 | given. 11 | a_new_sockjs_handler_is_created().and(). 12 | a_server_is_started() 13 | 14 | when. 15 | a_sockjs_htmlfile_connection_is_received() 16 | 17 | then. 18 | correct_http_response_should_be_received().and(). 19 | handler_should_be_started_with_session() 20 | } 21 | 22 | func TestHtmlFile_CloseSession(t *testing.T) { 23 | given, when, then := newHtmlFileStage(t) 24 | 25 | given. 26 | a_server_is_started_with_handler() 27 | 28 | when. 29 | active_session_is_closed() 30 | 31 | then. 32 | valid_htmlfile_response_should_be_received() 33 | } 34 | 35 | func TestHtmlFile_SendMessage(t *testing.T) { 36 | given, when, then := newHtmlFileStage(t) 37 | 38 | given. 39 | a_server_is_started_with_handler() 40 | 41 | when. 42 | session_is_active().and(). 43 | a_message_is_sent_from_client("Hello World!").and(). 44 | active_session_is_closed() 45 | 46 | then. 47 | same_message_should_be_received_from_session("Hello World!") 48 | } 49 | -------------------------------------------------------------------------------- /v3/sockjs/htmlfile_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestHandler_htmlFileNoCallback(t *testing.T) { 11 | h := newTestHandler() 12 | rw := httptest.NewRecorder() 13 | req, _ := http.NewRequest("GET", "/server/session/htmlfile", nil) 14 | h.htmlFile(rw, req) 15 | if rw.Code != http.StatusBadRequest { 16 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusBadRequest) 17 | } 18 | expectedContentType := "text/plain; charset=utf-8" 19 | if rw.Header().Get("content-type") != expectedContentType { 20 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 21 | } 22 | } 23 | 24 | func TestHandler_htmlFile(t *testing.T) { 25 | h := newTestHandler() 26 | rw := httptest.NewRecorder() 27 | req, _ := http.NewRequest("GET", "/server/session/htmlfile?c=testCallback", nil) 28 | h.htmlFile(rw, req) 29 | if rw.Code != http.StatusOK { 30 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusOK) 31 | } 32 | expectedContentType := "text/html; charset=UTF-8" 33 | if rw.Header().Get("content-type") != expectedContentType { 34 | t.Errorf("Unexpected content-type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 35 | } 36 | if rw.Body.String() != expectedIFrame { 37 | t.Errorf("Unexpected response body, got '%s', expected '%s'", rw.Body, expectedIFrame) 38 | } 39 | sess, _ := h.sessionByRequest(req) 40 | if rt := sess.ReceiverType(); rt != ReceiverTypeHtmlFile { 41 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeHtmlFile) 42 | } 43 | } 44 | 45 | func TestHandler_cannotIntoXSS(t *testing.T) { 46 | h := newTestHandler() 47 | rw := httptest.NewRecorder() 48 | // test simple injection 49 | req, _ := http.NewRequest("GET", "/server/session/htmlfile?c=fake%3Balert(1337)", nil) 50 | h.htmlFile(rw, req) 51 | if rw.Code != http.StatusBadRequest { 52 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusBadRequest) 53 | } 54 | 55 | h = newTestHandler() 56 | rw = httptest.NewRecorder() 57 | // test simple injection 58 | req, _ = http.NewRequest("GET", "/server/session/htmlfile?c=fake%2Dalert", nil) 59 | h.htmlFile(rw, req) 60 | if rw.Code != http.StatusBadRequest { 61 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusBadRequest) 62 | } 63 | } 64 | 65 | func init() { 66 | expectedIFrame += strings.Repeat(" ", 1024-len(expectedIFrame)+len("testCallack")+13) 67 | expectedIFrame += "\r\n\r\n" 68 | expectedIFrame += "\r\n" 69 | } 70 | 71 | var expectedIFrame = ` 72 | 73 | 74 | 75 |

Don't panic!

76 | 83 | ` 84 | -------------------------------------------------------------------------------- /v3/sockjs/httpreceiver.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | type frameWriter interface { 12 | write(writer io.Writer, frame string) (int, error) 13 | } 14 | 15 | type httpReceiverState int 16 | 17 | const ( 18 | stateHTTPReceiverActive httpReceiverState = iota 19 | stateHTTPReceiverClosed 20 | ) 21 | 22 | type httpReceiver struct { 23 | sync.Mutex 24 | state httpReceiverState 25 | 26 | frameWriter frameWriter 27 | rw http.ResponseWriter 28 | maxResponseSize uint32 29 | currentResponseSize uint32 30 | doneCh chan struct{} 31 | interruptCh chan struct{} 32 | recType ReceiverType 33 | } 34 | 35 | func newHTTPReceiver(rw http.ResponseWriter, req *http.Request, maxResponse uint32, frameWriter frameWriter, receiverType ReceiverType) *httpReceiver { 36 | recv := &httpReceiver{ 37 | rw: rw, 38 | frameWriter: frameWriter, 39 | maxResponseSize: maxResponse, 40 | doneCh: make(chan struct{}), 41 | interruptCh: make(chan struct{}), 42 | recType: receiverType, 43 | } 44 | ctx := req.Context() 45 | 46 | go func() { 47 | select { 48 | case <-ctx.Done(): 49 | recv.Lock() 50 | defer recv.Unlock() 51 | if recv.state < stateHTTPReceiverClosed { 52 | recv.state = stateHTTPReceiverClosed 53 | close(recv.interruptCh) 54 | } 55 | case <-recv.doneCh: 56 | // ok, no action needed here, receiver closed in correct way 57 | // just finish the routine 58 | } 59 | }() 60 | return recv 61 | } 62 | 63 | func (recv *httpReceiver) sendBulk(messages ...string) error { 64 | if len(messages) > 0 { 65 | return recv.sendFrame(fmt.Sprintf("a[%s]", 66 | strings.Join( 67 | transform(messages, quote), 68 | ",", 69 | ), 70 | )) 71 | } 72 | return nil 73 | } 74 | 75 | func (recv *httpReceiver) sendFrame(value string) error { 76 | recv.Lock() 77 | defer recv.Unlock() 78 | 79 | if recv.state == stateHTTPReceiverActive { 80 | n, err := recv.frameWriter.write(recv.rw, value) 81 | if err != nil { 82 | return err 83 | } 84 | recv.currentResponseSize += uint32(n) 85 | if recv.currentResponseSize >= recv.maxResponseSize { 86 | recv.state = stateHTTPReceiverClosed 87 | close(recv.doneCh) 88 | } else { 89 | recv.rw.(http.Flusher).Flush() 90 | } 91 | } 92 | return nil 93 | } 94 | 95 | func (recv *httpReceiver) doneNotify() <-chan struct{} { return recv.doneCh } 96 | func (recv *httpReceiver) interruptedNotify() <-chan struct{} { return recv.interruptCh } 97 | func (recv *httpReceiver) close() { 98 | recv.Lock() 99 | defer recv.Unlock() 100 | if recv.state < stateHTTPReceiverClosed { 101 | recv.state = stateHTTPReceiverClosed 102 | close(recv.doneCh) 103 | } 104 | } 105 | func (recv *httpReceiver) canSend() bool { 106 | recv.Lock() 107 | defer recv.Unlock() 108 | return recv.state != stateHTTPReceiverClosed 109 | } 110 | 111 | func (recv *httpReceiver) receiverType() ReceiverType { 112 | return recv.recType 113 | } 114 | -------------------------------------------------------------------------------- /v3/sockjs/httpreceiver_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | type testFrameWriter struct { 13 | frames []string 14 | } 15 | 16 | func (t *testFrameWriter) write(w io.Writer, frame string) (int, error) { 17 | t.frames = append(t.frames, frame) 18 | return len(frame), nil 19 | } 20 | 21 | func TestHttpReceiver_Create(t *testing.T) { 22 | rec := httptest.NewRecorder() 23 | req, _ := http.NewRequest("GET", "", nil) 24 | recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter), ReceiverTypeNone) 25 | if recv.doneCh != recv.doneNotify() { 26 | t.Errorf("Calling done() must return close channel, but it does not") 27 | } 28 | if recv.rw != rec { 29 | t.Errorf("Http.ResponseWriter not properly initialized") 30 | } 31 | if recv.maxResponseSize != 1024 { 32 | t.Errorf("MaxResponseSize not properly initialized") 33 | } 34 | } 35 | 36 | func TestHttpReceiver_SendEmptyFrames(t *testing.T) { 37 | rec := httptest.NewRecorder() 38 | req, _ := http.NewRequest("GET", "", nil) 39 | recv := newHTTPReceiver(rec, req, 1024, new(testFrameWriter), ReceiverTypeNone) 40 | noError(t, recv.sendBulk()) 41 | if rec.Body.String() != "" { 42 | t.Errorf("Incorrect body content received from receiver '%s'", rec.Body.String()) 43 | } 44 | } 45 | 46 | func TestHttpReceiver_SendFrame(t *testing.T) { 47 | rec := httptest.NewRecorder() 48 | fw := new(testFrameWriter) 49 | req, _ := http.NewRequest("GET", "", nil) 50 | recv := newHTTPReceiver(rec, req, 1024, fw, ReceiverTypeNone) 51 | var frame = "some frame content" 52 | noError(t, recv.sendFrame(frame)) 53 | if len(fw.frames) != 1 || fw.frames[0] != frame { 54 | t.Errorf("Incorrect body content received, got '%s', expected '%s'", fw.frames, frame) 55 | } 56 | 57 | } 58 | 59 | func TestHttpReceiver_SendBulk(t *testing.T) { 60 | rec := httptest.NewRecorder() 61 | fw := new(testFrameWriter) 62 | req, _ := http.NewRequest("GET", "", nil) 63 | recv := newHTTPReceiver(rec, req, 1024, fw, ReceiverTypeNone) 64 | noError(t, recv.sendBulk("message 1", "message 2", "message 3")) 65 | expected := "a[\"message 1\",\"message 2\",\"message 3\"]" 66 | if len(fw.frames) != 1 || fw.frames[0] != expected { 67 | t.Errorf("Incorrect body content received from receiver, got '%s' expected '%s'", fw.frames, expected) 68 | } 69 | } 70 | 71 | func TestHttpReceiver_MaximumResponseSize(t *testing.T) { 72 | rec := httptest.NewRecorder() 73 | req, _ := http.NewRequest("GET", "", nil) 74 | recv := newHTTPReceiver(rec, req, 52, new(testFrameWriter), ReceiverTypeNone) 75 | noError(t, recv.sendBulk("message 1", "message 2")) // produces 26 bytes of response in 1 frame 76 | if recv.currentResponseSize != 26 { 77 | t.Errorf("Incorrect response size calcualated, got '%d' expected '%d'", recv.currentResponseSize, 26) 78 | } 79 | select { 80 | case <-recv.doneNotify(): 81 | t.Errorf("Receiver should not be done yet") 82 | default: // ok 83 | } 84 | noError(t, recv.sendBulk("message 1", "message 2")) // produces another 26 bytes of response in 1 frame to go over max resposne size 85 | select { 86 | case <-recv.doneNotify(): // ok 87 | default: 88 | t.Errorf("Receiver closed channel did not close") 89 | } 90 | } 91 | 92 | func TestHttpReceiver_Close(t *testing.T) { 93 | rec := httptest.NewRecorder() 94 | req, _ := http.NewRequest("GET", "", nil) 95 | recv := newHTTPReceiver(rec, req, 1024, nil, ReceiverTypeNone) 96 | recv.close() 97 | if recv.state != stateHTTPReceiverClosed { 98 | t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed) 99 | } 100 | } 101 | 102 | func TestHttpReceiver_ConnectionInterrupt(t *testing.T) { 103 | rw := httptest.NewRecorder() 104 | req, _ := http.NewRequest("GET", "", nil) 105 | ctx, cancel := context.WithCancel(req.Context()) 106 | req = req.WithContext(ctx) 107 | recv := newHTTPReceiver(rw, req, 1024, nil, ReceiverTypeNone) 108 | cancel() 109 | select { 110 | case <-recv.interruptCh: 111 | case <-time.After(1 * time.Second): 112 | t.Errorf("should interrupt") 113 | } 114 | if recv.state != stateHTTPReceiverClosed { 115 | t.Errorf("Unexpected state, got '%d', expected '%d'", recv.state, stateHTTPReceiverClosed) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /v3/sockjs/iframe.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "crypto/md5" 5 | "fmt" 6 | "net/http" 7 | "text/template" 8 | ) 9 | 10 | var tmpl = template.Must(template.New("iframe").Parse(iframeBody)) 11 | 12 | func (h *Handler) iframe(rw http.ResponseWriter, req *http.Request) { 13 | etagReq := req.Header.Get("If-None-Match") 14 | hash := md5.New() 15 | if _, err := hash.Write([]byte(iframeBody)); err!=nil { 16 | http.Error(rw, err.Error(), http.StatusInternalServerError) 17 | return 18 | } 19 | etag := fmt.Sprintf("%x", hash.Sum(nil)) 20 | if etag == etagReq { 21 | rw.WriteHeader(http.StatusNotModified) 22 | return 23 | } 24 | 25 | rw.Header().Set("Content-Type", "text/html; charset=UTF-8") 26 | rw.Header().Add("ETag", etag) 27 | if err := tmpl.Execute(rw, h.options.SockJSURL); err!=nil { 28 | http.Error(rw, "could not render iframe content: "+err.Error(), http.StatusInternalServerError) 29 | return 30 | } 31 | } 32 | 33 | var iframeBody = ` 34 | 35 | 36 | 37 | 38 | 42 | 43 | 44 | 45 |

Don't panic!

46 |

This is a SockJS hidden iframe. It's used for cross domain magic.

47 | 48 | ` 49 | -------------------------------------------------------------------------------- /v3/sockjs/iframe_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestHandler_iframe(t *testing.T) { 10 | h := newTestHandler() 11 | h.options.SockJSURL = "http://sockjs.com/sockjs.js" 12 | rw := httptest.NewRecorder() 13 | req, _ := http.NewRequest("GET", "/server/sess/iframe", nil) 14 | h.iframe(rw, req) 15 | if rw.Body.String() != expected { 16 | t.Errorf("Unexpected html content,\ngot:\n'%s'\n\nexpected\n'%s'", rw.Body, expected) 17 | } 18 | eTag := rw.Header().Get("etag") 19 | req.Header.Set("if-none-match", eTag) 20 | rw = httptest.NewRecorder() 21 | h.iframe(rw, req) 22 | if rw.Code != http.StatusNotModified { 23 | t.Errorf("Unexpected response, got '%d', expected '%d'", rw.Code, http.StatusNotModified) 24 | } 25 | } 26 | 27 | var expected = ` 28 | 29 | 30 | 31 | 32 | 36 | 37 | 38 | 39 |

Don't panic!

40 |

This is a SockJS hidden iframe. It's used for cross domain magic.

41 | 42 | ` 43 | -------------------------------------------------------------------------------- /v3/sockjs/jsonp.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | ) 10 | 11 | func (h *Handler) jsonp(rw http.ResponseWriter, req *http.Request) { 12 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 13 | 14 | if err := req.ParseForm(); err != nil { 15 | http.Error(rw, err.Error(), http.StatusBadRequest) 16 | return 17 | } 18 | callback := req.Form.Get("c") 19 | if callback == "" { 20 | http.Error(rw, `"callback" parameter required`, http.StatusInternalServerError) 21 | return 22 | } else if invalidCallback.MatchString(callback) { 23 | http.Error(rw, `invalid character in "callback" parameter`, http.StatusBadRequest) 24 | return 25 | } 26 | rw.WriteHeader(http.StatusOK) 27 | rw.(http.Flusher).Flush() 28 | 29 | sess, err := h.sessionByRequest(req) 30 | if err != nil { 31 | http.Error(rw, err.Error(), http.StatusInternalServerError) 32 | return 33 | } 34 | recv := newHTTPReceiver(rw, req, 1, &jsonpFrameWriter{callback}, ReceiverTypeJSONP) 35 | if err := sess.attachReceiver(recv); err != nil { 36 | if err := recv.sendFrame(cFrame); err != nil { 37 | http.Error(rw, err.Error(), http.StatusInternalServerError) 38 | return 39 | } 40 | recv.close() 41 | return 42 | } 43 | sess.startHandlerOnce.Do(func() { go h.handlerFunc(Session{sess}) }) 44 | select { 45 | case <-recv.doneNotify(): 46 | case <-recv.interruptedNotify(): 47 | } 48 | } 49 | 50 | func (h *Handler) jsonpSend(rw http.ResponseWriter, req *http.Request) { 51 | if err := req.ParseForm(); err != nil { 52 | http.Error(rw, err.Error(), http.StatusBadRequest) 53 | return 54 | } 55 | var data io.Reader 56 | data = req.Body 57 | 58 | formReader := strings.NewReader(req.PostFormValue("d")) 59 | if formReader.Len() != 0 { 60 | data = formReader 61 | } 62 | if data == nil { 63 | http.Error(rw, "Payload expected.", http.StatusBadRequest) 64 | return 65 | } 66 | var messages []string 67 | err := json.NewDecoder(data).Decode(&messages) 68 | if err == io.EOF { 69 | http.Error(rw, "Payload expected.", http.StatusBadRequest) 70 | return 71 | } 72 | if err != nil { 73 | http.Error(rw, "Broken JSON encoding.", http.StatusBadRequest) 74 | return 75 | } 76 | sessionID, _ := h.parseSessionID(req.URL) 77 | h.sessionsMux.Lock() 78 | sess, ok := h.sessions[sessionID] 79 | h.sessionsMux.Unlock() 80 | if !ok { 81 | http.NotFound(rw, req) 82 | } else { 83 | if err := sess.accept(messages...); err != nil { 84 | http.Error(rw, err.Error(), http.StatusInternalServerError) 85 | return 86 | } 87 | rw.Header().Set("content-type", "text/plain; charset=UTF-8") 88 | _, _ = rw.Write([]byte("ok")) 89 | } 90 | } 91 | 92 | type jsonpFrameWriter struct { 93 | callback string 94 | } 95 | 96 | func (j *jsonpFrameWriter) write(w io.Writer, frame string) (int, error) { 97 | return fmt.Fprintf(w, "%s(%s);\r\n", j.callback, quote(frame)) 98 | } 99 | -------------------------------------------------------------------------------- /v3/sockjs/jsonp_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestHandler_jsonpNoCallback(t *testing.T) { 12 | h := newTestHandler() 13 | rw := httptest.NewRecorder() 14 | req, _ := http.NewRequest("GET", "/server/session/jsonp", nil) 15 | h.jsonp(rw, req) 16 | if rw.Code != http.StatusInternalServerError { 17 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 18 | } 19 | expectedContentType := "text/plain; charset=utf-8" 20 | if rw.Header().Get("content-type") != expectedContentType { 21 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 22 | } 23 | } 24 | 25 | func TestHandler_jsonp(t *testing.T) { 26 | h := newTestHandler() 27 | rw := httptest.NewRecorder() 28 | req, _ := http.NewRequest("GET", "/server/session/jsonp?c=testCallback", nil) 29 | h.jsonp(rw, req) 30 | expectedContentType := "application/javascript; charset=UTF-8" 31 | if rw.Header().Get("content-type") != expectedContentType { 32 | t.Errorf("Unexpected content type, got '%s', expected '%s'", rw.Header().Get("content-type"), expectedContentType) 33 | } 34 | expectedBody := "testCallback(\"o\");\r\n" 35 | if rw.Body.String() != expectedBody { 36 | t.Errorf("Unexpected body, got '%s', expected '%s'", rw.Body, expectedBody) 37 | } 38 | sess, _ := h.sessionByRequest(req) 39 | if rt := sess.ReceiverType(); rt != ReceiverTypeJSONP { 40 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeJSONP) 41 | } 42 | } 43 | 44 | func TestHandler_jsonpSendNoPayload(t *testing.T) { 45 | h := newTestHandler() 46 | rw := httptest.NewRecorder() 47 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", nil) 48 | h.jsonpSend(rw, req) 49 | if rw.Code != http.StatusBadRequest { 50 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 51 | } 52 | } 53 | 54 | func TestHandler_jsonpSendWrongPayload(t *testing.T) { 55 | h := newTestHandler() 56 | rw := httptest.NewRecorder() 57 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("wrong payload")) 58 | h.jsonpSend(rw, req) 59 | if rw.Code != http.StatusBadRequest { 60 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusInternalServerError) 61 | } 62 | } 63 | 64 | func TestHandler_jsonpSendNoSession(t *testing.T) { 65 | h := newTestHandler() 66 | rw := httptest.NewRecorder() 67 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("[\"message\"]")) 68 | h.jsonpSend(rw, req) 69 | if rw.Code != http.StatusNotFound { 70 | t.Errorf("Unexpected response code, got '%d', expected '%d'", rw.Code, http.StatusNotFound) 71 | } 72 | } 73 | 74 | func TestHandler_jsonpSend(t *testing.T) { 75 | h := newTestHandler() 76 | 77 | rw := httptest.NewRecorder() 78 | req, _ := http.NewRequest("POST", "/server/session/jsonp_send", strings.NewReader("[\"message\"]")) 79 | 80 | sess := newSession(req, "session", time.Second, time.Second) 81 | h.sessions["session"] = sess 82 | 83 | var done = make(chan struct{}) 84 | go func() { 85 | h.jsonpSend(rw, req) 86 | close(done) 87 | }() 88 | msg, _ := sess.Recv() 89 | if msg != "message" { 90 | t.Errorf("Incorrect message in the channel, should be '%s', was '%s'", "some message", msg) 91 | } 92 | <-done 93 | if rw.Code != http.StatusOK { 94 | t.Errorf("Wrong response status received %d, should be %d", rw.Code, http.StatusOK) 95 | } 96 | if rw.Header().Get("content-type") != "text/plain; charset=UTF-8" { 97 | t.Errorf("Wrong content type received '%s'", rw.Header().Get("content-type")) 98 | } 99 | if rw.Body.String() != "ok" { 100 | t.Errorf("Unexpected body, got '%s', expected 'ok'", rw.Body) 101 | } 102 | } 103 | 104 | func TestHandler_jsonpCannotIntoXSS(t *testing.T) { 105 | h := newTestHandler() 106 | rw := httptest.NewRecorder() 107 | req, _ := http.NewRequest("GET", "/server/session/jsonp?c=%3Chtml%3E%3Chead%3E%3Cscript%3Ealert(5520)%3C%2Fscript%3E", nil) 108 | h.jsonp(rw, req) 109 | if rw.Code != http.StatusBadRequest { 110 | t.Errorf("JsonP forwarded an exploitable response.") 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /v3/sockjs/mapping.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | ) 7 | 8 | type mapping struct { 9 | method string 10 | path *regexp.Regexp 11 | chain []http.HandlerFunc 12 | } 13 | 14 | func newMapping(method string, re string, handlers ...http.HandlerFunc) *mapping { 15 | return &mapping{method, regexp.MustCompile(re), handlers} 16 | } 17 | 18 | type matchType uint32 19 | 20 | const ( 21 | fullMatch matchType = iota 22 | pathMatch 23 | noMatch 24 | ) 25 | 26 | // matches checks if given req.URL is a match with a mapping. Match can be either full, partial (http method mismatch) or no match. 27 | func (m *mapping) matches(req *http.Request) (match matchType, method string) { 28 | if !m.path.MatchString(req.URL.Path) { 29 | match, method = noMatch, "" 30 | } else if m.method != req.Method { 31 | match, method = pathMatch, m.method 32 | } else { 33 | match, method = fullMatch, m.method 34 | } 35 | return 36 | } 37 | -------------------------------------------------------------------------------- /v3/sockjs/mapping_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestMappingMatcher(t *testing.T) { 10 | mappingPrefix := mapping{"GET", regexp.MustCompile("prefix/$"), nil} 11 | mappingPrefixRegExp := mapping{"GET", regexp.MustCompile(".*x/$"), nil} 12 | 13 | var testRequests = []struct { 14 | mapping mapping 15 | method string 16 | url string 17 | expectedMatch matchType 18 | }{ 19 | {mappingPrefix, "GET", "http://foo/prefix/", fullMatch}, 20 | {mappingPrefix, "POST", "http://foo/prefix/", pathMatch}, 21 | {mappingPrefix, "GET", "http://foo/prefix_not_mapped", noMatch}, 22 | {mappingPrefixRegExp, "GET", "http://foo/prefix/", fullMatch}, 23 | } 24 | 25 | for _, request := range testRequests { 26 | req, _ := http.NewRequest(request.method, request.url, nil) 27 | m := request.mapping 28 | match, method := m.matches(req) 29 | if match != request.expectedMatch { 30 | t.Errorf("mapping %s should match url=%s", m.path, request.url) 31 | } 32 | if request.expectedMatch == pathMatch { 33 | if method != m.method { 34 | t.Errorf("Matcher method should be %s, but got %s", m.method, method) 35 | } 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /v3/sockjs/options_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestInfoGet(t *testing.T) { 11 | recorder := httptest.NewRecorder() 12 | request, _ := http.NewRequest("GET", "", nil) 13 | DefaultOptions.info(recorder, request) 14 | 15 | if recorder.Code != http.StatusOK { 16 | t.Errorf("Wrong status code, got '%d' expected '%d'", recorder.Code, http.StatusOK) 17 | } 18 | 19 | decoder := json.NewDecoder(recorder.Body) 20 | var a info 21 | if err := decoder.Decode(&a); err != nil { 22 | t.Error(err) 23 | t.Fail() 24 | } 25 | if !a.Websocket { 26 | t.Errorf("Websocket field should be set true") 27 | } 28 | if a.CookieNeeded { 29 | t.Errorf("CookieNeeded should be set to false") 30 | } 31 | } 32 | 33 | func TestInfoOptions(t *testing.T) { 34 | recorder := httptest.NewRecorder() 35 | request, _ := http.NewRequest("OPTIONS", "", nil) 36 | DefaultOptions.info(recorder, request) 37 | if recorder.Code != http.StatusNoContent { 38 | t.Errorf("Incorrect status code received, got '%d' expected '%d'", recorder.Code, http.StatusNoContent) 39 | } 40 | } 41 | 42 | func TestInfoUnknown(t *testing.T) { 43 | req, _ := http.NewRequest("PUT", "", nil) 44 | rec := httptest.NewRecorder() 45 | DefaultOptions.info(rec, req) 46 | if rec.Code != http.StatusNotFound { 47 | t.Errorf("Incorrec response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 48 | } 49 | } 50 | 51 | func TestCookies(t *testing.T) { 52 | rec := httptest.NewRecorder() 53 | req, _ := http.NewRequest("GET", "", nil) 54 | optionsWithCookies := DefaultOptions 55 | optionsWithCookies.JSessionID = DefaultJSessionID 56 | optionsWithCookies.cookie(rec, req) 57 | if rec.Header().Get("set-cookie") != "JSESSIONID=dummy; Path=/" { 58 | t.Errorf("Cookie not properly set in response") 59 | } 60 | // cookie value set in request 61 | req.AddCookie(&http.Cookie{Name: "JSESSIONID", Value: "some_jsession_id", Path: "/"}) 62 | rec = httptest.NewRecorder() 63 | optionsWithCookies.cookie(rec, req) 64 | if rec.Header().Get("set-cookie") != "JSESSIONID=some_jsession_id; Path=/" { 65 | t.Errorf("Cookie not properly set in response") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /v3/sockjs/rawwebsocket.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | func (h *Handler) rawWebsocket(rw http.ResponseWriter, req *http.Request) { 12 | upgrader := h.options.WebsocketUpgrader 13 | if upgrader == nil { 14 | upgrader = new(websocket.Upgrader) 15 | } 16 | conn, err := upgrader.Upgrade(rw, req, nil) 17 | if err != nil { 18 | return 19 | } 20 | 21 | sessID := "" 22 | sess := newSession(req, sessID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 23 | sess.raw = true 24 | 25 | receiver := newRawWsReceiver(conn, h.options.WebsocketWriteTimeout) 26 | if err := sess.attachReceiver(receiver); err != nil { 27 | http.Error(rw, err.Error(), http.StatusInternalServerError) 28 | return 29 | } 30 | if h.handlerFunc != nil { 31 | go h.handlerFunc(Session{sess}) 32 | } 33 | readCloseCh := make(chan struct{}) 34 | go func() { 35 | for { 36 | frameType, p, err := conn.ReadMessage() 37 | if err != nil { 38 | close(readCloseCh) 39 | return 40 | } 41 | if frameType == websocket.TextMessage || frameType == websocket.BinaryMessage { 42 | if err := sess.accept(string(p)); err != nil { 43 | close(readCloseCh) 44 | return 45 | } 46 | } 47 | } 48 | }() 49 | 50 | select { 51 | case <-readCloseCh: 52 | case <-receiver.doneNotify(): 53 | } 54 | sess.close() 55 | if err := conn.Close(); err != nil { 56 | http.Error(rw, err.Error(), http.StatusInternalServerError) 57 | return 58 | } 59 | } 60 | 61 | type rawWsReceiver struct { 62 | conn *websocket.Conn 63 | closeCh chan struct{} 64 | writeTimeout time.Duration 65 | } 66 | 67 | func newRawWsReceiver(conn *websocket.Conn, writeTimeout time.Duration) *rawWsReceiver { 68 | return &rawWsReceiver{ 69 | conn: conn, 70 | closeCh: make(chan struct{}), 71 | writeTimeout: writeTimeout, 72 | } 73 | } 74 | 75 | func (w *rawWsReceiver) sendBulk(messages ...string) error { 76 | if len(messages) > 0 { 77 | for _, m := range messages { 78 | if w.writeTimeout != 0 { 79 | if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { 80 | w.close() 81 | return err 82 | } 83 | } 84 | if err := w.conn.WriteMessage(websocket.TextMessage, []byte(m)); err != nil { 85 | w.close() 86 | return err 87 | } 88 | 89 | } 90 | } 91 | return nil 92 | } 93 | 94 | func (w *rawWsReceiver) sendFrame(frame string) error { 95 | if w.writeTimeout != 0 { 96 | if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { 97 | w.close() 98 | return err 99 | } 100 | } 101 | if frame == "h" { 102 | if err := w.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 103 | w.close() 104 | return err 105 | } 106 | } else if len(frame) > 0 && frame[0] == 'c' { 107 | status, reason, err := parseCloseFrame(frame) 108 | if err != nil { 109 | w.close() 110 | return err 111 | } 112 | msg := websocket.FormatCloseMessage(int(status), reason) 113 | if err := w.conn.WriteMessage(websocket.CloseMessage, msg); err != nil { 114 | w.close() 115 | return err 116 | } 117 | } else { 118 | if err := w.conn.WriteMessage(websocket.TextMessage, []byte(frame)); err != nil { 119 | w.close() 120 | return err 121 | } 122 | } 123 | return nil 124 | } 125 | 126 | func (w *rawWsReceiver) receiverType() ReceiverType { 127 | return ReceiverTypeRawWebsocket 128 | } 129 | 130 | func parseCloseFrame(frame string) (status uint32, reason string, err error) { 131 | var items [2]interface{} 132 | if err := json.Unmarshal([]byte(frame)[1:], &items); err != nil { 133 | return 0, "", err 134 | } 135 | statusF, _ := items[0].(float64) 136 | status = uint32(statusF) 137 | reason, _ = items[1].(string) 138 | return 139 | } 140 | 141 | func (w *rawWsReceiver) close() { 142 | select { 143 | case <-w.closeCh: // already closed 144 | default: 145 | close(w.closeCh) 146 | } 147 | } 148 | func (w *rawWsReceiver) canSend() bool { 149 | select { 150 | case <-w.closeCh: // already closed 151 | return false 152 | default: 153 | return true 154 | } 155 | } 156 | func (w *rawWsReceiver) doneNotify() <-chan struct{} { return w.closeCh } 157 | func (w *rawWsReceiver) interruptedNotify() <-chan struct{} { return nil } 158 | -------------------------------------------------------------------------------- /v3/sockjs/rawwebsocket_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | 9 | "github.com/gorilla/websocket" 10 | ) 11 | 12 | func TestHandler_RawWebSocketHandshakeError(t *testing.T) { 13 | h := newTestHandler() 14 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 15 | defer server.Close() 16 | req, _ := http.NewRequest("GET", server.URL, nil) 17 | req.Header.Set("origin", "https"+server.URL[4:]) 18 | resp, _ := http.DefaultClient.Do(req) 19 | if resp.StatusCode != http.StatusBadRequest { 20 | t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest) 21 | } 22 | } 23 | 24 | func TestHandler_RawWebSocket(t *testing.T) { 25 | h := newTestHandler() 26 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 27 | defer server.CloseClientConnections() 28 | url := "ws" + server.URL[4:] 29 | var connCh = make(chan Session) 30 | h.handlerFunc = func(conn Session) { connCh <- conn } 31 | conn, resp, err := websocket.DefaultDialer.Dial(url, nil) 32 | if conn == nil { 33 | t.Errorf("Connection should not be nil") 34 | } 35 | if err != nil { 36 | t.Errorf("Unexpected error '%v'", err) 37 | } 38 | if resp.StatusCode != http.StatusSwitchingProtocols { 39 | t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols) 40 | } 41 | select { 42 | case <-connCh: //ok 43 | case <-time.After(10 * time.Millisecond): 44 | t.Errorf("Sockjs Handler not invoked") 45 | } 46 | } 47 | 48 | func TestHandler_RawWebSocketTerminationByServer(t *testing.T) { 49 | h := newTestHandler() 50 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 51 | defer server.Close() 52 | url := "ws" + server.URL[4:] 53 | h.handlerFunc = func(conn Session) { 54 | // close the session without sending any message 55 | if rt := conn.ReceiverType(); rt != ReceiverTypeRawWebsocket { 56 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeRawWebsocket) 57 | } 58 | conn.Close(3000, "some close message") 59 | conn.Close(0, "this should be ignored") 60 | } 61 | conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 62 | if err != nil { 63 | t.Fatalf("websocket dial failed: %v", err) 64 | } 65 | for i := 0; i < 2; i++ { 66 | _, _, err := conn.ReadMessage() 67 | closeError, ok := err.(*websocket.CloseError) 68 | if !ok { 69 | t.Fatalf("expected close error but got: %v", err) 70 | } 71 | if closeError.Code != 3000 { 72 | t.Errorf("unexpected close status: %v", closeError.Code) 73 | } 74 | if closeError.Text != "some close message" { 75 | t.Errorf("unexpected close reason: '%v'", closeError.Text) 76 | } 77 | } 78 | } 79 | 80 | func TestHandler_RawWebSocketTerminationByClient(t *testing.T) { 81 | h := newTestHandler() 82 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 83 | defer server.Close() 84 | url := "ws" + server.URL[4:] 85 | var done = make(chan struct{}) 86 | h.handlerFunc = func(conn Session) { 87 | if _, err := conn.Recv(); err != ErrSessionNotOpen { 88 | t.Errorf("Recv should fail") 89 | } 90 | close(done) 91 | } 92 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 93 | conn.Close() 94 | <-done 95 | } 96 | 97 | func TestHandler_RawWebSocketCommunication(t *testing.T) { 98 | h := newTestHandler() 99 | h.options.WebsocketWriteTimeout = time.Second 100 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 101 | // defer server.CloseClientConnections() 102 | url := "ws" + server.URL[4:] 103 | var done = make(chan struct{}) 104 | h.handlerFunc = func(conn Session) { 105 | _ = conn.Send("message 1") 106 | _ = conn.Send("message 2") 107 | expected := "[\"message 3\"]\n" 108 | msg, err := conn.Recv() 109 | if msg != expected || err != nil { 110 | t.Errorf("Got '%s', expected '%s'", msg, expected) 111 | } 112 | _ = conn.Close(123, "close") 113 | close(done) 114 | } 115 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 116 | _ = conn.WriteJSON([]string{"message 3"}) 117 | var expected = []string{"message 1", "message 2"} 118 | for _, exp := range expected { 119 | _, msg, err := conn.ReadMessage() 120 | if string(msg) != exp || err != nil { 121 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 122 | } 123 | } 124 | <-done 125 | } 126 | 127 | func TestHandler_RawCustomWebSocketCommunication(t *testing.T) { 128 | h := newTestHandler() 129 | h.options.WebsocketWriteTimeout = time.Second 130 | h.options.WebsocketUpgrader = &websocket.Upgrader{ 131 | ReadBufferSize: 0, 132 | WriteBufferSize: 0, 133 | CheckOrigin: func(_ *http.Request) bool { return true }, 134 | Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, 135 | } 136 | server := httptest.NewServer(http.HandlerFunc(h.rawWebsocket)) 137 | url := "ws" + server.URL[4:] 138 | var done = make(chan struct{}) 139 | h.handlerFunc = func(conn Session) { 140 | _ = conn.Send("message 1") 141 | _ = conn.Send("message 2") 142 | expected := "[\"message 3\"]\n" 143 | msg, err := conn.Recv() 144 | if msg != expected || err != nil { 145 | t.Errorf("Got '%s', expected '%s'", msg, expected) 146 | } 147 | _ = conn.Close(123, "close") 148 | close(done) 149 | } 150 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 151 | _ = conn.WriteJSON([]string{"message 3"}) 152 | var expected = []string{"message 1", "message 2"} 153 | for _, exp := range expected { 154 | _, msg, err := conn.ReadMessage() 155 | if string(msg) != exp || err != nil { 156 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 157 | } 158 | } 159 | <-done 160 | } 161 | -------------------------------------------------------------------------------- /v3/sockjs/receiver.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | type ReceiverType int 4 | 5 | const ( 6 | ReceiverTypeNone ReceiverType = iota 7 | ReceiverTypeXHR 8 | ReceiverTypeEventSource 9 | ReceiverTypeHtmlFile 10 | ReceiverTypeJSONP 11 | ReceiverTypeXHRStreaming 12 | ReceiverTypeRawWebsocket 13 | ReceiverTypeWebsocket 14 | ) 15 | 16 | type receiver interface { 17 | // sendBulk send multiple data messages in frame frame in format: a["msg 1", "msg 2", ....] 18 | sendBulk(...string) error 19 | // sendFrame sends given frame over the wire (with possible chunking depending on receiver) 20 | sendFrame(string) error 21 | // close closes the receiver in a "done" way (idempotent) 22 | close() 23 | canSend() bool 24 | // done notification channel gets closed whenever receiver ends 25 | doneNotify() <-chan struct{} 26 | // interrupted channel gets closed whenever receiver is interrupted (i.e. http connection drops,...) 27 | interruptedNotify() <-chan struct{} 28 | // returns the type of receiver 29 | receiverType() ReceiverType 30 | } 31 | -------------------------------------------------------------------------------- /v3/sockjs/session.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | // SessionState defines the current state of the session 12 | type SessionState uint32 13 | 14 | const ( 15 | // brand new session, need to send "h" to receiver 16 | SessionOpening SessionState = iota 17 | // active session 18 | SessionActive 19 | // session being closed, sending "closeFrame" to receivers 20 | SessionClosing 21 | // closed session, no activity at all, should be removed from handler completely and not reused 22 | SessionClosed 23 | ) 24 | 25 | var ( 26 | // ErrSessionNotOpen error is used to denote session not in open state. 27 | // Recv() and Send() operations are not supported if session is closed. 28 | ErrSessionNotOpen = errors.New("sockjs: session not in open state") 29 | errSessionReceiverAttached = errors.New("sockjs: another receiver already attached") 30 | errSessionParse = errors.New("sockjs: unable to parse URL for session") 31 | ) 32 | 33 | type Session struct { 34 | *session 35 | } 36 | 37 | type session struct { 38 | mux sync.RWMutex 39 | id string 40 | req *http.Request 41 | state SessionState 42 | 43 | recv receiver // protocol dependent receiver (xhr, eventsource, ...) 44 | receiverType ReceiverType 45 | sendBuffer []string // messages to be sent to client 46 | recvBuffer *messageBuffer // messages received from client to be consumed by application 47 | closeFrame string // closeFrame to send after session is closed 48 | 49 | // do not use SockJS framing for raw websocket connections 50 | raw bool 51 | 52 | // internal timer used to handle session expiration if no receiver is attached, or heartbeats if recevier is attached 53 | sessionTimeoutInterval time.Duration 54 | heartbeatInterval time.Duration 55 | timer *time.Timer 56 | // once the session timeouts this channel also closes 57 | closeCh chan struct{} 58 | startHandlerOnce sync.Once 59 | context context.Context 60 | cancelFunc func() 61 | } 62 | 63 | // session is a central component that handles receiving and sending frames. It maintains internal state 64 | func newSession(req *http.Request, sessionID string, sessionTimeoutInterval, heartbeatInterval time.Duration) *session { 65 | context, cancel := context.WithCancel(context.Background()) 66 | s := &session{ 67 | id: sessionID, 68 | req: req, 69 | heartbeatInterval: heartbeatInterval, 70 | recvBuffer: newMessageBuffer(), 71 | closeCh: make(chan struct{}), 72 | sessionTimeoutInterval: sessionTimeoutInterval, 73 | receiverType: ReceiverTypeNone, 74 | context: context, 75 | cancelFunc: cancel, 76 | } 77 | 78 | s.mux.Lock() 79 | s.timer = time.AfterFunc(sessionTimeoutInterval, s.close) 80 | s.mux.Unlock() 81 | return s 82 | } 83 | 84 | func (s *session) sendMessage(msg string) error { 85 | s.mux.Lock() 86 | defer s.mux.Unlock() 87 | if s.state > SessionActive { 88 | return ErrSessionNotOpen 89 | } 90 | s.sendBuffer = append(s.sendBuffer, msg) 91 | if s.recv != nil && s.recv.canSend() { 92 | if err := s.recv.sendBulk(s.sendBuffer...); err != nil { 93 | return err 94 | } 95 | s.sendBuffer = nil 96 | } 97 | return nil 98 | } 99 | 100 | func (s *session) attachReceiver(recv receiver) error { 101 | s.mux.Lock() 102 | defer s.mux.Unlock() 103 | if s.recv != nil { 104 | return errSessionReceiverAttached 105 | } 106 | s.recv = recv 107 | s.receiverType = recv.receiverType() 108 | go func(r receiver) { 109 | select { 110 | case <-r.doneNotify(): 111 | s.detachReceiver() 112 | case <-r.interruptedNotify(): 113 | s.detachReceiver() 114 | s.close() 115 | } 116 | }(recv) 117 | 118 | if s.state == SessionClosing { 119 | if !s.raw { 120 | if err := s.recv.sendFrame(s.closeFrame); err != nil { 121 | return err 122 | } 123 | } 124 | s.recv.close() 125 | return nil 126 | } 127 | if s.state == SessionOpening { 128 | if !s.raw { 129 | if err := s.recv.sendFrame("o"); err != nil { 130 | return err 131 | } 132 | } 133 | s.state = SessionActive 134 | } 135 | if err := s.recv.sendBulk(s.sendBuffer...); err != nil { 136 | return err 137 | } 138 | s.sendBuffer = nil 139 | s.timer.Stop() 140 | if s.heartbeatInterval > 0 { 141 | s.timer = time.AfterFunc(s.heartbeatInterval, s.heartbeat) 142 | } 143 | return nil 144 | } 145 | 146 | func (s *session) detachReceiver() { 147 | s.mux.Lock() 148 | s.timer.Stop() 149 | s.timer = time.AfterFunc(s.sessionTimeoutInterval, s.close) 150 | s.recv = nil 151 | s.mux.Unlock() 152 | } 153 | 154 | func (s *session) heartbeat() { 155 | s.mux.Lock() 156 | if s.recv != nil { // timer could have fired between Lock and timer.Stop in detachReceiver 157 | _ = s.recv.sendFrame("h") 158 | s.timer = time.AfterFunc(s.heartbeatInterval, s.heartbeat) 159 | } 160 | s.mux.Unlock() 161 | } 162 | 163 | func (s *session) accept(messages ...string) error { 164 | return s.recvBuffer.push(messages...) 165 | } 166 | 167 | // idempotent operation 168 | func (s *session) closing() { 169 | s.mux.Lock() 170 | defer s.mux.Unlock() 171 | if s.state < SessionClosing { 172 | s.state = SessionClosing 173 | s.recvBuffer.close() 174 | if s.recv != nil { 175 | _ = s.recv.sendFrame(s.closeFrame) 176 | s.recv.close() 177 | } 178 | s.cancelFunc() 179 | } 180 | } 181 | 182 | // idempotent operation 183 | func (s *session) close() { 184 | s.closing() 185 | s.mux.Lock() 186 | defer s.mux.Unlock() 187 | if s.state < SessionClosed { 188 | s.state = SessionClosed 189 | s.timer.Stop() 190 | close(s.closeCh) 191 | s.cancelFunc() 192 | } 193 | } 194 | 195 | func (s *session) setCurrentRequest(req *http.Request) { 196 | s.mux.Lock() 197 | s.req = req 198 | s.mux.Unlock() 199 | } 200 | 201 | // Close closes the session with provided code and reason. 202 | func (s *session) Close(status uint32, reason string) error { 203 | s.mux.Lock() 204 | if s.state < SessionClosing { 205 | s.closeFrame = closeFrame(status, reason) 206 | s.mux.Unlock() 207 | s.closing() 208 | return nil 209 | } 210 | s.mux.Unlock() 211 | return ErrSessionNotOpen 212 | } 213 | 214 | // ID returns a session id 215 | func (s *session) ID() string { 216 | return s.id 217 | } 218 | 219 | // Recv reads one text frame from session 220 | func (s *session) Recv() (string, error) { 221 | return s.recvBuffer.pop(context.Background()) 222 | } 223 | 224 | // RecvCtx reads one text frame from session 225 | func (s *session) RecvCtx(ctx context.Context) (string, error) { 226 | return s.recvBuffer.pop(ctx) 227 | } 228 | 229 | // Send sends one text frame to session 230 | func (s *session) Send(msg string) error { 231 | return s.sendMessage(msg) 232 | } 233 | 234 | // Request returns the first http request 235 | func (s *session) Request() *http.Request { 236 | s.mux.RLock() 237 | defer s.mux.RUnlock() 238 | s.req.Context() 239 | return s.req 240 | } 241 | 242 | //GetSessionState returns the current state of the session 243 | func (s *session) GetSessionState() SessionState { 244 | s.mux.RLock() 245 | defer s.mux.RUnlock() 246 | return s.state 247 | } 248 | 249 | //ReceiverType returns receiver used in session 250 | func (s *session) ReceiverType() ReceiverType { 251 | s.mux.RLock() 252 | defer s.mux.RUnlock() 253 | return s.receiverType 254 | } 255 | 256 | // Context returns session context, the context is cancelled 257 | // whenever the session gets into closing or closed state 258 | func (s *session) Context() context.Context { 259 | return s.context 260 | } 261 | -------------------------------------------------------------------------------- /v3/sockjs/sockjs.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | -------------------------------------------------------------------------------- /v3/sockjs/sockjs_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "regexp" 7 | "testing" 8 | ) 9 | 10 | func TestSockJS_ServeHTTP(t *testing.T) { 11 | m := Handler{mappings: make([]*mapping, 0)} 12 | m.mappings = []*mapping{ 13 | {"POST", regexp.MustCompile("/foo/.*"), []http.HandlerFunc{func(http.ResponseWriter, *http.Request) {}}}, 14 | } 15 | req, _ := http.NewRequest("GET", "/foo/bar", nil) 16 | rec := httptest.NewRecorder() 17 | m.ServeHTTP(rec, req) 18 | if rec.Code != http.StatusMethodNotAllowed { 19 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusMethodNotAllowed) 20 | } 21 | req, _ = http.NewRequest("GET", "/bar", nil) 22 | rec = httptest.NewRecorder() 23 | m.ServeHTTP(rec, req) 24 | if rec.Code != http.StatusNotFound { 25 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /v3/sockjs/utils.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "encoding/json" 4 | 5 | func quote(in string) string { 6 | quoted, _ := json.Marshal(in) 7 | return string(quoted) 8 | } 9 | 10 | func transform(values []string, transformFn func(string) string) []string { 11 | ret := make([]string, len(values)) 12 | for i, msg := range values { 13 | ret[i] = transformFn(msg) 14 | } 15 | return ret 16 | } 17 | -------------------------------------------------------------------------------- /v3/sockjs/utils_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import "testing" 4 | 5 | func TestQuote(t *testing.T) { 6 | var quotationTests = []struct { 7 | input string 8 | output string 9 | }{ 10 | {"simple", "\"simple\""}, 11 | {"more complex \"", "\"more complex \\\"\""}, 12 | } 13 | 14 | for _, testCase := range quotationTests { 15 | if quote(testCase.input) != testCase.output { 16 | t.Errorf("Expected '%s', got '%s'", testCase.output, quote(testCase.input)) 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /v3/sockjs/web.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | func xhrCorsFactory(opts Options) func(rw http.ResponseWriter, req *http.Request) { 10 | return func(rw http.ResponseWriter, req *http.Request) { 11 | header := rw.Header() 12 | var corsEnabled bool 13 | var corsOrigin string 14 | 15 | if opts.CheckOrigin != nil { 16 | corsEnabled = opts.CheckOrigin(req) 17 | if corsEnabled { 18 | corsOrigin = req.Header.Get("origin") 19 | if corsOrigin == "" { 20 | corsOrigin = "*" 21 | } 22 | } 23 | } else { 24 | corsEnabled = true 25 | corsOrigin = opts.Origin 26 | if corsOrigin == "" { 27 | corsOrigin = req.Header.Get("origin") 28 | } 29 | if corsOrigin == "" || corsOrigin == "null" { 30 | corsOrigin = "*" 31 | } 32 | } 33 | 34 | if corsEnabled { 35 | header.Set("Access-Control-Allow-Origin", corsOrigin) 36 | if allowHeaders := req.Header.Get("Access-Control-Request-Headers"); allowHeaders != "" && allowHeaders != "null" { 37 | header.Add("Access-Control-Allow-Headers", allowHeaders) 38 | } 39 | header.Set("Access-Control-Allow-Credentials", "true") 40 | } 41 | } 42 | } 43 | 44 | func xhrOptions(rw http.ResponseWriter, req *http.Request) { 45 | rw.Header().Set("Access-Control-Allow-Methods", "OPTIONS, POST") 46 | rw.WriteHeader(http.StatusNoContent) // 204 47 | } 48 | 49 | func cacheFor(rw http.ResponseWriter, req *http.Request) { 50 | rw.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", 365*24*60*60)) 51 | rw.Header().Set("Expires", time.Now().AddDate(1, 0, 0).Format(time.RFC1123)) 52 | rw.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", 365*24*60*60)) 53 | } 54 | 55 | func noCache(rw http.ResponseWriter, req *http.Request) { 56 | rw.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") 57 | } 58 | 59 | func welcomeHandler(rw http.ResponseWriter, req *http.Request) { 60 | rw.Header().Set("content-type", "text/plain;charset=UTF-8") 61 | fmt.Fprint(rw, "Welcome to SockJS!\n") 62 | } 63 | 64 | func httpError(w http.ResponseWriter, error string, code int) { 65 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 66 | w.WriteHeader(code) 67 | fmt.Fprint(w, error) 68 | } 69 | -------------------------------------------------------------------------------- /v3/sockjs/web_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestXhrCors(t *testing.T) { 11 | recorder := httptest.NewRecorder() 12 | req, _ := http.NewRequest("GET", "/", nil) 13 | xhrCors := xhrCorsFactory(Options{}) 14 | xhrCors(recorder, req) 15 | acao := recorder.Header().Get("access-control-allow-origin") 16 | if acao != "*" { 17 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "*") 18 | } 19 | req.Header.Set("origin", "localhost") 20 | xhrCors(recorder, req) 21 | acao = recorder.Header().Get("access-control-allow-origin") 22 | if acao != "localhost" { 23 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "localhost") 24 | } 25 | req.Header.Set("access-control-request-headers", "some value") 26 | rec := httptest.NewRecorder() 27 | xhrCors(rec, req) 28 | if rec.Header().Get("access-control-allow-headers") != "some value" { 29 | t.Errorf("Incorent value for ACAH, got %s", rec.Header().Get("access-control-allow-headers")) 30 | } 31 | 32 | rec = httptest.NewRecorder() 33 | xhrCors(rec, req) 34 | if rec.Header().Get("access-control-allow-credentials") != "true" { 35 | t.Errorf("Incorent value for ACAC, got %s", rec.Header().Get("access-control-allow-credentials")) 36 | } 37 | 38 | // verify that if Access-Control-Allow-Credentials was previously set that xhrCors() does not duplicate the value 39 | rec = httptest.NewRecorder() 40 | rec.Header().Set("Access-Control-Allow-Credentials", "true") 41 | xhrCors(rec, req) 42 | acac := rec.Header()["Access-Control-Allow-Credentials"] 43 | if len(acac) != 1 || acac[0] != "true" { 44 | t.Errorf("Incorent value for ACAC, got %s", strings.Join(acac, ",")) 45 | } 46 | } 47 | 48 | func TestCheckOriginCORSAllowedNullOrigin(t *testing.T) { 49 | recorder := httptest.NewRecorder() 50 | req, _ := http.NewRequest("GET", "/", nil) 51 | xhrCors := xhrCorsFactory(Options{ 52 | CheckOrigin: func(req *http.Request) bool { 53 | return true 54 | }, 55 | }) 56 | req.Header.Set("origin", "null") 57 | xhrCors(recorder, req) 58 | acao := recorder.Header().Get("access-control-allow-origin") 59 | if acao != "null" { 60 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "null") 61 | } 62 | } 63 | 64 | func TestCheckOriginCORSAllowedEmptyOrigin(t *testing.T) { 65 | recorder := httptest.NewRecorder() 66 | req, _ := http.NewRequest("GET", "/", nil) 67 | xhrCors := xhrCorsFactory(Options{ 68 | CheckOrigin: func(req *http.Request) bool { 69 | return true 70 | }, 71 | }) 72 | xhrCors(recorder, req) 73 | acao := recorder.Header().Get("access-control-allow-origin") 74 | if acao != "*" { 75 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "*") 76 | } 77 | } 78 | 79 | func TestCheckOriginCORSNotAllowed(t *testing.T) { 80 | recorder := httptest.NewRecorder() 81 | req, _ := http.NewRequest("GET", "/", nil) 82 | xhrCors := xhrCorsFactory(Options{ 83 | CheckOrigin: func(req *http.Request) bool { 84 | return false 85 | }, 86 | }) 87 | req.Header.Set("origin", "localhost") 88 | xhrCors(recorder, req) 89 | acao := recorder.Header().Get("access-control-allow-origin") 90 | if acao != "" { 91 | t.Errorf("Incorrect value for access-control-allow-origin header, got %s, expected %s", acao, "") 92 | } 93 | } 94 | 95 | func TestXhrOptions(t *testing.T) { 96 | rec := httptest.NewRecorder() 97 | req, _ := http.NewRequest("GET", "/", nil) 98 | xhrOptions(rec, req) 99 | if rec.Code != http.StatusNoContent { 100 | t.Errorf("Wrong response status code, expected %d, got %d", http.StatusNoContent, rec.Code) 101 | } 102 | } 103 | 104 | func TestCacheFor(t *testing.T) { 105 | rec := httptest.NewRecorder() 106 | cacheFor(rec, nil) 107 | cacheControl := rec.Header().Get("cache-control") 108 | if cacheControl != "public, max-age=31536000" { 109 | t.Errorf("Incorrect cache-control header value, got '%s'", cacheControl) 110 | } 111 | expires := rec.Header().Get("expires") 112 | if expires == "" { 113 | t.Errorf("Expires header should not be empty") // TODO(igm) check proper formating of string 114 | } 115 | maxAge := rec.Header().Get("access-control-max-age") 116 | if maxAge != "31536000" { 117 | t.Errorf("Incorrect value for access-control-max-age, got '%s'", maxAge) 118 | } 119 | } 120 | 121 | func TestNoCache(t *testing.T) { 122 | rec := httptest.NewRecorder() 123 | noCache(rec, nil) 124 | } 125 | 126 | func TestWelcomeHandler(t *testing.T) { 127 | rec := httptest.NewRecorder() 128 | welcomeHandler(rec, nil) 129 | if rec.Body.String() != "Welcome to SockJS!\n" { 130 | t.Errorf("Incorrect welcome message received, got '%s'", rec.Body.String()) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /v3/sockjs/websocket.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | "time" 8 | 9 | "github.com/gorilla/websocket" 10 | ) 11 | 12 | func (h *Handler) sockjsWebsocket(rw http.ResponseWriter, req *http.Request) { 13 | upgrader := h.options.WebsocketUpgrader 14 | if upgrader == nil { 15 | upgrader = new(websocket.Upgrader) 16 | } 17 | conn, err := upgrader.Upgrade(rw, req, nil) 18 | if err != nil { 19 | return 20 | } 21 | sessID, _ := h.parseSessionID(req.URL) 22 | sess := newSession(req, sessID, h.options.DisconnectDelay, h.options.HeartbeatDelay) 23 | receiver := newWsReceiver(conn, h.options.WebsocketWriteTimeout) 24 | if err := sess.attachReceiver(receiver); err != nil { 25 | http.Error(rw, err.Error(), http.StatusInternalServerError) 26 | return 27 | } 28 | if h.handlerFunc != nil { 29 | go h.handlerFunc(Session{sess}) 30 | } 31 | readCloseCh := make(chan struct{}) 32 | go func() { 33 | var d []string 34 | for { 35 | err := conn.ReadJSON(&d) 36 | if err != nil { 37 | close(readCloseCh) 38 | return 39 | } 40 | if err := sess.accept(d...); err != nil { 41 | close(readCloseCh) 42 | return 43 | } 44 | } 45 | }() 46 | 47 | select { 48 | case <-readCloseCh: 49 | case <-receiver.doneNotify(): 50 | } 51 | sess.close() 52 | if err := conn.Close(); err != nil { 53 | http.Error(rw, err.Error(), http.StatusInternalServerError) 54 | return 55 | } 56 | } 57 | 58 | type wsReceiver struct { 59 | conn *websocket.Conn 60 | closeCh chan struct{} 61 | writeTimeout time.Duration 62 | } 63 | 64 | func newWsReceiver(conn *websocket.Conn, writeTimeout time.Duration) *wsReceiver { 65 | return &wsReceiver{ 66 | conn: conn, 67 | closeCh: make(chan struct{}), 68 | writeTimeout: writeTimeout, 69 | } 70 | } 71 | 72 | func (w *wsReceiver) sendBulk(messages ...string) error { 73 | if len(messages) > 0 { 74 | return w.sendFrame(fmt.Sprintf("a[%s]", strings.Join(transform(messages, quote), ","))) 75 | } 76 | return nil 77 | } 78 | 79 | func (w *wsReceiver) sendFrame(frame string) error { 80 | if w.writeTimeout != 0 { 81 | if err := w.conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)); err != nil { 82 | w.close() 83 | return err 84 | } 85 | } 86 | if err := w.conn.WriteMessage(websocket.TextMessage, []byte(frame)); err != nil { 87 | w.close() 88 | return err 89 | } 90 | return nil 91 | } 92 | 93 | func (w *wsReceiver) close() { 94 | select { 95 | case <-w.closeCh: // already closed 96 | default: 97 | close(w.closeCh) 98 | } 99 | } 100 | func (w *wsReceiver) canSend() bool { 101 | select { 102 | case <-w.closeCh: // already closed 103 | return false 104 | default: 105 | return true 106 | } 107 | } 108 | func (w *wsReceiver) doneNotify() <-chan struct{} { return w.closeCh } 109 | func (w *wsReceiver) interruptedNotify() <-chan struct{} { return nil } 110 | func (w *wsReceiver) receiverType() ReceiverType { return ReceiverTypeWebsocket } 111 | -------------------------------------------------------------------------------- /v3/sockjs/websocket_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/gorilla/websocket" 11 | ) 12 | 13 | func TestHandler_WebSocketHandshakeError(t *testing.T) { 14 | h := newTestHandler() 15 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 16 | defer server.Close() 17 | req, _ := http.NewRequest("GET", server.URL, nil) 18 | req.Header.Set("origin", "https"+server.URL[4:]) 19 | resp, err := http.DefaultClient.Do(req) 20 | if err != nil { 21 | t.Errorf("There should not be any error, got '%s'", err) 22 | t.FailNow() 23 | } 24 | if resp == nil { 25 | t.Errorf("Response should not be nil") 26 | t.FailNow() 27 | } 28 | if resp.StatusCode != http.StatusBadRequest { 29 | t.Errorf("Unexpected response code, got '%d', expected '%d'", resp.StatusCode, http.StatusBadRequest) 30 | } 31 | } 32 | 33 | func TestHandler_WebSocket(t *testing.T) { 34 | h := newTestHandler() 35 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 36 | defer server.CloseClientConnections() 37 | url := "ws" + server.URL[4:] 38 | var connCh = make(chan Session) 39 | h.handlerFunc = func(conn Session) { 40 | if rt := conn.ReceiverType(); rt != ReceiverTypeWebsocket { 41 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeWebsocket) 42 | } 43 | connCh <- conn 44 | } 45 | conn, resp, err := websocket.DefaultDialer.Dial(url, nil) 46 | if err != nil { 47 | t.Errorf("Unexpected error '%v'", err) 48 | t.FailNow() 49 | } 50 | if conn == nil { 51 | t.Errorf("Connection should not be nil") 52 | t.FailNow() 53 | } 54 | if resp == nil { 55 | t.Errorf("Response should not be nil") 56 | t.FailNow() 57 | } 58 | if resp.StatusCode != http.StatusSwitchingProtocols { 59 | t.Errorf("Wrong response code returned, got '%d', expected '%d'", resp.StatusCode, http.StatusSwitchingProtocols) 60 | } 61 | select { 62 | case <-connCh: //ok 63 | case <-time.After(10 * time.Millisecond): 64 | t.Errorf("Sockjs Handler not invoked") 65 | } 66 | } 67 | 68 | func TestHandler_WebSocketTerminationByServer(t *testing.T) { 69 | h := newTestHandler() 70 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 71 | defer server.Close() 72 | url := "ws" + server.URL[4:] 73 | h.handlerFunc = func(conn Session) { 74 | conn.Close(1024, "some close message") 75 | conn.Close(0, "this should be ignored") 76 | } 77 | conn, _, err := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 78 | if err != nil { 79 | t.Fatalf("websocket dial failed: %v", err) 80 | t.FailNow() 81 | } 82 | if conn == nil { 83 | t.Errorf("Connection should not be nil") 84 | t.FailNow() 85 | } 86 | _, msg, err := conn.ReadMessage() 87 | if string(msg) != "o" || err != nil { 88 | t.Errorf("Open frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, "o") 89 | } 90 | _, msg, err = conn.ReadMessage() 91 | if string(msg) != `c[1024,"some close message"]` || err != nil { 92 | t.Errorf("Close frame expected, got '%s' and error '%v', expected '%s' without error", msg, err, `c[1024,"some close message"]`) 93 | } 94 | _, _, err = conn.ReadMessage() 95 | // gorilla websocket keeps `errUnexpectedEOF` private so we need to introspect the error message 96 | if err != nil { 97 | if !strings.Contains(err.Error(), "unexpected EOF") { 98 | t.Errorf("Expected 'unexpected EOF' error or similar, got '%v'", err) 99 | } 100 | } 101 | } 102 | 103 | func TestHandler_WebSocketTerminationByClient(t *testing.T) { 104 | h := newTestHandler() 105 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 106 | defer server.Close() 107 | url := "ws" + server.URL[4:] 108 | var done = make(chan struct{}) 109 | h.handlerFunc = func(conn Session) { 110 | if _, err := conn.Recv(); err != ErrSessionNotOpen { 111 | t.Errorf("Recv should fail") 112 | } 113 | select { 114 | case <-conn.Context().Done(): 115 | case <-time.After(1 * time.Second): 116 | t.Errorf("context should have been done") 117 | } 118 | close(done) 119 | } 120 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 121 | if conn == nil { 122 | t.Errorf("Connection should not be nil") 123 | t.FailNow() 124 | } 125 | conn.Close() 126 | <-done 127 | } 128 | 129 | func TestHandler_WebSocketCommunication(t *testing.T) { 130 | h := newTestHandler() 131 | h.options.WebsocketWriteTimeout = time.Second 132 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 133 | // defer server.CloseClientConnections() 134 | url := "ws" + server.URL[4:] 135 | var done = make(chan struct{}) 136 | h.handlerFunc = func(conn Session) { 137 | noError(t, conn.Send("message 1")) 138 | noError(t, conn.Send("message 2")) 139 | msg, err := conn.Recv() 140 | if msg != "message 3" || err != nil { 141 | t.Errorf("Got '%s', expected '%s'", msg, "message 3") 142 | } 143 | noError(t, conn.Close(123, "close")) 144 | close(done) 145 | } 146 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 147 | noError(t, conn.WriteJSON([]string{"message 3"})) 148 | var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`} 149 | for _, exp := range expected { 150 | _, msg, err := conn.ReadMessage() 151 | if string(msg) != exp || err != nil { 152 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 153 | } 154 | } 155 | <-done 156 | } 157 | 158 | func TestHandler_CustomWebSocketCommunication(t *testing.T) { 159 | h := newTestHandler() 160 | h.options.WebsocketUpgrader = &websocket.Upgrader{ 161 | ReadBufferSize: 0, 162 | WriteBufferSize: 0, 163 | CheckOrigin: func(_ *http.Request) bool { return true }, 164 | Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {}, 165 | } 166 | h.options.WebsocketWriteTimeout = time.Second 167 | server := httptest.NewServer(http.HandlerFunc(h.sockjsWebsocket)) 168 | url := "ws" + server.URL[4:] 169 | var done = make(chan struct{}) 170 | h.handlerFunc = func(conn Session) { 171 | noError(t, conn.Send("message 1")) 172 | noError(t, conn.Send("message 2")) 173 | msg, err := conn.Recv() 174 | if msg != "message 3" || err != nil { 175 | t.Errorf("Got '%s', expected '%s'", msg, "message 3") 176 | } 177 | noError(t, conn.Close(123, "close")) 178 | close(done) 179 | } 180 | conn, _, _ := websocket.DefaultDialer.Dial(url, map[string][]string{"Origin": []string{server.URL}}) 181 | noError(t, conn.WriteJSON([]string{"message 3"})) 182 | var expected = []string{"o", `a["message 1"]`, `a["message 2"]`, `c[123,"close"]`} 183 | for _, exp := range expected { 184 | _, msg, err := conn.ReadMessage() 185 | if string(msg) != exp || err != nil { 186 | t.Errorf("Wrong frame, got '%s' and error '%v', expected '%s' without error", msg, err, exp) 187 | } 188 | } 189 | <-done 190 | } 191 | -------------------------------------------------------------------------------- /v3/sockjs/xhr.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | ) 10 | 11 | var ( 12 | cFrame = closeFrame(2010, "Another connection still open") 13 | xhrStreamingPrelude = strings.Repeat("h", 2048) 14 | ) 15 | 16 | func (h *Handler) xhrSend(rw http.ResponseWriter, req *http.Request) { 17 | if req.Body == nil { 18 | httpError(rw, "Payload expected.", http.StatusBadRequest) 19 | return 20 | } 21 | var messages []string 22 | err := json.NewDecoder(req.Body).Decode(&messages) 23 | if err == io.EOF { 24 | httpError(rw, "Payload expected.", http.StatusBadRequest) 25 | return 26 | } 27 | if _, ok := err.(*json.SyntaxError); ok || err == io.ErrUnexpectedEOF { 28 | httpError(rw, "Broken JSON encoding.", http.StatusBadRequest) 29 | return 30 | } 31 | sessionID, err := h.parseSessionID(req.URL) 32 | if err != nil { 33 | http.Error(rw, err.Error(), http.StatusBadRequest) 34 | return 35 | } 36 | 37 | h.sessionsMux.Lock() 38 | sess, ok := h.sessions[sessionID] 39 | h.sessionsMux.Unlock() 40 | if !ok { 41 | http.NotFound(rw, req) 42 | return 43 | } 44 | if err := sess.accept(messages...); err != nil { 45 | http.Error(rw, err.Error(), http.StatusInternalServerError) 46 | return 47 | } 48 | rw.Header().Set("content-type", "text/plain; charset=UTF-8") // Ignored by net/http (but protocol test complains), see https://code.google.com/p/go/source/detail?r=902dc062bff8 49 | rw.WriteHeader(http.StatusNoContent) 50 | } 51 | 52 | type xhrFrameWriter struct{} 53 | 54 | func (*xhrFrameWriter) write(w io.Writer, frame string) (int, error) { 55 | return fmt.Fprintf(w, "%s\n", frame) 56 | } 57 | 58 | func (h *Handler) xhrPoll(rw http.ResponseWriter, req *http.Request) { 59 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 60 | sess, err := h.sessionByRequest(req) 61 | if err != nil { 62 | http.Error(rw, err.Error(), http.StatusInternalServerError) 63 | return 64 | } 65 | receiver := newHTTPReceiver(rw, req, 1, new(xhrFrameWriter), ReceiverTypeXHR) 66 | if err := sess.attachReceiver(receiver); err != nil { 67 | if err := receiver.sendFrame(cFrame); err != nil { 68 | http.Error(rw, err.Error(), http.StatusInternalServerError) 69 | return 70 | } 71 | receiver.close() 72 | return 73 | } 74 | 75 | sess.startHandlerOnce.Do(func() { 76 | if h.handlerFunc != nil { 77 | go h.handlerFunc(Session{sess}) 78 | } 79 | }) 80 | 81 | select { 82 | case <-receiver.doneNotify(): 83 | case <-receiver.interruptedNotify(): 84 | } 85 | } 86 | 87 | func (h *Handler) xhrStreaming(rw http.ResponseWriter, req *http.Request) { 88 | rw.Header().Set("content-type", "application/javascript; charset=UTF-8") 89 | fmt.Fprintf(rw, "%s\n", xhrStreamingPrelude) 90 | rw.(http.Flusher).Flush() 91 | 92 | sess, err := h.sessionByRequest(req) 93 | if err != nil { 94 | http.Error(rw, err.Error(), http.StatusInternalServerError) 95 | return 96 | } 97 | receiver := newHTTPReceiver(rw, req, h.options.ResponseLimit, new(xhrFrameWriter), ReceiverTypeXHRStreaming) 98 | 99 | if err := sess.attachReceiver(receiver); err != nil { 100 | if err := receiver.sendFrame(cFrame); err != nil { 101 | http.Error(rw, err.Error(), http.StatusInternalServerError) 102 | return 103 | } 104 | receiver.close() 105 | return 106 | } 107 | sess.startHandlerOnce.Do(func() { go h.handlerFunc(Session{sess}) }) 108 | 109 | select { 110 | case <-receiver.doneNotify(): 111 | case <-receiver.interruptedNotify(): 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /v3/sockjs/xhr_test.go: -------------------------------------------------------------------------------- 1 | package sockjs 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestHandler_XhrSendNilBody(t *testing.T) { 13 | h := newTestHandler() 14 | rec := httptest.NewRecorder() 15 | req, _ := http.NewRequest("POST", "/server/non_existing_session/xhr_send", nil) 16 | h.xhrSend(rec, req) 17 | if rec.Code != http.StatusBadRequest { 18 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusBadRequest) 19 | } 20 | if rec.Body.String() != "Payload expected." { 21 | t.Errorf("Unexcpected body received: '%s'", rec.Body.String()) 22 | } 23 | } 24 | 25 | func TestHandler_XhrSendEmptyBody(t *testing.T) { 26 | h := newTestHandler() 27 | rec := httptest.NewRecorder() 28 | req, _ := http.NewRequest("POST", "/server/non_existing_session/xhr_send", strings.NewReader("")) 29 | h.xhrSend(rec, req) 30 | if rec.Code != http.StatusBadRequest { 31 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusBadRequest) 32 | } 33 | if rec.Body.String() != "Payload expected." { 34 | t.Errorf("Unexcpected body received: '%s'", rec.Body.String()) 35 | } 36 | } 37 | 38 | func TestHandler_XhrSendWrongUrlPath(t *testing.T) { 39 | h := newTestHandler() 40 | rec := httptest.NewRecorder() 41 | req, _ := http.NewRequest("POST", "incorrect", strings.NewReader("[\"a\"]")) 42 | h.xhrSend(rec, req) 43 | if rec.Code != http.StatusBadRequest { 44 | t.Errorf("Unexcpected response status, got '%d', expected '%d'", rec.Code, http.StatusBadRequest) 45 | } 46 | } 47 | 48 | func TestHandler_XhrSendToExistingSession(t *testing.T) { 49 | h := newTestHandler() 50 | rec := httptest.NewRecorder() 51 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 52 | sess := newSession(req, "session", time.Second, time.Second) 53 | h.sessions["session"] = sess 54 | 55 | req, _ = http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 56 | var done = make(chan bool) 57 | go func() { 58 | h.xhrSend(rec, req) 59 | done <- true 60 | }() 61 | msg, _ := sess.Recv() 62 | if msg != "some message" { 63 | t.Errorf("Incorrect message in the channel, should be '%s', was '%s'", "some message", msg) 64 | } 65 | <-done 66 | if rec.Code != http.StatusNoContent { 67 | t.Errorf("Wrong response status received %d, should be %d", rec.Code, http.StatusNoContent) 68 | } 69 | if rec.Header().Get("content-type") != "text/plain; charset=UTF-8" { 70 | t.Errorf("Wrong content type received '%s'", rec.Header().Get("content-type")) 71 | } 72 | } 73 | 74 | func TestHandler_XhrSendInvalidInput(t *testing.T) { 75 | h := newTestHandler() 76 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("some invalid message frame")) 77 | rec := httptest.NewRecorder() 78 | h.xhrSend(rec, req) 79 | if rec.Code != http.StatusBadRequest || rec.Body.String() != "Broken JSON encoding." { 80 | t.Errorf("Unexpected response, got '%d,%s' expected '%d,Broken JSON encoding.'", rec.Code, rec.Body.String(), http.StatusBadRequest) 81 | } 82 | 83 | // unexpected EOF 84 | req, _ = http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"x")) 85 | rec = httptest.NewRecorder() 86 | h.xhrSend(rec, req) 87 | if rec.Code != http.StatusBadRequest || rec.Body.String() != "Broken JSON encoding." { 88 | t.Errorf("Unexpected response, got '%d,%s' expected '%d,Broken JSON encoding.'", rec.Code, rec.Body.String(), http.StatusBadRequest) 89 | } 90 | } 91 | 92 | func TestHandler_XhrSendSessionNotFound(t *testing.T) { 93 | h := Handler{} 94 | req, _ := http.NewRequest("POST", "/server/session/xhr_send", strings.NewReader("[\"some message\"]")) 95 | rec := httptest.NewRecorder() 96 | h.xhrSend(rec, req) 97 | if rec.Code != http.StatusNotFound { 98 | t.Errorf("Unexpected response status, got '%d' expected '%d'", rec.Code, http.StatusNotFound) 99 | } 100 | } 101 | 102 | func TestHandler_XhrPoll(t *testing.T) { 103 | h := newTestHandler() 104 | rw := httptest.NewRecorder() 105 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 106 | h.xhrPoll(rw, req) 107 | if rw.Header().Get("content-type") != "application/javascript; charset=UTF-8" { 108 | t.Errorf("Wrong content type received, got '%s'", rw.Header().Get("content-type")) 109 | } 110 | sess, _ := h.sessionByRequest(req) 111 | if rt := sess.ReceiverType(); rt != ReceiverTypeXHR { 112 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeXHR) 113 | } 114 | } 115 | 116 | func TestHandler_XhrPollConnectionInterrupted(t *testing.T) { 117 | h := newTestHandler() 118 | sess := newTestSession() 119 | sess.state = SessionActive 120 | h.sessions["session"] = sess 121 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 122 | ctx, cancel := context.WithCancel(req.Context()) 123 | req = req.WithContext(ctx) 124 | rw := httptest.NewRecorder() 125 | cancel() 126 | h.xhrPoll(rw, req) 127 | time.Sleep(1 * time.Millisecond) 128 | sess.mux.Lock() 129 | if sess.state != SessionClosed { 130 | t.Errorf("session should be closed") 131 | } 132 | } 133 | 134 | func TestHandler_XhrPollAnotherConnectionExists(t *testing.T) { 135 | h := newTestHandler() 136 | req, _ := http.NewRequest("POST", "/server/session/xhr", nil) 137 | // turn of timeoutes and heartbeats 138 | sess := newSession(req, "session", time.Hour, time.Hour) 139 | h.sessions["session"] = sess 140 | noError(t, sess.attachReceiver(newTestReceiver())) 141 | req, _ = http.NewRequest("POST", "/server/session/xhr", nil) 142 | rw2 := httptest.NewRecorder() 143 | h.xhrPoll(rw2, req) 144 | if rw2.Body.String() != "c[2010,\"Another connection still open\"]\n" { 145 | t.Errorf("Unexpected body, got '%s'", rw2.Body) 146 | } 147 | } 148 | 149 | func TestHandler_XhrStreaming(t *testing.T) { 150 | h := newTestHandler() 151 | rw := httptest.NewRecorder() 152 | req, _ := http.NewRequest("POST", "/server/session/xhr_streaming", nil) 153 | h.xhrStreaming(rw, req) 154 | expectedBody := strings.Repeat("h", 2048) + "\no\n" 155 | if rw.Body.String() != expectedBody { 156 | t.Errorf("Unexpected body, got '%s' expected '%s'", rw.Body, expectedBody) 157 | } 158 | sess, _ := h.sessionByRequest(req) 159 | if rt := sess.ReceiverType(); rt != ReceiverTypeXHRStreaming { 160 | t.Errorf("Unexpected recevier type, got '%v', extected '%v'", rt, ReceiverTypeXHRStreaming) 161 | } 162 | } 163 | 164 | func TestHandler_XhrStreamingAnotherReceiver(t *testing.T) { 165 | h := newTestHandler() 166 | h.options.ResponseLimit = 4096 167 | rw1 := httptest.NewRecorder() 168 | req, _ := http.NewRequest("POST", "/server/session/xhr_streaming", nil) 169 | ctx, cancel := context.WithCancel(req.Context()) 170 | req = req.WithContext(ctx) 171 | go func() { 172 | rec := httptest.NewRecorder() 173 | h.xhrStreaming(rec, req) 174 | expectedBody := strings.Repeat("h", 2048) + "\n" + "c[2010,\"Another connection still open\"]\n" 175 | if rec.Body.String() != expectedBody { 176 | t.Errorf("Unexpected body got '%s', expected '%s', ", rec.Body, expectedBody) 177 | } 178 | cancel() 179 | }() 180 | h.xhrStreaming(rw1, req) 181 | } 182 | 183 | // various test only structs 184 | func newTestHandler() *Handler { 185 | h := &Handler{sessions: make(map[string]*session)} 186 | h.options.HeartbeatDelay = time.Hour 187 | h.options.DisconnectDelay = time.Hour 188 | h.handlerFunc = func(s Session) {} 189 | return h 190 | } 191 | --------------------------------------------------------------------------------