├── .travis.yml ├── LICENSE ├── README.md ├── chan.go ├── chan_test.go ├── fuzz.go ├── fuzz_test.go ├── limit.go ├── mpool ├── pool.go └── pool_test.go ├── msgio.go ├── msgio ├── .gitignore ├── README.md └── msgio.go ├── msgio_test.go ├── num.go ├── varint.go └── varint_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.3 5 | - 1.4 6 | - release 7 | 8 | script: 9 | - go test -race -cpu=5 -v ./... 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Juan Batiz-Benet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-msgio - Message IO 2 | 3 | This is a simple package that helps read and write length-delimited slices. It's helpful for building wire protocols. 4 | 5 | ## Usage 6 | 7 | ### Reading 8 | 9 | ```go 10 | import "github.com/jbenet/msgio" 11 | rdr := ... // some reader from a wire 12 | mrdr := msgio.NewReader(rdr) 13 | 14 | for { 15 | msg, err := mrdr.ReadMsg() 16 | if err != nil { 17 | return err 18 | } 19 | 20 | doSomething(msg) 21 | } 22 | ``` 23 | 24 | ### Writing 25 | 26 | ```go 27 | import "github.com/jbenet/msgio" 28 | wtr := genReader() 29 | mwtr := msgio.NewWriter(wtr) 30 | 31 | for { 32 | msg := genMessage() 33 | err := mwtr.WriteMsg(msg) 34 | if err != nil { 35 | return err 36 | } 37 | } 38 | ``` 39 | 40 | ### Duplex 41 | 42 | ```go 43 | import "github.com/jbenet/msgio" 44 | rw := genReadWriter() 45 | mrw := msgio.NewReadWriter(rw) 46 | 47 | for { 48 | msg, err := mrdr.ReadMsg() 49 | if err != nil { 50 | return err 51 | } 52 | 53 | // echo it back :) 54 | err = mwtr.WriteMsg(msg) 55 | if err != nil { 56 | return err 57 | } 58 | } 59 | ``` 60 | 61 | ### Channels 62 | 63 | ```go 64 | import "github.com/jbenet/msgio" 65 | rw := genReadWriter() 66 | rch := msgio.NewReadChannel(rw) 67 | wch := msgio.NewWriteChannel(rw) 68 | 69 | for { 70 | msg, err := <-rch 71 | if err != nil { 72 | return err 73 | } 74 | 75 | // echo it back :) 76 | wch<- rw 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /chan.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "io" 5 | 6 | mpool "github.com/jbenet/go-msgio/mpool" 7 | ) 8 | 9 | // Chan is a msgio duplex channel. It is used to have a channel interface 10 | // around a msgio.Reader or Writer. 11 | type Chan struct { 12 | MsgChan chan []byte 13 | ErrChan chan error 14 | CloseChan chan bool 15 | } 16 | 17 | // NewChan constructs a Chan with a given buffer size. 18 | func NewChan(chanSize int) *Chan { 19 | return &Chan{ 20 | MsgChan: make(chan []byte, chanSize), 21 | ErrChan: make(chan error, 1), 22 | CloseChan: make(chan bool, 2), 23 | } 24 | } 25 | 26 | // ReadFrom wraps the given io.Reader with a msgio.Reader, reads all 27 | // messages, ands sends them down the channel. 28 | func (s *Chan) ReadFrom(r io.Reader) { 29 | s.readFrom(NewReader(r)) 30 | } 31 | 32 | // ReadFromWithPool wraps the given io.Reader with a msgio.Reader, reads all 33 | // messages, ands sends them down the channel. Uses given Pool 34 | func (s *Chan) ReadFromWithPool(r io.Reader, p *mpool.Pool) { 35 | s.readFrom(NewReaderWithPool(r, p)) 36 | } 37 | 38 | // ReadFrom wraps the given io.Reader with a msgio.Reader, reads all 39 | // messages, ands sends them down the channel. 40 | func (s *Chan) readFrom(mr Reader) { 41 | // single reader, no need for Mutex 42 | mr.(*reader).lock = new(nullLocker) 43 | 44 | Loop: 45 | for { 46 | buf, err := mr.ReadMsg() 47 | if err != nil { 48 | if err == io.EOF { 49 | break Loop // done 50 | } 51 | 52 | // unexpected error. tell the client. 53 | s.ErrChan <- err 54 | break Loop 55 | } 56 | 57 | select { 58 | case <-s.CloseChan: 59 | break Loop // told we're done 60 | case s.MsgChan <- buf: 61 | // ok seems fine. send it away 62 | } 63 | } 64 | 65 | close(s.MsgChan) 66 | // signal we're done 67 | s.CloseChan <- true 68 | } 69 | 70 | // WriteTo wraps the given io.Writer with a msgio.Writer, listens on the 71 | // channel and writes all messages to the writer. 72 | func (s *Chan) WriteTo(w io.Writer) { 73 | // new buffer per message 74 | // if bottleneck, cycle around a set of buffers 75 | mw := NewWriter(w) 76 | 77 | // single writer, no need for Mutex 78 | mw.(*writer).lock = new(nullLocker) 79 | Loop: 80 | for { 81 | select { 82 | case <-s.CloseChan: 83 | break Loop // told we're done 84 | 85 | case msg, ok := <-s.MsgChan: 86 | if !ok { // chan closed 87 | break Loop 88 | } 89 | 90 | if err := mw.WriteMsg(msg); err != nil { 91 | if err != io.EOF { 92 | // unexpected error. tell the client. 93 | s.ErrChan <- err 94 | } 95 | 96 | break Loop 97 | } 98 | } 99 | } 100 | 101 | // signal we're done 102 | s.CloseChan <- true 103 | } 104 | 105 | // Close the Chan 106 | func (s *Chan) Close() { 107 | s.CloseChan <- true 108 | } 109 | 110 | // nullLocker conforms to the sync.Locker interface but does nothing. 111 | type nullLocker struct{} 112 | 113 | func (l *nullLocker) Lock() {} 114 | func (l *nullLocker) Unlock() {} 115 | -------------------------------------------------------------------------------- /chan_test.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "bytes" 5 | randbuf "github.com/jbenet/go-randbuf" 6 | "io" 7 | "math/rand" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestReadChan(t *testing.T) { 13 | buf := bytes.NewBuffer(nil) 14 | writer := NewWriter(buf) 15 | rchan := NewChan(10) 16 | msgs := [1000][]byte{} 17 | 18 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 19 | for i := range msgs { 20 | msgs[i] = randbuf.RandBuf(r, r.Intn(1000)) 21 | err := writer.WriteMsg(msgs[i]) 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | } 26 | 27 | if err := writer.Close(); err != nil { 28 | t.Fatal(err) 29 | } 30 | 31 | go rchan.ReadFrom(buf) 32 | defer rchan.Close() 33 | 34 | Loop: 35 | for i := 0; ; i++ { 36 | select { 37 | case err := <-rchan.ErrChan: 38 | if err != nil { 39 | t.Fatal("unexpected error", err) 40 | } 41 | 42 | case msg2, ok := <-rchan.MsgChan: 43 | if !ok { 44 | if i < len(msg2) { 45 | t.Error("failed to read all messages", len(msgs), i) 46 | } 47 | break Loop 48 | } 49 | 50 | msg1 := msgs[i] 51 | if !bytes.Equal(msg1, msg2) { 52 | t.Fatal("message retrieved not equal\n", msg1, "\n\n", msg2) 53 | } 54 | } 55 | } 56 | } 57 | 58 | func TestWriteChan(t *testing.T) { 59 | buf := bytes.NewBuffer(nil) 60 | reader := NewReader(buf) 61 | wchan := NewChan(10) 62 | msgs := [1000][]byte{} 63 | 64 | go wchan.WriteTo(buf) 65 | 66 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 67 | for i := range msgs { 68 | msgs[i] = randbuf.RandBuf(r, r.Intn(1000)) 69 | 70 | select { 71 | case err := <-wchan.ErrChan: 72 | if err != nil { 73 | t.Fatal("unexpected error", err) 74 | } 75 | 76 | case wchan.MsgChan <- msgs[i]: 77 | } 78 | } 79 | 80 | // tell chan we're done. 81 | close(wchan.MsgChan) 82 | // wait for writing to end 83 | <-wchan.CloseChan 84 | 85 | defer wchan.Close() 86 | 87 | for i := 0; ; i++ { 88 | msg2, err := reader.ReadMsg() 89 | if err != nil { 90 | if err == io.EOF { 91 | if i < len(msg2) { 92 | t.Error("failed to read all messages", len(msgs), i) 93 | } 94 | break 95 | } 96 | t.Error("unexpected error", err) 97 | } 98 | 99 | msg1 := msgs[i] 100 | if !bytes.Equal(msg1, msg2) { 101 | t.Fatal("message retrieved not equal\n", msg1, "\n\n", msg2) 102 | } 103 | } 104 | 105 | if err := reader.Close(); err != nil { 106 | t.Error(err) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /fuzz.go: -------------------------------------------------------------------------------- 1 | // +build gofuzz 2 | 3 | package msgio 4 | 5 | import "bytes" 6 | 7 | // get the go-fuzz tools and build a fuzzer 8 | // $ go get -u github.com/dvyukov/go-fuzz/... 9 | // $ go-fuzz-build github.com/jbenet/go-msgio 10 | 11 | // put a corpus of random (even better if actual, structured) data in a corpus directry 12 | // $ go-fuzz -bin ./msgio-fuzz -corpus corpus -workdir=wdir -timeout=15 13 | 14 | func Fuzz(data []byte) int { 15 | rc := NewReader(bytes.NewReader(data)) 16 | // rc := NewVarintReader(bytes.NewReader(data)) 17 | 18 | if _, err := rc.ReadMsg(); err != nil { 19 | return 0 20 | } 21 | 22 | return 1 23 | } 24 | -------------------------------------------------------------------------------- /fuzz_test.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestReader_CrashOne(t *testing.T) { 9 | rc := NewReader(strings.NewReader("\x83000")) 10 | _, err := rc.ReadMsg() 11 | if err != ErrMsgTooLarge { 12 | t.Error("should get ErrMsgTooLarge") 13 | t.Log(err) 14 | } 15 | } 16 | 17 | func TestVarintReader_CrashOne(t *testing.T) { 18 | rc := NewVarintReader(strings.NewReader("\x9a\xf1\xed\x9a0")) 19 | _, err := rc.ReadMsg() 20 | if err != ErrMsgTooLarge { 21 | t.Error("should get ErrMsgTooLarge") 22 | t.Log(err) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /limit.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "sync" 7 | ) 8 | 9 | // LimitedReader wraps an io.Reader with a msgio framed reader. The LimitedReader 10 | // will return a reader which will io.EOF when the msg length is done. 11 | func LimitedReader(r io.Reader) (io.Reader, error) { 12 | l, err := ReadLen(r, nil) 13 | return io.LimitReader(r, int64(l)), err 14 | } 15 | 16 | // LimitedWriter wraps an io.Writer with a msgio framed writer. It is the inverse 17 | // of LimitedReader: it will buffer all writes until "Flush" is called. When Flush 18 | // is called, it will write the size of the buffer first, flush the buffer, reset 19 | // the buffer, and begin accept more incoming writes. 20 | func NewLimitedWriter(w io.Writer) *LimitedWriter { 21 | return &LimitedWriter{W: w} 22 | } 23 | 24 | type LimitedWriter struct { 25 | W io.Writer 26 | B bytes.Buffer 27 | M sync.Mutex 28 | } 29 | 30 | func (w *LimitedWriter) Write(buf []byte) (n int, err error) { 31 | w.M.Lock() 32 | n, err = w.B.Write(buf) 33 | w.M.Unlock() 34 | return n, err 35 | } 36 | 37 | func (w *LimitedWriter) Flush() error { 38 | w.M.Lock() 39 | defer w.M.Unlock() 40 | if err := WriteLen(w.W, w.B.Len()); err != nil { 41 | return err 42 | } 43 | _, err := w.B.WriteTo(w.W) 44 | return err 45 | } 46 | -------------------------------------------------------------------------------- /mpool/pool.go: -------------------------------------------------------------------------------- 1 | // Package mpool provides a sync.Pool equivalent that buckets incoming 2 | // requests to one of 32 sub-pools, one for each power of 2, 0-32. 3 | // 4 | // import "github.com/jbenet/go-msgio/mpool" 5 | // var p mpool.Pool 6 | // 7 | // small := make([]byte, 1024) 8 | // large := make([]byte, 4194304) 9 | // p.Put(1024, small) 10 | // p.Put(4194304, large) 11 | // 12 | // small2 := p.Get(1024).([]byte) 13 | // large2 := p.Get(4194304).([]byte) 14 | // fmt.Println("small2 len:", len(small2)) 15 | // fmt.Println("large2 len:", len(large2)) 16 | // 17 | // // Output: 18 | // // small2 len: 1024 19 | // // large2 len: 4194304 20 | // 21 | package mpool 22 | 23 | import ( 24 | "fmt" 25 | "sync" 26 | ) 27 | 28 | // ByteSlicePool is a static Pool for reusing byteslices of various sizes. 29 | var ByteSlicePool Pool 30 | 31 | func init() { 32 | ByteSlicePool.New = func(length int) interface{} { 33 | return make([]byte, length) 34 | } 35 | } 36 | 37 | // MaxLength is the maximum length of an element that can be added to the Pool. 38 | const MaxLength = 1 << 32 39 | 40 | // Pool is a pool to handle cases of reusing elements of varying sizes. 41 | // It maintains up to 32 internal pools, for each power of 2 in 0-32. 42 | type Pool struct { 43 | small int // the size of the first pool 44 | pools [32]*sync.Pool // a list of singlePools 45 | sync.Mutex // protecting list 46 | 47 | // New is a function that constructs a new element in the pool, with given len 48 | New func(len int) interface{} 49 | } 50 | 51 | func (p *Pool) getPool(idx uint32) *sync.Pool { 52 | if idx > uint32(len(p.pools)) { 53 | panic(fmt.Errorf("index too large: %d", idx)) 54 | } 55 | 56 | p.Lock() 57 | defer p.Unlock() 58 | 59 | sp := p.pools[idx] 60 | if sp == nil { 61 | sp = new(sync.Pool) 62 | p.pools[idx] = sp 63 | } 64 | return sp 65 | } 66 | 67 | // Get selects an arbitrary item from the Pool, removes it from the Pool, 68 | // and returns it to the caller. Get may choose to ignore the pool and 69 | // treat it as empty. Callers should not assume any relation between values 70 | // passed to Put and the values returned by Get. 71 | // 72 | // If Get would otherwise return nil and p.New is non-nil, Get returns the 73 | // result of calling p.New. 74 | func (p *Pool) Get(length uint32) interface{} { 75 | idx := nextPowerOfTwo(length) 76 | sp := p.getPool(idx) 77 | // fmt.Printf("Get(%d) idx(%d)\n", length, idx) 78 | val := sp.Get() 79 | if val == nil && p.New != nil { 80 | val = p.New(0x1 << idx) 81 | } 82 | return val 83 | } 84 | 85 | // Put adds x to the pool. 86 | func (p *Pool) Put(length uint32, val interface{}) { 87 | idx := prevPowerOfTwo(length) 88 | // fmt.Printf("Put(%d, -) idx(%d)\n", length, idx) 89 | sp := p.getPool(idx) 90 | sp.Put(val) 91 | } 92 | 93 | func nextPowerOfTwo(v uint32) uint32 { 94 | // fmt.Printf("nextPowerOfTwo(%d) ", v) 95 | v-- 96 | v |= v >> 1 97 | v |= v >> 2 98 | v |= v >> 4 99 | v |= v >> 8 100 | v |= v >> 16 101 | v++ 102 | 103 | // fmt.Printf("-> %d", v) 104 | 105 | i := uint32(0) 106 | for i = 0; v > 1; i++ { 107 | v = v >> 1 108 | } 109 | 110 | // fmt.Printf("-> %d\n", i) 111 | return i 112 | } 113 | 114 | func prevPowerOfTwo(num uint32) uint32 { 115 | next := nextPowerOfTwo(num) 116 | // fmt.Printf("prevPowerOfTwo(%d) next: %d", num, next) 117 | switch { 118 | case num == (1 << next): // num is a power of 2 119 | case next == 0: 120 | default: 121 | next = next - 1 // smaller 122 | } 123 | // fmt.Printf(" = %d\n", next) 124 | return next 125 | } 126 | -------------------------------------------------------------------------------- /mpool/pool_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Pool is no-op under race detector, so all these tests do not work. 6 | // +build !race 7 | 8 | package mpool 9 | 10 | import ( 11 | "fmt" 12 | "math/rand" 13 | "runtime" 14 | "runtime/debug" 15 | "sync/atomic" 16 | "testing" 17 | "time" 18 | ) 19 | 20 | func TestPool(t *testing.T) { 21 | // disable GC so we can control when it happens. 22 | defer debug.SetGCPercent(debug.SetGCPercent(-1)) 23 | var p Pool 24 | if p.Get(10) != nil { 25 | t.Fatal("expected empty") 26 | } 27 | p.Put(16, "a") 28 | p.Put(2048, "b") 29 | if g := p.Get(16); g != "a" { 30 | t.Fatalf("got %#v; want a", g) 31 | } 32 | if g := p.Get(2048); g != "b" { 33 | t.Fatalf("got %#v; want b", g) 34 | } 35 | if g := p.Get(16); g != nil { 36 | t.Fatalf("got %#v; want nil", g) 37 | } 38 | if g := p.Get(2048); g != nil { 39 | t.Fatalf("got %#v; want nil", g) 40 | } 41 | if g := p.Get(1); g != nil { 42 | t.Fatalf("got %#v; want nil", g) 43 | } 44 | p.Put(1023, "d") 45 | if g := p.Get(1024); g != nil { 46 | t.Fatalf("got %#v; want nil", g) 47 | } 48 | if g := p.Get(512); g != "d" { 49 | t.Fatalf("got %#v; want d", g) 50 | } 51 | 52 | debug.SetGCPercent(100) // to allow following GC to actually run 53 | runtime.GC() 54 | if g := p.Get(10); g != nil { 55 | t.Fatalf("got %#v; want nil after GC", g) 56 | } 57 | } 58 | 59 | func TestPoolNew(t *testing.T) { 60 | // disable GC so we can control when it happens. 61 | defer debug.SetGCPercent(debug.SetGCPercent(-1)) 62 | 63 | s := [32]int{} 64 | p := Pool{ 65 | New: func(length int) interface{} { 66 | idx := nextPowerOfTwo(uint32(length)) 67 | s[idx]++ 68 | return s[idx] 69 | }, 70 | } 71 | if v := p.Get(1 << 5); v != 1 { 72 | t.Fatalf("got %v; want 1", v) 73 | } 74 | if v := p.Get(1 << 2); v != 1 { 75 | t.Fatalf("got %v; want 1", v) 76 | } 77 | if v := p.Get(1 << 2); v != 2 { 78 | t.Fatalf("got %v; want 2", v) 79 | } 80 | if v := p.Get(1 << 5); v != 2 { 81 | t.Fatalf("got %v; want 2", v) 82 | } 83 | p.Put(1<<2, 42) 84 | p.Put(1<<5, 42) 85 | if v := p.Get(1 << 2); v != 42 { 86 | t.Fatalf("got %v; want 42", v) 87 | } 88 | if v := p.Get(1 << 2); v != 3 { 89 | t.Fatalf("got %v; want 3", v) 90 | } 91 | if v := p.Get(1 << 5); v != 42 { 92 | t.Fatalf("got %v; want 42", v) 93 | } 94 | if v := p.Get(1 << 5); v != 3 { 95 | t.Fatalf("got %v; want 3", v) 96 | } 97 | } 98 | 99 | // Test that Pool does not hold pointers to previously cached 100 | // resources 101 | func TestPoolGC(t *testing.T) { 102 | var p Pool 103 | var fin uint32 104 | const N = 100 105 | for i := 0; i < N; i++ { 106 | v := new(string) 107 | runtime.SetFinalizer(v, func(vv *string) { 108 | atomic.AddUint32(&fin, 1) 109 | }) 110 | p.Put(uint32(i), v) 111 | } 112 | for i := 0; i < N; i++ { 113 | p.Get(uint32(i)) 114 | } 115 | for i := 0; i < 5; i++ { 116 | runtime.GC() 117 | time.Sleep(time.Duration(i*100+10) * time.Millisecond) 118 | // 1 pointer can remain on stack or elsewhere 119 | if atomic.LoadUint32(&fin) >= N-1 { 120 | return 121 | } 122 | } 123 | t.Fatalf("only %v out of %v resources are finalized", 124 | atomic.LoadUint32(&fin), N) 125 | } 126 | 127 | func TestPoolStress(t *testing.T) { 128 | const P = 10 129 | N := int(1e6) 130 | if testing.Short() { 131 | N /= 100 132 | } 133 | var p Pool 134 | done := make(chan bool) 135 | for i := 0; i < P; i++ { 136 | go func() { 137 | var v interface{} = 0 138 | for j := 0; j < N; j++ { 139 | if v == nil { 140 | v = 0 141 | } 142 | p.Put(uint32(j), v) 143 | v = p.Get(uint32(j)) 144 | if v != nil && v.(int) != 0 { 145 | t.Fatalf("expect 0, got %v", v) 146 | } 147 | } 148 | done <- true 149 | }() 150 | } 151 | for i := 0; i < P; i++ { 152 | // fmt.Printf("%d/%d\n", i, P) 153 | <-done 154 | } 155 | } 156 | 157 | func TestPoolStressByteSlicePool(t *testing.T) { 158 | const P = 10 159 | chs := 10 160 | maxSize := uint32(1 << 16) 161 | N := int(1e4) 162 | if testing.Short() { 163 | N /= 100 164 | } 165 | p := ByteSlicePool 166 | done := make(chan bool) 167 | errs := make(chan error) 168 | for i := 0; i < P; i++ { 169 | go func() { 170 | ch := make(chan []byte, chs+1) 171 | 172 | for i := 0; i < chs; i++ { 173 | j := rand.Uint32() % maxSize 174 | ch <- p.Get(j).([]byte) 175 | } 176 | 177 | for j := 0; j < N; j++ { 178 | r := uint32(0) 179 | for i := 0; i < chs; i++ { 180 | v := <-ch 181 | p.Put(uint32(cap(v)), v) 182 | r = rand.Uint32() % maxSize 183 | v = p.Get(r).([]byte) 184 | if uint32(len(v)) < r { 185 | errs <- fmt.Errorf("expect len(v) >= %d, got %d", j, len(v)) 186 | } 187 | ch <- v 188 | } 189 | 190 | if r%1000 == 0 { 191 | runtime.GC() 192 | } 193 | } 194 | done <- true 195 | }() 196 | } 197 | 198 | for i := 0; i < P; { 199 | select { 200 | case <-done: 201 | i++ 202 | // fmt.Printf("%d/%d\n", i, P) 203 | case err := <-errs: 204 | t.Error(err) 205 | } 206 | } 207 | } 208 | 209 | func BenchmarkPool(b *testing.B) { 210 | var p Pool 211 | b.RunParallel(func(pb *testing.PB) { 212 | i := 0 213 | for pb.Next() { 214 | i = i << 1 215 | p.Put(uint32(i), 1) 216 | p.Get(uint32(i)) 217 | } 218 | }) 219 | } 220 | 221 | func BenchmarkPoolOverlflow(b *testing.B) { 222 | var p Pool 223 | b.RunParallel(func(pb *testing.PB) { 224 | for pb.Next() { 225 | for pow := uint32(0); pow < 32; pow++ { 226 | for b := 0; b < 100; b++ { 227 | p.Put(uint32(1< len(msg) { 183 | return 0, io.ErrShortBuffer 184 | } 185 | 186 | _, err = io.ReadFull(s.R, msg[:length]) 187 | s.next = -1 // signal we've consumed this msg 188 | return length, err 189 | } 190 | 191 | func (s *reader) ReadMsg() ([]byte, error) { 192 | s.lock.Lock() 193 | defer s.lock.Unlock() 194 | 195 | length, err := s.nextMsgLen() 196 | if err != nil { 197 | return nil, err 198 | } 199 | 200 | if length > s.max || length < 0 { 201 | return nil, ErrMsgTooLarge 202 | } 203 | 204 | msgb := s.pool.Get(uint32(length)) 205 | if msgb == nil { 206 | return nil, io.ErrShortBuffer 207 | } 208 | msg := msgb.([]byte)[:length] 209 | _, err = io.ReadFull(s.R, msg) 210 | s.next = -1 // signal we've consumed this msg 211 | return msg, err 212 | } 213 | 214 | func (s *reader) ReleaseMsg(msg []byte) { 215 | s.pool.Put(uint32(cap(msg)), msg) 216 | } 217 | 218 | func (s *reader) Close() error { 219 | s.lock.Lock() 220 | defer s.lock.Unlock() 221 | 222 | if c, ok := s.R.(io.Closer); ok { 223 | return c.Close() 224 | } 225 | return nil 226 | } 227 | 228 | // readWriter is the underlying type that implements a ReadWriter. 229 | type readWriter struct { 230 | Reader 231 | Writer 232 | } 233 | 234 | // NewReadWriter wraps an io.ReadWriter with a msgio.ReadWriter. Writing 235 | // and Reading will be appropriately framed. 236 | func NewReadWriter(rw io.ReadWriter) ReadWriteCloser { 237 | return &readWriter{ 238 | Reader: NewReader(rw), 239 | Writer: NewWriter(rw), 240 | } 241 | } 242 | 243 | // Combine wraps a pair of msgio.Writer and msgio.Reader with a msgio.ReadWriter. 244 | func Combine(w Writer, r Reader) ReadWriteCloser { 245 | return &readWriter{Reader: r, Writer: w} 246 | } 247 | 248 | func (rw *readWriter) Close() error { 249 | var errs []error 250 | 251 | if w, ok := rw.Writer.(WriteCloser); ok { 252 | if err := w.Close(); err != nil { 253 | errs = append(errs, err) 254 | } 255 | } 256 | if r, ok := rw.Reader.(ReadCloser); ok { 257 | if err := r.Close(); err != nil { 258 | errs = append(errs, err) 259 | } 260 | } 261 | 262 | if len(errs) > 0 { 263 | return multiErr(errs) 264 | } 265 | return nil 266 | } 267 | 268 | // multiErr is a util to return multiple errors 269 | type multiErr []error 270 | 271 | func (m multiErr) Error() string { 272 | if len(m) == 0 { 273 | return "no errors" 274 | } 275 | 276 | s := "Multiple errors: " 277 | for i, e := range m { 278 | if i != 0 { 279 | s += ", " 280 | } 281 | s += e.Error() 282 | } 283 | return s 284 | } 285 | -------------------------------------------------------------------------------- /msgio/.gitignore: -------------------------------------------------------------------------------- 1 | msgio 2 | -------------------------------------------------------------------------------- /msgio/README.md: -------------------------------------------------------------------------------- 1 | # msgio headers tool 2 | 3 | Conveniently output msgio headers. 4 | 5 | ## Install 6 | 7 | ``` 8 | go get github.com/jbenet/go-msgio/msgio 9 | ``` 10 | 11 | ## Usage 12 | 13 | ``` 14 | > msgio -h 15 | msgio - tool to wrap messages with msgio header 16 | 17 | Usage 18 | msgio header 1020 >header 19 | cat file | msgio wrap >wrapped 20 | 21 | Commands 22 | header output a msgio header of given size 23 | wrap wrap incoming stream with msgio 24 | ``` 25 | -------------------------------------------------------------------------------- /msgio/msgio.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "os" 9 | "strconv" 10 | "strings" 11 | 12 | msgio "github.com/jbenet/go-msgio" 13 | ) 14 | 15 | var Args ArgType 16 | 17 | type ArgType struct { 18 | Command string 19 | Args []string 20 | } 21 | 22 | func (a *ArgType) Arg(i int) string { 23 | n := i + 1 24 | if len(a.Args) < n { 25 | die(fmt.Sprintf("expected %d argument(s)", n)) 26 | } 27 | return a.Args[i] 28 | } 29 | 30 | var usageStr = ` 31 | msgio - tool to wrap messages with msgio header 32 | 33 | Usage 34 | msgio header 1020 >header 35 | cat file | msgio wrap >wrapped 36 | 37 | Commands 38 | header output a msgio header of given size 39 | wrap wrap incoming stream with msgio 40 | ` 41 | 42 | func usage() { 43 | fmt.Println(strings.TrimSpace(usageStr)) 44 | os.Exit(0) 45 | } 46 | 47 | func die(err string) { 48 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 49 | os.Exit(-1) 50 | } 51 | 52 | func main() { 53 | if err := run(); err != nil { 54 | die(err.Error()) 55 | } 56 | } 57 | 58 | func argParse() { 59 | flag.Usage = usage 60 | flag.Parse() 61 | 62 | args := flag.Args() 63 | if l := len(args); l < 1 || l > 2 { 64 | usage() 65 | } 66 | 67 | Args.Command = flag.Args()[0] 68 | Args.Args = flag.Args()[1:] 69 | } 70 | 71 | func run() error { 72 | argParse() 73 | 74 | w := os.Stdout 75 | r := os.Stdin 76 | 77 | switch Args.Command { 78 | case "header": 79 | size, err := strconv.Atoi(Args.Arg(0)) 80 | if err != nil { 81 | return err 82 | } 83 | return header(w, size) 84 | case "wrap": 85 | return wrap(w, r) 86 | default: 87 | usage() 88 | return nil 89 | } 90 | } 91 | 92 | func header(w io.Writer, size int) error { 93 | return msgio.WriteLen(w, size) 94 | } 95 | 96 | func wrap(w io.Writer, r io.Reader) error { 97 | buf, err := ioutil.ReadAll(r) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | if err := msgio.WriteLen(w, len(buf)); err != nil { 103 | return err 104 | } 105 | 106 | _, err = w.Write(buf) 107 | return err 108 | } 109 | -------------------------------------------------------------------------------- /msgio_test.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | randbuf "github.com/jbenet/go-randbuf" 7 | "io" 8 | "math/rand" 9 | "sync" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestReadWrite(t *testing.T) { 15 | buf := bytes.NewBuffer(nil) 16 | writer := NewWriter(buf) 17 | reader := NewReader(buf) 18 | SubtestReadWrite(t, writer, reader) 19 | } 20 | 21 | func TestReadWriteMsg(t *testing.T) { 22 | buf := bytes.NewBuffer(nil) 23 | writer := NewWriter(buf) 24 | reader := NewReader(buf) 25 | SubtestReadWriteMsg(t, writer, reader) 26 | } 27 | 28 | func TestReadWriteMsgSync(t *testing.T) { 29 | buf := bytes.NewBuffer(nil) 30 | writer := NewWriter(buf) 31 | reader := NewReader(buf) 32 | SubtestReadWriteMsgSync(t, writer, reader) 33 | } 34 | 35 | func SubtestReadWrite(t *testing.T, writer WriteCloser, reader ReadCloser) { 36 | msgs := [1000][]byte{} 37 | 38 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 39 | for i := range msgs { 40 | msgs[i] = randbuf.RandBuf(r, r.Intn(1000)) 41 | n, err := writer.Write(msgs[i]) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | if n != len(msgs[i]) { 46 | t.Fatal("wrong length:", n, len(msgs[i])) 47 | } 48 | } 49 | 50 | if err := writer.Close(); err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | for i := 0; ; i++ { 55 | msg2 := make([]byte, 1000) 56 | n, err := reader.Read(msg2) 57 | if err != nil { 58 | if err == io.EOF { 59 | if i < len(msg2) { 60 | t.Error("failed to read all messages", len(msgs), i) 61 | } 62 | break 63 | } 64 | t.Error("unexpected error", err) 65 | } 66 | 67 | msg1 := msgs[i] 68 | msg2 = msg2[:n] 69 | if !bytes.Equal(msg1, msg2) { 70 | t.Fatal("message retrieved not equal\n", msg1, "\n\n", msg2) 71 | } 72 | } 73 | 74 | if err := reader.Close(); err != nil { 75 | t.Error(err) 76 | } 77 | } 78 | 79 | func SubtestReadWriteMsg(t *testing.T, writer WriteCloser, reader ReadCloser) { 80 | msgs := [1000][]byte{} 81 | 82 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 83 | for i := range msgs { 84 | msgs[i] = randbuf.RandBuf(r, r.Intn(1000)) 85 | err := writer.WriteMsg(msgs[i]) 86 | if err != nil { 87 | t.Fatal(err) 88 | } 89 | } 90 | 91 | if err := writer.Close(); err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | for i := 0; ; i++ { 96 | msg2, err := reader.ReadMsg() 97 | if err != nil { 98 | if err == io.EOF { 99 | if i < len(msg2) { 100 | t.Error("failed to read all messages", len(msgs), i) 101 | } 102 | break 103 | } 104 | t.Error("unexpected error", err) 105 | } 106 | 107 | msg1 := msgs[i] 108 | if !bytes.Equal(msg1, msg2) { 109 | t.Fatal("message retrieved not equal\n", msg1, "\n\n", msg2) 110 | } 111 | } 112 | 113 | if err := reader.Close(); err != nil { 114 | t.Error(err) 115 | } 116 | } 117 | 118 | func SubtestReadWriteMsgSync(t *testing.T, writer WriteCloser, reader ReadCloser) { 119 | msgs := [1000][]byte{} 120 | 121 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 122 | for i := range msgs { 123 | msgs[i] = randbuf.RandBuf(r, r.Intn(1000)+4) 124 | NBO.PutUint32(msgs[i][:4], uint32(i)) 125 | } 126 | 127 | var wg1 sync.WaitGroup 128 | var wg2 sync.WaitGroup 129 | 130 | errs := make(chan error, 10000) 131 | for i := range msgs { 132 | wg1.Add(1) 133 | go func(i int) { 134 | defer wg1.Done() 135 | 136 | err := writer.WriteMsg(msgs[i]) 137 | if err != nil { 138 | errs <- err 139 | } 140 | }(i) 141 | } 142 | 143 | wg1.Wait() 144 | if err := writer.Close(); err != nil { 145 | t.Fatal(err) 146 | } 147 | 148 | for i := 0; i < len(msgs)+1; i++ { 149 | wg2.Add(1) 150 | go func(i int) { 151 | defer wg2.Done() 152 | 153 | msg2, err := reader.ReadMsg() 154 | if err != nil { 155 | if err == io.EOF { 156 | if i < len(msg2) { 157 | errs <- fmt.Errorf("failed to read all messages", len(msgs), i) 158 | } 159 | return 160 | } 161 | errs <- fmt.Errorf("unexpected error", err) 162 | } 163 | 164 | mi := NBO.Uint32(msg2[:4]) 165 | msg1 := msgs[mi] 166 | if !bytes.Equal(msg1, msg2) { 167 | errs <- fmt.Errorf("message retrieved not equal\n", msg1, "\n\n", msg2) 168 | } 169 | }(i) 170 | } 171 | 172 | wg2.Wait() 173 | close(errs) 174 | 175 | if err := reader.Close(); err != nil { 176 | t.Error(err) 177 | } 178 | 179 | for e := range errs { 180 | t.Error(e) 181 | } 182 | } 183 | 184 | func TestBadSizes(t *testing.T) { 185 | data := make([]byte, 4) 186 | 187 | // on a 64 bit system, this will fail because its too large 188 | // on a 32 bit system, this will fail because its too small 189 | NBO.PutUint32(data, 4000000000) 190 | buf := bytes.NewReader(data) 191 | read := NewReader(buf) 192 | msg, err := read.ReadMsg() 193 | if err == nil { 194 | t.Fatal(err) 195 | } 196 | _ = msg 197 | } 198 | -------------------------------------------------------------------------------- /num.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | ) 7 | 8 | // NBO is NetworkByteOrder 9 | var NBO = binary.BigEndian 10 | 11 | // WriteLen writes a length to the given writer. 12 | func WriteLen(w io.Writer, l int) error { 13 | ul := uint32(l) 14 | return binary.Write(w, NBO, &ul) 15 | } 16 | 17 | // ReadLen reads a length from the given reader. 18 | // if buf is non-nil, it reuses the buffer. Ex: 19 | // l, err := ReadLen(r, nil) 20 | // _, err := ReadLen(r, buf) 21 | func ReadLen(r io.Reader, buf []byte) (int, error) { 22 | if len(buf) < 4 { 23 | buf = make([]byte, 4) 24 | } 25 | buf = buf[:4] 26 | 27 | if _, err := io.ReadFull(r, buf); err != nil { 28 | return 0, err 29 | } 30 | 31 | n := int(NBO.Uint32(buf)) 32 | return n, nil 33 | } 34 | -------------------------------------------------------------------------------- /varint.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "sync" 7 | 8 | mpool "github.com/jbenet/go-msgio/mpool" 9 | ) 10 | 11 | // varintWriter is the underlying type that implements the Writer interface. 12 | type varintWriter struct { 13 | W io.Writer 14 | 15 | lbuf []byte // for encoding varints 16 | lock sync.Locker // for threadsafe writes 17 | } 18 | 19 | // NewVarintWriter wraps an io.Writer with a varint msgio framed writer. 20 | // The msgio.Writer will write the length prefix of every message written 21 | // as a varint, using https://golang.org/pkg/encoding/binary/#PutUvarint 22 | func NewVarintWriter(w io.Writer) WriteCloser { 23 | return &varintWriter{ 24 | W: w, 25 | lbuf: make([]byte, binary.MaxVarintLen64), 26 | lock: new(sync.Mutex), 27 | } 28 | } 29 | 30 | func (s *varintWriter) Write(msg []byte) (int, error) { 31 | err := s.WriteMsg(msg) 32 | if err != nil { 33 | return 0, err 34 | } 35 | return len(msg), nil 36 | } 37 | 38 | func (s *varintWriter) WriteMsg(msg []byte) error { 39 | s.lock.Lock() 40 | defer s.lock.Unlock() 41 | 42 | length := uint64(len(msg)) 43 | n := binary.PutUvarint(s.lbuf, length) 44 | if _, err := s.W.Write(s.lbuf[:n]); err != nil { 45 | return err 46 | } 47 | _, err := s.W.Write(msg) 48 | return err 49 | } 50 | 51 | func (s *varintWriter) Close() error { 52 | s.lock.Lock() 53 | defer s.lock.Unlock() 54 | 55 | if c, ok := s.W.(io.Closer); ok { 56 | return c.Close() 57 | } 58 | return nil 59 | } 60 | 61 | // varintReader is the underlying type that implements the Reader interface. 62 | type varintReader struct { 63 | R io.Reader 64 | br io.ByteReader // for reading varints. 65 | 66 | lbuf []byte 67 | next int 68 | pool *mpool.Pool 69 | lock sync.Locker 70 | max int // the maximal message size (in bytes) this reader handles 71 | } 72 | 73 | // NewVarintReader wraps an io.Reader with a varint msgio framed reader. 74 | // The msgio.Reader will read whole messages at a time (using the length). 75 | // Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint 76 | // Assumes an equivalent writer on the other side. 77 | func NewVarintReader(r io.Reader) ReadCloser { 78 | return NewVarintReaderWithPool(r, &mpool.ByteSlicePool) 79 | } 80 | 81 | // NewVarintReaderWithPool wraps an io.Reader with a varint msgio framed reader. 82 | // The msgio.Reader will read whole messages at a time (using the length). 83 | // Varints read according to https://golang.org/pkg/encoding/binary/#ReadUvarint 84 | // Assumes an equivalent writer on the other side. It uses a given mpool.Pool 85 | func NewVarintReaderWithPool(r io.Reader, p *mpool.Pool) ReadCloser { 86 | if p == nil { 87 | panic("nil pool") 88 | } 89 | return &varintReader{ 90 | R: r, 91 | br: &simpleByteReader{R: r}, 92 | lbuf: make([]byte, binary.MaxVarintLen64), 93 | next: -1, 94 | pool: p, 95 | lock: new(sync.Mutex), 96 | max: defaultMaxSize, 97 | } 98 | } 99 | 100 | // NextMsgLen reads the length of the next msg into s.lbuf, and returns it. 101 | // WARNING: like Read, NextMsgLen is destructive. It reads from the internal 102 | // reader. 103 | func (s *varintReader) NextMsgLen() (int, error) { 104 | s.lock.Lock() 105 | defer s.lock.Unlock() 106 | return s.nextMsgLen() 107 | } 108 | 109 | func (s *varintReader) nextMsgLen() (int, error) { 110 | if s.next == -1 { 111 | length, err := binary.ReadUvarint(s.br) 112 | if err != nil { 113 | return 0, err 114 | } 115 | s.next = int(length) 116 | } 117 | return s.next, nil 118 | } 119 | 120 | func (s *varintReader) Read(msg []byte) (int, error) { 121 | s.lock.Lock() 122 | defer s.lock.Unlock() 123 | 124 | length, err := s.nextMsgLen() 125 | if err != nil { 126 | return 0, err 127 | } 128 | 129 | if length > len(msg) { 130 | return 0, io.ErrShortBuffer 131 | } 132 | _, err = io.ReadFull(s.R, msg[:length]) 133 | s.next = -1 // signal we've consumed this msg 134 | return length, err 135 | } 136 | 137 | func (s *varintReader) ReadMsg() ([]byte, error) { 138 | s.lock.Lock() 139 | defer s.lock.Unlock() 140 | 141 | length, err := s.nextMsgLen() 142 | if err != nil { 143 | return nil, err 144 | } 145 | 146 | if length > s.max { 147 | return nil, ErrMsgTooLarge 148 | } 149 | 150 | msgb := s.pool.Get(uint32(length)) 151 | if msgb == nil { 152 | return nil, io.ErrShortBuffer 153 | } 154 | msg := msgb.([]byte)[:length] 155 | _, err = io.ReadFull(s.R, msg) 156 | s.next = -1 // signal we've consumed this msg 157 | return msg, err 158 | } 159 | 160 | func (s *varintReader) ReleaseMsg(msg []byte) { 161 | s.pool.Put(uint32(cap(msg)), msg) 162 | } 163 | 164 | func (s *varintReader) Close() error { 165 | s.lock.Lock() 166 | defer s.lock.Unlock() 167 | 168 | if c, ok := s.R.(io.Closer); ok { 169 | return c.Close() 170 | } 171 | return nil 172 | } 173 | 174 | type simpleByteReader struct { 175 | R io.Reader 176 | buf []byte 177 | } 178 | 179 | func (r *simpleByteReader) ReadByte() (c byte, err error) { 180 | if r.buf == nil { 181 | r.buf = make([]byte, 1) 182 | } 183 | 184 | if _, err := io.ReadFull(r.R, r.buf); err != nil { 185 | return 0, err 186 | } 187 | return r.buf[0], nil 188 | } 189 | -------------------------------------------------------------------------------- /varint_test.go: -------------------------------------------------------------------------------- 1 | package msgio 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "testing" 7 | ) 8 | 9 | func TestVarintReadWrite(t *testing.T) { 10 | buf := bytes.NewBuffer(nil) 11 | writer := NewVarintWriter(buf) 12 | reader := NewVarintReader(buf) 13 | SubtestReadWrite(t, writer, reader) 14 | } 15 | 16 | func TestVarintReadWriteMsg(t *testing.T) { 17 | buf := bytes.NewBuffer(nil) 18 | writer := NewVarintWriter(buf) 19 | reader := NewVarintReader(buf) 20 | SubtestReadWriteMsg(t, writer, reader) 21 | } 22 | 23 | func TestVarintReadWriteMsgSync(t *testing.T) { 24 | buf := bytes.NewBuffer(nil) 25 | writer := NewVarintWriter(buf) 26 | reader := NewVarintReader(buf) 27 | SubtestReadWriteMsgSync(t, writer, reader) 28 | } 29 | 30 | func TestVarintWrite(t *testing.T) { 31 | SubtestVarintWrite(t, []byte("hello world")) 32 | SubtestVarintWrite(t, []byte("hello world hello world hello world")) 33 | SubtestVarintWrite(t, make([]byte, 1<<20)) 34 | SubtestVarintWrite(t, []byte("")) 35 | } 36 | 37 | func SubtestVarintWrite(t *testing.T, msg []byte) { 38 | buf := bytes.NewBuffer(nil) 39 | writer := NewVarintWriter(buf) 40 | 41 | if err := writer.WriteMsg(msg); err != nil { 42 | t.Fatal(err) 43 | } 44 | 45 | bb := buf.Bytes() 46 | 47 | sbr := simpleByteReader{R: buf} 48 | length, err := binary.ReadUvarint(&sbr) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | t.Logf("checking varint is %d", len(msg)) 54 | if int(length) != len(msg) { 55 | t.Fatalf("incorrect varint: %d != %d", length, len(msg)) 56 | } 57 | 58 | lbuf := make([]byte, binary.MaxVarintLen64) 59 | n := binary.PutUvarint(lbuf, length) 60 | 61 | bblen := int(length) + n 62 | t.Logf("checking wrote (%d + %d) bytes", length, n) 63 | if len(bb) != bblen { 64 | t.Fatalf("wrote incorrect number of bytes: %d != %d", len(bb), bblen) 65 | } 66 | } 67 | --------------------------------------------------------------------------------