├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── alloc.go ├── alloc_test.go ├── assets ├── curve.jpg ├── mux.jpg └── smux.png ├── frame.go ├── go.mod ├── go.sum ├── mux.go ├── mux_test.go ├── pkg.go ├── session.go ├── session_test.go ├── shaper.go ├── shaper_test.go └── stream.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | arch: 2 | - amd64 3 | - ppc64le 4 | language: go 5 | go: 6 | - 1.9.x 7 | - 1.10.x 8 | - 1.11.x 9 | 10 | before_install: 11 | - go get -t -v ./... 12 | 13 | install: 14 | - go get github.com/xtaci/smux 15 | 16 | script: 17 | - go test -coverprofile=coverage.txt -covermode=atomic -bench . 18 | 19 | after_success: 20 | - bash <(curl -s https://codecov.io/bash) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016-2017 xtaci 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | smux 2 | 3 | [![GoDoc][1]][2] [![MIT licensed][3]][4] [![Build Status][5]][6] [![Go Report Card][7]][8] [![Coverage Statusd][9]][10] [![Sourcegraph][11]][12] 4 | 5 | smux 6 | 7 | [1]: https://godoc.org/github.com/xtaci/smux?status.svg 8 | [2]: https://godoc.org/github.com/xtaci/smux 9 | [3]: https://img.shields.io/badge/license-MIT-blue.svg 10 | [4]: LICENSE 11 | [5]: https://img.shields.io/github/created-at/xtaci/smux 12 | [6]: https://img.shields.io/github/created-at/xtaci/smux 13 | [7]: https://goreportcard.com/badge/github.com/xtaci/smux 14 | [8]: https://goreportcard.com/report/github.com/xtaci/smux 15 | [9]: https://codecov.io/gh/xtaci/smux/branch/master/graph/badge.svg 16 | [10]: https://codecov.io/gh/xtaci/smux 17 | [11]: https://sourcegraph.com/github.com/xtaci/smux/-/badge.svg 18 | [12]: https://sourcegraph.com/github.com/xtaci/smux?badge 19 | 20 | ## Introduction 21 | 22 | Smux ( **S**imple **MU**ltiple**X**ing) is a multiplexing library for Golang. It relies on an underlying connection to provide reliability and ordering, such as TCP or [KCP](https://github.com/xtaci/kcp-go), and provides stream-oriented multiplexing. The original intention of this library is to power the connection management for [kcp-go](https://github.com/xtaci/kcp-go). 23 | 24 | ## Features 25 | 26 | 1. ***Token bucket*** controlled receiving, which provides smoother bandwidth graph(see picture below). 27 | 2. Session-wide receive buffer, shared among streams, **fully controlled** overall memory usage. 28 | 3. Minimized header(8Bytes), maximized payload. 29 | 4. Well-tested on millions of devices in [kcptun](https://github.com/xtaci/kcptun). 30 | 5. Builtin fair queue traffic shaping. 31 | 6. Per-stream sliding window to control congestion.(protocol version 2+). 32 | 33 | ![smooth bandwidth curve](assets/curve.jpg) 34 | 35 | ## Documentation 36 | 37 | For complete documentation, see the associated [Godoc](https://godoc.org/github.com/xtaci/smux). 38 | 39 | ## Benchmark 40 | ``` 41 | $ go test -v -run=^$ -bench . 42 | goos: darwin 43 | goarch: amd64 44 | pkg: github.com/xtaci/smux 45 | BenchmarkMSB-4 30000000 51.8 ns/op 46 | BenchmarkAcceptClose-4 50000 36783 ns/op 47 | BenchmarkConnSmux-4 30000 58335 ns/op 2246.88 MB/s 1208 B/op 19 allocs/op 48 | BenchmarkConnTCP-4 50000 25579 ns/op 5124.04 MB/s 0 B/op 0 allocs/op 49 | PASS 50 | ok github.com/xtaci/smux 7.811s 51 | ``` 52 | 53 | ## Specification 54 | 55 | ``` 56 | VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH) 57 | 58 | VALUES FOR LATEST VERSION: 59 | VERSION: 60 | 1/2 61 | 62 | CMD: 63 | cmdSYN(0) 64 | cmdFIN(1) 65 | cmdPSH(2) 66 | cmdNOP(3) 67 | cmdUPD(4) // only supported on version 2 68 | 69 | STREAMID: 70 | client use odd numbers starts from 1 71 | server use even numbers starts from 0 72 | 73 | cmdUPD: 74 | | CONSUMED(4B) | WINDOW(4B) | 75 | ``` 76 | 77 | ## Usage 78 | 79 | ```go 80 | 81 | func client() { 82 | // Get a TCP connection 83 | conn, err := net.Dial(...) 84 | if err != nil { 85 | panic(err) 86 | } 87 | 88 | // Setup client side of smux 89 | session, err := smux.Client(conn, nil) 90 | if err != nil { 91 | panic(err) 92 | } 93 | 94 | // Open a new stream 95 | stream, err := session.OpenStream() 96 | if err != nil { 97 | panic(err) 98 | } 99 | 100 | // Stream implements io.ReadWriteCloser 101 | stream.Write([]byte("ping")) 102 | stream.Close() 103 | session.Close() 104 | } 105 | 106 | func server() { 107 | // Accept a TCP connection 108 | conn, err := listener.Accept() 109 | if err != nil { 110 | panic(err) 111 | } 112 | 113 | // Setup server side of smux 114 | session, err := smux.Server(conn, nil) 115 | if err != nil { 116 | panic(err) 117 | } 118 | 119 | // Accept a stream 120 | stream, err := session.AcceptStream() 121 | if err != nil { 122 | panic(err) 123 | } 124 | 125 | // Listen for a message 126 | buf := make([]byte, 4) 127 | stream.Read(buf) 128 | stream.Close() 129 | session.Close() 130 | } 131 | 132 | ``` 133 | 134 | ## Status 135 | 136 | Stable 137 | -------------------------------------------------------------------------------- /alloc.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "errors" 27 | "sync" 28 | ) 29 | 30 | var ( 31 | defaultAllocator *Allocator 32 | debruijinPos = [...]byte{0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31} 33 | ) 34 | 35 | func init() { 36 | defaultAllocator = NewAllocator() 37 | } 38 | 39 | // Allocator for incoming frames, optimized to prevent overwriting after zeroing 40 | type Allocator struct { 41 | buffers []sync.Pool 42 | } 43 | 44 | // NewAllocator initiates a []byte allocator for frames less than 65536 bytes, 45 | // the waste(memory fragmentation) of space allocation is guaranteed to be 46 | // no more than 50%. 47 | func NewAllocator() *Allocator { 48 | alloc := new(Allocator) 49 | alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K 50 | for k := range alloc.buffers { 51 | i := k 52 | alloc.buffers[k].New = func() interface{} { 53 | b := make([]byte, 1< 65536 { 63 | return nil 64 | } 65 | 66 | bits := msb(size) 67 | if size == 1< 65536 || cap(*p) != 1<> 1 96 | v |= v >> 2 97 | v |= v >> 4 98 | v |= v >> 8 99 | v |= v >> 16 100 | return debruijinPos[(v*0x07C4ACDD)>>27] 101 | } 102 | -------------------------------------------------------------------------------- /alloc_test.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "math/bits" 27 | "math/rand" 28 | "testing" 29 | ) 30 | 31 | func TestAllocGet(t *testing.T) { 32 | alloc := NewAllocator() 33 | if alloc.Get(0) != nil { 34 | t.Fatal(0) 35 | } 36 | if len(*alloc.Get(1)) != 1 { 37 | t.Fatal(1) 38 | } 39 | if len(*alloc.Get(2)) != 2 { 40 | t.Fatal(2) 41 | } 42 | if len(*alloc.Get(3)) != 3 || cap(*alloc.Get(3)) != 4 { 43 | t.Fatal(3) 44 | } 45 | if len(*alloc.Get(4)) != 4 { 46 | t.Fatal(4) 47 | } 48 | if len(*alloc.Get(1023)) != 1023 || cap(*alloc.Get(1023)) != 1024 { 49 | t.Fatal(1023) 50 | } 51 | if len(*alloc.Get(1024)) != 1024 { 52 | t.Fatal(1024) 53 | } 54 | if len(*alloc.Get(65536)) != 65536 { 55 | t.Fatal(65536) 56 | } 57 | if alloc.Get(65537) != nil { 58 | t.Fatal(65537) 59 | } 60 | } 61 | 62 | func TestAllocPut(t *testing.T) { 63 | alloc := NewAllocator() 64 | if err := alloc.Put(nil); err == nil { 65 | t.Fatal("put nil misbehavior") 66 | } 67 | b := make([]byte, 3) 68 | if err := alloc.Put(&b); err == nil { 69 | t.Fatal("put elem:3 []bytes misbehavior") 70 | } 71 | b = make([]byte, 4) 72 | if err := alloc.Put(&b); err != nil { 73 | t.Fatal("put elem:4 []bytes misbehavior") 74 | } 75 | b = make([]byte, 1023, 1024) 76 | if err := alloc.Put(&b); err != nil { 77 | t.Fatal("put elem:1024 []bytes misbehavior") 78 | } 79 | b = make([]byte, 65536) 80 | if err := alloc.Put(&b); err != nil { 81 | t.Fatal("put elem:65536 []bytes misbehavior") 82 | } 83 | b = make([]byte, 65537) 84 | if err := alloc.Put(&b); err == nil { 85 | t.Fatal("put elem:65537 []bytes misbehavior") 86 | } 87 | } 88 | 89 | func TestAllocPutThenGet(t *testing.T) { 90 | alloc := NewAllocator() 91 | data := alloc.Get(4) 92 | alloc.Put(data) 93 | newData := alloc.Get(4) 94 | if cap(*data) != cap(*newData) { 95 | t.Fatal("different cap while alloc.Get()") 96 | } 97 | } 98 | 99 | func BenchmarkMSB(b *testing.B) { 100 | for i := 0; i < b.N; i++ { 101 | msb(rand.Int()) 102 | } 103 | } 104 | 105 | func BenchmarkAlloc(b *testing.B) { 106 | for i := 0; i < b.N; i++ { 107 | pbuf := defaultAllocator.Get(i % 65536) 108 | defaultAllocator.Put(pbuf) 109 | } 110 | } 111 | 112 | func TestDebrujin(t *testing.T) { 113 | for i := 1; i < 65536; i++ { 114 | a := int(msb(i)) 115 | b := bits.Len(uint(i)) 116 | if a+1 != b { 117 | t.Fatal("debrujin") 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /assets/curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/curve.jpg -------------------------------------------------------------------------------- /assets/mux.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/mux.jpg -------------------------------------------------------------------------------- /assets/smux.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/smux.png -------------------------------------------------------------------------------- /frame.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "encoding/binary" 27 | "fmt" 28 | ) 29 | 30 | const ( // cmds 31 | // protocol version 1: 32 | cmdSYN byte = iota // stream open 33 | cmdFIN // stream close, a.k.a EOF mark 34 | cmdPSH // data push 35 | cmdNOP // no operation 36 | 37 | // protocol version 2 extra commands 38 | // notify bytes consumed by remote peer-end 39 | cmdUPD 40 | ) 41 | 42 | const ( 43 | // data size of cmdUPD, format: 44 | // |4B data consumed(ACK)| 4B window size(WINDOW) | 45 | szCmdUPD = 8 46 | ) 47 | 48 | const ( 49 | // initial peer window guess, a slow-start 50 | initialPeerWindow = 262144 51 | ) 52 | 53 | const ( 54 | sizeOfVer = 1 55 | sizeOfCmd = 1 56 | sizeOfLength = 2 57 | sizeOfSid = 4 58 | headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength 59 | ) 60 | 61 | // Frame defines a packet from or to be multiplexed into a single connection 62 | type Frame struct { 63 | ver byte // version 64 | cmd byte // command 65 | sid uint32 // stream id 66 | data []byte // payload 67 | } 68 | 69 | // newFrame creates a new frame with given version, command and stream id 70 | func newFrame(version byte, cmd byte, sid uint32) Frame { 71 | return Frame{ver: version, cmd: cmd, sid: sid} 72 | } 73 | 74 | // rawHeader is a byte array representation of Frame header 75 | type rawHeader [headerSize]byte 76 | 77 | func (h rawHeader) Version() byte { 78 | return h[0] 79 | } 80 | 81 | func (h rawHeader) Cmd() byte { 82 | return h[1] 83 | } 84 | 85 | func (h rawHeader) Length() uint16 { 86 | return binary.LittleEndian.Uint16(h[2:]) 87 | } 88 | 89 | func (h rawHeader) StreamID() uint32 { 90 | return binary.LittleEndian.Uint32(h[4:]) 91 | } 92 | 93 | func (h rawHeader) String() string { 94 | return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d", 95 | h.Version(), h.Cmd(), h.StreamID(), h.Length()) 96 | } 97 | 98 | // updHeader is a byte array representation of cmdUPD 99 | type updHeader [szCmdUPD]byte 100 | 101 | func (h updHeader) Consumed() uint32 { 102 | return binary.LittleEndian.Uint32(h[:]) 103 | } 104 | func (h updHeader) Window() uint32 { 105 | return binary.LittleEndian.Uint32(h[4:]) 106 | } 107 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xtaci/smux 2 | 3 | go 1.13 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/go.sum -------------------------------------------------------------------------------- /mux.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "errors" 27 | "fmt" 28 | "io" 29 | "math" 30 | "time" 31 | ) 32 | 33 | // Config is used to tune the Smux session 34 | type Config struct { 35 | // SMUX Protocol version, support 1,2 36 | Version int 37 | 38 | // Disabled keepalive 39 | KeepAliveDisabled bool 40 | 41 | // KeepAliveInterval is how often to send a NOP command to the remote 42 | KeepAliveInterval time.Duration 43 | 44 | // KeepAliveTimeout is how long the session 45 | // will be closed if no data has arrived 46 | KeepAliveTimeout time.Duration 47 | 48 | // MaxFrameSize is used to control the maximum 49 | // frame size to sent to the remote 50 | MaxFrameSize int 51 | 52 | // MaxReceiveBuffer is used to control the maximum 53 | // number of data in the buffer pool 54 | MaxReceiveBuffer int 55 | 56 | // MaxStreamBuffer is used to control the maximum 57 | // number of data per stream 58 | MaxStreamBuffer int 59 | } 60 | 61 | // DefaultConfig is used to return a default configuration 62 | func DefaultConfig() *Config { 63 | return &Config{ 64 | Version: 1, 65 | KeepAliveInterval: 10 * time.Second, 66 | KeepAliveTimeout: 30 * time.Second, 67 | MaxFrameSize: 32768, 68 | MaxReceiveBuffer: 4194304, 69 | MaxStreamBuffer: 65536, 70 | } 71 | } 72 | 73 | // VerifyConfig is used to verify the sanity of configuration 74 | func VerifyConfig(config *Config) error { 75 | if !(config.Version == 1 || config.Version == 2) { 76 | return errors.New("unsupported protocol version") 77 | } 78 | if !config.KeepAliveDisabled { 79 | if config.KeepAliveInterval == 0 { 80 | return errors.New("keep-alive interval must be positive") 81 | } 82 | if config.KeepAliveTimeout < config.KeepAliveInterval { 83 | return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval") 84 | } 85 | } 86 | if config.MaxFrameSize <= 0 { 87 | return errors.New("max frame size must be positive") 88 | } 89 | if config.MaxFrameSize > 65535 { 90 | return errors.New("max frame size must not be larger than 65535") 91 | } 92 | if config.MaxReceiveBuffer <= 0 { 93 | return errors.New("max receive buffer must be positive") 94 | } 95 | if config.MaxStreamBuffer <= 0 { 96 | return errors.New("max stream buffer must be positive") 97 | } 98 | if config.MaxStreamBuffer > config.MaxReceiveBuffer { 99 | return errors.New("max stream buffer must not be larger than max receive buffer") 100 | } 101 | if config.MaxStreamBuffer > math.MaxInt32 { 102 | return errors.New("max stream buffer cannot be larger than 2147483647") 103 | } 104 | return nil 105 | } 106 | 107 | // Server is used to initialize a new server-side connection. 108 | func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { 109 | if config == nil { 110 | config = DefaultConfig() 111 | } 112 | if err := VerifyConfig(config); err != nil { 113 | return nil, err 114 | } 115 | return newSession(config, conn, false), nil 116 | } 117 | 118 | // Client is used to initialize a new client-side connection. 119 | func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { 120 | if config == nil { 121 | config = DefaultConfig() 122 | } 123 | 124 | if err := VerifyConfig(config); err != nil { 125 | return nil, err 126 | } 127 | return newSession(config, conn, true), nil 128 | } 129 | -------------------------------------------------------------------------------- /mux_test.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "bytes" 27 | "testing" 28 | ) 29 | 30 | type buffer struct { 31 | bytes.Buffer 32 | } 33 | 34 | func (b *buffer) Close() error { 35 | b.Buffer.Reset() 36 | return nil 37 | } 38 | 39 | func TestConfig(t *testing.T) { 40 | VerifyConfig(DefaultConfig()) 41 | 42 | config := DefaultConfig() 43 | config.KeepAliveInterval = 0 44 | err := VerifyConfig(config) 45 | t.Log(err) 46 | if err == nil { 47 | t.Fatal(err) 48 | } 49 | 50 | config = DefaultConfig() 51 | config.KeepAliveInterval = 10 52 | config.KeepAliveTimeout = 5 53 | err = VerifyConfig(config) 54 | t.Log(err) 55 | if err == nil { 56 | t.Fatal(err) 57 | } 58 | 59 | config = DefaultConfig() 60 | config.MaxFrameSize = 0 61 | err = VerifyConfig(config) 62 | t.Log(err) 63 | if err == nil { 64 | t.Fatal(err) 65 | } 66 | 67 | config = DefaultConfig() 68 | config.MaxFrameSize = 65536 69 | err = VerifyConfig(config) 70 | t.Log(err) 71 | if err == nil { 72 | t.Fatal(err) 73 | } 74 | 75 | config = DefaultConfig() 76 | config.MaxReceiveBuffer = 0 77 | err = VerifyConfig(config) 78 | t.Log(err) 79 | if err == nil { 80 | t.Fatal(err) 81 | } 82 | 83 | config = DefaultConfig() 84 | config.MaxStreamBuffer = 0 85 | err = VerifyConfig(config) 86 | t.Log(err) 87 | if err == nil { 88 | t.Fatal(err) 89 | } 90 | 91 | config = DefaultConfig() 92 | config.MaxStreamBuffer = 100 93 | config.MaxReceiveBuffer = 99 94 | err = VerifyConfig(config) 95 | t.Log(err) 96 | if err == nil { 97 | t.Fatal(err) 98 | } 99 | 100 | var bts buffer 101 | if _, err := Server(&bts, config); err == nil { 102 | t.Fatal("server started with wrong config") 103 | } 104 | 105 | if _, err := Client(&bts, config); err == nil { 106 | t.Fatal("client started with wrong config") 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /pkg.go: -------------------------------------------------------------------------------- 1 | // Package smux is a multiplexing library for Golang. 2 | // 3 | // It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP, 4 | // and provides stream-oriented multiplexing over a single channel. 5 | 6 | package smux 7 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "container/heap" 27 | "encoding/binary" 28 | "errors" 29 | "io" 30 | "net" 31 | "runtime" 32 | "sync" 33 | "sync/atomic" 34 | "time" 35 | ) 36 | 37 | const ( 38 | defaultAcceptBacklog = 1024 39 | maxShaperSize = 1024 40 | openCloseTimeout = 30 * time.Second // Timeout for opening/closing streams 41 | ) 42 | 43 | // CLASSID represents the class of a frame 44 | type CLASSID int 45 | 46 | const ( 47 | CLSCTRL CLASSID = iota // prioritized control signal 48 | CLSDATA 49 | ) 50 | 51 | // timeoutError representing timeouts for operations such as accept, read and write 52 | // 53 | // To better cooperate with the standard library, timeoutError should implement the standard library's `net.Error`. 54 | // 55 | // For example, using smux to implement net.Listener and work with http.Server, the keep-alive connection (*smux.Stream) will be unexpectedly closed. 56 | // For more details, see https://github.com/xtaci/smux/pull/99. 57 | type timeoutError struct{} 58 | 59 | func (timeoutError) Error() string { return "timeout" } 60 | func (timeoutError) Temporary() bool { return true } 61 | func (timeoutError) Timeout() bool { return true } 62 | 63 | var ( 64 | ErrInvalidProtocol = errors.New("invalid protocol") 65 | ErrConsumed = errors.New("peer consumed more than sent") 66 | ErrGoAway = errors.New("stream id overflows, should start a new connection") 67 | ErrTimeout net.Error = &timeoutError{} 68 | ErrWouldBlock = errors.New("operation would block on IO") 69 | ) 70 | 71 | // writeRequest represents a request to write a frame 72 | type writeRequest struct { 73 | class CLASSID 74 | frame Frame 75 | seq uint32 76 | result chan writeResult 77 | } 78 | 79 | // writeResult represents the result of a write request 80 | type writeResult struct { 81 | n int 82 | err error 83 | } 84 | 85 | // Session defines a multiplexed connection for streams 86 | type Session struct { 87 | conn io.ReadWriteCloser 88 | 89 | config *Config 90 | nextStreamID uint32 // next stream identifier 91 | nextStreamIDLock sync.Mutex 92 | 93 | bucket int32 // token bucket 94 | bucketNotify chan struct{} // used for waiting for tokens 95 | 96 | streams map[uint32]*stream // all streams in this session 97 | streamLock sync.Mutex // locks streams 98 | 99 | die chan struct{} // flag session has died 100 | dieOnce sync.Once 101 | 102 | // socket error handling 103 | socketReadError atomic.Value 104 | socketWriteError atomic.Value 105 | chSocketReadError chan struct{} 106 | chSocketWriteError chan struct{} 107 | socketReadErrorOnce sync.Once 108 | socketWriteErrorOnce sync.Once 109 | 110 | // smux protocol errors 111 | protoError atomic.Value 112 | chProtoError chan struct{} 113 | protoErrorOnce sync.Once 114 | 115 | chAccepts chan *stream 116 | 117 | dataReady int32 // flag data has arrived 118 | 119 | goAway int32 // flag id exhausted 120 | 121 | deadline atomic.Value 122 | 123 | requestID uint32 // Monotonic increasing write request ID 124 | shaper chan writeRequest // a shaper for writing 125 | writes chan writeRequest 126 | } 127 | 128 | func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { 129 | s := new(Session) 130 | s.die = make(chan struct{}) 131 | s.conn = conn 132 | s.config = config 133 | s.streams = make(map[uint32]*stream) 134 | s.chAccepts = make(chan *stream, defaultAcceptBacklog) 135 | s.bucket = int32(config.MaxReceiveBuffer) 136 | s.bucketNotify = make(chan struct{}, 1) 137 | s.shaper = make(chan writeRequest) 138 | s.writes = make(chan writeRequest) 139 | s.chSocketReadError = make(chan struct{}) 140 | s.chSocketWriteError = make(chan struct{}) 141 | s.chProtoError = make(chan struct{}) 142 | 143 | if client { 144 | s.nextStreamID = 1 145 | } else { 146 | s.nextStreamID = 0 147 | } 148 | 149 | go s.shaperLoop() 150 | go s.recvLoop() 151 | go s.sendLoop() 152 | if !config.KeepAliveDisabled { 153 | go s.keepalive() 154 | } 155 | return s 156 | } 157 | 158 | // OpenStream is used to create a new stream 159 | func (s *Session) OpenStream() (*Stream, error) { 160 | if s.IsClosed() { 161 | return nil, io.ErrClosedPipe 162 | } 163 | 164 | // generate stream id 165 | s.nextStreamIDLock.Lock() 166 | if s.goAway > 0 { 167 | s.nextStreamIDLock.Unlock() 168 | return nil, ErrGoAway 169 | } 170 | 171 | s.nextStreamID += 2 172 | sid := s.nextStreamID 173 | if sid == sid%2 { // stream-id overflows 174 | s.goAway = 1 175 | s.nextStreamIDLock.Unlock() 176 | return nil, ErrGoAway 177 | } 178 | s.nextStreamIDLock.Unlock() 179 | 180 | stream := newStream(sid, s.config.MaxFrameSize, s) 181 | 182 | if _, err := s.writeControlFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { 183 | return nil, err 184 | } 185 | 186 | s.streamLock.Lock() 187 | defer s.streamLock.Unlock() 188 | select { 189 | case <-s.chSocketReadError: 190 | return nil, s.socketReadError.Load().(error) 191 | case <-s.chSocketWriteError: 192 | return nil, s.socketWriteError.Load().(error) 193 | case <-s.die: 194 | return nil, io.ErrClosedPipe 195 | default: 196 | s.streams[sid] = stream 197 | wrapper := &Stream{stream: stream} 198 | // NOTE(x): disabled finalizer for issue #997 199 | /* 200 | runtime.SetFinalizer(wrapper, func(s *Stream) { 201 | s.Close() 202 | }) 203 | */ 204 | return wrapper, nil 205 | } 206 | } 207 | 208 | // Open returns a generic ReadWriteCloser 209 | func (s *Session) Open() (io.ReadWriteCloser, error) { 210 | return s.OpenStream() 211 | } 212 | 213 | // AcceptStream is used to block until the next available stream 214 | // is ready to be accepted. 215 | func (s *Session) AcceptStream() (*Stream, error) { 216 | var deadline <-chan time.Time 217 | if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { 218 | timer := time.NewTimer(time.Until(d)) 219 | defer timer.Stop() 220 | deadline = timer.C 221 | } 222 | 223 | select { 224 | case stream := <-s.chAccepts: 225 | wrapper := &Stream{stream: stream} 226 | runtime.SetFinalizer(wrapper, func(s *Stream) { 227 | s.Close() 228 | }) 229 | return wrapper, nil 230 | case <-deadline: 231 | return nil, ErrTimeout 232 | case <-s.chSocketReadError: 233 | return nil, s.socketReadError.Load().(error) 234 | case <-s.chProtoError: 235 | return nil, s.protoError.Load().(error) 236 | case <-s.die: 237 | return nil, io.ErrClosedPipe 238 | } 239 | } 240 | 241 | // Accept Returns a generic ReadWriteCloser instead of smux.Stream 242 | func (s *Session) Accept() (io.ReadWriteCloser, error) { 243 | return s.AcceptStream() 244 | } 245 | 246 | // Close is used to close the session and all streams. 247 | func (s *Session) Close() error { 248 | var once bool 249 | s.dieOnce.Do(func() { 250 | close(s.die) 251 | once = true 252 | }) 253 | 254 | if once { 255 | s.streamLock.Lock() 256 | for k := range s.streams { 257 | s.streams[k].sessionClose() 258 | } 259 | s.streamLock.Unlock() 260 | return s.conn.Close() 261 | } else { 262 | return io.ErrClosedPipe 263 | } 264 | } 265 | 266 | // CloseChan can be used by someone who wants to be notified immediately when this 267 | // session is closed 268 | func (s *Session) CloseChan() <-chan struct{} { 269 | return s.die 270 | } 271 | 272 | // notifyBucket notifies recvLoop that bucket is available 273 | func (s *Session) notifyBucket() { 274 | select { 275 | case s.bucketNotify <- struct{}{}: 276 | default: 277 | } 278 | } 279 | 280 | func (s *Session) notifyReadError(err error) { 281 | s.socketReadErrorOnce.Do(func() { 282 | s.socketReadError.Store(err) 283 | close(s.chSocketReadError) 284 | }) 285 | } 286 | 287 | func (s *Session) notifyWriteError(err error) { 288 | s.socketWriteErrorOnce.Do(func() { 289 | s.socketWriteError.Store(err) 290 | close(s.chSocketWriteError) 291 | }) 292 | } 293 | 294 | func (s *Session) notifyProtoError(err error) { 295 | s.protoErrorOnce.Do(func() { 296 | s.protoError.Store(err) 297 | close(s.chProtoError) 298 | }) 299 | } 300 | 301 | // IsClosed does a safe check to see if we have shutdown 302 | func (s *Session) IsClosed() bool { 303 | select { 304 | case <-s.die: 305 | return true 306 | default: 307 | return false 308 | } 309 | } 310 | 311 | // NumStreams returns the number of currently open streams 312 | func (s *Session) NumStreams() int { 313 | if s.IsClosed() { 314 | return 0 315 | } 316 | s.streamLock.Lock() 317 | defer s.streamLock.Unlock() 318 | return len(s.streams) 319 | } 320 | 321 | // SetDeadline sets a deadline used by Accept* calls. 322 | // A zero time value disables the deadline. 323 | func (s *Session) SetDeadline(t time.Time) error { 324 | s.deadline.Store(t) 325 | return nil 326 | } 327 | 328 | // LocalAddr satisfies net.Conn interface 329 | func (s *Session) LocalAddr() net.Addr { 330 | if ts, ok := s.conn.(interface { 331 | LocalAddr() net.Addr 332 | }); ok { 333 | return ts.LocalAddr() 334 | } 335 | return nil 336 | } 337 | 338 | // RemoteAddr satisfies net.Conn interface 339 | func (s *Session) RemoteAddr() net.Addr { 340 | if ts, ok := s.conn.(interface { 341 | RemoteAddr() net.Addr 342 | }); ok { 343 | return ts.RemoteAddr() 344 | } 345 | return nil 346 | } 347 | 348 | // notify the session that a stream has closed 349 | func (s *Session) streamClosed(sid uint32) { 350 | s.streamLock.Lock() 351 | if stream, ok := s.streams[sid]; ok { 352 | n := stream.recycleTokens() 353 | if n > 0 { // return remaining tokens to the bucket 354 | if atomic.AddInt32(&s.bucket, int32(n)) > 0 { 355 | s.notifyBucket() 356 | } 357 | } 358 | delete(s.streams, sid) 359 | } 360 | s.streamLock.Unlock() 361 | } 362 | 363 | // returnTokens is called by stream to return token after read 364 | func (s *Session) returnTokens(n int) { 365 | if atomic.AddInt32(&s.bucket, int32(n)) > 0 { 366 | s.notifyBucket() 367 | } 368 | } 369 | 370 | // recvLoop keeps on reading from underlying connection if tokens are available 371 | func (s *Session) recvLoop() { 372 | var hdr rawHeader 373 | var updHdr updHeader 374 | 375 | for { 376 | for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { 377 | select { 378 | case <-s.bucketNotify: 379 | case <-s.die: 380 | return 381 | } 382 | } 383 | 384 | // read header first 385 | if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { 386 | atomic.StoreInt32(&s.dataReady, 1) 387 | if hdr.Version() != byte(s.config.Version) { 388 | s.notifyProtoError(ErrInvalidProtocol) 389 | return 390 | } 391 | sid := hdr.StreamID() 392 | switch hdr.Cmd() { 393 | case cmdNOP: 394 | case cmdSYN: // stream opening 395 | s.streamLock.Lock() 396 | if _, ok := s.streams[sid]; !ok { 397 | stream := newStream(sid, s.config.MaxFrameSize, s) 398 | s.streams[sid] = stream 399 | select { 400 | case s.chAccepts <- stream: 401 | case <-s.die: 402 | } 403 | } 404 | s.streamLock.Unlock() 405 | case cmdFIN: // stream closing 406 | s.streamLock.Lock() 407 | if stream, ok := s.streams[sid]; ok { 408 | stream.fin() 409 | stream.notifyReadEvent() 410 | } 411 | s.streamLock.Unlock() 412 | case cmdPSH: // data frame 413 | if hdr.Length() > 0 { 414 | pNewbuf := defaultAllocator.Get(int(hdr.Length())) 415 | if written, err := io.ReadFull(s.conn, *pNewbuf); err == nil { 416 | s.streamLock.Lock() 417 | if stream, ok := s.streams[sid]; ok { 418 | stream.pushBytes(pNewbuf) 419 | // a stream used some token 420 | atomic.AddInt32(&s.bucket, -int32(written)) 421 | stream.notifyReadEvent() 422 | } else { 423 | // data directed to a missing/closed stream, recycle the buffer immediately. 424 | defaultAllocator.Put(pNewbuf) 425 | } 426 | s.streamLock.Unlock() 427 | } else { 428 | s.notifyReadError(err) 429 | return 430 | } 431 | } 432 | case cmdUPD: // a window update signal 433 | if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { 434 | s.streamLock.Lock() 435 | if stream, ok := s.streams[sid]; ok { 436 | stream.update(updHdr.Consumed(), updHdr.Window()) 437 | } 438 | s.streamLock.Unlock() 439 | } else { 440 | s.notifyReadError(err) 441 | return 442 | } 443 | default: 444 | s.notifyProtoError(ErrInvalidProtocol) 445 | return 446 | } 447 | } else { 448 | s.notifyReadError(err) 449 | return 450 | } 451 | } 452 | } 453 | 454 | // keepalive sends NOP frame to peer to keep the connection alive, and detect dead peers 455 | func (s *Session) keepalive() { 456 | tickerPing := time.NewTicker(s.config.KeepAliveInterval) 457 | tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) 458 | defer tickerPing.Stop() 459 | defer tickerTimeout.Stop() 460 | for { 461 | select { 462 | case <-tickerPing.C: 463 | s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, CLSCTRL) 464 | s.notifyBucket() // force a signal to the recvLoop 465 | case <-tickerTimeout.C: 466 | if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { 467 | // recvLoop may block while bucket is 0, in this case, 468 | // session should not be closed. 469 | if atomic.LoadInt32(&s.bucket) > 0 { 470 | s.Close() 471 | return 472 | } 473 | } 474 | case <-s.die: 475 | return 476 | } 477 | } 478 | } 479 | 480 | // shaperLoop implements a priority queue for write requests, 481 | // some control messages are prioritized over data messages 482 | func (s *Session) shaperLoop() { 483 | var reqs shaperHeap 484 | var next writeRequest 485 | var chWrite chan writeRequest 486 | var chShaper chan writeRequest 487 | 488 | for { 489 | // chWrite is not available until it has packet to send 490 | if len(reqs) > 0 { 491 | chWrite = s.writes 492 | next = heap.Pop(&reqs).(writeRequest) 493 | } else { 494 | chWrite = nil 495 | } 496 | 497 | // control heap size, chShaper is not available until packets are less than maximum allowed 498 | if len(reqs) >= maxShaperSize { 499 | chShaper = nil 500 | } else { 501 | chShaper = s.shaper 502 | } 503 | 504 | // assertion on non nil 505 | if chShaper == nil && chWrite == nil { 506 | panic("both channel are nil") 507 | } 508 | 509 | select { 510 | case <-s.die: 511 | return 512 | case r := <-chShaper: 513 | if chWrite != nil { // next is valid, reshape 514 | heap.Push(&reqs, next) 515 | } 516 | heap.Push(&reqs, r) 517 | case chWrite <- next: 518 | } 519 | } 520 | } 521 | 522 | // sendLoop sends frames to the underlying connection 523 | func (s *Session) sendLoop() { 524 | var buf []byte 525 | var n int 526 | var err error 527 | var vec [][]byte // vector for writeBuffers 528 | 529 | bw, ok := s.conn.(interface { 530 | WriteBuffers(v [][]byte) (n int, err error) 531 | }) 532 | 533 | if ok { 534 | buf = make([]byte, headerSize) 535 | vec = make([][]byte, 2) 536 | } else { 537 | buf = make([]byte, (1<<16)+headerSize) 538 | } 539 | 540 | for { 541 | select { 542 | case <-s.die: 543 | return 544 | case request := <-s.writes: 545 | buf[0] = request.frame.ver 546 | buf[1] = request.frame.cmd 547 | binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) 548 | binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) 549 | 550 | // support for scatter-gather I/O 551 | if len(vec) > 0 { 552 | vec[0] = buf[:headerSize] 553 | vec[1] = request.frame.data 554 | n, err = bw.WriteBuffers(vec) 555 | } else { 556 | copy(buf[headerSize:], request.frame.data) 557 | n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) 558 | } 559 | 560 | n -= headerSize 561 | if n < 0 { 562 | n = 0 563 | } 564 | 565 | result := writeResult{ 566 | n: n, 567 | err: err, 568 | } 569 | 570 | request.result <- result 571 | close(request.result) 572 | 573 | // store conn error 574 | if err != nil { 575 | s.notifyWriteError(err) 576 | return 577 | } 578 | } 579 | } 580 | } 581 | 582 | // writeControlFrame writes the control frame to the underlying connection 583 | // and returns the number of bytes written if successful 584 | func (s *Session) writeControlFrame(f Frame) (n int, err error) { 585 | timer := time.NewTimer(openCloseTimeout) 586 | defer timer.Stop() 587 | 588 | return s.writeFrameInternal(f, timer.C, CLSCTRL) 589 | } 590 | 591 | // internal writeFrame version to support deadline used in keepalive 592 | func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, class CLASSID) (int, error) { 593 | req := writeRequest{ 594 | class: class, 595 | frame: f, 596 | seq: atomic.AddUint32(&s.requestID, 1), 597 | result: make(chan writeResult, 1), 598 | } 599 | select { 600 | case s.shaper <- req: 601 | case <-s.die: 602 | return 0, io.ErrClosedPipe 603 | case <-s.chSocketWriteError: 604 | return 0, s.socketWriteError.Load().(error) 605 | case <-deadline: 606 | return 0, ErrTimeout 607 | } 608 | 609 | select { 610 | case result := <-req.result: 611 | return result.n, result.err 612 | case <-s.die: 613 | return 0, io.ErrClosedPipe 614 | case <-s.chSocketWriteError: 615 | return 0, s.socketWriteError.Load().(error) 616 | case <-deadline: 617 | return 0, ErrTimeout 618 | } 619 | } 620 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "bytes" 27 | crand "crypto/rand" 28 | "encoding/binary" 29 | "fmt" 30 | "io" 31 | "log" 32 | "math/rand" 33 | "net" 34 | "net/http" 35 | _ "net/http/pprof" 36 | "strings" 37 | "sync" 38 | "testing" 39 | "time" 40 | ) 41 | 42 | func init() { 43 | go func() { 44 | log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) 45 | }() 46 | } 47 | 48 | // setupServer starts new server listening on a random localhost port and 49 | // returns address of the server, function to stop the server, new client 50 | // connection to this server or an error. 51 | func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { 52 | ln, err := net.Listen("tcp", "localhost:0") 53 | if err != nil { 54 | return "", nil, nil, err 55 | } 56 | go func() { 57 | conn, err := ln.Accept() 58 | if err != nil { 59 | return 60 | } 61 | go handleConnection(conn) 62 | }() 63 | addr = ln.Addr().String() 64 | conn, err := net.Dial("tcp", addr) 65 | if err != nil { 66 | ln.Close() 67 | return "", nil, nil, err 68 | } 69 | return ln.Addr().String(), func() { ln.Close() }, conn, nil 70 | } 71 | 72 | func handleConnection(conn net.Conn) { 73 | session, _ := Server(conn, nil) 74 | for { 75 | if stream, err := session.AcceptStream(); err == nil { 76 | go func(s io.ReadWriteCloser) { 77 | buf := make([]byte, 65536) 78 | for { 79 | n, err := s.Read(buf) 80 | if err != nil { 81 | return 82 | } 83 | s.Write(buf[:n]) 84 | } 85 | }(stream) 86 | } else { 87 | return 88 | } 89 | } 90 | } 91 | 92 | // setupServer starts new server listening on a random localhost port and 93 | // returns address of the server, function to stop the server, new client 94 | // connection to this server or an error. 95 | func setupServerV2(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { 96 | ln, err := net.Listen("tcp", "localhost:0") 97 | if err != nil { 98 | return "", nil, nil, err 99 | } 100 | go func() { 101 | conn, err := ln.Accept() 102 | if err != nil { 103 | return 104 | } 105 | go handleConnectionV2(conn) 106 | }() 107 | addr = ln.Addr().String() 108 | conn, err := net.Dial("tcp", addr) 109 | if err != nil { 110 | ln.Close() 111 | return "", nil, nil, err 112 | } 113 | return ln.Addr().String(), func() { ln.Close() }, conn, nil 114 | } 115 | 116 | func handleConnectionV2(conn net.Conn) { 117 | config := DefaultConfig() 118 | config.Version = 2 119 | session, _ := Server(conn, config) 120 | for { 121 | if stream, err := session.AcceptStream(); err == nil { 122 | go func(s io.ReadWriteCloser) { 123 | buf := make([]byte, 65536) 124 | for { 125 | n, err := s.Read(buf) 126 | if err != nil { 127 | return 128 | } 129 | s.Write(buf[:n]) 130 | } 131 | }(stream) 132 | } else { 133 | return 134 | } 135 | } 136 | } 137 | 138 | func TestEcho(t *testing.T) { 139 | _, stop, cli, err := setupServer(t) 140 | if err != nil { 141 | t.Fatal(err) 142 | } 143 | defer stop() 144 | session, _ := Client(cli, nil) 145 | stream, _ := session.OpenStream() 146 | const N = 100 147 | buf := make([]byte, 10) 148 | var sent string 149 | var received string 150 | for i := 0; i < N; i++ { 151 | msg := fmt.Sprintf("hello%v", i) 152 | stream.Write([]byte(msg)) 153 | sent += msg 154 | if n, err := stream.Read(buf); err != nil { 155 | t.Fatal(err) 156 | } else { 157 | received += string(buf[:n]) 158 | } 159 | } 160 | if sent != received { 161 | t.Fatal("data mimatch") 162 | } 163 | session.Close() 164 | } 165 | 166 | func TestWriteTo(t *testing.T) { 167 | const N = 1 << 20 168 | // server 169 | ln, err := net.Listen("tcp", "localhost:0") 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | defer ln.Close() 174 | 175 | go func() { 176 | conn, err := ln.Accept() 177 | if err != nil { 178 | return 179 | } 180 | session, _ := Server(conn, nil) 181 | for { 182 | if stream, err := session.AcceptStream(); err == nil { 183 | go func(s io.ReadWriteCloser) { 184 | numBytes := 0 185 | buf := make([]byte, 65536) 186 | for { 187 | n, err := s.Read(buf) 188 | if err != nil { 189 | return 190 | } 191 | s.Write(buf[:n]) 192 | numBytes += n 193 | 194 | if numBytes == N { 195 | s.Close() 196 | return 197 | } 198 | } 199 | }(stream) 200 | } else { 201 | return 202 | } 203 | } 204 | }() 205 | 206 | addr := ln.Addr().String() 207 | conn, err := net.Dial("tcp", addr) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | defer conn.Close() 212 | 213 | // client 214 | session, _ := Client(conn, nil) 215 | stream, _ := session.OpenStream() 216 | sndbuf := make([]byte, N) 217 | for i := range sndbuf { 218 | sndbuf[i] = byte(rand.Int()) 219 | } 220 | 221 | go stream.Write(sndbuf) 222 | 223 | var rcvbuf bytes.Buffer 224 | nw, ew := stream.WriteTo(&rcvbuf) 225 | if ew != io.EOF { 226 | t.Fatal(ew) 227 | } 228 | 229 | if nw != N { 230 | t.Fatal("WriteTo nw mismatch", nw) 231 | } 232 | 233 | if !bytes.Equal(sndbuf, rcvbuf.Bytes()) { 234 | t.Fatal("mismatched echo bytes") 235 | } 236 | t.Log(stream) 237 | } 238 | 239 | func TestWriteToV2(t *testing.T) { 240 | config := DefaultConfig() 241 | config.Version = 2 242 | const N = 1 << 20 243 | // server 244 | ln, err := net.Listen("tcp", "localhost:0") 245 | if err != nil { 246 | t.Fatal(err) 247 | } 248 | defer ln.Close() 249 | 250 | go func() { 251 | conn, err := ln.Accept() 252 | if err != nil { 253 | return 254 | } 255 | session, _ := Server(conn, config) 256 | for { 257 | if stream, err := session.AcceptStream(); err == nil { 258 | go func(s io.ReadWriteCloser) { 259 | numBytes := 0 260 | buf := make([]byte, 65536) 261 | for { 262 | n, err := s.Read(buf) 263 | if err != nil { 264 | return 265 | } 266 | s.Write(buf[:n]) 267 | numBytes += n 268 | 269 | if numBytes == N { 270 | s.Close() 271 | return 272 | } 273 | } 274 | }(stream) 275 | } else { 276 | return 277 | } 278 | } 279 | }() 280 | 281 | addr := ln.Addr().String() 282 | conn, err := net.Dial("tcp", addr) 283 | if err != nil { 284 | t.Fatal(err) 285 | } 286 | defer conn.Close() 287 | 288 | // client 289 | session, _ := Client(conn, config) 290 | stream, _ := session.OpenStream() 291 | sndbuf := make([]byte, N) 292 | for i := range sndbuf { 293 | sndbuf[i] = byte(rand.Int()) 294 | } 295 | 296 | go stream.Write(sndbuf) 297 | 298 | var rcvbuf bytes.Buffer 299 | nw, ew := stream.WriteTo(&rcvbuf) 300 | if ew != io.EOF { 301 | t.Fatal(ew) 302 | } 303 | 304 | if nw != N { 305 | t.Fatal("WriteTo nw mismatch", nw, N) 306 | } 307 | 308 | if !bytes.Equal(sndbuf, rcvbuf.Bytes()) { 309 | t.Fatal("mismatched echo bytes") 310 | } 311 | 312 | t.Log(stream) 313 | } 314 | 315 | func TestGetDieCh(t *testing.T) { 316 | cs, ss, err := getSmuxStreamPair() 317 | if err != nil { 318 | t.Fatal(err) 319 | } 320 | defer ss.Close() 321 | dieCh := ss.GetDieCh() 322 | go func() { 323 | select { 324 | case <-dieCh: 325 | case <-time.Tick(time.Second): 326 | t.Fatal("wait die chan timeout") 327 | } 328 | }() 329 | cs.Close() 330 | } 331 | 332 | func TestSpeed(t *testing.T) { 333 | _, stop, cli, err := setupServer(t) 334 | if err != nil { 335 | t.Fatal(err) 336 | } 337 | defer stop() 338 | session, _ := Client(cli, nil) 339 | stream, _ := session.OpenStream() 340 | t.Log(stream.LocalAddr(), stream.RemoteAddr()) 341 | 342 | start := time.Now() 343 | var wg sync.WaitGroup 344 | wg.Add(1) 345 | go func() { 346 | buf := make([]byte, 1024*1024) 347 | nrecv := 0 348 | for { 349 | n, err := stream.Read(buf) 350 | if err != nil { 351 | t.Error(err) 352 | break 353 | } else { 354 | nrecv += n 355 | if nrecv == 4096*4096 { 356 | break 357 | } 358 | } 359 | } 360 | stream.Close() 361 | t.Log("time for 16MB rtt", time.Since(start)) 362 | wg.Done() 363 | }() 364 | msg := make([]byte, 8192) 365 | for i := 0; i < 2048; i++ { 366 | stream.Write(msg) 367 | } 368 | wg.Wait() 369 | session.Close() 370 | } 371 | 372 | func TestParallel(t *testing.T) { 373 | _, stop, cli, err := setupServer(t) 374 | if err != nil { 375 | t.Fatal(err) 376 | } 377 | defer stop() 378 | session, _ := Client(cli, nil) 379 | 380 | par := 1000 381 | messages := 100 382 | var wg sync.WaitGroup 383 | wg.Add(par) 384 | for i := 0; i < par; i++ { 385 | stream, _ := session.OpenStream() 386 | go func(s *Stream) { 387 | buf := make([]byte, 20) 388 | for j := 0; j < messages; j++ { 389 | msg := fmt.Sprintf("hello%v", j) 390 | s.Write([]byte(msg)) 391 | if _, err := s.Read(buf); err != nil { 392 | break 393 | } 394 | } 395 | s.Close() 396 | wg.Done() 397 | }(stream) 398 | } 399 | t.Log("created", session.NumStreams(), "streams") 400 | wg.Wait() 401 | session.Close() 402 | } 403 | 404 | func TestParallelV2(t *testing.T) { 405 | config := DefaultConfig() 406 | config.Version = 2 407 | _, stop, cli, err := setupServerV2(t) 408 | if err != nil { 409 | t.Fatal(err) 410 | } 411 | defer stop() 412 | session, _ := Client(cli, config) 413 | 414 | par := 1000 415 | messages := 100 416 | var wg sync.WaitGroup 417 | wg.Add(par) 418 | for i := 0; i < par; i++ { 419 | stream, _ := session.OpenStream() 420 | go func(s *Stream) { 421 | buf := make([]byte, 20) 422 | for j := 0; j < messages; j++ { 423 | msg := fmt.Sprintf("hello%v", j) 424 | s.Write([]byte(msg)) 425 | if _, err := s.Read(buf); err != nil { 426 | break 427 | } 428 | } 429 | s.Close() 430 | wg.Done() 431 | }(stream) 432 | } 433 | t.Log("created", session.NumStreams(), "streams") 434 | wg.Wait() 435 | session.Close() 436 | } 437 | 438 | func TestCloseThenOpen(t *testing.T) { 439 | _, stop, cli, err := setupServer(t) 440 | if err != nil { 441 | t.Fatal(err) 442 | } 443 | defer stop() 444 | session, _ := Client(cli, nil) 445 | session.Close() 446 | if _, err := session.OpenStream(); err == nil { 447 | t.Fatal("opened after close") 448 | } 449 | } 450 | 451 | func TestSessionDoubleClose(t *testing.T) { 452 | _, stop, cli, err := setupServer(t) 453 | if err != nil { 454 | t.Fatal(err) 455 | } 456 | defer stop() 457 | session, _ := Client(cli, nil) 458 | session.Close() 459 | if err := session.Close(); err == nil { 460 | t.Fatal("session double close doesn't return error") 461 | } 462 | } 463 | 464 | func TestStreamDoubleClose(t *testing.T) { 465 | _, stop, cli, err := setupServer(t) 466 | if err != nil { 467 | t.Fatal(err) 468 | } 469 | defer stop() 470 | session, _ := Client(cli, nil) 471 | stream, _ := session.OpenStream() 472 | stream.Close() 473 | if err := stream.Close(); err == nil { 474 | t.Fatal("stream double close doesn't return error") 475 | } 476 | session.Close() 477 | } 478 | 479 | func TestConcurrentClose(t *testing.T) { 480 | _, stop, cli, err := setupServer(t) 481 | if err != nil { 482 | t.Fatal(err) 483 | } 484 | defer stop() 485 | session, _ := Client(cli, nil) 486 | numStreams := 100 487 | streams := make([]*Stream, 0, numStreams) 488 | var wg sync.WaitGroup 489 | wg.Add(numStreams) 490 | for i := 0; i < 100; i++ { 491 | stream, _ := session.OpenStream() 492 | streams = append(streams, stream) 493 | } 494 | for _, s := range streams { 495 | stream := s 496 | go func() { 497 | stream.Close() 498 | wg.Done() 499 | }() 500 | } 501 | session.Close() 502 | wg.Wait() 503 | } 504 | 505 | func TestTinyReadBuffer(t *testing.T) { 506 | _, stop, cli, err := setupServer(t) 507 | if err != nil { 508 | t.Fatal(err) 509 | } 510 | defer stop() 511 | session, _ := Client(cli, nil) 512 | stream, _ := session.OpenStream() 513 | const N = 100 514 | tinybuf := make([]byte, 6) 515 | var sent string 516 | var received string 517 | for i := 0; i < N; i++ { 518 | msg := fmt.Sprintf("hello%v", i) 519 | sent += msg 520 | nsent, err := stream.Write([]byte(msg)) 521 | if err != nil { 522 | t.Fatal("cannot write") 523 | } 524 | nrecv := 0 525 | for nrecv < nsent { 526 | if n, err := stream.Read(tinybuf); err == nil { 527 | nrecv += n 528 | received += string(tinybuf[:n]) 529 | } else { 530 | t.Fatal("cannot read with tiny buffer") 531 | } 532 | } 533 | } 534 | 535 | if sent != received { 536 | t.Fatal("data mimatch") 537 | } 538 | session.Close() 539 | } 540 | 541 | func TestIsClose(t *testing.T) { 542 | _, stop, cli, err := setupServer(t) 543 | if err != nil { 544 | t.Fatal(err) 545 | } 546 | defer stop() 547 | session, _ := Client(cli, nil) 548 | session.Close() 549 | if !session.IsClosed() { 550 | t.Fatal("still open after close") 551 | } 552 | } 553 | 554 | func TestKeepAliveTimeout(t *testing.T) { 555 | ln, err := net.Listen("tcp", "localhost:0") 556 | if err != nil { 557 | t.Fatal(err) 558 | } 559 | defer ln.Close() 560 | go func() { 561 | ln.Accept() 562 | }() 563 | 564 | cli, err := net.Dial("tcp", ln.Addr().String()) 565 | if err != nil { 566 | t.Fatal(err) 567 | } 568 | defer cli.Close() 569 | 570 | config := DefaultConfig() 571 | config.KeepAliveInterval = time.Second 572 | config.KeepAliveTimeout = 2 * time.Second 573 | session, _ := Client(cli, config) 574 | time.Sleep(3 * time.Second) 575 | if !session.IsClosed() { 576 | t.Fatal("keepalive-timeout failed") 577 | } 578 | } 579 | 580 | type blockWriteConn struct { 581 | net.Conn 582 | } 583 | 584 | func (c *blockWriteConn) Write(b []byte) (n int, err error) { 585 | forever := time.Hour * 24 586 | time.Sleep(forever) 587 | return c.Conn.Write(b) 588 | } 589 | 590 | func TestKeepAliveBlockWriteTimeout(t *testing.T) { 591 | ln, err := net.Listen("tcp", "localhost:0") 592 | if err != nil { 593 | t.Fatal(err) 594 | } 595 | defer ln.Close() 596 | go func() { 597 | ln.Accept() 598 | }() 599 | 600 | cli, err := net.Dial("tcp", ln.Addr().String()) 601 | if err != nil { 602 | t.Fatal(err) 603 | } 604 | defer cli.Close() 605 | //when writeFrame block, keepalive in old version never timeout 606 | blockWriteCli := &blockWriteConn{cli} 607 | 608 | config := DefaultConfig() 609 | config.KeepAliveInterval = time.Second 610 | config.KeepAliveTimeout = 2 * time.Second 611 | session, _ := Client(blockWriteCli, config) 612 | time.Sleep(3 * time.Second) 613 | if !session.IsClosed() { 614 | t.Fatal("keepalive-timeout failed") 615 | } 616 | } 617 | 618 | func TestServerEcho(t *testing.T) { 619 | ln, err := net.Listen("tcp", "localhost:0") 620 | if err != nil { 621 | t.Fatal(err) 622 | } 623 | defer ln.Close() 624 | go func() { 625 | err := func() error { 626 | conn, err := ln.Accept() 627 | if err != nil { 628 | return err 629 | } 630 | defer conn.Close() 631 | session, err := Server(conn, nil) 632 | if err != nil { 633 | return err 634 | } 635 | defer session.Close() 636 | buf := make([]byte, 10) 637 | stream, err := session.OpenStream() 638 | if err != nil { 639 | return err 640 | } 641 | defer stream.Close() 642 | for i := 0; i < 100; i++ { 643 | msg := fmt.Sprintf("hello%v", i) 644 | stream.Write([]byte(msg)) 645 | n, err := stream.Read(buf) 646 | if err != nil { 647 | return err 648 | } 649 | if got := string(buf[:n]); got != msg { 650 | return fmt.Errorf("got: %q, want: %q", got, msg) 651 | } 652 | } 653 | return nil 654 | }() 655 | if err != nil { 656 | t.Error(err) 657 | } 658 | }() 659 | 660 | cli, err := net.Dial("tcp", ln.Addr().String()) 661 | if err != nil { 662 | t.Fatal(err) 663 | } 664 | defer cli.Close() 665 | if session, err := Client(cli, nil); err == nil { 666 | if stream, err := session.AcceptStream(); err == nil { 667 | buf := make([]byte, 65536) 668 | for { 669 | n, err := stream.Read(buf) 670 | if err != nil { 671 | break 672 | } 673 | stream.Write(buf[:n]) 674 | } 675 | } else { 676 | t.Fatal(err) 677 | } 678 | } else { 679 | t.Fatal(err) 680 | } 681 | } 682 | 683 | func TestSendWithoutRecv(t *testing.T) { 684 | _, stop, cli, err := setupServer(t) 685 | if err != nil { 686 | t.Fatal(err) 687 | } 688 | defer stop() 689 | session, _ := Client(cli, nil) 690 | stream, _ := session.OpenStream() 691 | const N = 100 692 | for i := 0; i < N; i++ { 693 | msg := fmt.Sprintf("hello%v", i) 694 | stream.Write([]byte(msg)) 695 | } 696 | buf := make([]byte, 1) 697 | if _, err := stream.Read(buf); err != nil { 698 | t.Fatal(err) 699 | } 700 | stream.Close() 701 | } 702 | 703 | func TestWriteAfterClose(t *testing.T) { 704 | _, stop, cli, err := setupServer(t) 705 | if err != nil { 706 | t.Fatal(err) 707 | } 708 | defer stop() 709 | session, _ := Client(cli, nil) 710 | stream, _ := session.OpenStream() 711 | stream.Close() 712 | if _, err := stream.Write([]byte("write after close")); err == nil { 713 | t.Fatal("write after close failed") 714 | } 715 | } 716 | 717 | func TestReadStreamAfterSessionClose(t *testing.T) { 718 | _, stop, cli, err := setupServer(t) 719 | if err != nil { 720 | t.Fatal(err) 721 | } 722 | defer stop() 723 | session, _ := Client(cli, nil) 724 | stream, _ := session.OpenStream() 725 | session.Close() 726 | buf := make([]byte, 10) 727 | if _, err := stream.Read(buf); err != nil { 728 | t.Log(err) 729 | } else { 730 | t.Fatal("read stream after session close succeeded") 731 | } 732 | } 733 | 734 | func TestWriteStreamAfterConnectionClose(t *testing.T) { 735 | _, stop, cli, err := setupServer(t) 736 | if err != nil { 737 | t.Fatal(err) 738 | } 739 | defer stop() 740 | session, _ := Client(cli, nil) 741 | stream, _ := session.OpenStream() 742 | session.conn.Close() 743 | if _, err := stream.Write([]byte("write after connection close")); err == nil { 744 | t.Fatal("write after connection close failed") 745 | } 746 | } 747 | 748 | func TestNumStreamAfterClose(t *testing.T) { 749 | _, stop, cli, err := setupServer(t) 750 | if err != nil { 751 | t.Fatal(err) 752 | } 753 | defer stop() 754 | session, _ := Client(cli, nil) 755 | if _, err := session.OpenStream(); err == nil { 756 | if session.NumStreams() != 1 { 757 | t.Fatal("wrong number of streams after opened") 758 | } 759 | session.Close() 760 | if session.NumStreams() != 0 { 761 | t.Fatal("wrong number of streams after session closed") 762 | } 763 | } else { 764 | t.Fatal(err) 765 | } 766 | cli.Close() 767 | } 768 | 769 | func TestRandomFrame(t *testing.T) { 770 | addr, stop, cli, err := setupServer(t) 771 | if err != nil { 772 | t.Fatal(err) 773 | } 774 | defer stop() 775 | // pure random 776 | session, _ := Client(cli, nil) 777 | for i := 0; i < 100; i++ { 778 | rnd := make([]byte, rand.Uint32()%1024) 779 | io.ReadFull(crand.Reader, rnd) 780 | session.conn.Write(rnd) 781 | } 782 | cli.Close() 783 | 784 | // double syn 785 | cli, err = net.Dial("tcp", addr) 786 | if err != nil { 787 | t.Fatal(err) 788 | } 789 | session, _ = Client(cli, nil) 790 | for i := 0; i < 100; i++ { 791 | f := newFrame(1, cmdSYN, 1000) 792 | session.writeControlFrame(f) 793 | } 794 | cli.Close() 795 | 796 | // random cmds 797 | cli, err = net.Dial("tcp", addr) 798 | if err != nil { 799 | t.Fatal(err) 800 | } 801 | allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} 802 | session, _ = Client(cli, nil) 803 | for i := 0; i < 100; i++ { 804 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) 805 | session.writeControlFrame(f) 806 | } 807 | cli.Close() 808 | 809 | // random cmds & sids 810 | cli, err = net.Dial("tcp", addr) 811 | if err != nil { 812 | t.Fatal(err) 813 | } 814 | session, _ = Client(cli, nil) 815 | for i := 0; i < 100; i++ { 816 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 817 | session.writeControlFrame(f) 818 | } 819 | cli.Close() 820 | 821 | // random version 822 | cli, err = net.Dial("tcp", addr) 823 | if err != nil { 824 | t.Fatal(err) 825 | } 826 | session, _ = Client(cli, nil) 827 | for i := 0; i < 100; i++ { 828 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 829 | f.ver = byte(rand.Uint32()) 830 | session.writeControlFrame(f) 831 | } 832 | cli.Close() 833 | 834 | // incorrect size 835 | cli, err = net.Dial("tcp", addr) 836 | if err != nil { 837 | t.Fatal(err) 838 | } 839 | session, _ = Client(cli, nil) 840 | 841 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 842 | rnd := make([]byte, rand.Uint32()%1024) 843 | io.ReadFull(crand.Reader, rnd) 844 | f.data = rnd 845 | 846 | buf := make([]byte, headerSize+len(f.data)) 847 | buf[0] = f.ver 848 | buf[1] = f.cmd 849 | binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size 850 | binary.LittleEndian.PutUint32(buf[4:], f.sid) 851 | copy(buf[headerSize:], f.data) 852 | 853 | session.conn.Write(buf) 854 | cli.Close() 855 | 856 | // writeFrame after die 857 | cli, err = net.Dial("tcp", addr) 858 | if err != nil { 859 | t.Fatal(err) 860 | } 861 | session, _ = Client(cli, nil) 862 | //close first 863 | session.Close() 864 | for i := 0; i < 100; i++ { 865 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 866 | session.writeControlFrame(f) 867 | } 868 | } 869 | 870 | func TestWriteFrameInternal(t *testing.T) { 871 | addr, stop, cli, err := setupServer(t) 872 | if err != nil { 873 | t.Fatal(err) 874 | } 875 | defer stop() 876 | // pure random 877 | session, _ := Client(cli, nil) 878 | for i := 0; i < 100; i++ { 879 | rnd := make([]byte, rand.Uint32()%1024) 880 | io.ReadFull(crand.Reader, rnd) 881 | session.conn.Write(rnd) 882 | } 883 | cli.Close() 884 | 885 | // writeFrame after die 886 | cli, err = net.Dial("tcp", addr) 887 | if err != nil { 888 | t.Fatal(err) 889 | } 890 | session, _ = Client(cli, nil) 891 | //close first 892 | session.Close() 893 | for i := 0; i < 100; i++ { 894 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 895 | 896 | timer := time.NewTimer(session.config.KeepAliveTimeout) 897 | defer timer.Stop() 898 | 899 | session.writeFrameInternal(f, timer.C, CLSDATA) 900 | } 901 | 902 | // random cmds 903 | cli, err = net.Dial("tcp", addr) 904 | if err != nil { 905 | t.Fatal(err) 906 | } 907 | allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} 908 | session, _ = Client(cli, nil) 909 | for i := 0; i < 100; i++ { 910 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) 911 | 912 | timer := time.NewTimer(session.config.KeepAliveTimeout) 913 | defer timer.Stop() 914 | 915 | session.writeFrameInternal(f, timer.C, CLSDATA) 916 | } 917 | //deadline occur 918 | { 919 | c := make(chan time.Time) 920 | close(c) 921 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) 922 | _, err := session.writeFrameInternal(f, c, CLSDATA) 923 | if !strings.Contains(err.Error(), "timeout") { 924 | t.Fatal("write frame with deadline failed", err) 925 | } 926 | } 927 | cli.Close() 928 | 929 | { 930 | cli, err = net.Dial("tcp", addr) 931 | if err != nil { 932 | t.Fatal(err) 933 | } 934 | config := DefaultConfig() 935 | config.KeepAliveInterval = time.Second 936 | config.KeepAliveTimeout = 2 * time.Second 937 | session, _ = Client(&blockWriteConn{cli}, config) 938 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) 939 | c := make(chan time.Time) 940 | go func() { 941 | //die first, deadline second, better for coverage 942 | time.Sleep(time.Second) 943 | session.Close() 944 | time.Sleep(time.Second) 945 | close(c) 946 | }() 947 | _, err = session.writeFrameInternal(f, c, CLSDATA) 948 | if !strings.Contains(err.Error(), "closed pipe") { 949 | t.Fatal("write frame with to closed conn failed", err) 950 | } 951 | } 952 | } 953 | 954 | func TestReadDeadline(t *testing.T) { 955 | _, stop, cli, err := setupServer(t) 956 | if err != nil { 957 | t.Fatal(err) 958 | } 959 | defer stop() 960 | session, _ := Client(cli, nil) 961 | stream, _ := session.OpenStream() 962 | const N = 100 963 | buf := make([]byte, 10) 964 | var readErr error 965 | for i := 0; i < N; i++ { 966 | stream.SetReadDeadline(time.Now().Add(-1 * time.Minute)) 967 | if _, readErr = stream.Read(buf); readErr != nil { 968 | break 969 | } 970 | } 971 | if readErr != nil { 972 | if !strings.Contains(readErr.Error(), "timeout") { 973 | t.Fatalf("Wrong error: %v", readErr) 974 | } 975 | } else { 976 | t.Fatal("No error when reading with past deadline") 977 | } 978 | session.Close() 979 | } 980 | 981 | func TestWriteDeadline(t *testing.T) { 982 | _, stop, cli, err := setupServer(t) 983 | if err != nil { 984 | t.Fatal(err) 985 | } 986 | defer stop() 987 | session, _ := Client(cli, nil) 988 | stream, _ := session.OpenStream() 989 | buf := make([]byte, 10) 990 | var writeErr error 991 | for { 992 | stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute)) 993 | if _, writeErr = stream.Write(buf); writeErr != nil { 994 | if !strings.Contains(writeErr.Error(), "timeout") { 995 | t.Fatalf("Wrong error: %v", writeErr) 996 | } 997 | break 998 | } 999 | } 1000 | session.Close() 1001 | } 1002 | 1003 | func BenchmarkAcceptClose(b *testing.B) { 1004 | _, stop, cli, err := setupServer(b) 1005 | if err != nil { 1006 | b.Fatal(err) 1007 | } 1008 | defer stop() 1009 | session, _ := Client(cli, nil) 1010 | for i := 0; i < b.N; i++ { 1011 | if stream, err := session.OpenStream(); err == nil { 1012 | stream.Close() 1013 | } else { 1014 | b.Fatal(err) 1015 | } 1016 | } 1017 | } 1018 | func BenchmarkConnSmux(b *testing.B) { 1019 | cs, ss, err := getSmuxStreamPair() 1020 | if err != nil { 1021 | b.Fatal(err) 1022 | } 1023 | defer cs.Close() 1024 | defer ss.Close() 1025 | bench(b, cs, ss) 1026 | } 1027 | 1028 | func BenchmarkConnTCP(b *testing.B) { 1029 | cs, ss, err := getTCPConnectionPair() 1030 | if err != nil { 1031 | b.Fatal(err) 1032 | } 1033 | defer cs.Close() 1034 | defer ss.Close() 1035 | bench(b, cs, ss) 1036 | } 1037 | 1038 | func getSmuxStreamPair() (*Stream, *Stream, error) { 1039 | c1, c2, err := getTCPConnectionPair() 1040 | if err != nil { 1041 | return nil, nil, err 1042 | } 1043 | 1044 | s, err := Server(c2, nil) 1045 | if err != nil { 1046 | return nil, nil, err 1047 | } 1048 | c, err := Client(c1, nil) 1049 | if err != nil { 1050 | return nil, nil, err 1051 | } 1052 | var ss *Stream 1053 | done := make(chan error) 1054 | go func() { 1055 | var rerr error 1056 | ss, rerr = s.AcceptStream() 1057 | done <- rerr 1058 | close(done) 1059 | }() 1060 | cs, err := c.OpenStream() 1061 | if err != nil { 1062 | return nil, nil, err 1063 | } 1064 | err = <-done 1065 | if err != nil { 1066 | return nil, nil, err 1067 | } 1068 | 1069 | return cs, ss, nil 1070 | } 1071 | 1072 | func getTCPConnectionPair() (net.Conn, net.Conn, error) { 1073 | lst, err := net.Listen("tcp", "localhost:0") 1074 | if err != nil { 1075 | return nil, nil, err 1076 | } 1077 | defer lst.Close() 1078 | 1079 | var conn0 net.Conn 1080 | var err0 error 1081 | done := make(chan struct{}) 1082 | go func() { 1083 | conn0, err0 = lst.Accept() 1084 | close(done) 1085 | }() 1086 | 1087 | conn1, err := net.Dial("tcp", lst.Addr().String()) 1088 | if err != nil { 1089 | return nil, nil, err 1090 | } 1091 | 1092 | <-done 1093 | if err0 != nil { 1094 | return nil, nil, err0 1095 | } 1096 | return conn0, conn1, nil 1097 | } 1098 | 1099 | func bench(b *testing.B, rd io.Reader, wr io.Writer) { 1100 | buf := make([]byte, 128*1024) 1101 | buf2 := make([]byte, 128*1024) 1102 | b.SetBytes(128 * 1024) 1103 | b.ResetTimer() 1104 | b.ReportAllocs() 1105 | 1106 | var wg sync.WaitGroup 1107 | wg.Add(1) 1108 | go func() { 1109 | defer wg.Done() 1110 | count := 0 1111 | for { 1112 | n, _ := rd.Read(buf2) 1113 | count += n 1114 | if count == 128*1024*b.N { 1115 | return 1116 | } 1117 | } 1118 | }() 1119 | for i := 0; i < b.N; i++ { 1120 | wr.Write(buf) 1121 | // invalidate L3 Cache 1122 | buf = make([]byte, 128*1024) 1123 | } 1124 | wg.Wait() 1125 | } 1126 | -------------------------------------------------------------------------------- /shaper.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | // _itimediff returns the time difference between two uint32 values. 26 | // The result is a signed 32-bit integer representing the difference between 'later' and 'earlier'. 27 | func _itimediff(later, earlier uint32) int32 { 28 | return (int32)(later - earlier) 29 | } 30 | 31 | // shaperHeap is a min-heap of writeRequest. 32 | // It orders writeRequests by class first, then by sequence number within the same class. 33 | type shaperHeap []writeRequest 34 | 35 | func (h shaperHeap) Len() int { return len(h) } 36 | 37 | // Less determines the ordering of elements in the heap. 38 | // Requests are ordered by their class first. If two requests have the same class, 39 | // they are ordered by their sequence numbers. 40 | func (h shaperHeap) Less(i, j int) bool { 41 | if h[i].class != h[j].class { 42 | return h[i].class < h[j].class 43 | } 44 | return _itimediff(h[j].seq, h[i].seq) > 0 45 | } 46 | 47 | func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 48 | func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) } 49 | 50 | func (h *shaperHeap) Pop() interface{} { 51 | old := *h 52 | n := len(old) 53 | x := old[n-1] 54 | *h = old[0 : n-1] 55 | return x 56 | } 57 | -------------------------------------------------------------------------------- /shaper_test.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "container/heap" 27 | "testing" 28 | ) 29 | 30 | func TestShaper(t *testing.T) { 31 | w1 := writeRequest{seq: 1} 32 | w2 := writeRequest{seq: 2} 33 | w3 := writeRequest{seq: 3} 34 | w4 := writeRequest{seq: 4} 35 | w5 := writeRequest{seq: 5} 36 | 37 | var reqs shaperHeap 38 | heap.Push(&reqs, w5) 39 | heap.Push(&reqs, w4) 40 | heap.Push(&reqs, w3) 41 | heap.Push(&reqs, w2) 42 | heap.Push(&reqs, w1) 43 | 44 | for len(reqs) > 0 { 45 | w := heap.Pop(&reqs).(writeRequest) 46 | t.Log("sid:", w.frame.sid, "seq:", w.seq) 47 | } 48 | } 49 | 50 | func TestShaper2(t *testing.T) { 51 | w1 := writeRequest{class: CLSDATA, seq: 1} // stream 0 52 | w2 := writeRequest{class: CLSDATA, seq: 2} 53 | w3 := writeRequest{class: CLSDATA, seq: 3} 54 | w4 := writeRequest{class: CLSDATA, seq: 4} 55 | w5 := writeRequest{class: CLSDATA, seq: 5} 56 | w6 := writeRequest{class: CLSCTRL, seq: 6, frame: Frame{sid: 10}} // ctrl 1 57 | w7 := writeRequest{class: CLSCTRL, seq: 7, frame: Frame{sid: 11}} // ctrl 2 58 | 59 | var reqs shaperHeap 60 | heap.Push(&reqs, w6) 61 | heap.Push(&reqs, w5) 62 | heap.Push(&reqs, w4) 63 | heap.Push(&reqs, w3) 64 | heap.Push(&reqs, w2) 65 | heap.Push(&reqs, w1) 66 | heap.Push(&reqs, w7) 67 | 68 | for len(reqs) > 0 { 69 | w := heap.Pop(&reqs).(writeRequest) 70 | t.Log("sid:", w.frame.sid, "seq:", w.seq) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | // MIT License 2 | // 3 | // Copyright (c) 2016-2017 xtaci 4 | // 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | // 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | // 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | package smux 24 | 25 | import ( 26 | "encoding/binary" 27 | "io" 28 | "net" 29 | "sync" 30 | "sync/atomic" 31 | "time" 32 | ) 33 | 34 | // wrapper for GC 35 | type Stream struct { 36 | *stream 37 | } 38 | 39 | // Stream implements net.Conn 40 | type stream struct { 41 | id uint32 // Stream identifier 42 | sess *Session 43 | 44 | buffers []*[]byte // the sequential buffers of stream 45 | heads []*[]byte // slice heads of the buffers above, kept for recycle 46 | 47 | bufferLock sync.Mutex // Mutex to protect access to buffers 48 | frameSize int // Maximum frame size for the stream 49 | 50 | // notify a read event 51 | chReadEvent chan struct{} 52 | 53 | // flag the stream has closed 54 | die chan struct{} 55 | dieOnce sync.Once // Ensures die channel is closed only once 56 | 57 | // FIN command 58 | chFinEvent chan struct{} 59 | finEventOnce sync.Once // Ensures chFinEvent is closed only once 60 | 61 | // deadlines 62 | readDeadline atomic.Value 63 | writeDeadline atomic.Value 64 | 65 | // per stream sliding window control 66 | numRead uint32 // count num of bytes read 67 | numWritten uint32 // count num of bytes written 68 | incr uint32 // bytes sent since last window update 69 | 70 | // UPD command 71 | peerConsumed uint32 // num of bytes the peer has consumed 72 | peerWindow uint32 // peer window, initialized to 256KB, updated by peer 73 | chUpdate chan struct{} // notify of remote data consuming and window update 74 | } 75 | 76 | // newStream initializes and returns a new Stream. 77 | func newStream(id uint32, frameSize int, sess *Session) *stream { 78 | s := new(stream) 79 | s.id = id 80 | s.chReadEvent = make(chan struct{}, 1) 81 | s.chUpdate = make(chan struct{}, 1) 82 | s.frameSize = frameSize 83 | s.sess = sess 84 | s.die = make(chan struct{}) 85 | s.chFinEvent = make(chan struct{}) 86 | s.peerWindow = initialPeerWindow // set to initial window size 87 | 88 | return s 89 | } 90 | 91 | // ID returns the stream's unique identifier. 92 | func (s *stream) ID() uint32 { 93 | return s.id 94 | } 95 | 96 | // Read reads data from the stream into the provided buffer. 97 | func (s *stream) Read(b []byte) (n int, err error) { 98 | for { 99 | n, err = s.tryRead(b) 100 | if err == ErrWouldBlock { 101 | if ew := s.waitRead(); ew != nil { 102 | return 0, ew 103 | } 104 | } else { 105 | return n, err 106 | } 107 | } 108 | } 109 | 110 | // tryRead attempts to read data from the stream without blocking. 111 | func (s *stream) tryRead(b []byte) (n int, err error) { 112 | if s.sess.config.Version == 2 { 113 | return s.tryReadv2(b) 114 | } 115 | 116 | if len(b) == 0 { 117 | return 0, nil 118 | } 119 | 120 | // A critical section to copy data from buffers to 121 | s.bufferLock.Lock() 122 | if len(s.buffers) > 0 { 123 | n = copy(b, *s.buffers[0]) 124 | s.buffers[0] = s.buffers[0] 125 | *s.buffers[0] = (*s.buffers[0])[n:] 126 | if len(*s.buffers[0]) == 0 { 127 | s.buffers[0] = nil 128 | s.buffers = s.buffers[1:] 129 | // full recycle 130 | defaultAllocator.Put(s.heads[0]) 131 | s.heads = s.heads[1:] 132 | } 133 | } 134 | s.bufferLock.Unlock() 135 | 136 | if n > 0 { 137 | s.sess.returnTokens(n) 138 | return n, nil 139 | } 140 | 141 | select { 142 | case <-s.die: 143 | return 0, io.EOF 144 | default: 145 | return 0, ErrWouldBlock 146 | } 147 | } 148 | 149 | // tryReadv2 is the non-blocking version of Read for version 2 streams. 150 | func (s *stream) tryReadv2(b []byte) (n int, err error) { 151 | if len(b) == 0 { 152 | return 0, nil 153 | } 154 | 155 | var notifyConsumed uint32 156 | s.bufferLock.Lock() 157 | if len(s.buffers) > 0 { 158 | n = copy(b, *s.buffers[0]) 159 | s.buffers[0] = s.buffers[0] 160 | *s.buffers[0] = (*s.buffers[0])[n:] 161 | if len(*s.buffers[0]) == 0 { 162 | s.buffers[0] = nil 163 | s.buffers = s.buffers[1:] 164 | // full recycle 165 | defaultAllocator.Put(s.heads[0]) 166 | s.heads = s.heads[1:] 167 | } 168 | } 169 | 170 | // in an ideal environment: 171 | // if more than half of buffer has consumed, send read ack to peer 172 | // based on round-trip time of ACK, continous flowing data 173 | // won't slow down due to waiting for ACK, as long as the 174 | // consumer keeps on reading data. 175 | // 176 | // s.numRead == n implies that it's the initial reading 177 | s.numRead += uint32(n) 178 | s.incr += uint32(n) 179 | 180 | // for initial reading, send window update 181 | if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) { 182 | notifyConsumed = s.numRead 183 | s.incr = 0 // reset couting for next window update 184 | } 185 | s.bufferLock.Unlock() 186 | 187 | if n > 0 { 188 | s.sess.returnTokens(n) 189 | 190 | // send window update if necessary 191 | if notifyConsumed > 0 { 192 | err := s.sendWindowUpdate(notifyConsumed) 193 | return n, err 194 | } else { 195 | return n, nil 196 | } 197 | } 198 | 199 | select { 200 | case <-s.die: 201 | return 0, io.EOF 202 | default: 203 | return 0, ErrWouldBlock 204 | } 205 | } 206 | 207 | // WriteTo implements io.WriteTo 208 | // WriteTo writes data to w until there's no more data to write or when an error occurs. 209 | // The return value n is the number of bytes written. Any error encountered during the write is also returned. 210 | // WriteTo calls Write in a loop until there is no more data to write or when an error occurs. 211 | // If the underlying stream is a v2 stream, it will send window update to peer when necessary. 212 | // If the underlying stream is a v1 stream, it will not send window update to peer. 213 | func (s *stream) WriteTo(w io.Writer) (n int64, err error) { 214 | if s.sess.config.Version == 2 { 215 | return s.writeTov2(w) 216 | } 217 | 218 | for { 219 | var pbuf *[]byte 220 | s.bufferLock.Lock() 221 | if len(s.buffers) > 0 { 222 | pbuf = s.buffers[0] 223 | s.buffers = s.buffers[1:] 224 | s.heads = s.heads[1:] 225 | } 226 | s.bufferLock.Unlock() 227 | 228 | if pbuf != nil { 229 | nw, ew := w.Write(*pbuf) 230 | // NOTE: WriteTo is a reader, so we need to return tokens here 231 | s.sess.returnTokens(len(*pbuf)) 232 | defaultAllocator.Put(pbuf) 233 | if nw > 0 { 234 | n += int64(nw) 235 | } 236 | 237 | if ew != nil { 238 | return n, ew 239 | } 240 | } else if ew := s.waitRead(); ew != nil { 241 | return n, ew 242 | } 243 | } 244 | } 245 | 246 | // check comments in WriteTo 247 | func (s *stream) writeTov2(w io.Writer) (n int64, err error) { 248 | for { 249 | var notifyConsumed uint32 250 | var pbuf *[]byte 251 | s.bufferLock.Lock() 252 | if len(s.buffers) > 0 { 253 | pbuf = s.buffers[0] 254 | s.buffers = s.buffers[1:] 255 | s.heads = s.heads[1:] 256 | } 257 | var bufLen uint32 258 | if pbuf != nil { 259 | bufLen = uint32(len(*pbuf)) 260 | } 261 | s.numRead += bufLen 262 | s.incr += bufLen 263 | if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == bufLen { 264 | notifyConsumed = s.numRead 265 | s.incr = 0 266 | } 267 | s.bufferLock.Unlock() 268 | 269 | if pbuf != nil { 270 | nw, ew := w.Write(*pbuf) 271 | // NOTE: WriteTo is a reader, so we need to return tokens here 272 | s.sess.returnTokens(len(*pbuf)) 273 | defaultAllocator.Put(pbuf) 274 | if nw > 0 { 275 | n += int64(nw) 276 | } 277 | 278 | if ew != nil { 279 | return n, ew 280 | } 281 | 282 | if notifyConsumed > 0 { 283 | if err := s.sendWindowUpdate(notifyConsumed); err != nil { 284 | return n, err 285 | } 286 | } 287 | } else if ew := s.waitRead(); ew != nil { 288 | return n, ew 289 | } 290 | } 291 | } 292 | 293 | // sendWindowUpdate sends a window update frame to the peer. 294 | func (s *stream) sendWindowUpdate(consumed uint32) error { 295 | var timer *time.Timer 296 | var deadline <-chan time.Time 297 | if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { 298 | timer = time.NewTimer(time.Until(d)) 299 | defer timer.Stop() 300 | deadline = timer.C 301 | } 302 | 303 | frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id) 304 | var hdr updHeader 305 | binary.LittleEndian.PutUint32(hdr[:], consumed) 306 | binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer)) 307 | frame.data = hdr[:] 308 | _, err := s.sess.writeFrameInternal(frame, deadline, CLSCTRL) 309 | return err 310 | } 311 | 312 | // waitRead blocks until a read event occurs or a deadline is reached. 313 | func (s *stream) waitRead() error { 314 | var timer *time.Timer 315 | var deadline <-chan time.Time 316 | if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { 317 | timer = time.NewTimer(time.Until(d)) 318 | defer timer.Stop() 319 | deadline = timer.C 320 | } 321 | 322 | select { 323 | case <-s.chReadEvent: // notify some data has arrived, or closed 324 | return nil 325 | case <-s.chFinEvent: 326 | // BUGFIX(xtaci): Fix for https://github.com/xtaci/smux/issues/82 327 | s.bufferLock.Lock() 328 | defer s.bufferLock.Unlock() 329 | if len(s.buffers) > 0 { 330 | return nil 331 | } 332 | return io.EOF 333 | case <-s.sess.chSocketReadError: 334 | return s.sess.socketReadError.Load().(error) 335 | case <-s.sess.chProtoError: 336 | return s.sess.protoError.Load().(error) 337 | case <-deadline: 338 | return ErrTimeout 339 | case <-s.die: 340 | return io.ErrClosedPipe 341 | } 342 | 343 | } 344 | 345 | // Write implements net.Conn 346 | // 347 | // Note that the behavior when multiple goroutines write concurrently is not deterministic, 348 | // frames may interleave in random way. 349 | func (s *stream) Write(b []byte) (n int, err error) { 350 | if s.sess.config.Version == 2 { 351 | return s.writeV2(b) 352 | } 353 | 354 | var deadline <-chan time.Time 355 | if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { 356 | timer := time.NewTimer(time.Until(d)) 357 | defer timer.Stop() 358 | deadline = timer.C 359 | } 360 | 361 | // check if stream has closed 362 | select { 363 | case <-s.chFinEvent: // passive closing 364 | return 0, io.EOF 365 | case <-s.die: 366 | return 0, io.ErrClosedPipe 367 | default: 368 | } 369 | 370 | // frame split and transmit 371 | sent := 0 372 | frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) 373 | bts := b 374 | for len(bts) > 0 { 375 | sz := len(bts) 376 | if sz > s.frameSize { 377 | sz = s.frameSize 378 | } 379 | frame.data = bts[:sz] 380 | bts = bts[sz:] 381 | n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) 382 | s.numWritten++ 383 | sent += n 384 | if err != nil { 385 | return sent, err 386 | } 387 | } 388 | 389 | return sent, nil 390 | } 391 | 392 | // writeV2 writes data to the stream for version 2 streams. 393 | func (s *stream) writeV2(b []byte) (n int, err error) { 394 | // check empty input 395 | if len(b) == 0 { 396 | return 0, nil 397 | } 398 | 399 | // check if stream has closed 400 | select { 401 | case <-s.chFinEvent: 402 | return 0, io.EOF 403 | case <-s.die: 404 | return 0, io.ErrClosedPipe 405 | default: 406 | } 407 | 408 | // create write deadline timer 409 | var deadline <-chan time.Time 410 | if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { 411 | timer := time.NewTimer(time.Until(d)) 412 | defer timer.Stop() 413 | deadline = timer.C 414 | } 415 | 416 | // frame split and transmit process 417 | sent := 0 418 | frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) 419 | 420 | for { 421 | // per stream sliding window control 422 | // [.... [consumed... numWritten] ... win... ] 423 | // [.... [consumed...................+rmtwnd]] 424 | var bts []byte 425 | // note: 426 | // even if uint32 overflow, this math still works: 427 | // eg1: uint32(0) - uint32(math.MaxUint32) = 1 428 | // eg2: int32(uint32(0) - uint32(1)) = -1 429 | // 430 | // basicially, you can take it as a MODULAR ARITHMETIC 431 | inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed)) 432 | if inflight < 0 { // security check for malformed data 433 | return 0, ErrConsumed 434 | } 435 | 436 | // make sure you understand 'win' is calculated in modular arithmetic(2^32(4GB)) 437 | win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight 438 | 439 | if win > 0 { 440 | // determine how many bytes to send 441 | if win > int32(len(b)) { 442 | bts = b 443 | b = nil 444 | } else { 445 | bts = b[:win] 446 | b = b[win:] 447 | } 448 | 449 | // frame split and transmit 450 | for len(bts) > 0 { 451 | // splitting frame 452 | sz := len(bts) 453 | if sz > s.frameSize { 454 | sz = s.frameSize 455 | } 456 | frame.data = bts[:sz] 457 | bts = bts[sz:] 458 | 459 | // transmit of frame 460 | n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA) 461 | atomic.AddUint32(&s.numWritten, uint32(sz)) 462 | sent += n 463 | if err != nil { 464 | return sent, err 465 | } 466 | } 467 | } 468 | 469 | // if there is any data left to be sent, 470 | // wait until stream closes, window changes or deadline reached 471 | // this blocking behavior will back propagate flow control to upper layer. 472 | if len(b) > 0 { 473 | select { 474 | case <-s.chFinEvent: 475 | return 0, io.EOF 476 | case <-s.die: 477 | return sent, io.ErrClosedPipe 478 | case <-deadline: 479 | return sent, ErrTimeout 480 | case <-s.sess.chSocketWriteError: 481 | return sent, s.sess.socketWriteError.Load().(error) 482 | case <-s.chUpdate: // notify of remote data consuming and window update 483 | continue 484 | } 485 | } else { 486 | return sent, nil 487 | } 488 | } 489 | } 490 | 491 | // Close implements net.Conn 492 | func (s *stream) Close() error { 493 | var once bool 494 | var err error 495 | s.dieOnce.Do(func() { 496 | close(s.die) 497 | once = true 498 | }) 499 | 500 | if once { 501 | // send FIN in order 502 | f := newFrame(byte(s.sess.config.Version), cmdFIN, s.id) 503 | 504 | timer := time.NewTimer(openCloseTimeout) 505 | defer timer.Stop() 506 | 507 | _, err = s.sess.writeFrameInternal(f, timer.C, CLSDATA) 508 | s.sess.streamClosed(s.id) 509 | return err 510 | } else { 511 | return io.ErrClosedPipe 512 | } 513 | } 514 | 515 | // GetDieCh returns a readonly chan which can be readable 516 | // when the stream is to be closed. 517 | func (s *stream) GetDieCh() <-chan struct{} { 518 | return s.die 519 | } 520 | 521 | // SetReadDeadline sets the read deadline as defined by 522 | // net.Conn.SetReadDeadline. 523 | // A zero time value disables the deadline. 524 | func (s *stream) SetReadDeadline(t time.Time) error { 525 | s.readDeadline.Store(t) 526 | s.notifyReadEvent() 527 | return nil 528 | } 529 | 530 | // SetWriteDeadline sets the write deadline as defined by 531 | // net.Conn.SetWriteDeadline. 532 | // A zero time value disables the deadline. 533 | func (s *stream) SetWriteDeadline(t time.Time) error { 534 | s.writeDeadline.Store(t) 535 | return nil 536 | } 537 | 538 | // SetDeadline sets both read and write deadlines as defined by 539 | // net.Conn.SetDeadline. 540 | // A zero time value disables the deadlines. 541 | func (s *stream) SetDeadline(t time.Time) error { 542 | if err := s.SetReadDeadline(t); err != nil { 543 | return err 544 | } 545 | if err := s.SetWriteDeadline(t); err != nil { 546 | return err 547 | } 548 | return nil 549 | } 550 | 551 | // session closes 552 | func (s *stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) } 553 | 554 | // LocalAddr satisfies net.Conn interface 555 | func (s *stream) LocalAddr() net.Addr { 556 | if ts, ok := s.sess.conn.(interface { 557 | LocalAddr() net.Addr 558 | }); ok { 559 | return ts.LocalAddr() 560 | } 561 | return nil 562 | } 563 | 564 | // RemoteAddr satisfies net.Conn interface 565 | func (s *stream) RemoteAddr() net.Addr { 566 | if ts, ok := s.sess.conn.(interface { 567 | RemoteAddr() net.Addr 568 | }); ok { 569 | return ts.RemoteAddr() 570 | } 571 | return nil 572 | } 573 | 574 | // pushBytes append buf to buffers 575 | func (s *stream) pushBytes(pbuf *[]byte) (written int, err error) { 576 | s.bufferLock.Lock() 577 | s.buffers = append(s.buffers, pbuf) 578 | s.heads = append(s.heads, pbuf) 579 | s.bufferLock.Unlock() 580 | return 581 | } 582 | 583 | // recycleTokens transform remaining bytes to tokens(will truncate buffer) 584 | func (s *stream) recycleTokens() (n int) { 585 | s.bufferLock.Lock() 586 | for k := range s.buffers { 587 | n += len(*s.buffers[k]) 588 | defaultAllocator.Put(s.heads[k]) 589 | } 590 | s.buffers = nil 591 | s.heads = nil 592 | s.bufferLock.Unlock() 593 | return 594 | } 595 | 596 | // notify read event 597 | func (s *stream) notifyReadEvent() { 598 | select { 599 | case s.chReadEvent <- struct{}{}: 600 | default: 601 | } 602 | } 603 | 604 | // update command 605 | func (s *stream) update(consumed uint32, window uint32) { 606 | atomic.StoreUint32(&s.peerConsumed, consumed) 607 | atomic.StoreUint32(&s.peerWindow, window) 608 | select { 609 | case s.chUpdate <- struct{}{}: 610 | default: 611 | } 612 | } 613 | 614 | // mark this stream has been closed in protocol 615 | func (s *stream) fin() { 616 | s.finEventOnce.Do(func() { 617 | close(s.chFinEvent) 618 | }) 619 | } 620 | --------------------------------------------------------------------------------