├── LICENSE ├── README.md ├── dag ├── dagctx.go └── dagctx_test.go ├── doc.go ├── frac ├── fracctx.go └── fracctx_test.go └── io ├── ctxio.go └── ctxio_test.go /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Juan Batiz-Benet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### go-context - jbenet's CONText EXTensions 2 | 3 | https://godoc.org/github.com/jbenet/go-context 4 | 5 | - `WithDeadlineFraction`: https://godoc.org/github.com/jbenet/go-context/ext#WithDeadlineFraction 6 | - `WithParents`: https://godoc.org/github.com/jbenet/go-context/ext#WithParents 7 | - `io.{Reader, Writer}`: https://godoc.org/github.com/jbenet/go-context/io 8 | 9 | -------------------------------------------------------------------------------- /dag/dagctx.go: -------------------------------------------------------------------------------- 1 | package ctxext 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | context "golang.org/x/net/context" 8 | ) 9 | 10 | // WithParents returns a Context that listens to all given 11 | // parents. It effectively transforms the Context Tree into 12 | // a Directed Acyclic Graph. This is useful when a context 13 | // may be cancelled for more than one reason. For example, 14 | // consider a database with the following Get function: 15 | // 16 | // func (db *DB) Get(ctx context.Context, ...) {} 17 | // 18 | // DB.Get may have to stop for two different contexts: 19 | // * the caller's context (caller might cancel) 20 | // * the database's context (might be shut down mid-request) 21 | // 22 | // WithParents saves the day by allowing us to "merge" contexts 23 | // and continue on our merry contextual way: 24 | // 25 | // ctx = ctxext.WithParents(ctx, db.ctx) 26 | // 27 | // Passing related (mutually derived) contexts to WithParents is 28 | // actually ok. The child is cancelled when any of its parents is 29 | // cancelled, so if any of its parents are also related, the cancel 30 | // propagation will reach the child via the shortest path. 31 | func WithParents(ctxts ...context.Context) context.Context { 32 | if len(ctxts) < 1 { 33 | panic("no contexts provided") 34 | } 35 | 36 | ctx := &errCtx{ 37 | done: make(chan struct{}), 38 | dead: earliestDeadline(ctxts), 39 | } 40 | 41 | // listen to all contexts and use the first. 42 | for _, c2 := range ctxts { 43 | go func(pctx context.Context) { 44 | select { 45 | case <-ctx.Done(): // cancelled by another parent 46 | return 47 | case <-pctx.Done(): // this parent cancelled 48 | // race: two parents may have cancelled at the same time. 49 | // break tie with mutex (inside c.cancel) 50 | ctx.cancel(pctx.Err()) 51 | } 52 | }(c2) 53 | } 54 | 55 | return ctx 56 | } 57 | 58 | func earliestDeadline(ctxts []context.Context) *time.Time { 59 | var d1 *time.Time 60 | for _, c := range ctxts { 61 | if c == nil { 62 | panic("given nil context.Context") 63 | } 64 | 65 | // use earliest deadline. 66 | d2, ok := c.Deadline() 67 | if !ok { 68 | continue 69 | } 70 | 71 | if d1 == nil || (*d1).After(d2) { 72 | d1 = &d2 73 | } 74 | } 75 | return d1 76 | } 77 | 78 | type errCtx struct { 79 | dead *time.Time 80 | done chan struct{} 81 | err error 82 | mu sync.RWMutex 83 | } 84 | 85 | func (c *errCtx) cancel(err error) { 86 | c.mu.Lock() 87 | defer c.mu.Unlock() 88 | 89 | select { 90 | case <-c.Done(): 91 | return 92 | default: 93 | } 94 | 95 | c.err = err 96 | close(c.done) // signal done to all 97 | } 98 | 99 | func (c *errCtx) Done() <-chan struct{} { 100 | return c.done 101 | } 102 | 103 | func (c *errCtx) Err() error { 104 | c.mu.Lock() 105 | defer c.mu.Unlock() 106 | return c.err 107 | } 108 | 109 | func (c *errCtx) Value(key interface{}) interface{} { 110 | return nil 111 | } 112 | 113 | func (c *errCtx) Deadline() (deadline time.Time, ok bool) { 114 | if c.dead == nil { 115 | return 116 | } 117 | 118 | return *c.dead, true 119 | } 120 | -------------------------------------------------------------------------------- /dag/dagctx_test.go: -------------------------------------------------------------------------------- 1 | package ctxext 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | "time" 7 | 8 | context "golang.org/x/net/context" 9 | ) 10 | 11 | func TestWithParentsSingle(t *testing.T) { 12 | ctx1, cancel := context.WithCancel(context.Background()) 13 | ctx2 := WithParents(ctx1) 14 | 15 | select { 16 | case <-ctx2.Done(): 17 | t.Fatal("ended too early") 18 | case <-time.After(time.Millisecond): 19 | } 20 | 21 | cancel() 22 | 23 | select { 24 | case <-ctx2.Done(): 25 | case <-time.After(time.Millisecond): 26 | t.Error("should've cancelled it") 27 | } 28 | 29 | if ctx2.Err() != ctx1.Err() { 30 | t.Error("errors should match") 31 | } 32 | } 33 | 34 | func TestWithParentsDeadline(t *testing.T) { 35 | ctx1, _ := context.WithCancel(context.Background()) 36 | ctx2, _ := context.WithTimeout(context.Background(), time.Second) 37 | ctx3, _ := context.WithTimeout(context.Background(), time.Second*2) 38 | 39 | ctx := WithParents(ctx1) 40 | d, ok := ctx.Deadline() 41 | if ok { 42 | t.Error("ctx should have no deadline") 43 | } 44 | 45 | ctx = WithParents(ctx1, ctx2, ctx3) 46 | d, ok = ctx.Deadline() 47 | d2, ok2 := ctx2.Deadline() 48 | if !ok { 49 | t.Error("ctx should have deadline") 50 | } else if !ok2 { 51 | t.Error("ctx2 should have deadline") 52 | } else if !d.Equal(d2) { 53 | t.Error("ctx should have ctx2 deadline") 54 | } 55 | } 56 | 57 | func SubtestWithParentsMany(t *testing.T, n int) { 58 | 59 | ctxs := make([]context.Context, n) 60 | cancels := make([]context.CancelFunc, n) 61 | for i := 0; i < n; i++ { 62 | if i == 0 { // first must be new. 63 | ctxs[i], cancels[i] = context.WithCancel(context.Background()) 64 | continue 65 | } 66 | 67 | r := rand.Intn(i) // select a previous context 68 | switch rand.Intn(6) { 69 | case 0: // same as old 70 | ctxs[i], cancels[i] = ctxs[r], cancels[r] 71 | case 1: // derive from another 72 | ctxs[i], cancels[i] = context.WithCancel(ctxs[r]) 73 | case 2: // deadline 74 | t := (time.Second) * time.Duration(r+2) // +2 so we dont run into 0 or timing bugs 75 | ctxs[i], cancels[i] = context.WithTimeout(ctxs[r], t) 76 | default: // new context 77 | ctxs[i], cancels[i] = context.WithCancel(context.Background()) 78 | } 79 | } 80 | 81 | ctx := WithParents(ctxs...) 82 | 83 | // test deadline is earliest. 84 | d1 := earliestDeadline(ctxs) 85 | d2, ok := ctx.Deadline() 86 | switch { 87 | case d1 == nil && ok: 88 | t.Error("nil, should not have deadline") 89 | case d1 != nil && !ok: 90 | t.Error("not nil, should have deadline") 91 | case d1 != nil && ok && !d1.Equal(d2): 92 | t.Error("should find same deadline") 93 | } 94 | if ok { 95 | t.Logf("deadline - now: %s", d2.Sub(time.Now())) 96 | } 97 | 98 | select { 99 | case <-ctx.Done(): 100 | t.Fatal("ended too early") 101 | case <-time.After(time.Millisecond): 102 | } 103 | 104 | // cancel just one 105 | r := rand.Intn(len(cancels)) 106 | cancels[r]() 107 | 108 | select { 109 | case <-ctx.Done(): 110 | case <-time.After(time.Millisecond): 111 | t.Error("any should've cancelled it") 112 | } 113 | 114 | if ctx.Err() != ctxs[r].Err() { 115 | t.Error("errors should match") 116 | } 117 | } 118 | 119 | func TestWithParentsMany(t *testing.T) { 120 | n := 100 121 | for i := 1; i < n; i++ { 122 | SubtestWithParentsMany(t, i) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package context contains some extenstions to go.net/context by @jbenet 2 | package context 3 | -------------------------------------------------------------------------------- /frac/fracctx.go: -------------------------------------------------------------------------------- 1 | // Package ctxext provides multiple useful context constructors. 2 | package ctxext 3 | 4 | import ( 5 | "time" 6 | 7 | context "golang.org/x/net/context" 8 | ) 9 | 10 | // WithDeadlineFraction returns a Context with a fraction of the 11 | // original context's timeout. This is useful in sequential pipelines 12 | // of work, where one might try options and fall back to others 13 | // depending on the time available, or failure to respond. For example: 14 | // 15 | // // getPicture returns a picture from our encrypted database 16 | // // we have a pipeline of multiple steps. we need to: 17 | // // - get the data from a database 18 | // // - decrypt it 19 | // // - apply many transforms 20 | // // 21 | // // we **know** that each step takes increasingly more time. 22 | // // The transforms are much more expensive than decryption, and 23 | // // decryption is more expensive than the database lookup. 24 | // // If our database takes too long (i.e. >0.2 of available time), 25 | // // there's no use in continuing. 26 | // func getPicture(ctx context.Context, key string) ([]byte, error) { 27 | // // fractional timeout contexts to the rescue! 28 | // 29 | // // try the database with 0.2 of remaining time. 30 | // ctx1, _ := ctxext.WithDeadlineFraction(ctx, 0.2) 31 | // val, err := db.Get(ctx1, key) 32 | // if err != nil { 33 | // return nil, err 34 | // } 35 | // 36 | // // try decryption with 0.3 of remaining time. 37 | // ctx2, _ := ctxext.WithDeadlineFraction(ctx, 0.3) 38 | // if val, err = decryptor.Decrypt(ctx2, val); err != nil { 39 | // return nil, err 40 | // } 41 | // 42 | // // try transforms with all remaining time. hopefully it's enough! 43 | // return transformer.Transform(ctx, val) 44 | // } 45 | // 46 | // 47 | func WithDeadlineFraction(ctx context.Context, fraction float64) ( 48 | context.Context, context.CancelFunc) { 49 | 50 | d, found := ctx.Deadline() 51 | if !found { // no deadline 52 | return context.WithCancel(ctx) 53 | } 54 | 55 | left := d.Sub(time.Now()) 56 | if left < 0 { // already passed... 57 | return context.WithCancel(ctx) 58 | } 59 | 60 | left = time.Duration(float64(left) * fraction) 61 | return context.WithTimeout(ctx, left) 62 | } 63 | -------------------------------------------------------------------------------- /frac/fracctx_test.go: -------------------------------------------------------------------------------- 1 | package ctxext 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | context "golang.org/x/net/context" 9 | ) 10 | 11 | // this test is on the context tool itself, not our stuff. it's for sanity on ours. 12 | func TestDeadline(t *testing.T) { 13 | if os.Getenv("TRAVIS") == "true" { 14 | t.Skip("timeouts don't work reliably on travis") 15 | } 16 | 17 | ctx, _ := context.WithTimeout(context.Background(), 5*time.Millisecond) 18 | 19 | select { 20 | case <-ctx.Done(): 21 | t.Fatal("ended too early") 22 | default: 23 | } 24 | 25 | <-time.After(6 * time.Millisecond) 26 | 27 | select { 28 | case <-ctx.Done(): 29 | default: 30 | t.Fatal("ended too late") 31 | } 32 | } 33 | 34 | func TestDeadlineFractionForever(t *testing.T) { 35 | 36 | ctx, _ := WithDeadlineFraction(context.Background(), 0.5) 37 | 38 | _, found := ctx.Deadline() 39 | if found { 40 | t.Fatal("should last forever") 41 | } 42 | } 43 | 44 | func TestDeadlineFractionHalf(t *testing.T) { 45 | if os.Getenv("TRAVIS") == "true" { 46 | t.Skip("timeouts don't work reliably on travis") 47 | } 48 | 49 | ctx1, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) 50 | ctx2, _ := WithDeadlineFraction(ctx1, 0.5) 51 | 52 | select { 53 | case <-ctx1.Done(): 54 | t.Fatal("ctx1 ended too early") 55 | case <-ctx2.Done(): 56 | t.Fatal("ctx2 ended too early") 57 | default: 58 | } 59 | 60 | <-time.After(2 * time.Millisecond) 61 | 62 | select { 63 | case <-ctx1.Done(): 64 | t.Fatal("ctx1 ended too early") 65 | case <-ctx2.Done(): 66 | t.Fatal("ctx2 ended too early") 67 | default: 68 | } 69 | 70 | <-time.After(4 * time.Millisecond) 71 | 72 | select { 73 | case <-ctx1.Done(): 74 | t.Fatal("ctx1 ended too early") 75 | case <-ctx2.Done(): 76 | default: 77 | t.Fatal("ctx2 ended too late") 78 | } 79 | 80 | <-time.After(6 * time.Millisecond) 81 | 82 | select { 83 | case <-ctx1.Done(): 84 | default: 85 | t.Fatal("ctx1 ended too late") 86 | } 87 | 88 | } 89 | 90 | func TestDeadlineFractionCancel(t *testing.T) { 91 | 92 | ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Millisecond) 93 | ctx2, cancel2 := WithDeadlineFraction(ctx1, 0.5) 94 | 95 | select { 96 | case <-ctx1.Done(): 97 | t.Fatal("ctx1 ended too early") 98 | case <-ctx2.Done(): 99 | t.Fatal("ctx2 ended too early") 100 | default: 101 | } 102 | 103 | cancel2() 104 | 105 | select { 106 | case <-ctx1.Done(): 107 | t.Fatal("ctx1 should NOT be cancelled") 108 | case <-ctx2.Done(): 109 | default: 110 | t.Fatal("ctx2 should be cancelled") 111 | } 112 | 113 | cancel1() 114 | 115 | select { 116 | case <-ctx1.Done(): 117 | case <-ctx2.Done(): 118 | default: 119 | t.Fatal("ctx1 should be cancelled") 120 | } 121 | 122 | } 123 | 124 | func TestDeadlineFractionObeysParent(t *testing.T) { 125 | 126 | ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Millisecond) 127 | ctx2, _ := WithDeadlineFraction(ctx1, 0.5) 128 | 129 | select { 130 | case <-ctx1.Done(): 131 | t.Fatal("ctx1 ended too early") 132 | case <-ctx2.Done(): 133 | t.Fatal("ctx2 ended too early") 134 | default: 135 | } 136 | 137 | cancel1() 138 | 139 | select { 140 | case <-ctx2.Done(): 141 | default: 142 | t.Fatal("ctx2 should be cancelled") 143 | } 144 | 145 | } 146 | -------------------------------------------------------------------------------- /io/ctxio.go: -------------------------------------------------------------------------------- 1 | // Package ctxio provides io.Reader and io.Writer wrappers that 2 | // respect context.Contexts. Use these at the interface between 3 | // your context code and your io. 4 | // 5 | // WARNING: read the code. see how writes and reads will continue 6 | // until you cancel the io. Maybe this package should provide 7 | // versions of io.ReadCloser and io.WriteCloser that automatically 8 | // call .Close when the context expires. But for now -- since in my 9 | // use cases I have long-lived connections with ephemeral io wrappers 10 | // -- this has yet to be a need. 11 | package ctxio 12 | 13 | import ( 14 | "io" 15 | 16 | context "golang.org/x/net/context" 17 | ) 18 | 19 | type ioret struct { 20 | n int 21 | err error 22 | } 23 | 24 | type Writer interface { 25 | io.Writer 26 | } 27 | 28 | type ctxWriter struct { 29 | w io.Writer 30 | ctx context.Context 31 | } 32 | 33 | // NewWriter wraps a writer to make it respect given Context. 34 | // If there is a blocking write, the returned Writer will return 35 | // whenever the context is cancelled (the return values are n=0 36 | // and err=ctx.Err().) 37 | // 38 | // Note well: this wrapper DOES NOT ACTUALLY cancel the underlying 39 | // write-- there is no way to do that with the standard go io 40 | // interface. So the read and write _will_ happen or hang. So, use 41 | // this sparingly, make sure to cancel the read or write as necesary 42 | // (e.g. closing a connection whose context is up, etc.) 43 | // 44 | // Furthermore, in order to protect your memory from being read 45 | // _after_ you've cancelled the context, this io.Writer will 46 | // first make a **copy** of the buffer. 47 | func NewWriter(ctx context.Context, w io.Writer) *ctxWriter { 48 | if ctx == nil { 49 | ctx = context.Background() 50 | } 51 | return &ctxWriter{ctx: ctx, w: w} 52 | } 53 | 54 | func (w *ctxWriter) Write(buf []byte) (int, error) { 55 | buf2 := make([]byte, len(buf)) 56 | copy(buf2, buf) 57 | 58 | c := make(chan ioret, 1) 59 | 60 | go func() { 61 | n, err := w.w.Write(buf2) 62 | c <- ioret{n, err} 63 | close(c) 64 | }() 65 | 66 | select { 67 | case r := <-c: 68 | return r.n, r.err 69 | case <-w.ctx.Done(): 70 | return 0, w.ctx.Err() 71 | } 72 | } 73 | 74 | type Reader interface { 75 | io.Reader 76 | } 77 | 78 | type ctxReader struct { 79 | r io.Reader 80 | ctx context.Context 81 | } 82 | 83 | // NewReader wraps a reader to make it respect given Context. 84 | // If there is a blocking read, the returned Reader will return 85 | // whenever the context is cancelled (the return values are n=0 86 | // and err=ctx.Err().) 87 | // 88 | // Note well: this wrapper DOES NOT ACTUALLY cancel the underlying 89 | // write-- there is no way to do that with the standard go io 90 | // interface. So the read and write _will_ happen or hang. So, use 91 | // this sparingly, make sure to cancel the read or write as necesary 92 | // (e.g. closing a connection whose context is up, etc.) 93 | // 94 | // Furthermore, in order to protect your memory from being read 95 | // _before_ you've cancelled the context, this io.Reader will 96 | // allocate a buffer of the same size, and **copy** into the client's 97 | // if the read succeeds in time. 98 | func NewReader(ctx context.Context, r io.Reader) *ctxReader { 99 | return &ctxReader{ctx: ctx, r: r} 100 | } 101 | 102 | func (r *ctxReader) Read(buf []byte) (int, error) { 103 | buf2 := make([]byte, len(buf)) 104 | 105 | c := make(chan ioret, 1) 106 | 107 | go func() { 108 | n, err := r.r.Read(buf2) 109 | c <- ioret{n, err} 110 | close(c) 111 | }() 112 | 113 | select { 114 | case ret := <-c: 115 | copy(buf, buf2) 116 | return ret.n, ret.err 117 | case <-r.ctx.Done(): 118 | return 0, r.ctx.Err() 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /io/ctxio_test.go: -------------------------------------------------------------------------------- 1 | package ctxio 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | "time" 8 | 9 | context "golang.org/x/net/context" 10 | ) 11 | 12 | func TestReader(t *testing.T) { 13 | buf := []byte("abcdef") 14 | buf2 := make([]byte, 3) 15 | r := NewReader(context.Background(), bytes.NewReader(buf)) 16 | 17 | // read first half 18 | n, err := r.Read(buf2) 19 | if n != 3 { 20 | t.Error("n should be 3") 21 | } 22 | if err != nil { 23 | t.Error("should have no error") 24 | } 25 | if string(buf2) != string(buf[:3]) { 26 | t.Error("incorrect contents") 27 | } 28 | 29 | // read second half 30 | n, err = r.Read(buf2) 31 | if n != 3 { 32 | t.Error("n should be 3") 33 | } 34 | if err != nil { 35 | t.Error("should have no error") 36 | } 37 | if string(buf2) != string(buf[3:6]) { 38 | t.Error("incorrect contents") 39 | } 40 | 41 | // read more. 42 | n, err = r.Read(buf2) 43 | if n != 0 { 44 | t.Error("n should be 0", n) 45 | } 46 | if err != io.EOF { 47 | t.Error("should be EOF", err) 48 | } 49 | } 50 | 51 | func TestWriter(t *testing.T) { 52 | var buf bytes.Buffer 53 | w := NewWriter(context.Background(), &buf) 54 | 55 | // write three 56 | n, err := w.Write([]byte("abc")) 57 | if n != 3 { 58 | t.Error("n should be 3") 59 | } 60 | if err != nil { 61 | t.Error("should have no error") 62 | } 63 | if string(buf.Bytes()) != string("abc") { 64 | t.Error("incorrect contents") 65 | } 66 | 67 | // write three more 68 | n, err = w.Write([]byte("def")) 69 | if n != 3 { 70 | t.Error("n should be 3") 71 | } 72 | if err != nil { 73 | t.Error("should have no error") 74 | } 75 | if string(buf.Bytes()) != string("abcdef") { 76 | t.Error("incorrect contents") 77 | } 78 | } 79 | 80 | func TestReaderCancel(t *testing.T) { 81 | ctx, cancel := context.WithCancel(context.Background()) 82 | piper, pipew := io.Pipe() 83 | r := NewReader(ctx, piper) 84 | 85 | buf := make([]byte, 10) 86 | done := make(chan ioret) 87 | 88 | go func() { 89 | n, err := r.Read(buf) 90 | done <- ioret{n, err} 91 | }() 92 | 93 | pipew.Write([]byte("abcdefghij")) 94 | 95 | select { 96 | case ret := <-done: 97 | if ret.n != 10 { 98 | t.Error("ret.n should be 10", ret.n) 99 | } 100 | if ret.err != nil { 101 | t.Error("ret.err should be nil", ret.err) 102 | } 103 | if string(buf) != "abcdefghij" { 104 | t.Error("read contents differ") 105 | } 106 | case <-time.After(20 * time.Millisecond): 107 | t.Fatal("failed to read") 108 | } 109 | 110 | go func() { 111 | n, err := r.Read(buf) 112 | done <- ioret{n, err} 113 | }() 114 | 115 | cancel() 116 | 117 | select { 118 | case ret := <-done: 119 | if ret.n != 0 { 120 | t.Error("ret.n should be 0", ret.n) 121 | } 122 | if ret.err == nil { 123 | t.Error("ret.err should be ctx error", ret.err) 124 | } 125 | case <-time.After(20 * time.Millisecond): 126 | t.Fatal("failed to stop reading after cancel") 127 | } 128 | } 129 | 130 | func TestWriterCancel(t *testing.T) { 131 | ctx, cancel := context.WithCancel(context.Background()) 132 | piper, pipew := io.Pipe() 133 | w := NewWriter(ctx, pipew) 134 | 135 | buf := make([]byte, 10) 136 | done := make(chan ioret) 137 | 138 | go func() { 139 | n, err := w.Write([]byte("abcdefghij")) 140 | done <- ioret{n, err} 141 | }() 142 | 143 | piper.Read(buf) 144 | 145 | select { 146 | case ret := <-done: 147 | if ret.n != 10 { 148 | t.Error("ret.n should be 10", ret.n) 149 | } 150 | if ret.err != nil { 151 | t.Error("ret.err should be nil", ret.err) 152 | } 153 | if string(buf) != "abcdefghij" { 154 | t.Error("write contents differ") 155 | } 156 | case <-time.After(20 * time.Millisecond): 157 | t.Fatal("failed to write") 158 | } 159 | 160 | go func() { 161 | n, err := w.Write([]byte("abcdefghij")) 162 | done <- ioret{n, err} 163 | }() 164 | 165 | cancel() 166 | 167 | select { 168 | case ret := <-done: 169 | if ret.n != 0 { 170 | t.Error("ret.n should be 0", ret.n) 171 | } 172 | if ret.err == nil { 173 | t.Error("ret.err should be ctx error", ret.err) 174 | } 175 | case <-time.After(20 * time.Millisecond): 176 | t.Fatal("failed to stop writing after cancel") 177 | } 178 | } 179 | 180 | func TestReadPostCancel(t *testing.T) { 181 | ctx, cancel := context.WithCancel(context.Background()) 182 | piper, pipew := io.Pipe() 183 | r := NewReader(ctx, piper) 184 | 185 | buf := make([]byte, 10) 186 | done := make(chan ioret) 187 | 188 | go func() { 189 | n, err := r.Read(buf) 190 | done <- ioret{n, err} 191 | }() 192 | 193 | cancel() 194 | 195 | select { 196 | case ret := <-done: 197 | if ret.n != 0 { 198 | t.Error("ret.n should be 0", ret.n) 199 | } 200 | if ret.err == nil { 201 | t.Error("ret.err should be ctx error", ret.err) 202 | } 203 | case <-time.After(20 * time.Millisecond): 204 | t.Fatal("failed to stop reading after cancel") 205 | } 206 | 207 | pipew.Write([]byte("abcdefghij")) 208 | 209 | if !bytes.Equal(buf, make([]byte, len(buf))) { 210 | t.Fatal("buffer should have not been written to") 211 | } 212 | } 213 | 214 | func TestWritePostCancel(t *testing.T) { 215 | ctx, cancel := context.WithCancel(context.Background()) 216 | piper, pipew := io.Pipe() 217 | w := NewWriter(ctx, pipew) 218 | 219 | buf := []byte("abcdefghij") 220 | buf2 := make([]byte, 10) 221 | done := make(chan ioret) 222 | 223 | go func() { 224 | n, err := w.Write(buf) 225 | done <- ioret{n, err} 226 | }() 227 | 228 | piper.Read(buf2) 229 | 230 | select { 231 | case ret := <-done: 232 | if ret.n != 10 { 233 | t.Error("ret.n should be 10", ret.n) 234 | } 235 | if ret.err != nil { 236 | t.Error("ret.err should be nil", ret.err) 237 | } 238 | if string(buf2) != "abcdefghij" { 239 | t.Error("write contents differ") 240 | } 241 | case <-time.After(20 * time.Millisecond): 242 | t.Fatal("failed to write") 243 | } 244 | 245 | go func() { 246 | n, err := w.Write(buf) 247 | done <- ioret{n, err} 248 | }() 249 | 250 | cancel() 251 | 252 | select { 253 | case ret := <-done: 254 | if ret.n != 0 { 255 | t.Error("ret.n should be 0", ret.n) 256 | } 257 | if ret.err == nil { 258 | t.Error("ret.err should be ctx error", ret.err) 259 | } 260 | case <-time.After(20 * time.Millisecond): 261 | t.Fatal("failed to stop writing after cancel") 262 | } 263 | 264 | copy(buf, []byte("aaaaaaaaaa")) 265 | 266 | piper.Read(buf2) 267 | 268 | if string(buf2) == "aaaaaaaaaa" { 269 | t.Error("buffer was read from after ctx cancel") 270 | } else if string(buf2) != "abcdefghij" { 271 | t.Error("write contents differ from expected") 272 | } 273 | } 274 | --------------------------------------------------------------------------------