├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── alloc.go
├── alloc_test.go
├── assets
├── curve.jpg
├── mux.jpg
└── smux.png
├── frame.go
├── go.mod
├── go.sum
├── mux.go
├── mux_test.go
├── pkg.go
├── session.go
├── session_test.go
├── shaper.go
├── shaper_test.go
└── stream.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files, Static and Dynamic libs (Shared Objects)
2 | *.o
3 | *.a
4 | *.so
5 |
6 | # Folders
7 | _obj
8 | _test
9 |
10 | # Architecture specific extensions/prefixes
11 | *.[568vq]
12 | [568vq].out
13 |
14 | *.cgo1.go
15 | *.cgo2.c
16 | _cgo_defun.c
17 | _cgo_gotypes.go
18 | _cgo_export.*
19 |
20 | _testmain.go
21 |
22 | *.exe
23 | *.test
24 | *.prof
25 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | arch:
2 | - amd64
3 | - ppc64le
4 | language: go
5 | go:
6 | - 1.9.x
7 | - 1.10.x
8 | - 1.11.x
9 |
10 | before_install:
11 | - go get -t -v ./...
12 |
13 | install:
14 | - go get github.com/xtaci/smux
15 |
16 | script:
17 | - go test -coverprofile=coverage.txt -covermode=atomic -bench .
18 |
19 | after_success:
20 | - bash <(curl -s https://codecov.io/bash)
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2016-2017 xtaci
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [![GoDoc][1]][2] [![MIT licensed][3]][4] [![Build Status][5]][6] [![Go Report Card][7]][8] [![Coverage Statusd][9]][10] [![Sourcegraph][11]][12]
4 |
5 |
6 |
7 | [1]: https://godoc.org/github.com/xtaci/smux?status.svg
8 | [2]: https://godoc.org/github.com/xtaci/smux
9 | [3]: https://img.shields.io/badge/license-MIT-blue.svg
10 | [4]: LICENSE
11 | [5]: https://img.shields.io/github/created-at/xtaci/smux
12 | [6]: https://img.shields.io/github/created-at/xtaci/smux
13 | [7]: https://goreportcard.com/badge/github.com/xtaci/smux
14 | [8]: https://goreportcard.com/report/github.com/xtaci/smux
15 | [9]: https://codecov.io/gh/xtaci/smux/branch/master/graph/badge.svg
16 | [10]: https://codecov.io/gh/xtaci/smux
17 | [11]: https://sourcegraph.com/github.com/xtaci/smux/-/badge.svg
18 | [12]: https://sourcegraph.com/github.com/xtaci/smux?badge
19 |
20 | ## Introduction
21 |
22 | Smux ( **S**imple **MU**ltiple**X**ing) is a multiplexing library for Golang. It relies on an underlying connection to provide reliability and ordering, such as TCP or [KCP](https://github.com/xtaci/kcp-go), and provides stream-oriented multiplexing. The original intention of this library is to power the connection management for [kcp-go](https://github.com/xtaci/kcp-go).
23 |
24 | ## Features
25 |
26 | 1. ***Token bucket*** controlled receiving, which provides smoother bandwidth graph(see picture below).
27 | 2. Session-wide receive buffer, shared among streams, **fully controlled** overall memory usage.
28 | 3. Minimized header(8Bytes), maximized payload.
29 | 4. Well-tested on millions of devices in [kcptun](https://github.com/xtaci/kcptun).
30 | 5. Builtin fair queue traffic shaping.
31 | 6. Per-stream sliding window to control congestion.(protocol version 2+).
32 |
33 | 
34 |
35 | ## Documentation
36 |
37 | For complete documentation, see the associated [Godoc](https://godoc.org/github.com/xtaci/smux).
38 |
39 | ## Benchmark
40 | ```
41 | $ go test -v -run=^$ -bench .
42 | goos: darwin
43 | goarch: amd64
44 | pkg: github.com/xtaci/smux
45 | BenchmarkMSB-4 30000000 51.8 ns/op
46 | BenchmarkAcceptClose-4 50000 36783 ns/op
47 | BenchmarkConnSmux-4 30000 58335 ns/op 2246.88 MB/s 1208 B/op 19 allocs/op
48 | BenchmarkConnTCP-4 50000 25579 ns/op 5124.04 MB/s 0 B/op 0 allocs/op
49 | PASS
50 | ok github.com/xtaci/smux 7.811s
51 | ```
52 |
53 | ## Specification
54 |
55 | ```
56 | VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH)
57 |
58 | VALUES FOR LATEST VERSION:
59 | VERSION:
60 | 1/2
61 |
62 | CMD:
63 | cmdSYN(0)
64 | cmdFIN(1)
65 | cmdPSH(2)
66 | cmdNOP(3)
67 | cmdUPD(4) // only supported on version 2
68 |
69 | STREAMID:
70 | client use odd numbers starts from 1
71 | server use even numbers starts from 0
72 |
73 | cmdUPD:
74 | | CONSUMED(4B) | WINDOW(4B) |
75 | ```
76 |
77 | ## Usage
78 |
79 | ```go
80 |
81 | func client() {
82 | // Get a TCP connection
83 | conn, err := net.Dial(...)
84 | if err != nil {
85 | panic(err)
86 | }
87 |
88 | // Setup client side of smux
89 | session, err := smux.Client(conn, nil)
90 | if err != nil {
91 | panic(err)
92 | }
93 |
94 | // Open a new stream
95 | stream, err := session.OpenStream()
96 | if err != nil {
97 | panic(err)
98 | }
99 |
100 | // Stream implements io.ReadWriteCloser
101 | stream.Write([]byte("ping"))
102 | stream.Close()
103 | session.Close()
104 | }
105 |
106 | func server() {
107 | // Accept a TCP connection
108 | conn, err := listener.Accept()
109 | if err != nil {
110 | panic(err)
111 | }
112 |
113 | // Setup server side of smux
114 | session, err := smux.Server(conn, nil)
115 | if err != nil {
116 | panic(err)
117 | }
118 |
119 | // Accept a stream
120 | stream, err := session.AcceptStream()
121 | if err != nil {
122 | panic(err)
123 | }
124 |
125 | // Listen for a message
126 | buf := make([]byte, 4)
127 | stream.Read(buf)
128 | stream.Close()
129 | session.Close()
130 | }
131 |
132 | ```
133 |
134 | ## Status
135 |
136 | Stable
137 |
--------------------------------------------------------------------------------
/alloc.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "errors"
27 | "sync"
28 | )
29 |
30 | var (
31 | defaultAllocator *Allocator
32 | debruijinPos = [...]byte{0, 9, 1, 10, 13, 21, 2, 29, 11, 14, 16, 18, 22, 25, 3, 30, 8, 12, 20, 28, 15, 17, 24, 7, 19, 27, 23, 6, 26, 5, 4, 31}
33 | )
34 |
35 | func init() {
36 | defaultAllocator = NewAllocator()
37 | }
38 |
39 | // Allocator for incoming frames, optimized to prevent overwriting after zeroing
40 | type Allocator struct {
41 | buffers []sync.Pool
42 | }
43 |
44 | // NewAllocator initiates a []byte allocator for frames less than 65536 bytes,
45 | // the waste(memory fragmentation) of space allocation is guaranteed to be
46 | // no more than 50%.
47 | func NewAllocator() *Allocator {
48 | alloc := new(Allocator)
49 | alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
50 | for k := range alloc.buffers {
51 | i := k
52 | alloc.buffers[k].New = func() interface{} {
53 | b := make([]byte, 1< 65536 {
63 | return nil
64 | }
65 |
66 | bits := msb(size)
67 | if size == 1< 65536 || cap(*p) != 1<> 1
96 | v |= v >> 2
97 | v |= v >> 4
98 | v |= v >> 8
99 | v |= v >> 16
100 | return debruijinPos[(v*0x07C4ACDD)>>27]
101 | }
102 |
--------------------------------------------------------------------------------
/alloc_test.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "math/bits"
27 | "math/rand"
28 | "testing"
29 | )
30 |
31 | func TestAllocGet(t *testing.T) {
32 | alloc := NewAllocator()
33 | if alloc.Get(0) != nil {
34 | t.Fatal(0)
35 | }
36 | if len(*alloc.Get(1)) != 1 {
37 | t.Fatal(1)
38 | }
39 | if len(*alloc.Get(2)) != 2 {
40 | t.Fatal(2)
41 | }
42 | if len(*alloc.Get(3)) != 3 || cap(*alloc.Get(3)) != 4 {
43 | t.Fatal(3)
44 | }
45 | if len(*alloc.Get(4)) != 4 {
46 | t.Fatal(4)
47 | }
48 | if len(*alloc.Get(1023)) != 1023 || cap(*alloc.Get(1023)) != 1024 {
49 | t.Fatal(1023)
50 | }
51 | if len(*alloc.Get(1024)) != 1024 {
52 | t.Fatal(1024)
53 | }
54 | if len(*alloc.Get(65536)) != 65536 {
55 | t.Fatal(65536)
56 | }
57 | if alloc.Get(65537) != nil {
58 | t.Fatal(65537)
59 | }
60 | }
61 |
62 | func TestAllocPut(t *testing.T) {
63 | alloc := NewAllocator()
64 | if err := alloc.Put(nil); err == nil {
65 | t.Fatal("put nil misbehavior")
66 | }
67 | b := make([]byte, 3)
68 | if err := alloc.Put(&b); err == nil {
69 | t.Fatal("put elem:3 []bytes misbehavior")
70 | }
71 | b = make([]byte, 4)
72 | if err := alloc.Put(&b); err != nil {
73 | t.Fatal("put elem:4 []bytes misbehavior")
74 | }
75 | b = make([]byte, 1023, 1024)
76 | if err := alloc.Put(&b); err != nil {
77 | t.Fatal("put elem:1024 []bytes misbehavior")
78 | }
79 | b = make([]byte, 65536)
80 | if err := alloc.Put(&b); err != nil {
81 | t.Fatal("put elem:65536 []bytes misbehavior")
82 | }
83 | b = make([]byte, 65537)
84 | if err := alloc.Put(&b); err == nil {
85 | t.Fatal("put elem:65537 []bytes misbehavior")
86 | }
87 | }
88 |
89 | func TestAllocPutThenGet(t *testing.T) {
90 | alloc := NewAllocator()
91 | data := alloc.Get(4)
92 | alloc.Put(data)
93 | newData := alloc.Get(4)
94 | if cap(*data) != cap(*newData) {
95 | t.Fatal("different cap while alloc.Get()")
96 | }
97 | }
98 |
99 | func BenchmarkMSB(b *testing.B) {
100 | for i := 0; i < b.N; i++ {
101 | msb(rand.Int())
102 | }
103 | }
104 |
105 | func BenchmarkAlloc(b *testing.B) {
106 | for i := 0; i < b.N; i++ {
107 | pbuf := defaultAllocator.Get(i % 65536)
108 | defaultAllocator.Put(pbuf)
109 | }
110 | }
111 |
112 | func TestDebrujin(t *testing.T) {
113 | for i := 1; i < 65536; i++ {
114 | a := int(msb(i))
115 | b := bits.Len(uint(i))
116 | if a+1 != b {
117 | t.Fatal("debrujin")
118 | }
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/assets/curve.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/curve.jpg
--------------------------------------------------------------------------------
/assets/mux.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/mux.jpg
--------------------------------------------------------------------------------
/assets/smux.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/assets/smux.png
--------------------------------------------------------------------------------
/frame.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "encoding/binary"
27 | "fmt"
28 | )
29 |
30 | const ( // cmds
31 | // protocol version 1:
32 | cmdSYN byte = iota // stream open
33 | cmdFIN // stream close, a.k.a EOF mark
34 | cmdPSH // data push
35 | cmdNOP // no operation
36 |
37 | // protocol version 2 extra commands
38 | // notify bytes consumed by remote peer-end
39 | cmdUPD
40 | )
41 |
42 | const (
43 | // data size of cmdUPD, format:
44 | // |4B data consumed(ACK)| 4B window size(WINDOW) |
45 | szCmdUPD = 8
46 | )
47 |
48 | const (
49 | // initial peer window guess, a slow-start
50 | initialPeerWindow = 262144
51 | )
52 |
53 | const (
54 | sizeOfVer = 1
55 | sizeOfCmd = 1
56 | sizeOfLength = 2
57 | sizeOfSid = 4
58 | headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength
59 | )
60 |
61 | // Frame defines a packet from or to be multiplexed into a single connection
62 | type Frame struct {
63 | ver byte // version
64 | cmd byte // command
65 | sid uint32 // stream id
66 | data []byte // payload
67 | }
68 |
69 | // newFrame creates a new frame with given version, command and stream id
70 | func newFrame(version byte, cmd byte, sid uint32) Frame {
71 | return Frame{ver: version, cmd: cmd, sid: sid}
72 | }
73 |
74 | // rawHeader is a byte array representation of Frame header
75 | type rawHeader [headerSize]byte
76 |
77 | func (h rawHeader) Version() byte {
78 | return h[0]
79 | }
80 |
81 | func (h rawHeader) Cmd() byte {
82 | return h[1]
83 | }
84 |
85 | func (h rawHeader) Length() uint16 {
86 | return binary.LittleEndian.Uint16(h[2:])
87 | }
88 |
89 | func (h rawHeader) StreamID() uint32 {
90 | return binary.LittleEndian.Uint32(h[4:])
91 | }
92 |
93 | func (h rawHeader) String() string {
94 | return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d",
95 | h.Version(), h.Cmd(), h.StreamID(), h.Length())
96 | }
97 |
98 | // updHeader is a byte array representation of cmdUPD
99 | type updHeader [szCmdUPD]byte
100 |
101 | func (h updHeader) Consumed() uint32 {
102 | return binary.LittleEndian.Uint32(h[:])
103 | }
104 | func (h updHeader) Window() uint32 {
105 | return binary.LittleEndian.Uint32(h[4:])
106 | }
107 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xtaci/smux
2 |
3 | go 1.13
4 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xtaci/smux/e6b0586a4539b2fa60cf4c5bd3e4883c83ce4f9f/go.sum
--------------------------------------------------------------------------------
/mux.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "errors"
27 | "fmt"
28 | "io"
29 | "math"
30 | "time"
31 | )
32 |
33 | // Config is used to tune the Smux session
34 | type Config struct {
35 | // SMUX Protocol version, support 1,2
36 | Version int
37 |
38 | // Disabled keepalive
39 | KeepAliveDisabled bool
40 |
41 | // KeepAliveInterval is how often to send a NOP command to the remote
42 | KeepAliveInterval time.Duration
43 |
44 | // KeepAliveTimeout is how long the session
45 | // will be closed if no data has arrived
46 | KeepAliveTimeout time.Duration
47 |
48 | // MaxFrameSize is used to control the maximum
49 | // frame size to sent to the remote
50 | MaxFrameSize int
51 |
52 | // MaxReceiveBuffer is used to control the maximum
53 | // number of data in the buffer pool
54 | MaxReceiveBuffer int
55 |
56 | // MaxStreamBuffer is used to control the maximum
57 | // number of data per stream
58 | MaxStreamBuffer int
59 | }
60 |
61 | // DefaultConfig is used to return a default configuration
62 | func DefaultConfig() *Config {
63 | return &Config{
64 | Version: 1,
65 | KeepAliveInterval: 10 * time.Second,
66 | KeepAliveTimeout: 30 * time.Second,
67 | MaxFrameSize: 32768,
68 | MaxReceiveBuffer: 4194304,
69 | MaxStreamBuffer: 65536,
70 | }
71 | }
72 |
73 | // VerifyConfig is used to verify the sanity of configuration
74 | func VerifyConfig(config *Config) error {
75 | if !(config.Version == 1 || config.Version == 2) {
76 | return errors.New("unsupported protocol version")
77 | }
78 | if !config.KeepAliveDisabled {
79 | if config.KeepAliveInterval == 0 {
80 | return errors.New("keep-alive interval must be positive")
81 | }
82 | if config.KeepAliveTimeout < config.KeepAliveInterval {
83 | return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval")
84 | }
85 | }
86 | if config.MaxFrameSize <= 0 {
87 | return errors.New("max frame size must be positive")
88 | }
89 | if config.MaxFrameSize > 65535 {
90 | return errors.New("max frame size must not be larger than 65535")
91 | }
92 | if config.MaxReceiveBuffer <= 0 {
93 | return errors.New("max receive buffer must be positive")
94 | }
95 | if config.MaxStreamBuffer <= 0 {
96 | return errors.New("max stream buffer must be positive")
97 | }
98 | if config.MaxStreamBuffer > config.MaxReceiveBuffer {
99 | return errors.New("max stream buffer must not be larger than max receive buffer")
100 | }
101 | if config.MaxStreamBuffer > math.MaxInt32 {
102 | return errors.New("max stream buffer cannot be larger than 2147483647")
103 | }
104 | return nil
105 | }
106 |
107 | // Server is used to initialize a new server-side connection.
108 | func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
109 | if config == nil {
110 | config = DefaultConfig()
111 | }
112 | if err := VerifyConfig(config); err != nil {
113 | return nil, err
114 | }
115 | return newSession(config, conn, false), nil
116 | }
117 |
118 | // Client is used to initialize a new client-side connection.
119 | func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
120 | if config == nil {
121 | config = DefaultConfig()
122 | }
123 |
124 | if err := VerifyConfig(config); err != nil {
125 | return nil, err
126 | }
127 | return newSession(config, conn, true), nil
128 | }
129 |
--------------------------------------------------------------------------------
/mux_test.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "bytes"
27 | "testing"
28 | )
29 |
30 | type buffer struct {
31 | bytes.Buffer
32 | }
33 |
34 | func (b *buffer) Close() error {
35 | b.Buffer.Reset()
36 | return nil
37 | }
38 |
39 | func TestConfig(t *testing.T) {
40 | VerifyConfig(DefaultConfig())
41 |
42 | config := DefaultConfig()
43 | config.KeepAliveInterval = 0
44 | err := VerifyConfig(config)
45 | t.Log(err)
46 | if err == nil {
47 | t.Fatal(err)
48 | }
49 |
50 | config = DefaultConfig()
51 | config.KeepAliveInterval = 10
52 | config.KeepAliveTimeout = 5
53 | err = VerifyConfig(config)
54 | t.Log(err)
55 | if err == nil {
56 | t.Fatal(err)
57 | }
58 |
59 | config = DefaultConfig()
60 | config.MaxFrameSize = 0
61 | err = VerifyConfig(config)
62 | t.Log(err)
63 | if err == nil {
64 | t.Fatal(err)
65 | }
66 |
67 | config = DefaultConfig()
68 | config.MaxFrameSize = 65536
69 | err = VerifyConfig(config)
70 | t.Log(err)
71 | if err == nil {
72 | t.Fatal(err)
73 | }
74 |
75 | config = DefaultConfig()
76 | config.MaxReceiveBuffer = 0
77 | err = VerifyConfig(config)
78 | t.Log(err)
79 | if err == nil {
80 | t.Fatal(err)
81 | }
82 |
83 | config = DefaultConfig()
84 | config.MaxStreamBuffer = 0
85 | err = VerifyConfig(config)
86 | t.Log(err)
87 | if err == nil {
88 | t.Fatal(err)
89 | }
90 |
91 | config = DefaultConfig()
92 | config.MaxStreamBuffer = 100
93 | config.MaxReceiveBuffer = 99
94 | err = VerifyConfig(config)
95 | t.Log(err)
96 | if err == nil {
97 | t.Fatal(err)
98 | }
99 |
100 | var bts buffer
101 | if _, err := Server(&bts, config); err == nil {
102 | t.Fatal("server started with wrong config")
103 | }
104 |
105 | if _, err := Client(&bts, config); err == nil {
106 | t.Fatal("client started with wrong config")
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/pkg.go:
--------------------------------------------------------------------------------
1 | // Package smux is a multiplexing library for Golang.
2 | //
3 | // It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP,
4 | // and provides stream-oriented multiplexing over a single channel.
5 |
6 | package smux
7 |
--------------------------------------------------------------------------------
/session.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "container/heap"
27 | "encoding/binary"
28 | "errors"
29 | "io"
30 | "net"
31 | "runtime"
32 | "sync"
33 | "sync/atomic"
34 | "time"
35 | )
36 |
37 | const (
38 | defaultAcceptBacklog = 1024
39 | maxShaperSize = 1024
40 | openCloseTimeout = 30 * time.Second // Timeout for opening/closing streams
41 | )
42 |
43 | // CLASSID represents the class of a frame
44 | type CLASSID int
45 |
46 | const (
47 | CLSCTRL CLASSID = iota // prioritized control signal
48 | CLSDATA
49 | )
50 |
51 | // timeoutError representing timeouts for operations such as accept, read and write
52 | //
53 | // To better cooperate with the standard library, timeoutError should implement the standard library's `net.Error`.
54 | //
55 | // For example, using smux to implement net.Listener and work with http.Server, the keep-alive connection (*smux.Stream) will be unexpectedly closed.
56 | // For more details, see https://github.com/xtaci/smux/pull/99.
57 | type timeoutError struct{}
58 |
59 | func (timeoutError) Error() string { return "timeout" }
60 | func (timeoutError) Temporary() bool { return true }
61 | func (timeoutError) Timeout() bool { return true }
62 |
63 | var (
64 | ErrInvalidProtocol = errors.New("invalid protocol")
65 | ErrConsumed = errors.New("peer consumed more than sent")
66 | ErrGoAway = errors.New("stream id overflows, should start a new connection")
67 | ErrTimeout net.Error = &timeoutError{}
68 | ErrWouldBlock = errors.New("operation would block on IO")
69 | )
70 |
71 | // writeRequest represents a request to write a frame
72 | type writeRequest struct {
73 | class CLASSID
74 | frame Frame
75 | seq uint32
76 | result chan writeResult
77 | }
78 |
79 | // writeResult represents the result of a write request
80 | type writeResult struct {
81 | n int
82 | err error
83 | }
84 |
85 | // Session defines a multiplexed connection for streams
86 | type Session struct {
87 | conn io.ReadWriteCloser
88 |
89 | config *Config
90 | nextStreamID uint32 // next stream identifier
91 | nextStreamIDLock sync.Mutex
92 |
93 | bucket int32 // token bucket
94 | bucketNotify chan struct{} // used for waiting for tokens
95 |
96 | streams map[uint32]*stream // all streams in this session
97 | streamLock sync.Mutex // locks streams
98 |
99 | die chan struct{} // flag session has died
100 | dieOnce sync.Once
101 |
102 | // socket error handling
103 | socketReadError atomic.Value
104 | socketWriteError atomic.Value
105 | chSocketReadError chan struct{}
106 | chSocketWriteError chan struct{}
107 | socketReadErrorOnce sync.Once
108 | socketWriteErrorOnce sync.Once
109 |
110 | // smux protocol errors
111 | protoError atomic.Value
112 | chProtoError chan struct{}
113 | protoErrorOnce sync.Once
114 |
115 | chAccepts chan *stream
116 |
117 | dataReady int32 // flag data has arrived
118 |
119 | goAway int32 // flag id exhausted
120 |
121 | deadline atomic.Value
122 |
123 | requestID uint32 // Monotonic increasing write request ID
124 | shaper chan writeRequest // a shaper for writing
125 | writes chan writeRequest
126 | }
127 |
128 | func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
129 | s := new(Session)
130 | s.die = make(chan struct{})
131 | s.conn = conn
132 | s.config = config
133 | s.streams = make(map[uint32]*stream)
134 | s.chAccepts = make(chan *stream, defaultAcceptBacklog)
135 | s.bucket = int32(config.MaxReceiveBuffer)
136 | s.bucketNotify = make(chan struct{}, 1)
137 | s.shaper = make(chan writeRequest)
138 | s.writes = make(chan writeRequest)
139 | s.chSocketReadError = make(chan struct{})
140 | s.chSocketWriteError = make(chan struct{})
141 | s.chProtoError = make(chan struct{})
142 |
143 | if client {
144 | s.nextStreamID = 1
145 | } else {
146 | s.nextStreamID = 0
147 | }
148 |
149 | go s.shaperLoop()
150 | go s.recvLoop()
151 | go s.sendLoop()
152 | if !config.KeepAliveDisabled {
153 | go s.keepalive()
154 | }
155 | return s
156 | }
157 |
158 | // OpenStream is used to create a new stream
159 | func (s *Session) OpenStream() (*Stream, error) {
160 | if s.IsClosed() {
161 | return nil, io.ErrClosedPipe
162 | }
163 |
164 | // generate stream id
165 | s.nextStreamIDLock.Lock()
166 | if s.goAway > 0 {
167 | s.nextStreamIDLock.Unlock()
168 | return nil, ErrGoAway
169 | }
170 |
171 | s.nextStreamID += 2
172 | sid := s.nextStreamID
173 | if sid == sid%2 { // stream-id overflows
174 | s.goAway = 1
175 | s.nextStreamIDLock.Unlock()
176 | return nil, ErrGoAway
177 | }
178 | s.nextStreamIDLock.Unlock()
179 |
180 | stream := newStream(sid, s.config.MaxFrameSize, s)
181 |
182 | if _, err := s.writeControlFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
183 | return nil, err
184 | }
185 |
186 | s.streamLock.Lock()
187 | defer s.streamLock.Unlock()
188 | select {
189 | case <-s.chSocketReadError:
190 | return nil, s.socketReadError.Load().(error)
191 | case <-s.chSocketWriteError:
192 | return nil, s.socketWriteError.Load().(error)
193 | case <-s.die:
194 | return nil, io.ErrClosedPipe
195 | default:
196 | s.streams[sid] = stream
197 | wrapper := &Stream{stream: stream}
198 | // NOTE(x): disabled finalizer for issue #997
199 | /*
200 | runtime.SetFinalizer(wrapper, func(s *Stream) {
201 | s.Close()
202 | })
203 | */
204 | return wrapper, nil
205 | }
206 | }
207 |
208 | // Open returns a generic ReadWriteCloser
209 | func (s *Session) Open() (io.ReadWriteCloser, error) {
210 | return s.OpenStream()
211 | }
212 |
213 | // AcceptStream is used to block until the next available stream
214 | // is ready to be accepted.
215 | func (s *Session) AcceptStream() (*Stream, error) {
216 | var deadline <-chan time.Time
217 | if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
218 | timer := time.NewTimer(time.Until(d))
219 | defer timer.Stop()
220 | deadline = timer.C
221 | }
222 |
223 | select {
224 | case stream := <-s.chAccepts:
225 | wrapper := &Stream{stream: stream}
226 | runtime.SetFinalizer(wrapper, func(s *Stream) {
227 | s.Close()
228 | })
229 | return wrapper, nil
230 | case <-deadline:
231 | return nil, ErrTimeout
232 | case <-s.chSocketReadError:
233 | return nil, s.socketReadError.Load().(error)
234 | case <-s.chProtoError:
235 | return nil, s.protoError.Load().(error)
236 | case <-s.die:
237 | return nil, io.ErrClosedPipe
238 | }
239 | }
240 |
241 | // Accept Returns a generic ReadWriteCloser instead of smux.Stream
242 | func (s *Session) Accept() (io.ReadWriteCloser, error) {
243 | return s.AcceptStream()
244 | }
245 |
246 | // Close is used to close the session and all streams.
247 | func (s *Session) Close() error {
248 | var once bool
249 | s.dieOnce.Do(func() {
250 | close(s.die)
251 | once = true
252 | })
253 |
254 | if once {
255 | s.streamLock.Lock()
256 | for k := range s.streams {
257 | s.streams[k].sessionClose()
258 | }
259 | s.streamLock.Unlock()
260 | return s.conn.Close()
261 | } else {
262 | return io.ErrClosedPipe
263 | }
264 | }
265 |
266 | // CloseChan can be used by someone who wants to be notified immediately when this
267 | // session is closed
268 | func (s *Session) CloseChan() <-chan struct{} {
269 | return s.die
270 | }
271 |
272 | // notifyBucket notifies recvLoop that bucket is available
273 | func (s *Session) notifyBucket() {
274 | select {
275 | case s.bucketNotify <- struct{}{}:
276 | default:
277 | }
278 | }
279 |
280 | func (s *Session) notifyReadError(err error) {
281 | s.socketReadErrorOnce.Do(func() {
282 | s.socketReadError.Store(err)
283 | close(s.chSocketReadError)
284 | })
285 | }
286 |
287 | func (s *Session) notifyWriteError(err error) {
288 | s.socketWriteErrorOnce.Do(func() {
289 | s.socketWriteError.Store(err)
290 | close(s.chSocketWriteError)
291 | })
292 | }
293 |
294 | func (s *Session) notifyProtoError(err error) {
295 | s.protoErrorOnce.Do(func() {
296 | s.protoError.Store(err)
297 | close(s.chProtoError)
298 | })
299 | }
300 |
301 | // IsClosed does a safe check to see if we have shutdown
302 | func (s *Session) IsClosed() bool {
303 | select {
304 | case <-s.die:
305 | return true
306 | default:
307 | return false
308 | }
309 | }
310 |
311 | // NumStreams returns the number of currently open streams
312 | func (s *Session) NumStreams() int {
313 | if s.IsClosed() {
314 | return 0
315 | }
316 | s.streamLock.Lock()
317 | defer s.streamLock.Unlock()
318 | return len(s.streams)
319 | }
320 |
321 | // SetDeadline sets a deadline used by Accept* calls.
322 | // A zero time value disables the deadline.
323 | func (s *Session) SetDeadline(t time.Time) error {
324 | s.deadline.Store(t)
325 | return nil
326 | }
327 |
328 | // LocalAddr satisfies net.Conn interface
329 | func (s *Session) LocalAddr() net.Addr {
330 | if ts, ok := s.conn.(interface {
331 | LocalAddr() net.Addr
332 | }); ok {
333 | return ts.LocalAddr()
334 | }
335 | return nil
336 | }
337 |
338 | // RemoteAddr satisfies net.Conn interface
339 | func (s *Session) RemoteAddr() net.Addr {
340 | if ts, ok := s.conn.(interface {
341 | RemoteAddr() net.Addr
342 | }); ok {
343 | return ts.RemoteAddr()
344 | }
345 | return nil
346 | }
347 |
348 | // notify the session that a stream has closed
349 | func (s *Session) streamClosed(sid uint32) {
350 | s.streamLock.Lock()
351 | if stream, ok := s.streams[sid]; ok {
352 | n := stream.recycleTokens()
353 | if n > 0 { // return remaining tokens to the bucket
354 | if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
355 | s.notifyBucket()
356 | }
357 | }
358 | delete(s.streams, sid)
359 | }
360 | s.streamLock.Unlock()
361 | }
362 |
363 | // returnTokens is called by stream to return token after read
364 | func (s *Session) returnTokens(n int) {
365 | if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
366 | s.notifyBucket()
367 | }
368 | }
369 |
370 | // recvLoop keeps on reading from underlying connection if tokens are available
371 | func (s *Session) recvLoop() {
372 | var hdr rawHeader
373 | var updHdr updHeader
374 |
375 | for {
376 | for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
377 | select {
378 | case <-s.bucketNotify:
379 | case <-s.die:
380 | return
381 | }
382 | }
383 |
384 | // read header first
385 | if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
386 | atomic.StoreInt32(&s.dataReady, 1)
387 | if hdr.Version() != byte(s.config.Version) {
388 | s.notifyProtoError(ErrInvalidProtocol)
389 | return
390 | }
391 | sid := hdr.StreamID()
392 | switch hdr.Cmd() {
393 | case cmdNOP:
394 | case cmdSYN: // stream opening
395 | s.streamLock.Lock()
396 | if _, ok := s.streams[sid]; !ok {
397 | stream := newStream(sid, s.config.MaxFrameSize, s)
398 | s.streams[sid] = stream
399 | select {
400 | case s.chAccepts <- stream:
401 | case <-s.die:
402 | }
403 | }
404 | s.streamLock.Unlock()
405 | case cmdFIN: // stream closing
406 | s.streamLock.Lock()
407 | if stream, ok := s.streams[sid]; ok {
408 | stream.fin()
409 | stream.notifyReadEvent()
410 | }
411 | s.streamLock.Unlock()
412 | case cmdPSH: // data frame
413 | if hdr.Length() > 0 {
414 | pNewbuf := defaultAllocator.Get(int(hdr.Length()))
415 | if written, err := io.ReadFull(s.conn, *pNewbuf); err == nil {
416 | s.streamLock.Lock()
417 | if stream, ok := s.streams[sid]; ok {
418 | stream.pushBytes(pNewbuf)
419 | // a stream used some token
420 | atomic.AddInt32(&s.bucket, -int32(written))
421 | stream.notifyReadEvent()
422 | } else {
423 | // data directed to a missing/closed stream, recycle the buffer immediately.
424 | defaultAllocator.Put(pNewbuf)
425 | }
426 | s.streamLock.Unlock()
427 | } else {
428 | s.notifyReadError(err)
429 | return
430 | }
431 | }
432 | case cmdUPD: // a window update signal
433 | if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
434 | s.streamLock.Lock()
435 | if stream, ok := s.streams[sid]; ok {
436 | stream.update(updHdr.Consumed(), updHdr.Window())
437 | }
438 | s.streamLock.Unlock()
439 | } else {
440 | s.notifyReadError(err)
441 | return
442 | }
443 | default:
444 | s.notifyProtoError(ErrInvalidProtocol)
445 | return
446 | }
447 | } else {
448 | s.notifyReadError(err)
449 | return
450 | }
451 | }
452 | }
453 |
454 | // keepalive sends NOP frame to peer to keep the connection alive, and detect dead peers
455 | func (s *Session) keepalive() {
456 | tickerPing := time.NewTicker(s.config.KeepAliveInterval)
457 | tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
458 | defer tickerPing.Stop()
459 | defer tickerTimeout.Stop()
460 | for {
461 | select {
462 | case <-tickerPing.C:
463 | s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, CLSCTRL)
464 | s.notifyBucket() // force a signal to the recvLoop
465 | case <-tickerTimeout.C:
466 | if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
467 | // recvLoop may block while bucket is 0, in this case,
468 | // session should not be closed.
469 | if atomic.LoadInt32(&s.bucket) > 0 {
470 | s.Close()
471 | return
472 | }
473 | }
474 | case <-s.die:
475 | return
476 | }
477 | }
478 | }
479 |
480 | // shaperLoop implements a priority queue for write requests,
481 | // some control messages are prioritized over data messages
482 | func (s *Session) shaperLoop() {
483 | var reqs shaperHeap
484 | var next writeRequest
485 | var chWrite chan writeRequest
486 | var chShaper chan writeRequest
487 |
488 | for {
489 | // chWrite is not available until it has packet to send
490 | if len(reqs) > 0 {
491 | chWrite = s.writes
492 | next = heap.Pop(&reqs).(writeRequest)
493 | } else {
494 | chWrite = nil
495 | }
496 |
497 | // control heap size, chShaper is not available until packets are less than maximum allowed
498 | if len(reqs) >= maxShaperSize {
499 | chShaper = nil
500 | } else {
501 | chShaper = s.shaper
502 | }
503 |
504 | // assertion on non nil
505 | if chShaper == nil && chWrite == nil {
506 | panic("both channel are nil")
507 | }
508 |
509 | select {
510 | case <-s.die:
511 | return
512 | case r := <-chShaper:
513 | if chWrite != nil { // next is valid, reshape
514 | heap.Push(&reqs, next)
515 | }
516 | heap.Push(&reqs, r)
517 | case chWrite <- next:
518 | }
519 | }
520 | }
521 |
522 | // sendLoop sends frames to the underlying connection
523 | func (s *Session) sendLoop() {
524 | var buf []byte
525 | var n int
526 | var err error
527 | var vec [][]byte // vector for writeBuffers
528 |
529 | bw, ok := s.conn.(interface {
530 | WriteBuffers(v [][]byte) (n int, err error)
531 | })
532 |
533 | if ok {
534 | buf = make([]byte, headerSize)
535 | vec = make([][]byte, 2)
536 | } else {
537 | buf = make([]byte, (1<<16)+headerSize)
538 | }
539 |
540 | for {
541 | select {
542 | case <-s.die:
543 | return
544 | case request := <-s.writes:
545 | buf[0] = request.frame.ver
546 | buf[1] = request.frame.cmd
547 | binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
548 | binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
549 |
550 | // support for scatter-gather I/O
551 | if len(vec) > 0 {
552 | vec[0] = buf[:headerSize]
553 | vec[1] = request.frame.data
554 | n, err = bw.WriteBuffers(vec)
555 | } else {
556 | copy(buf[headerSize:], request.frame.data)
557 | n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
558 | }
559 |
560 | n -= headerSize
561 | if n < 0 {
562 | n = 0
563 | }
564 |
565 | result := writeResult{
566 | n: n,
567 | err: err,
568 | }
569 |
570 | request.result <- result
571 | close(request.result)
572 |
573 | // store conn error
574 | if err != nil {
575 | s.notifyWriteError(err)
576 | return
577 | }
578 | }
579 | }
580 | }
581 |
582 | // writeControlFrame writes the control frame to the underlying connection
583 | // and returns the number of bytes written if successful
584 | func (s *Session) writeControlFrame(f Frame) (n int, err error) {
585 | timer := time.NewTimer(openCloseTimeout)
586 | defer timer.Stop()
587 |
588 | return s.writeFrameInternal(f, timer.C, CLSCTRL)
589 | }
590 |
591 | // internal writeFrame version to support deadline used in keepalive
592 | func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, class CLASSID) (int, error) {
593 | req := writeRequest{
594 | class: class,
595 | frame: f,
596 | seq: atomic.AddUint32(&s.requestID, 1),
597 | result: make(chan writeResult, 1),
598 | }
599 | select {
600 | case s.shaper <- req:
601 | case <-s.die:
602 | return 0, io.ErrClosedPipe
603 | case <-s.chSocketWriteError:
604 | return 0, s.socketWriteError.Load().(error)
605 | case <-deadline:
606 | return 0, ErrTimeout
607 | }
608 |
609 | select {
610 | case result := <-req.result:
611 | return result.n, result.err
612 | case <-s.die:
613 | return 0, io.ErrClosedPipe
614 | case <-s.chSocketWriteError:
615 | return 0, s.socketWriteError.Load().(error)
616 | case <-deadline:
617 | return 0, ErrTimeout
618 | }
619 | }
620 |
--------------------------------------------------------------------------------
/session_test.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "bytes"
27 | crand "crypto/rand"
28 | "encoding/binary"
29 | "fmt"
30 | "io"
31 | "log"
32 | "math/rand"
33 | "net"
34 | "net/http"
35 | _ "net/http/pprof"
36 | "strings"
37 | "sync"
38 | "testing"
39 | "time"
40 | )
41 |
42 | func init() {
43 | go func() {
44 | log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
45 | }()
46 | }
47 |
48 | // setupServer starts new server listening on a random localhost port and
49 | // returns address of the server, function to stop the server, new client
50 | // connection to this server or an error.
51 | func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
52 | ln, err := net.Listen("tcp", "localhost:0")
53 | if err != nil {
54 | return "", nil, nil, err
55 | }
56 | go func() {
57 | conn, err := ln.Accept()
58 | if err != nil {
59 | return
60 | }
61 | go handleConnection(conn)
62 | }()
63 | addr = ln.Addr().String()
64 | conn, err := net.Dial("tcp", addr)
65 | if err != nil {
66 | ln.Close()
67 | return "", nil, nil, err
68 | }
69 | return ln.Addr().String(), func() { ln.Close() }, conn, nil
70 | }
71 |
72 | func handleConnection(conn net.Conn) {
73 | session, _ := Server(conn, nil)
74 | for {
75 | if stream, err := session.AcceptStream(); err == nil {
76 | go func(s io.ReadWriteCloser) {
77 | buf := make([]byte, 65536)
78 | for {
79 | n, err := s.Read(buf)
80 | if err != nil {
81 | return
82 | }
83 | s.Write(buf[:n])
84 | }
85 | }(stream)
86 | } else {
87 | return
88 | }
89 | }
90 | }
91 |
92 | // setupServer starts new server listening on a random localhost port and
93 | // returns address of the server, function to stop the server, new client
94 | // connection to this server or an error.
95 | func setupServerV2(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
96 | ln, err := net.Listen("tcp", "localhost:0")
97 | if err != nil {
98 | return "", nil, nil, err
99 | }
100 | go func() {
101 | conn, err := ln.Accept()
102 | if err != nil {
103 | return
104 | }
105 | go handleConnectionV2(conn)
106 | }()
107 | addr = ln.Addr().String()
108 | conn, err := net.Dial("tcp", addr)
109 | if err != nil {
110 | ln.Close()
111 | return "", nil, nil, err
112 | }
113 | return ln.Addr().String(), func() { ln.Close() }, conn, nil
114 | }
115 |
116 | func handleConnectionV2(conn net.Conn) {
117 | config := DefaultConfig()
118 | config.Version = 2
119 | session, _ := Server(conn, config)
120 | for {
121 | if stream, err := session.AcceptStream(); err == nil {
122 | go func(s io.ReadWriteCloser) {
123 | buf := make([]byte, 65536)
124 | for {
125 | n, err := s.Read(buf)
126 | if err != nil {
127 | return
128 | }
129 | s.Write(buf[:n])
130 | }
131 | }(stream)
132 | } else {
133 | return
134 | }
135 | }
136 | }
137 |
138 | func TestEcho(t *testing.T) {
139 | _, stop, cli, err := setupServer(t)
140 | if err != nil {
141 | t.Fatal(err)
142 | }
143 | defer stop()
144 | session, _ := Client(cli, nil)
145 | stream, _ := session.OpenStream()
146 | const N = 100
147 | buf := make([]byte, 10)
148 | var sent string
149 | var received string
150 | for i := 0; i < N; i++ {
151 | msg := fmt.Sprintf("hello%v", i)
152 | stream.Write([]byte(msg))
153 | sent += msg
154 | if n, err := stream.Read(buf); err != nil {
155 | t.Fatal(err)
156 | } else {
157 | received += string(buf[:n])
158 | }
159 | }
160 | if sent != received {
161 | t.Fatal("data mimatch")
162 | }
163 | session.Close()
164 | }
165 |
166 | func TestWriteTo(t *testing.T) {
167 | const N = 1 << 20
168 | // server
169 | ln, err := net.Listen("tcp", "localhost:0")
170 | if err != nil {
171 | t.Fatal(err)
172 | }
173 | defer ln.Close()
174 |
175 | go func() {
176 | conn, err := ln.Accept()
177 | if err != nil {
178 | return
179 | }
180 | session, _ := Server(conn, nil)
181 | for {
182 | if stream, err := session.AcceptStream(); err == nil {
183 | go func(s io.ReadWriteCloser) {
184 | numBytes := 0
185 | buf := make([]byte, 65536)
186 | for {
187 | n, err := s.Read(buf)
188 | if err != nil {
189 | return
190 | }
191 | s.Write(buf[:n])
192 | numBytes += n
193 |
194 | if numBytes == N {
195 | s.Close()
196 | return
197 | }
198 | }
199 | }(stream)
200 | } else {
201 | return
202 | }
203 | }
204 | }()
205 |
206 | addr := ln.Addr().String()
207 | conn, err := net.Dial("tcp", addr)
208 | if err != nil {
209 | t.Fatal(err)
210 | }
211 | defer conn.Close()
212 |
213 | // client
214 | session, _ := Client(conn, nil)
215 | stream, _ := session.OpenStream()
216 | sndbuf := make([]byte, N)
217 | for i := range sndbuf {
218 | sndbuf[i] = byte(rand.Int())
219 | }
220 |
221 | go stream.Write(sndbuf)
222 |
223 | var rcvbuf bytes.Buffer
224 | nw, ew := stream.WriteTo(&rcvbuf)
225 | if ew != io.EOF {
226 | t.Fatal(ew)
227 | }
228 |
229 | if nw != N {
230 | t.Fatal("WriteTo nw mismatch", nw)
231 | }
232 |
233 | if !bytes.Equal(sndbuf, rcvbuf.Bytes()) {
234 | t.Fatal("mismatched echo bytes")
235 | }
236 | t.Log(stream)
237 | }
238 |
239 | func TestWriteToV2(t *testing.T) {
240 | config := DefaultConfig()
241 | config.Version = 2
242 | const N = 1 << 20
243 | // server
244 | ln, err := net.Listen("tcp", "localhost:0")
245 | if err != nil {
246 | t.Fatal(err)
247 | }
248 | defer ln.Close()
249 |
250 | go func() {
251 | conn, err := ln.Accept()
252 | if err != nil {
253 | return
254 | }
255 | session, _ := Server(conn, config)
256 | for {
257 | if stream, err := session.AcceptStream(); err == nil {
258 | go func(s io.ReadWriteCloser) {
259 | numBytes := 0
260 | buf := make([]byte, 65536)
261 | for {
262 | n, err := s.Read(buf)
263 | if err != nil {
264 | return
265 | }
266 | s.Write(buf[:n])
267 | numBytes += n
268 |
269 | if numBytes == N {
270 | s.Close()
271 | return
272 | }
273 | }
274 | }(stream)
275 | } else {
276 | return
277 | }
278 | }
279 | }()
280 |
281 | addr := ln.Addr().String()
282 | conn, err := net.Dial("tcp", addr)
283 | if err != nil {
284 | t.Fatal(err)
285 | }
286 | defer conn.Close()
287 |
288 | // client
289 | session, _ := Client(conn, config)
290 | stream, _ := session.OpenStream()
291 | sndbuf := make([]byte, N)
292 | for i := range sndbuf {
293 | sndbuf[i] = byte(rand.Int())
294 | }
295 |
296 | go stream.Write(sndbuf)
297 |
298 | var rcvbuf bytes.Buffer
299 | nw, ew := stream.WriteTo(&rcvbuf)
300 | if ew != io.EOF {
301 | t.Fatal(ew)
302 | }
303 |
304 | if nw != N {
305 | t.Fatal("WriteTo nw mismatch", nw, N)
306 | }
307 |
308 | if !bytes.Equal(sndbuf, rcvbuf.Bytes()) {
309 | t.Fatal("mismatched echo bytes")
310 | }
311 |
312 | t.Log(stream)
313 | }
314 |
315 | func TestGetDieCh(t *testing.T) {
316 | cs, ss, err := getSmuxStreamPair()
317 | if err != nil {
318 | t.Fatal(err)
319 | }
320 | defer ss.Close()
321 | dieCh := ss.GetDieCh()
322 | go func() {
323 | select {
324 | case <-dieCh:
325 | case <-time.Tick(time.Second):
326 | t.Fatal("wait die chan timeout")
327 | }
328 | }()
329 | cs.Close()
330 | }
331 |
332 | func TestSpeed(t *testing.T) {
333 | _, stop, cli, err := setupServer(t)
334 | if err != nil {
335 | t.Fatal(err)
336 | }
337 | defer stop()
338 | session, _ := Client(cli, nil)
339 | stream, _ := session.OpenStream()
340 | t.Log(stream.LocalAddr(), stream.RemoteAddr())
341 |
342 | start := time.Now()
343 | var wg sync.WaitGroup
344 | wg.Add(1)
345 | go func() {
346 | buf := make([]byte, 1024*1024)
347 | nrecv := 0
348 | for {
349 | n, err := stream.Read(buf)
350 | if err != nil {
351 | t.Error(err)
352 | break
353 | } else {
354 | nrecv += n
355 | if nrecv == 4096*4096 {
356 | break
357 | }
358 | }
359 | }
360 | stream.Close()
361 | t.Log("time for 16MB rtt", time.Since(start))
362 | wg.Done()
363 | }()
364 | msg := make([]byte, 8192)
365 | for i := 0; i < 2048; i++ {
366 | stream.Write(msg)
367 | }
368 | wg.Wait()
369 | session.Close()
370 | }
371 |
372 | func TestParallel(t *testing.T) {
373 | _, stop, cli, err := setupServer(t)
374 | if err != nil {
375 | t.Fatal(err)
376 | }
377 | defer stop()
378 | session, _ := Client(cli, nil)
379 |
380 | par := 1000
381 | messages := 100
382 | var wg sync.WaitGroup
383 | wg.Add(par)
384 | for i := 0; i < par; i++ {
385 | stream, _ := session.OpenStream()
386 | go func(s *Stream) {
387 | buf := make([]byte, 20)
388 | for j := 0; j < messages; j++ {
389 | msg := fmt.Sprintf("hello%v", j)
390 | s.Write([]byte(msg))
391 | if _, err := s.Read(buf); err != nil {
392 | break
393 | }
394 | }
395 | s.Close()
396 | wg.Done()
397 | }(stream)
398 | }
399 | t.Log("created", session.NumStreams(), "streams")
400 | wg.Wait()
401 | session.Close()
402 | }
403 |
404 | func TestParallelV2(t *testing.T) {
405 | config := DefaultConfig()
406 | config.Version = 2
407 | _, stop, cli, err := setupServerV2(t)
408 | if err != nil {
409 | t.Fatal(err)
410 | }
411 | defer stop()
412 | session, _ := Client(cli, config)
413 |
414 | par := 1000
415 | messages := 100
416 | var wg sync.WaitGroup
417 | wg.Add(par)
418 | for i := 0; i < par; i++ {
419 | stream, _ := session.OpenStream()
420 | go func(s *Stream) {
421 | buf := make([]byte, 20)
422 | for j := 0; j < messages; j++ {
423 | msg := fmt.Sprintf("hello%v", j)
424 | s.Write([]byte(msg))
425 | if _, err := s.Read(buf); err != nil {
426 | break
427 | }
428 | }
429 | s.Close()
430 | wg.Done()
431 | }(stream)
432 | }
433 | t.Log("created", session.NumStreams(), "streams")
434 | wg.Wait()
435 | session.Close()
436 | }
437 |
438 | func TestCloseThenOpen(t *testing.T) {
439 | _, stop, cli, err := setupServer(t)
440 | if err != nil {
441 | t.Fatal(err)
442 | }
443 | defer stop()
444 | session, _ := Client(cli, nil)
445 | session.Close()
446 | if _, err := session.OpenStream(); err == nil {
447 | t.Fatal("opened after close")
448 | }
449 | }
450 |
451 | func TestSessionDoubleClose(t *testing.T) {
452 | _, stop, cli, err := setupServer(t)
453 | if err != nil {
454 | t.Fatal(err)
455 | }
456 | defer stop()
457 | session, _ := Client(cli, nil)
458 | session.Close()
459 | if err := session.Close(); err == nil {
460 | t.Fatal("session double close doesn't return error")
461 | }
462 | }
463 |
464 | func TestStreamDoubleClose(t *testing.T) {
465 | _, stop, cli, err := setupServer(t)
466 | if err != nil {
467 | t.Fatal(err)
468 | }
469 | defer stop()
470 | session, _ := Client(cli, nil)
471 | stream, _ := session.OpenStream()
472 | stream.Close()
473 | if err := stream.Close(); err == nil {
474 | t.Fatal("stream double close doesn't return error")
475 | }
476 | session.Close()
477 | }
478 |
479 | func TestConcurrentClose(t *testing.T) {
480 | _, stop, cli, err := setupServer(t)
481 | if err != nil {
482 | t.Fatal(err)
483 | }
484 | defer stop()
485 | session, _ := Client(cli, nil)
486 | numStreams := 100
487 | streams := make([]*Stream, 0, numStreams)
488 | var wg sync.WaitGroup
489 | wg.Add(numStreams)
490 | for i := 0; i < 100; i++ {
491 | stream, _ := session.OpenStream()
492 | streams = append(streams, stream)
493 | }
494 | for _, s := range streams {
495 | stream := s
496 | go func() {
497 | stream.Close()
498 | wg.Done()
499 | }()
500 | }
501 | session.Close()
502 | wg.Wait()
503 | }
504 |
505 | func TestTinyReadBuffer(t *testing.T) {
506 | _, stop, cli, err := setupServer(t)
507 | if err != nil {
508 | t.Fatal(err)
509 | }
510 | defer stop()
511 | session, _ := Client(cli, nil)
512 | stream, _ := session.OpenStream()
513 | const N = 100
514 | tinybuf := make([]byte, 6)
515 | var sent string
516 | var received string
517 | for i := 0; i < N; i++ {
518 | msg := fmt.Sprintf("hello%v", i)
519 | sent += msg
520 | nsent, err := stream.Write([]byte(msg))
521 | if err != nil {
522 | t.Fatal("cannot write")
523 | }
524 | nrecv := 0
525 | for nrecv < nsent {
526 | if n, err := stream.Read(tinybuf); err == nil {
527 | nrecv += n
528 | received += string(tinybuf[:n])
529 | } else {
530 | t.Fatal("cannot read with tiny buffer")
531 | }
532 | }
533 | }
534 |
535 | if sent != received {
536 | t.Fatal("data mimatch")
537 | }
538 | session.Close()
539 | }
540 |
541 | func TestIsClose(t *testing.T) {
542 | _, stop, cli, err := setupServer(t)
543 | if err != nil {
544 | t.Fatal(err)
545 | }
546 | defer stop()
547 | session, _ := Client(cli, nil)
548 | session.Close()
549 | if !session.IsClosed() {
550 | t.Fatal("still open after close")
551 | }
552 | }
553 |
554 | func TestKeepAliveTimeout(t *testing.T) {
555 | ln, err := net.Listen("tcp", "localhost:0")
556 | if err != nil {
557 | t.Fatal(err)
558 | }
559 | defer ln.Close()
560 | go func() {
561 | ln.Accept()
562 | }()
563 |
564 | cli, err := net.Dial("tcp", ln.Addr().String())
565 | if err != nil {
566 | t.Fatal(err)
567 | }
568 | defer cli.Close()
569 |
570 | config := DefaultConfig()
571 | config.KeepAliveInterval = time.Second
572 | config.KeepAliveTimeout = 2 * time.Second
573 | session, _ := Client(cli, config)
574 | time.Sleep(3 * time.Second)
575 | if !session.IsClosed() {
576 | t.Fatal("keepalive-timeout failed")
577 | }
578 | }
579 |
580 | type blockWriteConn struct {
581 | net.Conn
582 | }
583 |
584 | func (c *blockWriteConn) Write(b []byte) (n int, err error) {
585 | forever := time.Hour * 24
586 | time.Sleep(forever)
587 | return c.Conn.Write(b)
588 | }
589 |
590 | func TestKeepAliveBlockWriteTimeout(t *testing.T) {
591 | ln, err := net.Listen("tcp", "localhost:0")
592 | if err != nil {
593 | t.Fatal(err)
594 | }
595 | defer ln.Close()
596 | go func() {
597 | ln.Accept()
598 | }()
599 |
600 | cli, err := net.Dial("tcp", ln.Addr().String())
601 | if err != nil {
602 | t.Fatal(err)
603 | }
604 | defer cli.Close()
605 | //when writeFrame block, keepalive in old version never timeout
606 | blockWriteCli := &blockWriteConn{cli}
607 |
608 | config := DefaultConfig()
609 | config.KeepAliveInterval = time.Second
610 | config.KeepAliveTimeout = 2 * time.Second
611 | session, _ := Client(blockWriteCli, config)
612 | time.Sleep(3 * time.Second)
613 | if !session.IsClosed() {
614 | t.Fatal("keepalive-timeout failed")
615 | }
616 | }
617 |
618 | func TestServerEcho(t *testing.T) {
619 | ln, err := net.Listen("tcp", "localhost:0")
620 | if err != nil {
621 | t.Fatal(err)
622 | }
623 | defer ln.Close()
624 | go func() {
625 | err := func() error {
626 | conn, err := ln.Accept()
627 | if err != nil {
628 | return err
629 | }
630 | defer conn.Close()
631 | session, err := Server(conn, nil)
632 | if err != nil {
633 | return err
634 | }
635 | defer session.Close()
636 | buf := make([]byte, 10)
637 | stream, err := session.OpenStream()
638 | if err != nil {
639 | return err
640 | }
641 | defer stream.Close()
642 | for i := 0; i < 100; i++ {
643 | msg := fmt.Sprintf("hello%v", i)
644 | stream.Write([]byte(msg))
645 | n, err := stream.Read(buf)
646 | if err != nil {
647 | return err
648 | }
649 | if got := string(buf[:n]); got != msg {
650 | return fmt.Errorf("got: %q, want: %q", got, msg)
651 | }
652 | }
653 | return nil
654 | }()
655 | if err != nil {
656 | t.Error(err)
657 | }
658 | }()
659 |
660 | cli, err := net.Dial("tcp", ln.Addr().String())
661 | if err != nil {
662 | t.Fatal(err)
663 | }
664 | defer cli.Close()
665 | if session, err := Client(cli, nil); err == nil {
666 | if stream, err := session.AcceptStream(); err == nil {
667 | buf := make([]byte, 65536)
668 | for {
669 | n, err := stream.Read(buf)
670 | if err != nil {
671 | break
672 | }
673 | stream.Write(buf[:n])
674 | }
675 | } else {
676 | t.Fatal(err)
677 | }
678 | } else {
679 | t.Fatal(err)
680 | }
681 | }
682 |
683 | func TestSendWithoutRecv(t *testing.T) {
684 | _, stop, cli, err := setupServer(t)
685 | if err != nil {
686 | t.Fatal(err)
687 | }
688 | defer stop()
689 | session, _ := Client(cli, nil)
690 | stream, _ := session.OpenStream()
691 | const N = 100
692 | for i := 0; i < N; i++ {
693 | msg := fmt.Sprintf("hello%v", i)
694 | stream.Write([]byte(msg))
695 | }
696 | buf := make([]byte, 1)
697 | if _, err := stream.Read(buf); err != nil {
698 | t.Fatal(err)
699 | }
700 | stream.Close()
701 | }
702 |
703 | func TestWriteAfterClose(t *testing.T) {
704 | _, stop, cli, err := setupServer(t)
705 | if err != nil {
706 | t.Fatal(err)
707 | }
708 | defer stop()
709 | session, _ := Client(cli, nil)
710 | stream, _ := session.OpenStream()
711 | stream.Close()
712 | if _, err := stream.Write([]byte("write after close")); err == nil {
713 | t.Fatal("write after close failed")
714 | }
715 | }
716 |
717 | func TestReadStreamAfterSessionClose(t *testing.T) {
718 | _, stop, cli, err := setupServer(t)
719 | if err != nil {
720 | t.Fatal(err)
721 | }
722 | defer stop()
723 | session, _ := Client(cli, nil)
724 | stream, _ := session.OpenStream()
725 | session.Close()
726 | buf := make([]byte, 10)
727 | if _, err := stream.Read(buf); err != nil {
728 | t.Log(err)
729 | } else {
730 | t.Fatal("read stream after session close succeeded")
731 | }
732 | }
733 |
734 | func TestWriteStreamAfterConnectionClose(t *testing.T) {
735 | _, stop, cli, err := setupServer(t)
736 | if err != nil {
737 | t.Fatal(err)
738 | }
739 | defer stop()
740 | session, _ := Client(cli, nil)
741 | stream, _ := session.OpenStream()
742 | session.conn.Close()
743 | if _, err := stream.Write([]byte("write after connection close")); err == nil {
744 | t.Fatal("write after connection close failed")
745 | }
746 | }
747 |
748 | func TestNumStreamAfterClose(t *testing.T) {
749 | _, stop, cli, err := setupServer(t)
750 | if err != nil {
751 | t.Fatal(err)
752 | }
753 | defer stop()
754 | session, _ := Client(cli, nil)
755 | if _, err := session.OpenStream(); err == nil {
756 | if session.NumStreams() != 1 {
757 | t.Fatal("wrong number of streams after opened")
758 | }
759 | session.Close()
760 | if session.NumStreams() != 0 {
761 | t.Fatal("wrong number of streams after session closed")
762 | }
763 | } else {
764 | t.Fatal(err)
765 | }
766 | cli.Close()
767 | }
768 |
769 | func TestRandomFrame(t *testing.T) {
770 | addr, stop, cli, err := setupServer(t)
771 | if err != nil {
772 | t.Fatal(err)
773 | }
774 | defer stop()
775 | // pure random
776 | session, _ := Client(cli, nil)
777 | for i := 0; i < 100; i++ {
778 | rnd := make([]byte, rand.Uint32()%1024)
779 | io.ReadFull(crand.Reader, rnd)
780 | session.conn.Write(rnd)
781 | }
782 | cli.Close()
783 |
784 | // double syn
785 | cli, err = net.Dial("tcp", addr)
786 | if err != nil {
787 | t.Fatal(err)
788 | }
789 | session, _ = Client(cli, nil)
790 | for i := 0; i < 100; i++ {
791 | f := newFrame(1, cmdSYN, 1000)
792 | session.writeControlFrame(f)
793 | }
794 | cli.Close()
795 |
796 | // random cmds
797 | cli, err = net.Dial("tcp", addr)
798 | if err != nil {
799 | t.Fatal(err)
800 | }
801 | allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
802 | session, _ = Client(cli, nil)
803 | for i := 0; i < 100; i++ {
804 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
805 | session.writeControlFrame(f)
806 | }
807 | cli.Close()
808 |
809 | // random cmds & sids
810 | cli, err = net.Dial("tcp", addr)
811 | if err != nil {
812 | t.Fatal(err)
813 | }
814 | session, _ = Client(cli, nil)
815 | for i := 0; i < 100; i++ {
816 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
817 | session.writeControlFrame(f)
818 | }
819 | cli.Close()
820 |
821 | // random version
822 | cli, err = net.Dial("tcp", addr)
823 | if err != nil {
824 | t.Fatal(err)
825 | }
826 | session, _ = Client(cli, nil)
827 | for i := 0; i < 100; i++ {
828 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
829 | f.ver = byte(rand.Uint32())
830 | session.writeControlFrame(f)
831 | }
832 | cli.Close()
833 |
834 | // incorrect size
835 | cli, err = net.Dial("tcp", addr)
836 | if err != nil {
837 | t.Fatal(err)
838 | }
839 | session, _ = Client(cli, nil)
840 |
841 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
842 | rnd := make([]byte, rand.Uint32()%1024)
843 | io.ReadFull(crand.Reader, rnd)
844 | f.data = rnd
845 |
846 | buf := make([]byte, headerSize+len(f.data))
847 | buf[0] = f.ver
848 | buf[1] = f.cmd
849 | binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size
850 | binary.LittleEndian.PutUint32(buf[4:], f.sid)
851 | copy(buf[headerSize:], f.data)
852 |
853 | session.conn.Write(buf)
854 | cli.Close()
855 |
856 | // writeFrame after die
857 | cli, err = net.Dial("tcp", addr)
858 | if err != nil {
859 | t.Fatal(err)
860 | }
861 | session, _ = Client(cli, nil)
862 | //close first
863 | session.Close()
864 | for i := 0; i < 100; i++ {
865 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
866 | session.writeControlFrame(f)
867 | }
868 | }
869 |
870 | func TestWriteFrameInternal(t *testing.T) {
871 | addr, stop, cli, err := setupServer(t)
872 | if err != nil {
873 | t.Fatal(err)
874 | }
875 | defer stop()
876 | // pure random
877 | session, _ := Client(cli, nil)
878 | for i := 0; i < 100; i++ {
879 | rnd := make([]byte, rand.Uint32()%1024)
880 | io.ReadFull(crand.Reader, rnd)
881 | session.conn.Write(rnd)
882 | }
883 | cli.Close()
884 |
885 | // writeFrame after die
886 | cli, err = net.Dial("tcp", addr)
887 | if err != nil {
888 | t.Fatal(err)
889 | }
890 | session, _ = Client(cli, nil)
891 | //close first
892 | session.Close()
893 | for i := 0; i < 100; i++ {
894 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
895 |
896 | timer := time.NewTimer(session.config.KeepAliveTimeout)
897 | defer timer.Stop()
898 |
899 | session.writeFrameInternal(f, timer.C, CLSDATA)
900 | }
901 |
902 | // random cmds
903 | cli, err = net.Dial("tcp", addr)
904 | if err != nil {
905 | t.Fatal(err)
906 | }
907 | allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
908 | session, _ = Client(cli, nil)
909 | for i := 0; i < 100; i++ {
910 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
911 |
912 | timer := time.NewTimer(session.config.KeepAliveTimeout)
913 | defer timer.Stop()
914 |
915 | session.writeFrameInternal(f, timer.C, CLSDATA)
916 | }
917 | //deadline occur
918 | {
919 | c := make(chan time.Time)
920 | close(c)
921 | f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
922 | _, err := session.writeFrameInternal(f, c, CLSDATA)
923 | if !strings.Contains(err.Error(), "timeout") {
924 | t.Fatal("write frame with deadline failed", err)
925 | }
926 | }
927 | cli.Close()
928 |
929 | {
930 | cli, err = net.Dial("tcp", addr)
931 | if err != nil {
932 | t.Fatal(err)
933 | }
934 | config := DefaultConfig()
935 | config.KeepAliveInterval = time.Second
936 | config.KeepAliveTimeout = 2 * time.Second
937 | session, _ = Client(&blockWriteConn{cli}, config)
938 | f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
939 | c := make(chan time.Time)
940 | go func() {
941 | //die first, deadline second, better for coverage
942 | time.Sleep(time.Second)
943 | session.Close()
944 | time.Sleep(time.Second)
945 | close(c)
946 | }()
947 | _, err = session.writeFrameInternal(f, c, CLSDATA)
948 | if !strings.Contains(err.Error(), "closed pipe") {
949 | t.Fatal("write frame with to closed conn failed", err)
950 | }
951 | }
952 | }
953 |
954 | func TestReadDeadline(t *testing.T) {
955 | _, stop, cli, err := setupServer(t)
956 | if err != nil {
957 | t.Fatal(err)
958 | }
959 | defer stop()
960 | session, _ := Client(cli, nil)
961 | stream, _ := session.OpenStream()
962 | const N = 100
963 | buf := make([]byte, 10)
964 | var readErr error
965 | for i := 0; i < N; i++ {
966 | stream.SetReadDeadline(time.Now().Add(-1 * time.Minute))
967 | if _, readErr = stream.Read(buf); readErr != nil {
968 | break
969 | }
970 | }
971 | if readErr != nil {
972 | if !strings.Contains(readErr.Error(), "timeout") {
973 | t.Fatalf("Wrong error: %v", readErr)
974 | }
975 | } else {
976 | t.Fatal("No error when reading with past deadline")
977 | }
978 | session.Close()
979 | }
980 |
981 | func TestWriteDeadline(t *testing.T) {
982 | _, stop, cli, err := setupServer(t)
983 | if err != nil {
984 | t.Fatal(err)
985 | }
986 | defer stop()
987 | session, _ := Client(cli, nil)
988 | stream, _ := session.OpenStream()
989 | buf := make([]byte, 10)
990 | var writeErr error
991 | for {
992 | stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute))
993 | if _, writeErr = stream.Write(buf); writeErr != nil {
994 | if !strings.Contains(writeErr.Error(), "timeout") {
995 | t.Fatalf("Wrong error: %v", writeErr)
996 | }
997 | break
998 | }
999 | }
1000 | session.Close()
1001 | }
1002 |
1003 | func BenchmarkAcceptClose(b *testing.B) {
1004 | _, stop, cli, err := setupServer(b)
1005 | if err != nil {
1006 | b.Fatal(err)
1007 | }
1008 | defer stop()
1009 | session, _ := Client(cli, nil)
1010 | for i := 0; i < b.N; i++ {
1011 | if stream, err := session.OpenStream(); err == nil {
1012 | stream.Close()
1013 | } else {
1014 | b.Fatal(err)
1015 | }
1016 | }
1017 | }
1018 | func BenchmarkConnSmux(b *testing.B) {
1019 | cs, ss, err := getSmuxStreamPair()
1020 | if err != nil {
1021 | b.Fatal(err)
1022 | }
1023 | defer cs.Close()
1024 | defer ss.Close()
1025 | bench(b, cs, ss)
1026 | }
1027 |
1028 | func BenchmarkConnTCP(b *testing.B) {
1029 | cs, ss, err := getTCPConnectionPair()
1030 | if err != nil {
1031 | b.Fatal(err)
1032 | }
1033 | defer cs.Close()
1034 | defer ss.Close()
1035 | bench(b, cs, ss)
1036 | }
1037 |
1038 | func getSmuxStreamPair() (*Stream, *Stream, error) {
1039 | c1, c2, err := getTCPConnectionPair()
1040 | if err != nil {
1041 | return nil, nil, err
1042 | }
1043 |
1044 | s, err := Server(c2, nil)
1045 | if err != nil {
1046 | return nil, nil, err
1047 | }
1048 | c, err := Client(c1, nil)
1049 | if err != nil {
1050 | return nil, nil, err
1051 | }
1052 | var ss *Stream
1053 | done := make(chan error)
1054 | go func() {
1055 | var rerr error
1056 | ss, rerr = s.AcceptStream()
1057 | done <- rerr
1058 | close(done)
1059 | }()
1060 | cs, err := c.OpenStream()
1061 | if err != nil {
1062 | return nil, nil, err
1063 | }
1064 | err = <-done
1065 | if err != nil {
1066 | return nil, nil, err
1067 | }
1068 |
1069 | return cs, ss, nil
1070 | }
1071 |
1072 | func getTCPConnectionPair() (net.Conn, net.Conn, error) {
1073 | lst, err := net.Listen("tcp", "localhost:0")
1074 | if err != nil {
1075 | return nil, nil, err
1076 | }
1077 | defer lst.Close()
1078 |
1079 | var conn0 net.Conn
1080 | var err0 error
1081 | done := make(chan struct{})
1082 | go func() {
1083 | conn0, err0 = lst.Accept()
1084 | close(done)
1085 | }()
1086 |
1087 | conn1, err := net.Dial("tcp", lst.Addr().String())
1088 | if err != nil {
1089 | return nil, nil, err
1090 | }
1091 |
1092 | <-done
1093 | if err0 != nil {
1094 | return nil, nil, err0
1095 | }
1096 | return conn0, conn1, nil
1097 | }
1098 |
1099 | func bench(b *testing.B, rd io.Reader, wr io.Writer) {
1100 | buf := make([]byte, 128*1024)
1101 | buf2 := make([]byte, 128*1024)
1102 | b.SetBytes(128 * 1024)
1103 | b.ResetTimer()
1104 | b.ReportAllocs()
1105 |
1106 | var wg sync.WaitGroup
1107 | wg.Add(1)
1108 | go func() {
1109 | defer wg.Done()
1110 | count := 0
1111 | for {
1112 | n, _ := rd.Read(buf2)
1113 | count += n
1114 | if count == 128*1024*b.N {
1115 | return
1116 | }
1117 | }
1118 | }()
1119 | for i := 0; i < b.N; i++ {
1120 | wr.Write(buf)
1121 | // invalidate L3 Cache
1122 | buf = make([]byte, 128*1024)
1123 | }
1124 | wg.Wait()
1125 | }
1126 |
--------------------------------------------------------------------------------
/shaper.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | // _itimediff returns the time difference between two uint32 values.
26 | // The result is a signed 32-bit integer representing the difference between 'later' and 'earlier'.
27 | func _itimediff(later, earlier uint32) int32 {
28 | return (int32)(later - earlier)
29 | }
30 |
31 | // shaperHeap is a min-heap of writeRequest.
32 | // It orders writeRequests by class first, then by sequence number within the same class.
33 | type shaperHeap []writeRequest
34 |
35 | func (h shaperHeap) Len() int { return len(h) }
36 |
37 | // Less determines the ordering of elements in the heap.
38 | // Requests are ordered by their class first. If two requests have the same class,
39 | // they are ordered by their sequence numbers.
40 | func (h shaperHeap) Less(i, j int) bool {
41 | if h[i].class != h[j].class {
42 | return h[i].class < h[j].class
43 | }
44 | return _itimediff(h[j].seq, h[i].seq) > 0
45 | }
46 |
47 | func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
48 | func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }
49 |
50 | func (h *shaperHeap) Pop() interface{} {
51 | old := *h
52 | n := len(old)
53 | x := old[n-1]
54 | *h = old[0 : n-1]
55 | return x
56 | }
57 |
--------------------------------------------------------------------------------
/shaper_test.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "container/heap"
27 | "testing"
28 | )
29 |
30 | func TestShaper(t *testing.T) {
31 | w1 := writeRequest{seq: 1}
32 | w2 := writeRequest{seq: 2}
33 | w3 := writeRequest{seq: 3}
34 | w4 := writeRequest{seq: 4}
35 | w5 := writeRequest{seq: 5}
36 |
37 | var reqs shaperHeap
38 | heap.Push(&reqs, w5)
39 | heap.Push(&reqs, w4)
40 | heap.Push(&reqs, w3)
41 | heap.Push(&reqs, w2)
42 | heap.Push(&reqs, w1)
43 |
44 | for len(reqs) > 0 {
45 | w := heap.Pop(&reqs).(writeRequest)
46 | t.Log("sid:", w.frame.sid, "seq:", w.seq)
47 | }
48 | }
49 |
50 | func TestShaper2(t *testing.T) {
51 | w1 := writeRequest{class: CLSDATA, seq: 1} // stream 0
52 | w2 := writeRequest{class: CLSDATA, seq: 2}
53 | w3 := writeRequest{class: CLSDATA, seq: 3}
54 | w4 := writeRequest{class: CLSDATA, seq: 4}
55 | w5 := writeRequest{class: CLSDATA, seq: 5}
56 | w6 := writeRequest{class: CLSCTRL, seq: 6, frame: Frame{sid: 10}} // ctrl 1
57 | w7 := writeRequest{class: CLSCTRL, seq: 7, frame: Frame{sid: 11}} // ctrl 2
58 |
59 | var reqs shaperHeap
60 | heap.Push(&reqs, w6)
61 | heap.Push(&reqs, w5)
62 | heap.Push(&reqs, w4)
63 | heap.Push(&reqs, w3)
64 | heap.Push(&reqs, w2)
65 | heap.Push(&reqs, w1)
66 | heap.Push(&reqs, w7)
67 |
68 | for len(reqs) > 0 {
69 | w := heap.Pop(&reqs).(writeRequest)
70 | t.Log("sid:", w.frame.sid, "seq:", w.seq)
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/stream.go:
--------------------------------------------------------------------------------
1 | // MIT License
2 | //
3 | // Copyright (c) 2016-2017 xtaci
4 | //
5 | // Permission is hereby granted, free of charge, to any person obtaining a copy
6 | // of this software and associated documentation files (the "Software"), to deal
7 | // in the Software without restriction, including without limitation the rights
8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | // copies of the Software, and to permit persons to whom the Software is
10 | // furnished to do so, subject to the following conditions:
11 | //
12 | // The above copyright notice and this permission notice shall be included in all
13 | // copies or substantial portions of the Software.
14 | //
15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | // SOFTWARE.
22 |
23 | package smux
24 |
25 | import (
26 | "encoding/binary"
27 | "io"
28 | "net"
29 | "sync"
30 | "sync/atomic"
31 | "time"
32 | )
33 |
34 | // wrapper for GC
35 | type Stream struct {
36 | *stream
37 | }
38 |
39 | // Stream implements net.Conn
40 | type stream struct {
41 | id uint32 // Stream identifier
42 | sess *Session
43 |
44 | buffers []*[]byte // the sequential buffers of stream
45 | heads []*[]byte // slice heads of the buffers above, kept for recycle
46 |
47 | bufferLock sync.Mutex // Mutex to protect access to buffers
48 | frameSize int // Maximum frame size for the stream
49 |
50 | // notify a read event
51 | chReadEvent chan struct{}
52 |
53 | // flag the stream has closed
54 | die chan struct{}
55 | dieOnce sync.Once // Ensures die channel is closed only once
56 |
57 | // FIN command
58 | chFinEvent chan struct{}
59 | finEventOnce sync.Once // Ensures chFinEvent is closed only once
60 |
61 | // deadlines
62 | readDeadline atomic.Value
63 | writeDeadline atomic.Value
64 |
65 | // per stream sliding window control
66 | numRead uint32 // count num of bytes read
67 | numWritten uint32 // count num of bytes written
68 | incr uint32 // bytes sent since last window update
69 |
70 | // UPD command
71 | peerConsumed uint32 // num of bytes the peer has consumed
72 | peerWindow uint32 // peer window, initialized to 256KB, updated by peer
73 | chUpdate chan struct{} // notify of remote data consuming and window update
74 | }
75 |
76 | // newStream initializes and returns a new Stream.
77 | func newStream(id uint32, frameSize int, sess *Session) *stream {
78 | s := new(stream)
79 | s.id = id
80 | s.chReadEvent = make(chan struct{}, 1)
81 | s.chUpdate = make(chan struct{}, 1)
82 | s.frameSize = frameSize
83 | s.sess = sess
84 | s.die = make(chan struct{})
85 | s.chFinEvent = make(chan struct{})
86 | s.peerWindow = initialPeerWindow // set to initial window size
87 |
88 | return s
89 | }
90 |
91 | // ID returns the stream's unique identifier.
92 | func (s *stream) ID() uint32 {
93 | return s.id
94 | }
95 |
96 | // Read reads data from the stream into the provided buffer.
97 | func (s *stream) Read(b []byte) (n int, err error) {
98 | for {
99 | n, err = s.tryRead(b)
100 | if err == ErrWouldBlock {
101 | if ew := s.waitRead(); ew != nil {
102 | return 0, ew
103 | }
104 | } else {
105 | return n, err
106 | }
107 | }
108 | }
109 |
110 | // tryRead attempts to read data from the stream without blocking.
111 | func (s *stream) tryRead(b []byte) (n int, err error) {
112 | if s.sess.config.Version == 2 {
113 | return s.tryReadv2(b)
114 | }
115 |
116 | if len(b) == 0 {
117 | return 0, nil
118 | }
119 |
120 | // A critical section to copy data from buffers to
121 | s.bufferLock.Lock()
122 | if len(s.buffers) > 0 {
123 | n = copy(b, *s.buffers[0])
124 | s.buffers[0] = s.buffers[0]
125 | *s.buffers[0] = (*s.buffers[0])[n:]
126 | if len(*s.buffers[0]) == 0 {
127 | s.buffers[0] = nil
128 | s.buffers = s.buffers[1:]
129 | // full recycle
130 | defaultAllocator.Put(s.heads[0])
131 | s.heads = s.heads[1:]
132 | }
133 | }
134 | s.bufferLock.Unlock()
135 |
136 | if n > 0 {
137 | s.sess.returnTokens(n)
138 | return n, nil
139 | }
140 |
141 | select {
142 | case <-s.die:
143 | return 0, io.EOF
144 | default:
145 | return 0, ErrWouldBlock
146 | }
147 | }
148 |
149 | // tryReadv2 is the non-blocking version of Read for version 2 streams.
150 | func (s *stream) tryReadv2(b []byte) (n int, err error) {
151 | if len(b) == 0 {
152 | return 0, nil
153 | }
154 |
155 | var notifyConsumed uint32
156 | s.bufferLock.Lock()
157 | if len(s.buffers) > 0 {
158 | n = copy(b, *s.buffers[0])
159 | s.buffers[0] = s.buffers[0]
160 | *s.buffers[0] = (*s.buffers[0])[n:]
161 | if len(*s.buffers[0]) == 0 {
162 | s.buffers[0] = nil
163 | s.buffers = s.buffers[1:]
164 | // full recycle
165 | defaultAllocator.Put(s.heads[0])
166 | s.heads = s.heads[1:]
167 | }
168 | }
169 |
170 | // in an ideal environment:
171 | // if more than half of buffer has consumed, send read ack to peer
172 | // based on round-trip time of ACK, continous flowing data
173 | // won't slow down due to waiting for ACK, as long as the
174 | // consumer keeps on reading data.
175 | //
176 | // s.numRead == n implies that it's the initial reading
177 | s.numRead += uint32(n)
178 | s.incr += uint32(n)
179 |
180 | // for initial reading, send window update
181 | if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
182 | notifyConsumed = s.numRead
183 | s.incr = 0 // reset couting for next window update
184 | }
185 | s.bufferLock.Unlock()
186 |
187 | if n > 0 {
188 | s.sess.returnTokens(n)
189 |
190 | // send window update if necessary
191 | if notifyConsumed > 0 {
192 | err := s.sendWindowUpdate(notifyConsumed)
193 | return n, err
194 | } else {
195 | return n, nil
196 | }
197 | }
198 |
199 | select {
200 | case <-s.die:
201 | return 0, io.EOF
202 | default:
203 | return 0, ErrWouldBlock
204 | }
205 | }
206 |
207 | // WriteTo implements io.WriteTo
208 | // WriteTo writes data to w until there's no more data to write or when an error occurs.
209 | // The return value n is the number of bytes written. Any error encountered during the write is also returned.
210 | // WriteTo calls Write in a loop until there is no more data to write or when an error occurs.
211 | // If the underlying stream is a v2 stream, it will send window update to peer when necessary.
212 | // If the underlying stream is a v1 stream, it will not send window update to peer.
213 | func (s *stream) WriteTo(w io.Writer) (n int64, err error) {
214 | if s.sess.config.Version == 2 {
215 | return s.writeTov2(w)
216 | }
217 |
218 | for {
219 | var pbuf *[]byte
220 | s.bufferLock.Lock()
221 | if len(s.buffers) > 0 {
222 | pbuf = s.buffers[0]
223 | s.buffers = s.buffers[1:]
224 | s.heads = s.heads[1:]
225 | }
226 | s.bufferLock.Unlock()
227 |
228 | if pbuf != nil {
229 | nw, ew := w.Write(*pbuf)
230 | // NOTE: WriteTo is a reader, so we need to return tokens here
231 | s.sess.returnTokens(len(*pbuf))
232 | defaultAllocator.Put(pbuf)
233 | if nw > 0 {
234 | n += int64(nw)
235 | }
236 |
237 | if ew != nil {
238 | return n, ew
239 | }
240 | } else if ew := s.waitRead(); ew != nil {
241 | return n, ew
242 | }
243 | }
244 | }
245 |
246 | // check comments in WriteTo
247 | func (s *stream) writeTov2(w io.Writer) (n int64, err error) {
248 | for {
249 | var notifyConsumed uint32
250 | var pbuf *[]byte
251 | s.bufferLock.Lock()
252 | if len(s.buffers) > 0 {
253 | pbuf = s.buffers[0]
254 | s.buffers = s.buffers[1:]
255 | s.heads = s.heads[1:]
256 | }
257 | var bufLen uint32
258 | if pbuf != nil {
259 | bufLen = uint32(len(*pbuf))
260 | }
261 | s.numRead += bufLen
262 | s.incr += bufLen
263 | if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == bufLen {
264 | notifyConsumed = s.numRead
265 | s.incr = 0
266 | }
267 | s.bufferLock.Unlock()
268 |
269 | if pbuf != nil {
270 | nw, ew := w.Write(*pbuf)
271 | // NOTE: WriteTo is a reader, so we need to return tokens here
272 | s.sess.returnTokens(len(*pbuf))
273 | defaultAllocator.Put(pbuf)
274 | if nw > 0 {
275 | n += int64(nw)
276 | }
277 |
278 | if ew != nil {
279 | return n, ew
280 | }
281 |
282 | if notifyConsumed > 0 {
283 | if err := s.sendWindowUpdate(notifyConsumed); err != nil {
284 | return n, err
285 | }
286 | }
287 | } else if ew := s.waitRead(); ew != nil {
288 | return n, ew
289 | }
290 | }
291 | }
292 |
293 | // sendWindowUpdate sends a window update frame to the peer.
294 | func (s *stream) sendWindowUpdate(consumed uint32) error {
295 | var timer *time.Timer
296 | var deadline <-chan time.Time
297 | if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
298 | timer = time.NewTimer(time.Until(d))
299 | defer timer.Stop()
300 | deadline = timer.C
301 | }
302 |
303 | frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
304 | var hdr updHeader
305 | binary.LittleEndian.PutUint32(hdr[:], consumed)
306 | binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
307 | frame.data = hdr[:]
308 | _, err := s.sess.writeFrameInternal(frame, deadline, CLSCTRL)
309 | return err
310 | }
311 |
312 | // waitRead blocks until a read event occurs or a deadline is reached.
313 | func (s *stream) waitRead() error {
314 | var timer *time.Timer
315 | var deadline <-chan time.Time
316 | if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
317 | timer = time.NewTimer(time.Until(d))
318 | defer timer.Stop()
319 | deadline = timer.C
320 | }
321 |
322 | select {
323 | case <-s.chReadEvent: // notify some data has arrived, or closed
324 | return nil
325 | case <-s.chFinEvent:
326 | // BUGFIX(xtaci): Fix for https://github.com/xtaci/smux/issues/82
327 | s.bufferLock.Lock()
328 | defer s.bufferLock.Unlock()
329 | if len(s.buffers) > 0 {
330 | return nil
331 | }
332 | return io.EOF
333 | case <-s.sess.chSocketReadError:
334 | return s.sess.socketReadError.Load().(error)
335 | case <-s.sess.chProtoError:
336 | return s.sess.protoError.Load().(error)
337 | case <-deadline:
338 | return ErrTimeout
339 | case <-s.die:
340 | return io.ErrClosedPipe
341 | }
342 |
343 | }
344 |
345 | // Write implements net.Conn
346 | //
347 | // Note that the behavior when multiple goroutines write concurrently is not deterministic,
348 | // frames may interleave in random way.
349 | func (s *stream) Write(b []byte) (n int, err error) {
350 | if s.sess.config.Version == 2 {
351 | return s.writeV2(b)
352 | }
353 |
354 | var deadline <-chan time.Time
355 | if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
356 | timer := time.NewTimer(time.Until(d))
357 | defer timer.Stop()
358 | deadline = timer.C
359 | }
360 |
361 | // check if stream has closed
362 | select {
363 | case <-s.chFinEvent: // passive closing
364 | return 0, io.EOF
365 | case <-s.die:
366 | return 0, io.ErrClosedPipe
367 | default:
368 | }
369 |
370 | // frame split and transmit
371 | sent := 0
372 | frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
373 | bts := b
374 | for len(bts) > 0 {
375 | sz := len(bts)
376 | if sz > s.frameSize {
377 | sz = s.frameSize
378 | }
379 | frame.data = bts[:sz]
380 | bts = bts[sz:]
381 | n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
382 | s.numWritten++
383 | sent += n
384 | if err != nil {
385 | return sent, err
386 | }
387 | }
388 |
389 | return sent, nil
390 | }
391 |
392 | // writeV2 writes data to the stream for version 2 streams.
393 | func (s *stream) writeV2(b []byte) (n int, err error) {
394 | // check empty input
395 | if len(b) == 0 {
396 | return 0, nil
397 | }
398 |
399 | // check if stream has closed
400 | select {
401 | case <-s.chFinEvent:
402 | return 0, io.EOF
403 | case <-s.die:
404 | return 0, io.ErrClosedPipe
405 | default:
406 | }
407 |
408 | // create write deadline timer
409 | var deadline <-chan time.Time
410 | if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
411 | timer := time.NewTimer(time.Until(d))
412 | defer timer.Stop()
413 | deadline = timer.C
414 | }
415 |
416 | // frame split and transmit process
417 | sent := 0
418 | frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
419 |
420 | for {
421 | // per stream sliding window control
422 | // [.... [consumed... numWritten] ... win... ]
423 | // [.... [consumed...................+rmtwnd]]
424 | var bts []byte
425 | // note:
426 | // even if uint32 overflow, this math still works:
427 | // eg1: uint32(0) - uint32(math.MaxUint32) = 1
428 | // eg2: int32(uint32(0) - uint32(1)) = -1
429 | //
430 | // basicially, you can take it as a MODULAR ARITHMETIC
431 | inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
432 | if inflight < 0 { // security check for malformed data
433 | return 0, ErrConsumed
434 | }
435 |
436 | // make sure you understand 'win' is calculated in modular arithmetic(2^32(4GB))
437 | win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
438 |
439 | if win > 0 {
440 | // determine how many bytes to send
441 | if win > int32(len(b)) {
442 | bts = b
443 | b = nil
444 | } else {
445 | bts = b[:win]
446 | b = b[win:]
447 | }
448 |
449 | // frame split and transmit
450 | for len(bts) > 0 {
451 | // splitting frame
452 | sz := len(bts)
453 | if sz > s.frameSize {
454 | sz = s.frameSize
455 | }
456 | frame.data = bts[:sz]
457 | bts = bts[sz:]
458 |
459 | // transmit of frame
460 | n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
461 | atomic.AddUint32(&s.numWritten, uint32(sz))
462 | sent += n
463 | if err != nil {
464 | return sent, err
465 | }
466 | }
467 | }
468 |
469 | // if there is any data left to be sent,
470 | // wait until stream closes, window changes or deadline reached
471 | // this blocking behavior will back propagate flow control to upper layer.
472 | if len(b) > 0 {
473 | select {
474 | case <-s.chFinEvent:
475 | return 0, io.EOF
476 | case <-s.die:
477 | return sent, io.ErrClosedPipe
478 | case <-deadline:
479 | return sent, ErrTimeout
480 | case <-s.sess.chSocketWriteError:
481 | return sent, s.sess.socketWriteError.Load().(error)
482 | case <-s.chUpdate: // notify of remote data consuming and window update
483 | continue
484 | }
485 | } else {
486 | return sent, nil
487 | }
488 | }
489 | }
490 |
491 | // Close implements net.Conn
492 | func (s *stream) Close() error {
493 | var once bool
494 | var err error
495 | s.dieOnce.Do(func() {
496 | close(s.die)
497 | once = true
498 | })
499 |
500 | if once {
501 | // send FIN in order
502 | f := newFrame(byte(s.sess.config.Version), cmdFIN, s.id)
503 |
504 | timer := time.NewTimer(openCloseTimeout)
505 | defer timer.Stop()
506 |
507 | _, err = s.sess.writeFrameInternal(f, timer.C, CLSDATA)
508 | s.sess.streamClosed(s.id)
509 | return err
510 | } else {
511 | return io.ErrClosedPipe
512 | }
513 | }
514 |
515 | // GetDieCh returns a readonly chan which can be readable
516 | // when the stream is to be closed.
517 | func (s *stream) GetDieCh() <-chan struct{} {
518 | return s.die
519 | }
520 |
521 | // SetReadDeadline sets the read deadline as defined by
522 | // net.Conn.SetReadDeadline.
523 | // A zero time value disables the deadline.
524 | func (s *stream) SetReadDeadline(t time.Time) error {
525 | s.readDeadline.Store(t)
526 | s.notifyReadEvent()
527 | return nil
528 | }
529 |
530 | // SetWriteDeadline sets the write deadline as defined by
531 | // net.Conn.SetWriteDeadline.
532 | // A zero time value disables the deadline.
533 | func (s *stream) SetWriteDeadline(t time.Time) error {
534 | s.writeDeadline.Store(t)
535 | return nil
536 | }
537 |
538 | // SetDeadline sets both read and write deadlines as defined by
539 | // net.Conn.SetDeadline.
540 | // A zero time value disables the deadlines.
541 | func (s *stream) SetDeadline(t time.Time) error {
542 | if err := s.SetReadDeadline(t); err != nil {
543 | return err
544 | }
545 | if err := s.SetWriteDeadline(t); err != nil {
546 | return err
547 | }
548 | return nil
549 | }
550 |
551 | // session closes
552 | func (s *stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
553 |
554 | // LocalAddr satisfies net.Conn interface
555 | func (s *stream) LocalAddr() net.Addr {
556 | if ts, ok := s.sess.conn.(interface {
557 | LocalAddr() net.Addr
558 | }); ok {
559 | return ts.LocalAddr()
560 | }
561 | return nil
562 | }
563 |
564 | // RemoteAddr satisfies net.Conn interface
565 | func (s *stream) RemoteAddr() net.Addr {
566 | if ts, ok := s.sess.conn.(interface {
567 | RemoteAddr() net.Addr
568 | }); ok {
569 | return ts.RemoteAddr()
570 | }
571 | return nil
572 | }
573 |
574 | // pushBytes append buf to buffers
575 | func (s *stream) pushBytes(pbuf *[]byte) (written int, err error) {
576 | s.bufferLock.Lock()
577 | s.buffers = append(s.buffers, pbuf)
578 | s.heads = append(s.heads, pbuf)
579 | s.bufferLock.Unlock()
580 | return
581 | }
582 |
583 | // recycleTokens transform remaining bytes to tokens(will truncate buffer)
584 | func (s *stream) recycleTokens() (n int) {
585 | s.bufferLock.Lock()
586 | for k := range s.buffers {
587 | n += len(*s.buffers[k])
588 | defaultAllocator.Put(s.heads[k])
589 | }
590 | s.buffers = nil
591 | s.heads = nil
592 | s.bufferLock.Unlock()
593 | return
594 | }
595 |
596 | // notify read event
597 | func (s *stream) notifyReadEvent() {
598 | select {
599 | case s.chReadEvent <- struct{}{}:
600 | default:
601 | }
602 | }
603 |
604 | // update command
605 | func (s *stream) update(consumed uint32, window uint32) {
606 | atomic.StoreUint32(&s.peerConsumed, consumed)
607 | atomic.StoreUint32(&s.peerWindow, window)
608 | select {
609 | case s.chUpdate <- struct{}{}:
610 | default:
611 | }
612 | }
613 |
614 | // mark this stream has been closed in protocol
615 | func (s *stream) fin() {
616 | s.finEventOnce.Do(func() {
617 | close(s.chFinEvent)
618 | })
619 | }
620 |
--------------------------------------------------------------------------------