├── .gitignore ├── LICENSE ├── README.md ├── client.go ├── client_test.go ├── conn.go ├── examples └── session │ └── main.go ├── go.mod ├── go.sum ├── io.go ├── net.go ├── pool.go ├── server.go ├── server_test.go ├── session.go └── session_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kenta Iwasaki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # monte 2 | 3 | [![MIT License](https://img.shields.io/apm/l/atomic-design-ui.svg?)](LICENSE) 4 | [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/lithdew/monte) 5 | [![Discord Chat](https://img.shields.io/discord/697002823123992617)](https://discord.gg/HZEbkeQ) 6 | 7 | The bare minimum for high performance, fully-encrypted RPC over TCP in Go. 8 | 9 | ## Features 10 | 11 | 1. Send requests, receive responses, or send messages without waiting for a response. 12 | 2. Send from 50MiB/s to 1500MiB/s, with zero allocations per sent message or RPC call. 13 | 3. Gracefully establish multiple client connections to a single endpoint up to a configurable limit. 14 | 4. Set the total number of connections that may concurrently be accepted and handled by a single endpoint. 15 | 5. Configure read/write timeouts, dial timeouts, handshake timeouts, or customize the handshaking protocol. 16 | 6. All messages, once the handshake protocol is complete, are encrypted and non-distinguishable from each other. 17 | 7. Supports graceful shutdowns for both client and server, with extensive tests for highly-concurrent scenarios. 18 | 19 | ## Protocol 20 | 21 | ### Handshake 22 | 23 | 1. Send X25519 curve point (32 bytes) to peer. 24 | 2. Receive X25519 curve point (32 bytes) from our peer. 25 | 3. Multiply X25519 curve scalar with X25519 curve point received from our peer. 26 | 4. Derive a shared key by using BLAKE-2b as a key derivation function over our scalar point multiplication result. 27 | 5. Encrypt further communication with AES 256-bit GCM using our shared key, with a nonce counter increasing for every 28 | incoming/outgoing message. 29 | 30 | ### Message Format 31 | 32 | 1. Encrypted messages are prefixed with an unsigned 32-bit integer denoting the message's length. 33 | 2. The decoded message content is prefixed with an unsigned 32-bit integer designating a sequence number. 34 | 3. The sequence number is used as an identifier to identify requests/responses from one another. 35 | 4. The sequence number 0 is reserved for requests that do not expect a response. 36 | 37 | ## Benchmarks 38 | 39 | ``` 40 | $ cat /proc/cpuinfo | grep 'model name' | uniq 41 | model name : Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz 42 | 43 | $ go test -bench=. -benchtime=10s 44 | goos: linux 45 | goarch: amd64 46 | pkg: github.com/lithdew/monte 47 | BenchmarkSend-8 1814391 6690 ns/op 209.27 MB/s 115 B/op 0 allocs/op 48 | BenchmarkSendNoWait-8 10638730 1153 ns/op 1214.19 MB/s 141 B/op 0 allocs/op 49 | BenchmarkRequest-8 438381 28556 ns/op 49.03 MB/s 140 B/op 0 allocs/op 50 | BenchmarkParallelSend-8 4917001 2876 ns/op 486.70 MB/s 115 B/op 0 allocs/op 51 | BenchmarkParallelSendNoWait-8 10317255 1291 ns/op 1084.78 MB/s 150 B/op 0 allocs/op 52 | BenchmarkParallelRequest-8 1341444 8520 ns/op 164.32 MB/s 140 B/op 0 allocs/op 53 | ``` -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | var DefaultMaxClientConns = 4 10 | var DefaultNumDialAttempts = 1 11 | var DefaultReadBufferSize = 4096 12 | var DefaultWriteBufferSize = 4096 13 | var DefaultDialTimeout = 3 * time.Second 14 | var DefaultReadTimeout = 3 * time.Second 15 | var DefaultWriteTimeout = 3 * time.Second 16 | var DefaultClientSeqOffset uint32 = 1 17 | var DefaultClientSeqDelta uint32 = 2 18 | 19 | type clientConn struct { 20 | conn *Conn 21 | ready chan struct{} 22 | err error 23 | } 24 | 25 | type Client struct { 26 | Addr string 27 | 28 | Handler Handler 29 | ConnState ConnStateHandler 30 | 31 | Handshaker Handshaker 32 | HandshakeTimeout time.Duration 33 | 34 | MaxConns int 35 | NumDialAttempts int 36 | 37 | ReadBufferSize int 38 | WriteBufferSize int 39 | 40 | DialTimeout time.Duration 41 | ReadTimeout time.Duration 42 | WriteTimeout time.Duration 43 | 44 | SeqOffset uint32 45 | SeqDelta uint32 46 | 47 | once sync.Once 48 | shutdown sync.Once 49 | 50 | done chan struct{} 51 | 52 | mu sync.Mutex 53 | conns []*clientConn 54 | } 55 | 56 | func (c *Client) Get() (*Conn, error) { 57 | c.once.Do(c.init) 58 | 59 | cc := c.getClientConn() 60 | 61 | <-cc.ready 62 | if cc.err != nil { 63 | return nil, cc.err 64 | } 65 | 66 | return cc.conn, nil 67 | } 68 | 69 | func (c *Client) Send(buf []byte) error { 70 | conn, err := c.Get() 71 | if err != nil { 72 | return err 73 | } 74 | return conn.Send(buf) 75 | } 76 | 77 | func (c *Client) SendNoWait(buf []byte) error { 78 | conn, err := c.Get() 79 | if err != nil { 80 | return err 81 | } 82 | return conn.SendNoWait(buf) 83 | } 84 | 85 | func (c *Client) Request(dst, buf []byte) ([]byte, error) { 86 | conn, err := c.Get() 87 | if err != nil { 88 | return nil, err 89 | } 90 | return conn.Request(dst, buf) 91 | } 92 | 93 | func (c *Client) NumPendingWrites() int { 94 | c.mu.Lock() 95 | defer c.mu.Unlock() 96 | 97 | n := 0 98 | for _, cc := range c.conns { 99 | n += cc.conn.NumPendingWrites() 100 | } 101 | return n 102 | } 103 | 104 | func (c *Client) Shutdown() { 105 | c.once.Do(c.init) 106 | 107 | c.shutdown.Do(func() { 108 | close(c.done) 109 | }) 110 | } 111 | 112 | func (c *Client) init() { 113 | c.done = make(chan struct{}) 114 | } 115 | 116 | func (c *Client) deleteClientConn(conn *clientConn) { 117 | c.mu.Lock() 118 | defer c.mu.Unlock() 119 | 120 | entries := c.conns[:] 121 | 122 | c.conns = c.conns[:0] 123 | for i := 0; i < len(entries); i++ { 124 | if entries[i] == conn { 125 | continue 126 | } 127 | c.conns = append(c.conns, entries[i]) 128 | } 129 | } 130 | 131 | func (c *Client) newClientConn() *clientConn { 132 | cc := &clientConn{ 133 | ready: make(chan struct{}), 134 | conn: &Conn{ 135 | SeqOffset: c.getSeqOffset(), 136 | SeqDelta: c.getSeqDelta(), 137 | Handler: c.getHandler(), 138 | ReadBufferSize: c.getReadBufferSize(), 139 | WriteBufferSize: c.getWriteBufferSize(), 140 | ReadTimeout: c.getReadTimeout(), 141 | WriteTimeout: c.getWriteTimeout(), 142 | }, 143 | } 144 | c.conns = append(c.conns, cc) 145 | 146 | go func() { 147 | defer c.deleteClientConn(cc) 148 | 149 | dialer := net.Dialer{Timeout: c.getDialTimeout()} 150 | 151 | var ( 152 | conn net.Conn 153 | bufConn BufferedConn 154 | ) 155 | 156 | for i := 0; i < c.getNumDialAttempts(); i++ { 157 | conn, cc.err = dialer.Dial("tcp", c.Addr) 158 | if cc.err == nil { 159 | cc.err = conn.SetDeadline(time.Now().Add(c.getHandshakeTimeout())) 160 | } 161 | if cc.err == nil { 162 | bufConn, cc.err = c.getHandshaker().Handshake(conn) 163 | } 164 | if cc.err == nil { 165 | cc.err = conn.SetDeadline(zeroTime) 166 | } 167 | if cc.err == nil { 168 | break 169 | } 170 | } 171 | 172 | if cc.err != nil { 173 | if conn != nil { 174 | conn.Close() 175 | } 176 | close(cc.ready) 177 | return 178 | } 179 | 180 | close(cc.ready) 181 | 182 | c.getConnStateHandler().HandleConnState(cc.conn, StateNew) 183 | 184 | cc.conn.close(cc.conn.Handle(c.done, bufConn)) 185 | 186 | c.getConnStateHandler().HandleConnState(cc.conn, StateClosed) 187 | }() 188 | 189 | return cc 190 | } 191 | 192 | func (c *Client) getClientConn() *clientConn { 193 | c.mu.Lock() 194 | defer c.mu.Unlock() 195 | 196 | if len(c.conns) == 0 { 197 | return c.newClientConn() 198 | } 199 | 200 | mc := c.conns[0] 201 | mp := mc.conn.NumPendingWrites() 202 | if mp == 0 { 203 | return mc 204 | } 205 | for i := 1; i < len(c.conns); i++ { 206 | cc := c.conns[i] 207 | cp := cc.conn.NumPendingWrites() 208 | if cp == 0 { 209 | return cc 210 | } 211 | if cp < mp { 212 | mc, mp = cc, cp 213 | } 214 | } 215 | if len(c.conns) < c.getMaxConns() { 216 | return c.newClientConn() 217 | } 218 | return mc 219 | } 220 | 221 | func (c *Client) getHandler() Handler { 222 | if c.Handler == nil { 223 | return DefaultHandler 224 | } 225 | return c.Handler 226 | } 227 | 228 | func (c *Client) getConnStateHandler() ConnStateHandler { 229 | if c.ConnState == nil { 230 | return DefaultConnStateHandler 231 | } 232 | return c.ConnState 233 | } 234 | 235 | func (c *Client) getHandshaker() Handshaker { 236 | if c.Handshaker == nil { 237 | return DefaultClientHandshaker 238 | } 239 | return c.Handshaker 240 | } 241 | 242 | func (c *Client) getMaxConns() int { 243 | if c.MaxConns <= 0 { 244 | return DefaultMaxClientConns 245 | } 246 | return c.MaxConns 247 | } 248 | 249 | func (c *Client) getNumDialAttempts() int { 250 | if c.NumDialAttempts <= 0 { 251 | return DefaultNumDialAttempts 252 | } 253 | return c.NumDialAttempts 254 | } 255 | 256 | func (c *Client) getReadBufferSize() int { 257 | if c.ReadBufferSize <= 0 { 258 | return DefaultReadBufferSize 259 | } 260 | return c.ReadBufferSize 261 | } 262 | 263 | func (c *Client) getWriteBufferSize() int { 264 | if c.WriteBufferSize <= 0 { 265 | return DefaultWriteBufferSize 266 | } 267 | return c.WriteBufferSize 268 | } 269 | 270 | func (c *Client) getHandshakeTimeout() time.Duration { 271 | if c.HandshakeTimeout <= 0 { 272 | return DefaultHandshakeTimeout 273 | } 274 | return c.HandshakeTimeout 275 | } 276 | 277 | func (c *Client) getDialTimeout() time.Duration { 278 | if c.DialTimeout <= 0 { 279 | return DefaultDialTimeout 280 | } 281 | return c.DialTimeout 282 | } 283 | 284 | func (c *Client) getReadTimeout() time.Duration { 285 | if c.ReadTimeout < 0 { 286 | return DefaultReadTimeout 287 | } 288 | return c.ReadTimeout 289 | } 290 | 291 | func (c *Client) getWriteTimeout() time.Duration { 292 | if c.WriteTimeout < 0 { 293 | return DefaultWriteTimeout 294 | } 295 | return c.WriteTimeout 296 | } 297 | 298 | func (c *Client) getSeqOffset() uint32 { 299 | if c.SeqOffset == 0 { 300 | return DefaultClientSeqOffset 301 | } 302 | return c.SeqOffset 303 | } 304 | func (c *Client) getSeqDelta() uint32 { 305 | if c.SeqDelta == 0 { 306 | return DefaultClientSeqDelta 307 | } 308 | return c.SeqDelta 309 | } 310 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "fmt" 5 | "github.com/stretchr/testify/require" 6 | "go.uber.org/goleak" 7 | "math/rand" 8 | "net" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func TestClientHandshakeTimeout(t *testing.T) { 16 | defer goleak.VerifyNone(t) 17 | 18 | ln, err := net.Listen("tcp", ":0") 19 | require.NoError(t, err) 20 | 21 | client := &Client{Addr: ln.Addr().String(), HandshakeTimeout: 1 * time.Millisecond} 22 | 23 | defer func() { 24 | client.Shutdown() 25 | require.NoError(t, ln.Close()) 26 | }() 27 | 28 | attempts := 16 29 | go func() { 30 | for i := 0; i < attempts; i++ { 31 | _, _ = ln.Accept() 32 | } 33 | }() 34 | 35 | for i := 0; i < attempts; i++ { 36 | require.Error(t, client.Send([]byte("hello\n"))) 37 | } 38 | } 39 | 40 | func TestClientSend(t *testing.T) { 41 | defer goleak.VerifyNone(t) 42 | 43 | n := 4 44 | m := 1024 45 | c := uint32(n * m) 46 | 47 | ln, err := net.Listen("tcp", ":0") 48 | require.NoError(t, err) 49 | 50 | var server Server 51 | 52 | client := &Client{Addr: ln.Addr().String()} 53 | 54 | go func() { 55 | require.NoError(t, server.Serve(ln)) 56 | }() 57 | 58 | defer func() { 59 | server.Shutdown() 60 | client.Shutdown() 61 | 62 | require.NoError(t, ln.Close()) 63 | require.EqualValues(t, 0, atomic.LoadUint32(&c)) 64 | }() 65 | 66 | var wg sync.WaitGroup 67 | wg.Add(n) 68 | 69 | for i := 0; i < n; i++ { 70 | go func(i int) { 71 | defer wg.Done() 72 | for j := 0; j < m; j++ { 73 | require.NoError(t, client.Send([]byte(fmt.Sprintf("[%d] hello %d", i, j)))) 74 | atomic.AddUint32(&c, ^uint32(0)) 75 | } 76 | }(i) 77 | } 78 | 79 | wg.Wait() 80 | } 81 | 82 | func TestClientRequest(t *testing.T) { 83 | defer goleak.VerifyNone(t) 84 | 85 | n := 4 86 | m := 1024 87 | c := uint32(n * m * 2) 88 | 89 | ln, err := net.Listen("tcp", ":0") 90 | require.NoError(t, err) 91 | 92 | handler := func(ctx *Context) error { 93 | atomic.AddUint32(&c, ^uint32(0)) 94 | return ctx.Reply([]byte("a reply!")) 95 | } 96 | 97 | var server Server 98 | server.Handler = HandlerFunc(handler) 99 | 100 | client := &Client{Addr: ln.Addr().String()} 101 | 102 | go func() { 103 | require.NoError(t, server.Serve(ln)) 104 | }() 105 | 106 | defer func() { 107 | server.Shutdown() 108 | client.Shutdown() 109 | 110 | require.NoError(t, ln.Close()) 111 | require.EqualValues(t, 0, atomic.LoadUint32(&c)) 112 | }() 113 | 114 | var wg sync.WaitGroup 115 | wg.Add(n) 116 | 117 | for i := 0; i < n; i++ { 118 | go func(i int) { 119 | defer wg.Done() 120 | for j := 0; j < m; j++ { 121 | res, err := client.Request(nil, []byte(fmt.Sprintf("[%d] hello %d", i, j))) 122 | require.NoError(t, err) 123 | require.EqualValues(t, []byte("a reply!"), res) 124 | atomic.AddUint32(&c, ^uint32(0)) 125 | } 126 | }(i) 127 | } 128 | 129 | wg.Wait() 130 | } 131 | 132 | func BenchmarkSend(b *testing.B) { 133 | ln, err := net.Listen("tcp", ":0") 134 | require.NoError(b, err) 135 | 136 | var server Server 137 | 138 | client := &Client{Addr: ln.Addr().String()} 139 | 140 | go func() { 141 | require.NoError(b, server.Serve(ln)) 142 | }() 143 | 144 | defer func() { 145 | server.Shutdown() 146 | client.Shutdown() 147 | 148 | require.NoError(b, ln.Close()) 149 | }() 150 | 151 | buf := make([]byte, 1400) 152 | _, err = rand.Read(buf) 153 | require.NoError(b, err) 154 | 155 | b.SetBytes(int64(len(buf))) 156 | b.ReportAllocs() 157 | b.ResetTimer() 158 | 159 | for i := 0; i < b.N; i++ { 160 | err := client.Send(buf) 161 | if err != nil { 162 | b.Fatal(err) 163 | } 164 | } 165 | } 166 | 167 | func BenchmarkSendNoWait(b *testing.B) { 168 | ln, err := net.Listen("tcp", ":0") 169 | require.NoError(b, err) 170 | 171 | var server Server 172 | 173 | client := &Client{Addr: ln.Addr().String()} 174 | 175 | go func() { 176 | require.NoError(b, server.Serve(ln)) 177 | }() 178 | 179 | defer func() { 180 | server.Shutdown() 181 | client.Shutdown() 182 | 183 | require.NoError(b, ln.Close()) 184 | }() 185 | 186 | buf := make([]byte, 1400) 187 | _, err = rand.Read(buf) 188 | require.NoError(b, err) 189 | 190 | b.SetBytes(int64(len(buf))) 191 | b.ReportAllocs() 192 | b.ResetTimer() 193 | 194 | for i := 0; i < b.N; i++ { 195 | err := client.SendNoWait(buf) 196 | if err != nil { 197 | b.Fatal(err) 198 | } 199 | } 200 | } 201 | 202 | func BenchmarkRequest(b *testing.B) { 203 | ln, err := net.Listen("tcp", ":0") 204 | require.NoError(b, err) 205 | 206 | var server Server 207 | server.Handler = HandlerFunc(func(ctx *Context) error { 208 | return ctx.Reply(nil) 209 | }) 210 | 211 | client := &Client{Addr: ln.Addr().String()} 212 | 213 | go func() { 214 | require.NoError(b, server.Serve(ln)) 215 | }() 216 | 217 | defer func() { 218 | server.Shutdown() 219 | client.Shutdown() 220 | 221 | require.NoError(b, ln.Close()) 222 | }() 223 | 224 | buf := make([]byte, 1400) 225 | _, err = rand.Read(buf) 226 | require.NoError(b, err) 227 | 228 | b.SetBytes(int64(len(buf))) 229 | b.ReportAllocs() 230 | b.ResetTimer() 231 | 232 | for i := 0; i < b.N; i++ { 233 | res, err := client.Request(nil, buf) 234 | if err != nil { 235 | b.Fatal(err) 236 | } 237 | if len(res) != 0 { 238 | b.Fatalf("expected empty response, got '%s'", string(res)) 239 | } 240 | } 241 | } 242 | 243 | func BenchmarkParallelSend(b *testing.B) { 244 | ln, err := net.Listen("tcp", ":0") 245 | require.NoError(b, err) 246 | 247 | var server Server 248 | 249 | client := &Client{Addr: ln.Addr().String()} 250 | 251 | go func() { 252 | require.NoError(b, server.Serve(ln)) 253 | }() 254 | 255 | defer func() { 256 | server.Shutdown() 257 | client.Shutdown() 258 | 259 | require.NoError(b, ln.Close()) 260 | }() 261 | 262 | buf := make([]byte, 1400) 263 | _, err = rand.Read(buf) 264 | require.NoError(b, err) 265 | 266 | b.SetBytes(int64(len(buf))) 267 | b.ReportAllocs() 268 | b.ResetTimer() 269 | 270 | b.RunParallel(func(pb *testing.PB) { 271 | for pb.Next() { 272 | err := client.Send(buf) 273 | if err != nil { 274 | b.Fatal(err) 275 | } 276 | } 277 | }) 278 | } 279 | 280 | func BenchmarkParallelSendNoWait(b *testing.B) { 281 | ln, err := net.Listen("tcp", ":0") 282 | require.NoError(b, err) 283 | 284 | var server Server 285 | 286 | client := &Client{Addr: ln.Addr().String()} 287 | 288 | go func() { 289 | require.NoError(b, server.Serve(ln)) 290 | }() 291 | 292 | defer func() { 293 | server.Shutdown() 294 | client.Shutdown() 295 | 296 | require.NoError(b, ln.Close()) 297 | }() 298 | 299 | buf := make([]byte, 1400) 300 | _, err = rand.Read(buf) 301 | require.NoError(b, err) 302 | 303 | b.SetBytes(int64(len(buf))) 304 | b.ReportAllocs() 305 | b.ResetTimer() 306 | 307 | b.RunParallel(func(pb *testing.PB) { 308 | for pb.Next() { 309 | err := client.SendNoWait(buf) 310 | if err != nil { 311 | b.Fatal(err) 312 | } 313 | } 314 | }) 315 | } 316 | 317 | func BenchmarkParallelRequest(b *testing.B) { 318 | ln, err := net.Listen("tcp", ":0") 319 | require.NoError(b, err) 320 | 321 | var server Server 322 | server.Handler = HandlerFunc(func(ctx *Context) error { 323 | return ctx.Reply(nil) 324 | }) 325 | 326 | client := &Client{Addr: ln.Addr().String()} 327 | 328 | go func() { 329 | require.NoError(b, server.Serve(ln)) 330 | }() 331 | 332 | defer func() { 333 | server.Shutdown() 334 | client.Shutdown() 335 | 336 | require.NoError(b, ln.Close()) 337 | }() 338 | 339 | buf := make([]byte, 1400) 340 | _, err = rand.Read(buf) 341 | require.NoError(b, err) 342 | 343 | b.SetBytes(int64(len(buf))) 344 | b.ReportAllocs() 345 | b.ResetTimer() 346 | 347 | b.RunParallel(func(pb *testing.PB) { 348 | for pb.Next() { 349 | res, err := client.Request(nil, buf) 350 | if err != nil { 351 | b.Fatal(err) 352 | } 353 | if len(res) != 0 { 354 | b.Fatalf("expected empty response, got '%s'", string(res)) 355 | } 356 | } 357 | }) 358 | } 359 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "github.com/lithdew/bytesutil" 7 | "github.com/valyala/bytebufferpool" 8 | "io" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | var DefaultSeqOffset uint32 = 1 14 | var DefaultSeqDelta uint32 = 2 15 | 16 | type Conn struct { 17 | Handler Handler 18 | 19 | ReadBufferSize int 20 | WriteBufferSize int 21 | 22 | ReadTimeout time.Duration 23 | WriteTimeout time.Duration 24 | 25 | SeqOffset uint32 26 | SeqDelta uint32 27 | 28 | mu sync.Mutex 29 | once sync.Once 30 | 31 | writerQueue []*pendingWrite 32 | writerCond sync.Cond 33 | writerDone bool 34 | 35 | reqs map[uint32]*pendingRequest 36 | seq uint32 37 | } 38 | 39 | func (c *Conn) NumPendingWrites() int { 40 | c.mu.Lock() 41 | defer c.mu.Unlock() 42 | return len(c.writerQueue) 43 | } 44 | 45 | func (c *Conn) Handle(done chan struct{}, conn BufferedConn) error { 46 | c.once.Do(c.init) 47 | 48 | writerDone := make(chan error) 49 | go func() { 50 | writerDone <- c.writeLoop(conn) 51 | close(writerDone) 52 | }() 53 | 54 | readerDone := make(chan error) 55 | go func() { 56 | readerDone <- c.readLoop(conn) 57 | close(readerDone) 58 | }() 59 | 60 | var err error 61 | 62 | select { 63 | case <-done: 64 | c.closeWriter() 65 | err = <-writerDone 66 | conn.Close() 67 | if err == nil { 68 | err = <-readerDone 69 | } else { 70 | <-readerDone 71 | } 72 | case err = <-writerDone: 73 | c.closeWriter() 74 | conn.Close() 75 | if err == nil { 76 | err = <-readerDone 77 | } else { 78 | <-readerDone 79 | } 80 | case err = <-readerDone: 81 | c.closeWriter() 82 | if err == nil { 83 | err = <-writerDone 84 | } else { 85 | <-writerDone 86 | } 87 | conn.Close() 88 | } 89 | 90 | return err 91 | } 92 | 93 | func (c *Conn) Send(payload []byte) error { c.once.Do(c.init); return c.send(0, payload) } 94 | func (c *Conn) SendNoWait(payload []byte) error { c.once.Do(c.init); return c.sendNoWait(0, payload) } 95 | 96 | func (c *Conn) Request(dst []byte, payload []byte) ([]byte, error) { 97 | c.once.Do(c.init) 98 | 99 | pr := acquirePendingRequest(dst) 100 | defer releasePendingRequest(pr) 101 | 102 | pr.wg.Add(1) 103 | 104 | seq := c.next() 105 | 106 | c.mu.Lock() 107 | c.reqs[seq] = pr 108 | c.mu.Unlock() 109 | 110 | err := c.sendNoWait(seq, payload) 111 | 112 | if err != nil { 113 | pr.wg.Done() 114 | 115 | c.mu.Lock() 116 | delete(c.reqs, seq) 117 | c.mu.Unlock() 118 | return nil, err 119 | } 120 | 121 | pr.wg.Wait() 122 | 123 | return pr.dst, pr.err 124 | } 125 | 126 | func (c *Conn) init() { 127 | c.reqs = make(map[uint32]*pendingRequest) 128 | c.writerCond.L = &c.mu 129 | } 130 | 131 | func (c *Conn) send(seq uint32, payload []byte) error { 132 | buf := bytebufferpool.Get() 133 | defer bytebufferpool.Put(buf) 134 | 135 | buf.B = bytesutil.ExtendSlice(buf.B, 4+len(payload)) 136 | binary.BigEndian.PutUint32(buf.B[:4], seq) 137 | copy(buf.B[4:], payload) 138 | 139 | return c.write(buf) 140 | } 141 | 142 | func (c *Conn) sendNoWait(seq uint32, payload []byte) error { 143 | buf := bytebufferpool.Get() 144 | buf.B = bytesutil.ExtendSlice(buf.B, 4+len(payload)) 145 | binary.BigEndian.PutUint32(buf.B[:4], seq) 146 | copy(buf.B[4:], payload) 147 | return c.writeNoWait(buf) 148 | } 149 | 150 | func (c *Conn) write(buf *bytebufferpool.ByteBuffer) error { 151 | pw, err := c.preparePendingWrite(buf, true) 152 | if err != nil { 153 | return err 154 | } 155 | defer releasePendingWrite(pw) 156 | pw.wg.Wait() 157 | return pw.err 158 | } 159 | 160 | func (c *Conn) writeNoWait(buf *bytebufferpool.ByteBuffer) error { 161 | _, err := c.preparePendingWrite(buf, false) 162 | return err 163 | } 164 | 165 | func (c *Conn) preparePendingWrite(buf *bytebufferpool.ByteBuffer, wait bool) (*pendingWrite, error) { 166 | c.mu.Lock() 167 | defer c.mu.Unlock() 168 | 169 | if c.writerDone { 170 | return nil, fmt.Errorf("node is shut down: %w", io.EOF) 171 | } 172 | 173 | pw := acquirePendingWrite(buf, wait) 174 | if wait { 175 | pw.wg.Add(1) 176 | } 177 | 178 | c.writerQueue = append(c.writerQueue, pw) 179 | c.writerCond.Signal() 180 | 181 | return pw, nil 182 | } 183 | 184 | func (c *Conn) closeWriter() { 185 | c.mu.Lock() 186 | defer c.mu.Unlock() 187 | c.writerDone = true 188 | c.writerCond.Signal() 189 | } 190 | 191 | func (c *Conn) getHandler() Handler { 192 | if c.Handler == nil { 193 | return DefaultHandler 194 | } 195 | return c.Handler 196 | } 197 | 198 | func (c *Conn) getReadBufferSize() int { 199 | if c.ReadBufferSize <= 0 { 200 | return DefaultReadBufferSize 201 | } 202 | return c.ReadBufferSize 203 | } 204 | 205 | func (c *Conn) getWriteBufferSize() int { 206 | if c.WriteBufferSize <= 0 { 207 | return DefaultWriteBufferSize 208 | } 209 | return c.WriteBufferSize 210 | } 211 | 212 | func (c *Conn) getReadTimeout() time.Duration { 213 | if c.ReadTimeout < 0 { 214 | return DefaultReadTimeout 215 | } 216 | return c.ReadTimeout 217 | } 218 | 219 | func (c *Conn) getWriteTimeout() time.Duration { 220 | if c.WriteTimeout < 0 { 221 | return DefaultWriteTimeout 222 | } 223 | return c.WriteTimeout 224 | } 225 | 226 | func (c *Conn) getSeqOffset() uint32 { 227 | if c.SeqOffset == 0 { 228 | return DefaultSeqOffset 229 | } 230 | return c.SeqOffset 231 | } 232 | 233 | func (c *Conn) getSeqDelta() uint32 { 234 | if c.SeqDelta == 0 { 235 | return DefaultSeqDelta 236 | } 237 | return c.SeqDelta 238 | } 239 | 240 | func (c *Conn) next() uint32 { 241 | c.mu.Lock() 242 | defer c.mu.Unlock() 243 | 244 | if c.seq == 0 { 245 | c.seq = c.getSeqOffset() 246 | } else { 247 | c.seq += c.getSeqDelta() 248 | } 249 | return c.seq 250 | } 251 | 252 | func (c *Conn) writeLoop(conn BufferedConn) error { 253 | var queue []*pendingWrite 254 | var err error 255 | 256 | for { 257 | c.mu.Lock() 258 | for !c.writerDone && len(c.writerQueue) == 0 { 259 | c.writerCond.Wait() 260 | } 261 | done := c.writerDone 262 | 263 | if n := len(c.writerQueue) - cap(queue); n > 0 { 264 | queue = append(queue[:cap(queue)], make([]*pendingWrite, n)...) 265 | } 266 | queue = queue[:len(c.writerQueue)] 267 | 268 | copy(queue, c.writerQueue) 269 | 270 | c.writerQueue = c.writerQueue[:0] 271 | c.mu.Unlock() 272 | 273 | if done && len(queue) == 0 { 274 | break 275 | } 276 | 277 | timeout := c.getWriteTimeout() 278 | if timeout > 0 { 279 | err = conn.SetWriteDeadline(time.Now().Add(timeout)) 280 | if err != nil { 281 | for _, pw := range queue { 282 | if pw.wait { 283 | pw.err = err 284 | pw.wg.Done() 285 | } else { 286 | bytebufferpool.Put(pw.buf) 287 | releasePendingWrite(pw) 288 | } 289 | } 290 | break 291 | } 292 | } 293 | 294 | for _, pw := range queue { 295 | if err == nil { 296 | _, err = conn.Write(pw.buf.B) 297 | } 298 | if pw.wait { 299 | pw.err = err 300 | pw.wg.Done() 301 | } else { 302 | bytebufferpool.Put(pw.buf) 303 | releasePendingWrite(pw) 304 | } 305 | } 306 | 307 | if err != nil { 308 | break 309 | } 310 | 311 | err = conn.Flush() 312 | if err != nil { 313 | break 314 | } 315 | } 316 | 317 | if err != nil { 318 | err = fmt.Errorf("write_loop: %w", err) 319 | } 320 | 321 | return err 322 | } 323 | 324 | func (c *Conn) readLoop(conn BufferedConn) error { 325 | buf := make([]byte, c.getReadBufferSize()) 326 | 327 | var ( 328 | n int 329 | err error 330 | ) 331 | 332 | for { 333 | timeout := c.getReadTimeout() 334 | if timeout > 0 { 335 | err = conn.SetReadDeadline(time.Now().Add(timeout)) 336 | if err != nil { 337 | break 338 | } 339 | } 340 | 341 | n, err = conn.Read(buf) 342 | if err != nil { 343 | break 344 | } 345 | 346 | data := buf[:n] 347 | if len(data) < 4 { 348 | err = fmt.Errorf("no sequence number to decode: %w", io.ErrUnexpectedEOF) 349 | break 350 | } 351 | 352 | seq := bytesutil.Uint32BE(data) 353 | data = data[4:] 354 | 355 | c.mu.Lock() 356 | pr, exists := c.reqs[seq] 357 | if exists { 358 | delete(c.reqs, seq) 359 | } 360 | c.mu.Unlock() 361 | 362 | if seq == 0 || !exists { 363 | err = c.call(seq, data) 364 | if err != nil { 365 | err = fmt.Errorf("handler encountered an error: %w", err) 366 | break 367 | } 368 | continue 369 | } 370 | 371 | // received response 372 | 373 | pr.dst = bytesutil.ExtendSlice(pr.dst, len(data)) 374 | copy(pr.dst, data) 375 | 376 | pr.wg.Done() 377 | } 378 | 379 | return fmt.Errorf("read_loop: %w", err) 380 | } 381 | 382 | func (c *Conn) call(seq uint32, data []byte) error { 383 | ctx := acquireContext(c, seq, data) 384 | defer releaseContext(ctx) 385 | return c.getHandler().HandleMessage(ctx) 386 | } 387 | 388 | func (c *Conn) close(err error) { 389 | c.mu.Lock() 390 | defer c.mu.Unlock() 391 | 392 | for _, pw := range c.writerQueue { 393 | if pw.wait { 394 | pw.err = err 395 | pw.wg.Done() 396 | } else { 397 | bytebufferpool.Put(pw.buf) 398 | releasePendingWrite(pw) 399 | } 400 | } 401 | 402 | c.writerQueue = nil 403 | 404 | for seq := range c.reqs { 405 | pr := c.reqs[seq] 406 | pr.err = err 407 | pr.wg.Done() 408 | 409 | delete(c.reqs, seq) 410 | } 411 | 412 | c.seq = 0 413 | } 414 | -------------------------------------------------------------------------------- /examples/session/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/hex" 5 | "fmt" 6 | "github.com/lithdew/monte" 7 | "net" 8 | ) 9 | 10 | func main() { 11 | check := func(err error) { 12 | if err != nil { 13 | panic(err) 14 | } 15 | } 16 | 17 | go func() { 18 | conn, err := net.Dial("tcp", ":4444") 19 | 20 | var sess monte.Session 21 | check(sess.DoClient(conn)) 22 | 23 | fmt.Println(hex.EncodeToString(sess.SharedKey())) 24 | 25 | sc := monte.NewSessionConn(sess.Suite(), conn) 26 | 27 | for i := 0; i < 100; i++ { 28 | _, err = sc.Write([]byte(fmt.Sprintf("[%d] Hello from Go!", i))) 29 | check(err) 30 | check(sc.Flush()) 31 | } 32 | 33 | }() 34 | 35 | ln, err := net.Listen("tcp", ":4444") 36 | check(err) 37 | defer ln.Close() 38 | 39 | conn, err := ln.Accept() 40 | check(err) 41 | defer conn.Close() 42 | 43 | var sess monte.Session 44 | check(sess.DoServer(conn)) 45 | 46 | fmt.Println(hex.EncodeToString(sess.SharedKey())) 47 | 48 | sc := monte.NewSessionConn(sess.Suite(), conn) 49 | 50 | buf := make([]byte, 1024) 51 | 52 | for i := 0; i < 100; i++ { 53 | n, err := sc.Read(buf) 54 | check(err) 55 | 56 | fmt.Println("Decrypted:", string(buf[:n])) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lithdew/monte 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/davecgh/go-spew v1.1.1 7 | github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59 8 | github.com/oasislabs/ed25519 v0.0.0-20200302143042-29f6767a7c3e 9 | github.com/stretchr/testify v1.5.1 10 | github.com/valyala/bytebufferpool v1.0.0 11 | go.uber.org/goleak v1.0.0 12 | golang.org/x/crypto v0.0.0-20191119213627-4f8c1d86b1ba 13 | golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect 14 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd // indirect 15 | golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 5 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 6 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 7 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 8 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 9 | github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59 h1:CQpoOQecHxhvgOU/ijue/yWuShZYDtNpI9bsD4Dkzrk= 10 | github.com/lithdew/bytesutil v0.0.0-20200409052507-d98389230a59/go.mod h1:89JlULMIJ/+YWzAp5aHXgAD2d02S2mY+a+PMgXDtoNs= 11 | github.com/oasislabs/ed25519 v0.0.0-20200302143042-29f6767a7c3e h1:85L+lUTJHx4O7UP9y/65XV8iq7oaA2Uqe5WiUSB8XE4= 12 | github.com/oasislabs/ed25519 v0.0.0-20200302143042-29f6767a7c3e/go.mod h1:xIpCyrK2ouGA4QBGbiNbkoONrvJ00u9P3QOkXSOAC0c= 13 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 14 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 15 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 16 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 17 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 18 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 19 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 20 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 21 | github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 22 | go.uber.org/goleak v1.0.0 h1:qsup4IcBdlmsnGfqyLl4Ntn3C2XCCuKAE7DwHpScyUo= 23 | go.uber.org/goleak v1.0.0/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= 24 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 25 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 26 | golang.org/x/crypto v0.0.0-20191119213627-4f8c1d86b1ba h1:9bFeDpN3gTqNanMVqNcoR/pJQuP5uroC3t1D7eXozTE= 27 | golang.org/x/crypto v0.0.0-20191119213627-4f8c1d86b1ba/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 28 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 29 | golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= 30 | golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= 31 | golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= 32 | golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= 33 | golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 34 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 35 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 36 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 37 | golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 38 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 39 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 40 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 41 | golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= 42 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 43 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= 44 | golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 45 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 46 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 47 | golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 48 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 49 | golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= 50 | golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b h1:2hSR2MyOaYEy6yJYg/CpErymr/m7xJEJpm9kfT7ZMg4= 51 | golang.org/x/tools v0.0.0-20200501005904-d351ea090f9b/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= 52 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 53 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 54 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 55 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 56 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 57 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 58 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 59 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 60 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 61 | -------------------------------------------------------------------------------- /io.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | "github.com/lithdew/bytesutil" 8 | "io" 9 | "net" 10 | ) 11 | 12 | type BufferedConn interface { 13 | net.Conn 14 | Flush() error 15 | } 16 | 17 | func Read(dst []byte, r io.Reader) ([]byte, error) { 18 | _, err := io.ReadFull(r, dst[:]) 19 | if err != nil { 20 | return nil, err 21 | } 22 | return dst, nil 23 | } 24 | 25 | func Write(w io.Writer, buf []byte) error { 26 | n, err := w.Write(buf) 27 | if n != len(buf) { 28 | return io.ErrShortWrite 29 | } 30 | return err 31 | } 32 | 33 | func ReadSized(dst []byte, r io.Reader, max int) ([]byte, error) { 34 | dst = bytesutil.ExtendSlice(dst, 4) 35 | _, err := io.ReadFull(r, dst[:]) 36 | if err != nil { 37 | return nil, err 38 | } 39 | n := bytesutil.Uint32BE(dst[:]) 40 | if int(n) > max { 41 | return nil, fmt.Errorf("max is %d bytes, got %d bytes", max, n) 42 | } 43 | dst = bytesutil.ExtendSlice(dst, int(n)) 44 | _, err = io.ReadFull(r, dst[:]) 45 | if err != nil { 46 | return nil, err 47 | } 48 | return dst, nil 49 | } 50 | 51 | func WriteSized(w io.Writer, buf []byte) error { 52 | buf = bytesutil.ExtendSlice(buf, len(buf)+4) 53 | binary.BigEndian.PutUint32(buf[len(buf)-4:], uint32(len(buf))-4) 54 | _, err := w.Write(buf[len(buf)-4:]) 55 | if err == nil { 56 | _, err = w.Write(buf[:len(buf)-4]) 57 | } 58 | return err 59 | } 60 | 61 | func IsEOF(err error) bool { 62 | if errors.Is(err, io.EOF) { 63 | return true 64 | } 65 | var netErr *net.OpError 66 | if !errors.As(err, &netErr) { 67 | return false 68 | } 69 | if netErr.Err.Error() == "use of closed network connection" { 70 | return true 71 | } 72 | return false 73 | } 74 | -------------------------------------------------------------------------------- /net.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import "net" 4 | 5 | type ConnState int 6 | 7 | const ( 8 | StateNew ConnState = iota 9 | StateClosed 10 | ) 11 | 12 | type ConnStateHandler interface { 13 | HandleConnState(conn *Conn, state ConnState) 14 | } 15 | 16 | type ConnStateHandlerFunc func(conn *Conn, state ConnState) 17 | 18 | func (fn ConnStateHandlerFunc) HandleConnState(conn *Conn, state ConnState) { fn(conn, state) } 19 | 20 | var DefaultConnStateHandler ConnStateHandlerFunc = func(conn *Conn, state ConnState) {} 21 | 22 | type Handler interface { 23 | HandleMessage(ctx *Context) error 24 | } 25 | 26 | type HandlerFunc func(ctx *Context) error 27 | 28 | func (fn HandlerFunc) HandleMessage(ctx *Context) error { return fn(ctx) } 29 | 30 | var DefaultHandler HandlerFunc = func(ctx *Context) error { return nil } 31 | 32 | type Handshaker interface { 33 | Handshake(conn net.Conn) (BufferedConn, error) 34 | } 35 | 36 | type HandshakerFunc func(conn net.Conn) (BufferedConn, error) 37 | 38 | func (fn HandshakerFunc) Handshake(conn net.Conn) (BufferedConn, error) { return fn(conn) } 39 | 40 | var DefaultClientHandshaker HandshakerFunc = func(conn net.Conn) (BufferedConn, error) { 41 | var session Session 42 | err := session.DoClient(conn) 43 | if err != nil { 44 | return nil, err 45 | } 46 | return NewSessionConn(session.Suite(), conn), nil 47 | } 48 | 49 | var DefaultServerHandshaker HandshakerFunc = func(conn net.Conn) (BufferedConn, error) { 50 | var session Session 51 | err := session.DoServer(conn) 52 | if err != nil { 53 | return nil, err 54 | } 55 | return NewSessionConn(session.Suite(), conn), nil 56 | } 57 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "github.com/valyala/bytebufferpool" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type Context struct { 10 | conn *Conn 11 | seq uint32 12 | buf []byte 13 | } 14 | 15 | func (c *Context) Conn() *Conn { return c.conn } 16 | func (c *Context) Body() []byte { return c.buf } 17 | func (c *Context) Reply(buf []byte) error { return c.conn.send(c.seq, buf) } 18 | 19 | var contextPool sync.Pool 20 | 21 | func acquireContext(conn *Conn, seq uint32, buf []byte) *Context { 22 | v := contextPool.Get() 23 | if v == nil { 24 | v = &Context{} 25 | } 26 | ctx := v.(*Context) 27 | ctx.conn = conn 28 | ctx.seq = seq 29 | ctx.buf = buf 30 | return ctx 31 | } 32 | 33 | func releaseContext(ctx *Context) { contextPool.Put(ctx) } 34 | 35 | type pendingWrite struct { 36 | buf *bytebufferpool.ByteBuffer // payload 37 | wait bool // signal to caller if they're waiting 38 | err error // keeps track of any socket errors on write 39 | wg sync.WaitGroup // signals the caller that this write is complete 40 | } 41 | 42 | var pendingWritePool sync.Pool 43 | 44 | func acquirePendingWrite(buf *bytebufferpool.ByteBuffer, wait bool) *pendingWrite { 45 | v := pendingWritePool.Get() 46 | if v == nil { 47 | v = &pendingWrite{} 48 | } 49 | pw := v.(*pendingWrite) 50 | pw.buf = buf 51 | pw.wait = wait 52 | return pw 53 | } 54 | 55 | func releasePendingWrite(pw *pendingWrite) { pw.err = nil; pendingWritePool.Put(pw) } 56 | 57 | type pendingRequest struct { 58 | dst []byte // dst to copy response to 59 | err error // error while waiting for response 60 | wg sync.WaitGroup // signals the caller that the response has been received 61 | } 62 | 63 | var pendingRequestPool sync.Pool 64 | 65 | func acquirePendingRequest(dst []byte) *pendingRequest { 66 | v := pendingRequestPool.Get() 67 | if v == nil { 68 | v = &pendingRequest{} 69 | } 70 | pr := v.(*pendingRequest) 71 | pr.dst = dst 72 | return pr 73 | } 74 | 75 | func releasePendingRequest(pr *pendingRequest) { 76 | pr.dst = nil 77 | pr.err = nil 78 | pendingRequestPool.Put(pr) 79 | } 80 | 81 | var zeroTime time.Time 82 | 83 | var timerPool sync.Pool 84 | 85 | func AcquireTimer(timeout time.Duration) *time.Timer { 86 | v := timerPool.Get() 87 | if v == nil { 88 | return time.NewTimer(timeout) 89 | } 90 | t := v.(*time.Timer) 91 | t.Reset(timeout) 92 | return t 93 | } 94 | 95 | func ReleaseTimer(t *time.Timer) { 96 | if !t.Stop() { 97 | select { 98 | case <-t.C: 99 | default: 100 | } 101 | } 102 | timerPool.Put(t) 103 | } 104 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | var DefaultMaxServerConns = 1024 12 | 13 | var DefaultHandshakeTimeout = 3 * time.Second 14 | var DefaultMaxConnWaitTimeout = 3 * time.Second 15 | 16 | var DefaultServerSeqOffset uint32 = 2 17 | var DefaultServerSeqDelta uint32 = 2 18 | 19 | type Server struct { 20 | Handler Handler 21 | ConnState ConnStateHandler 22 | 23 | Handshaker Handshaker 24 | HandshakeTimeout time.Duration 25 | 26 | MaxConns int 27 | MaxConnWaitTimeout time.Duration 28 | 29 | ReadBufferSize int 30 | WriteBufferSize int 31 | 32 | ReadTimeout time.Duration 33 | WriteTimeout time.Duration 34 | 35 | SeqOffset uint32 36 | SeqDelta uint32 37 | 38 | once sync.Once 39 | mu sync.Mutex 40 | wg sync.WaitGroup 41 | 42 | sem chan struct{} 43 | done chan struct{} 44 | } 45 | 46 | func (s *Server) init() { 47 | s.sem = make(chan struct{}, s.getMaxConns()) 48 | s.done = make(chan struct{}) 49 | } 50 | 51 | func (s *Server) getHandler() Handler { 52 | if s.Handler == nil { 53 | return DefaultHandler 54 | } 55 | return s.Handler 56 | } 57 | 58 | func (s *Server) getConnStateHandler() ConnStateHandler { 59 | if s.ConnState == nil { 60 | return DefaultConnStateHandler 61 | } 62 | return s.ConnState 63 | } 64 | 65 | func (s *Server) getHandshaker() Handshaker { 66 | if s.Handshaker == nil { 67 | return DefaultServerHandshaker 68 | } 69 | return s.Handshaker 70 | } 71 | 72 | func (s *Server) getHandshakeTimeout() time.Duration { 73 | if s.HandshakeTimeout < 0 { 74 | return DefaultHandshakeTimeout 75 | } 76 | return s.HandshakeTimeout 77 | } 78 | 79 | func (s *Server) getMaxConns() int { 80 | if s.MaxConns <= 0 { 81 | return DefaultMaxServerConns 82 | } 83 | return s.MaxConns 84 | } 85 | 86 | func (s *Server) getMaxConnWaitTimeout() time.Duration { 87 | if s.MaxConnWaitTimeout <= 0 { 88 | return DefaultMaxConnWaitTimeout 89 | } 90 | return s.MaxConnWaitTimeout 91 | } 92 | 93 | func (s *Server) getReadTimeout() time.Duration { 94 | if s.ReadTimeout < 0 { 95 | return DefaultReadTimeout 96 | } 97 | return s.ReadTimeout 98 | } 99 | 100 | func (s *Server) getWriteTimeout() time.Duration { 101 | if s.WriteTimeout < 0 { 102 | return DefaultWriteTimeout 103 | } 104 | return s.WriteTimeout 105 | } 106 | 107 | func (s *Server) getReadBufferSize() int { 108 | if s.ReadBufferSize <= 0 { 109 | return DefaultReadBufferSize 110 | } 111 | return s.ReadBufferSize 112 | } 113 | 114 | func (s *Server) getWriteBufferSize() int { 115 | if s.WriteBufferSize <= 0 { 116 | return DefaultWriteBufferSize 117 | } 118 | return s.WriteBufferSize 119 | } 120 | 121 | func (s *Server) getSeqOffset() uint32 { 122 | if s.SeqOffset == 0 { 123 | return DefaultServerSeqOffset 124 | } 125 | return s.SeqOffset 126 | } 127 | 128 | func (s *Server) getSeqDelta() uint32 { 129 | if s.SeqDelta == 0 { 130 | return DefaultServerSeqDelta 131 | } 132 | return s.SeqDelta 133 | } 134 | 135 | func (s *Server) serverAvailable() bool { 136 | select { 137 | case <-s.done: 138 | return false 139 | case s.sem <- struct{}{}: 140 | return true 141 | default: 142 | timer := AcquireTimer(s.getMaxConnWaitTimeout()) 143 | defer ReleaseTimer(timer) 144 | 145 | select { 146 | case <-timer.C: 147 | return false 148 | case <-s.done: 149 | return false 150 | case s.sem <- struct{}{}: 151 | return true 152 | } 153 | } 154 | } 155 | 156 | func (s *Server) wait(duration time.Duration) bool { 157 | timer := AcquireTimer(duration) 158 | defer ReleaseTimer(timer) 159 | 160 | select { 161 | case <-timer.C: 162 | return true 163 | case <-s.done: 164 | return false 165 | } 166 | } 167 | 168 | func (s *Server) client(conn net.Conn) error { 169 | defer func() { <-s.sem }() 170 | 171 | timeout := s.getHandshakeTimeout() 172 | 173 | if timeout != 0 { 174 | err := conn.SetDeadline(time.Now().Add(timeout)) 175 | if err != nil { 176 | return err 177 | } 178 | } 179 | 180 | bufConn, err := s.getHandshaker().Handshake(conn) 181 | if err != nil { 182 | return err 183 | } 184 | 185 | if timeout != 0 { 186 | err = conn.SetDeadline(zeroTime) 187 | if err != nil { 188 | return err 189 | } 190 | } 191 | 192 | cc := &Conn{ 193 | SeqOffset: s.getSeqOffset(), 194 | SeqDelta: s.getSeqDelta(), 195 | Handler: s.getHandler(), 196 | ReadBufferSize: s.getReadBufferSize(), 197 | WriteBufferSize: s.getWriteBufferSize(), 198 | ReadTimeout: s.getReadTimeout(), 199 | WriteTimeout: s.getWriteTimeout(), 200 | } 201 | 202 | s.getConnStateHandler().HandleConnState(cc, StateNew) 203 | 204 | cc.close(cc.Handle(s.done, bufConn)) 205 | 206 | s.getConnStateHandler().HandleConnState(cc, StateClosed) 207 | 208 | return nil 209 | } 210 | 211 | func (s *Server) Serve(ln net.Listener) error { 212 | s.once.Do(s.init) 213 | 214 | for { 215 | conn, err := ln.Accept() 216 | if err != nil { 217 | if errors.Is(err, io.EOF) { 218 | return nil 219 | } 220 | var netErr *net.OpError 221 | if !errors.As(err, &netErr) { 222 | return err 223 | } 224 | if netErr.Err.Error() == "use of closed network connection" { 225 | return nil 226 | } 227 | if !netErr.Temporary() { 228 | return err 229 | } 230 | ok := s.wait(100 * time.Millisecond) 231 | if !ok { 232 | return nil 233 | } 234 | continue 235 | } 236 | 237 | if !s.serverAvailable() { 238 | conn.Close() 239 | continue 240 | } 241 | 242 | s.wg.Add(1) 243 | 244 | go func() { 245 | defer s.wg.Done() 246 | s.client(conn) 247 | conn.Close() 248 | }() 249 | } 250 | } 251 | 252 | func (s *Server) Shutdown() { 253 | s.once.Do(s.init) 254 | 255 | close(s.done) 256 | s.wg.Wait() 257 | } 258 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | "go.uber.org/goleak" 6 | "net" 7 | "testing" 8 | ) 9 | 10 | func TestServerShutdown(t *testing.T) { 11 | defer goleak.VerifyNone(t) 12 | 13 | srv := &Server{} 14 | 15 | ln, err := net.Listen("tcp", ":0") 16 | require.NoError(t, err) 17 | 18 | go func() { 19 | srv.Shutdown() 20 | ln.Close() 21 | }() 22 | 23 | require.NoError(t, srv.Serve(ln)) 24 | } 25 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "bufio" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "github.com/lithdew/bytesutil" 11 | "github.com/oasislabs/ed25519" 12 | "github.com/oasislabs/ed25519/extra/x25519" 13 | "golang.org/x/crypto/blake2b" 14 | "net" 15 | "time" 16 | ) 17 | 18 | var _ BufferedConn = (*SessionConn)(nil) 19 | 20 | // SessionConn is not safe for concurrent use. It decrypts on reads and encrypts on writes 21 | // via a provided cipher.AEAD suite for a given conn that implements net.Conn. It assumes 22 | // all packets sent/received are to be prefixed with a 32-bit unsigned integer that 23 | // designates the length of each individual packet. 24 | // 25 | // The same cipher.AEAD suite must not be used for multiple SessionConn instances. Doing 26 | // so will cause for plaintext data to be leaked. 27 | type SessionConn struct { 28 | suite cipher.AEAD 29 | conn net.Conn 30 | 31 | bw *bufio.Writer 32 | br *bufio.Reader 33 | 34 | rb []byte // read buffer 35 | wb []byte // write buffer 36 | wn uint64 // write nonce 37 | rn uint64 // read nonce 38 | } 39 | 40 | func NewSessionConn(suite cipher.AEAD, conn net.Conn) *SessionConn { 41 | return &SessionConn{ 42 | suite: suite, 43 | conn: conn, 44 | 45 | bw: bufio.NewWriter(conn), 46 | br: bufio.NewReader(conn), 47 | } 48 | } 49 | 50 | func (s *SessionConn) Read(b []byte) (int, error) { 51 | var err error 52 | s.rb, err = ReadSized(s.rb[:0], s.br, cap(b)) 53 | if err != nil { 54 | return 0, err 55 | } 56 | 57 | s.rb = bytesutil.ExtendSlice(s.rb, len(s.rb)+s.suite.NonceSize()) 58 | for i := len(s.rb) - s.suite.NonceSize(); i < len(s.rb); i++ { 59 | s.rb[i] = 0 60 | } 61 | binary.BigEndian.PutUint64(s.rb[len(s.rb)-s.suite.NonceSize():], s.rn) 62 | s.rn++ 63 | 64 | s.rb, err = s.suite.Open( 65 | s.rb[:0], 66 | s.rb[len(s.rb)-s.suite.NonceSize():], 67 | s.rb[:len(s.rb)-s.suite.NonceSize()], 68 | nil, 69 | ) 70 | if err != nil { 71 | return 0, err 72 | } 73 | return copy(b, s.rb), err 74 | } 75 | 76 | func (s *SessionConn) Write(b []byte) (int, error) { 77 | s.wb = bytesutil.ExtendSlice(s.wb, s.suite.NonceSize()+len(b)+s.suite.Overhead()) 78 | binary.BigEndian.PutUint64(s.wb[:8], s.wn) 79 | for i := 8; i < s.suite.NonceSize(); i++ { 80 | s.wb[i] = 0 81 | } 82 | s.wn++ 83 | 84 | s.wb = s.suite.Seal( 85 | s.wb[s.suite.NonceSize():s.suite.NonceSize()], 86 | s.wb[:s.suite.NonceSize()], 87 | b, 88 | nil, 89 | ) 90 | 91 | err := WriteSized(s.bw, s.wb) 92 | if err != nil { 93 | return 0, err 94 | } 95 | 96 | return len(s.wb), nil 97 | } 98 | 99 | func (s *SessionConn) Flush() error { return s.bw.Flush() } 100 | 101 | func (s *SessionConn) Close() error { return s.conn.Close() } 102 | func (s *SessionConn) LocalAddr() net.Addr { return s.conn.LocalAddr() } 103 | func (s *SessionConn) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } 104 | func (s *SessionConn) SetDeadline(t time.Time) error { return s.conn.SetDeadline(t) } 105 | func (s *SessionConn) SetReadDeadline(t time.Time) error { return s.conn.SetReadDeadline(t) } 106 | func (s *SessionConn) SetWriteDeadline(t time.Time) error { return s.conn.SetWriteDeadline(t) } 107 | 108 | // Session is not safe for concurrent use. 109 | type Session struct { 110 | suite cipher.AEAD 111 | theirPub []byte 112 | sharedKey []byte 113 | } 114 | 115 | func (s *Session) Suite() cipher.AEAD { 116 | return s.suite 117 | } 118 | 119 | func (s *Session) SharedKey() []byte { 120 | return s.sharedKey 121 | } 122 | 123 | func (s *Session) GenerateEphemeralKeys() ([]byte, []byte, error) { 124 | publicKey, privateKey, err := ed25519.GenerateKey(nil) 125 | if err != nil { 126 | return nil, nil, err 127 | } 128 | 129 | ourPub, ok := x25519.EdPublicKeyToX25519(publicKey) 130 | if !ok { 131 | return nil, nil, errors.New("unable to derive ed25519 key to x25519 key") 132 | } 133 | 134 | ourPriv := x25519.EdPrivateKeyToX25519(privateKey) 135 | 136 | return ourPub, ourPriv, nil 137 | } 138 | 139 | func (s *Session) DoClient(conn net.Conn) error { 140 | ourPub, ourPriv, err := s.GenerateEphemeralKeys() 141 | if err != nil { 142 | return err 143 | } 144 | err = s.Write(conn, ourPub) 145 | if err == nil { 146 | err = s.Read(conn) 147 | } 148 | if err == nil { 149 | err = s.Establish(ourPriv) 150 | } 151 | return err 152 | } 153 | 154 | func (s *Session) DoServer(conn net.Conn) error { 155 | ourPub, ourPriv, err := s.GenerateEphemeralKeys() 156 | if err != nil { 157 | return err 158 | } 159 | err = s.Read(conn) 160 | if err == nil { 161 | err = s.Write(conn, ourPub) 162 | } 163 | if err == nil { 164 | err = s.Establish(ourPriv) 165 | } 166 | return err 167 | } 168 | 169 | func (s *Session) Write(conn net.Conn, ourPub []byte) error { 170 | err := Write(conn, ourPub) 171 | if err != nil { 172 | return fmt.Errorf("failed to write session public key: %w", err) 173 | } 174 | return nil 175 | } 176 | 177 | func (s *Session) Read(conn net.Conn) error { 178 | publicKey, err := Read(make([]byte, x25519.PointSize), conn) 179 | if err != nil { 180 | return fmt.Errorf("failed to read peer session public key: %w", err) 181 | } 182 | s.theirPub = publicKey 183 | return nil 184 | } 185 | 186 | func (s *Session) Establish(ourPriv []byte) error { 187 | if s.theirPub == nil { 188 | return errors.New("did not read peer session public key yet") 189 | } 190 | sharedKey, err := x25519.X25519(ourPriv, s.theirPub) 191 | if err != nil { 192 | return fmt.Errorf("failed to derive shared session key: %w", err) 193 | } 194 | derivedKey := blake2b.Sum256(sharedKey) 195 | block, err := aes.NewCipher(derivedKey[:]) 196 | if err != nil { 197 | return fmt.Errorf("failed to init aes cipher: %w", err) 198 | } 199 | suite, err := cipher.NewGCM(block) 200 | if err != nil { 201 | return fmt.Errorf("failed to init aead suite: %w", err) 202 | } 203 | s.sharedKey = derivedKey[:] 204 | s.suite = suite 205 | return nil 206 | } 207 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package monte 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | "go.uber.org/goleak" 6 | "net" 7 | "strconv" 8 | "sync" 9 | "testing" 10 | ) 11 | 12 | func TestSessionConn(t *testing.T) { 13 | defer goleak.VerifyNone(t) 14 | 15 | alice, bob := net.Pipe() 16 | defer func() { 17 | require.NoError(t, alice.Close()) 18 | require.NoError(t, bob.Close()) 19 | }() 20 | 21 | var a Session 22 | var b Session 23 | 24 | var wg sync.WaitGroup 25 | wg.Add(2) 26 | 27 | go func() { 28 | defer wg.Done() 29 | require.NoError(t, a.DoClient(alice)) 30 | }() 31 | 32 | go func() { 33 | defer wg.Done() 34 | require.NoError(t, b.DoServer(bob)) 35 | }() 36 | 37 | wg.Wait() 38 | 39 | aliceConn := NewSessionConn(a.Suite(), alice) 40 | bobConn := NewSessionConn(b.Suite(), bob) 41 | 42 | trials := 1024 43 | 44 | go func() { 45 | for i := 0; i < trials; i++ { 46 | _, err := aliceConn.Write(strconv.AppendUint(nil, uint64(i), 10)) 47 | require.NoError(t, err) 48 | } 49 | require.NoError(t, aliceConn.Flush()) 50 | }() 51 | 52 | buf := make([]byte, 1024) 53 | 54 | for i := 0; i < trials; i++ { 55 | n, err := bobConn.Read(buf) 56 | require.NoError(t, err) 57 | require.EqualValues(t, strconv.AppendUint(nil, uint64(i), 10), buf[:n]) 58 | } 59 | } 60 | 61 | func TestSession(t *testing.T) { 62 | defer goleak.VerifyNone(t) 63 | 64 | var aliceSession Session 65 | var bobSession Session 66 | 67 | bob, err := net.Listen("tcp", ":0") 68 | require.NoError(t, err) 69 | 70 | ch := make(chan []byte, 1) 71 | go func() { 72 | alice, err := net.Dial("tcp", bob.Addr().String()) 73 | require.NoError(t, err) 74 | 75 | require.NoError(t, aliceSession.DoClient(alice)) 76 | require.NotNil(t, aliceSession.SharedKey()) 77 | require.NotNil(t, aliceSession.Suite()) 78 | 79 | require.NoError(t, alice.Close()) 80 | 81 | ch <- aliceSession.SharedKey() 82 | close(ch) 83 | }() 84 | 85 | conn, err := bob.Accept() 86 | require.NoError(t, err) 87 | 88 | require.NoError(t, bobSession.DoServer(conn)) 89 | require.NotNil(t, bobSession.SharedKey()) 90 | require.NotNil(t, bobSession.Suite()) 91 | 92 | require.EqualValues(t, bobSession.SharedKey(), <-ch) 93 | 94 | require.NoError(t, conn.Close()) 95 | require.NoError(t, bob.Close()) 96 | } 97 | --------------------------------------------------------------------------------