├── .gitignore ├── Godeps ├── Readme └── Godeps.json ├── .travis.yml ├── codecov.yml ├── example ├── no-cache │ └── no-cache.go ├── lru-cache │ └── golang-lru.go └── ttl-cache │ └── go-cache.go ├── cache.go ├── Gopkg.toml ├── Gopkg.lock ├── LICENSE ├── key.go ├── inMemoryCache_go19.go ├── inMemoryCache.go ├── README.md ├── trace.go ├── TRACE.md ├── MIGRATE.md ├── dataloader.go └── dataloader_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor/ 2 | -------------------------------------------------------------------------------- /Godeps/Readme: -------------------------------------------------------------------------------- 1 | This directory tree is generated automatically by godep. 2 | 3 | Please do not edit. 4 | 5 | See https://github.com/tools/godep for more information. 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.8 5 | - 1.x 6 | 7 | install: 8 | - go get -u github.com/golang/dep/... 9 | - dep ensure 10 | 11 | script: 12 | - go test -v -race -coverprofile=coverage.txt -covermode=atomic 13 | 14 | after_success: 15 | - bash <(curl -s https://codecov.io/bash) 16 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | require_ci_to_pass: true 4 | comment: 5 | behavior: default 6 | layout: header, diff 7 | require_changes: false 8 | coverage: 9 | precision: 2 10 | range: 11 | - 70.0 12 | - 100.0 13 | round: down 14 | status: 15 | changes: false 16 | patch: true 17 | project: true 18 | parsers: 19 | gcov: 20 | branch_detection: 21 | conditional: true 22 | loop: true 23 | macro: false 24 | method: false 25 | javascript: 26 | enable_partials: false 27 | -------------------------------------------------------------------------------- /Godeps/Godeps.json: -------------------------------------------------------------------------------- 1 | { 2 | "ImportPath": "github.com/graph-gophers/dataloader", 3 | "GoVersion": "go1.10", 4 | "GodepVersion": "v80", 5 | "Deps": [ 6 | { 7 | "ImportPath": "github.com/opentracing/opentracing-go", 8 | "Comment": "v1.0.2-5-g1361b9c", 9 | "Rev": "1361b9cd60be79c4c3a7fa9841b3c132e40066a7" 10 | }, 11 | { 12 | "ImportPath": "github.com/opentracing/opentracing-go/log", 13 | "Comment": "v1.0.2-5-g1361b9c", 14 | "Rev": "1361b9cd60be79c4c3a7fa9841b3c132e40066a7" 15 | }, 16 | { 17 | "ImportPath": "golang.org/x/net/context", 18 | "Rev": "5ccada7d0a7ba9aeb5d3aca8d3501b4c2a509fec" 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /example/no-cache/no-cache.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/graph-gophers/dataloader" 8 | ) 9 | 10 | func main() { 11 | // go-cache will automaticlly cleanup expired items on given diration 12 | cache := &dataloader.NoCache{} 13 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache(cache)) 14 | 15 | result, err := loader.Load(context.TODO(), dataloader.StringKey("some key"))() 16 | if err != nil { 17 | // handle error 18 | } 19 | 20 | fmt.Printf("identity: %s\n", result) 21 | } 22 | 23 | func batchFunc(_ context.Context, keys dataloader.Keys) []*dataloader.Result { 24 | var results []*dataloader.Result 25 | // do some pretend work to resolve keys 26 | for _, key := range keys { 27 | results = append(results, &dataloader.Result{key.String(), nil}) 28 | } 29 | return results 30 | } 31 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import "context" 4 | 5 | // The Cache interface. If a custom cache is provided, it must implement this interface. 6 | type Cache interface { 7 | Get(context.Context, Key) (Thunk, bool) 8 | Set(context.Context, Key, Thunk) 9 | Delete(context.Context, Key) bool 10 | Clear() 11 | } 12 | 13 | // NoCache implements Cache interface where all methods are noops. 14 | // This is useful for when you don't want to cache items but still 15 | // want to use a data loader 16 | type NoCache struct{} 17 | 18 | // Get is a NOOP 19 | func (c *NoCache) Get(context.Context, Key) (Thunk, bool) { return nil, false } 20 | 21 | // Set is a NOOP 22 | func (c *NoCache) Set(context.Context, Key, Thunk) { return } 23 | 24 | // Delete is a NOOP 25 | func (c *NoCache) Delete(context.Context, Key) bool { return false } 26 | 27 | // Clear is a NOOP 28 | func (c *NoCache) Clear() { return } 29 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | 2 | # Gopkg.toml example 3 | # 4 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 5 | # for detailed Gopkg.toml documentation. 6 | # 7 | # required = ["github.com/user/thing/cmd/thing"] 8 | # ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] 9 | # 10 | # [[constraint]] 11 | # name = "github.com/user/project" 12 | # version = "1.0.0" 13 | # 14 | # [[constraint]] 15 | # name = "github.com/user/project2" 16 | # branch = "dev" 17 | # source = "github.com/myfork/project2" 18 | # 19 | # [[override]] 20 | # name = "github.com/x/y" 21 | # version = "2.4.0" 22 | 23 | 24 | [[constraint]] 25 | branch = "master" 26 | name = "github.com/hashicorp/golang-lru" 27 | 28 | [[constraint]] 29 | name = "github.com/opentracing/opentracing-go" 30 | version = "1.0.2" 31 | 32 | [[constraint]] 33 | name = "github.com/patrickmn/go-cache" 34 | version = "2.1.0" 35 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | branch = "master" 6 | name = "github.com/hashicorp/golang-lru" 7 | packages = [".","simplelru"] 8 | revision = "0a025b7e63adc15a622f29b0b2c4c3848243bbf6" 9 | 10 | [[projects]] 11 | name = "github.com/opentracing/opentracing-go" 12 | packages = [".","log"] 13 | revision = "1949ddbfd147afd4d964a9f00b24eb291e0e7c38" 14 | version = "v1.0.2" 15 | 16 | [[projects]] 17 | name = "github.com/patrickmn/go-cache" 18 | packages = ["."] 19 | revision = "a3647f8e31d79543b2d0f0ae2fe5c379d72cedc0" 20 | version = "v2.1.0" 21 | 22 | [[projects]] 23 | branch = "master" 24 | name = "golang.org/x/net" 25 | packages = ["context"] 26 | revision = "a8b9294777976932365dabb6640cf1468d95c70f" 27 | 28 | [solve-meta] 29 | analyzer-name = "dep" 30 | analyzer-version = 1 31 | inputs-digest = "a0b8606d9f2ed9df7e69cae570c65c7d7b090bb7a08f58d3535b584693d44da9" 32 | solver-name = "gps-cdcl" 33 | solver-version = 1 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nick Randall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /key.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | // Key is the interface that all keys need to implement 4 | type Key interface { 5 | // String returns a guaranteed unique string that can be used to identify an object 6 | String() string 7 | // Raw returns the raw, underlaying value of the key 8 | Raw() interface{} 9 | } 10 | 11 | // Keys wraps a slice of Key types to provide some convenience methods. 12 | type Keys []Key 13 | 14 | // Keys returns the list of strings. One for each "Key" in the list 15 | func (l Keys) Keys() []string { 16 | list := make([]string, len(l)) 17 | for i := range l { 18 | list[i] = l[i].String() 19 | } 20 | return list 21 | } 22 | 23 | // StringKey implements the Key interface for a string 24 | type StringKey string 25 | 26 | // String is an identity method. Used to implement String interface 27 | func (k StringKey) String() string { return string(k) } 28 | 29 | // Raw is an identity method. Used to implement Key Raw 30 | func (k StringKey) Raw() interface{} { return k } 31 | 32 | // NewKeysFromStrings converts a `[]strings` to a `Keys` ([]Key) 33 | func NewKeysFromStrings(strings []string) Keys { 34 | list := make(Keys, len(strings)) 35 | for i := range strings { 36 | list[i] = StringKey(strings[i]) 37 | } 38 | return list 39 | } 40 | -------------------------------------------------------------------------------- /inMemoryCache_go19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | 3 | package dataloader 4 | 5 | import ( 6 | "context" 7 | "sync" 8 | ) 9 | 10 | // InMemoryCache is an in memory implementation of Cache interface. 11 | // This simple implementation is well suited for 12 | // a "per-request" dataloader (i.e. one that only lives 13 | // for the life of an http request) but it's not well suited 14 | // for long lived cached items. 15 | type InMemoryCache struct { 16 | items *sync.Map 17 | } 18 | 19 | // NewCache constructs a new InMemoryCache 20 | func NewCache() *InMemoryCache { 21 | return &InMemoryCache{ 22 | items: &sync.Map{}, 23 | } 24 | } 25 | 26 | // Set sets the `value` at `key` in the cache 27 | func (c *InMemoryCache) Set(_ context.Context, key Key, value Thunk) { 28 | c.items.Store(key.String(), value) 29 | } 30 | 31 | // Get gets the value at `key` if it exsits, returns value (or nil) and bool 32 | // indicating of value was found 33 | func (c *InMemoryCache) Get(_ context.Context, key Key) (Thunk, bool) { 34 | item, found := c.items.Load(key.String()) 35 | if !found { 36 | return nil, false 37 | } 38 | 39 | return item.(Thunk), true 40 | } 41 | 42 | // Delete deletes item at `key` from cache 43 | func (c *InMemoryCache) Delete(_ context.Context, key Key) bool { 44 | if _, found := c.items.Load(key.String()); found { 45 | c.items.Delete(key.String()) 46 | return true 47 | } 48 | return false 49 | } 50 | 51 | // Clear clears the entire cache 52 | func (c *InMemoryCache) Clear() { 53 | c.items.Range(func(key, _ interface{}) bool { 54 | c.items.Delete(key) 55 | return true 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /inMemoryCache.go: -------------------------------------------------------------------------------- 1 | // +build !go1.9 2 | 3 | package dataloader 4 | 5 | import ( 6 | "context" 7 | "sync" 8 | ) 9 | 10 | // InMemoryCache is an in memory implementation of Cache interface. 11 | // This simple implementation is well suited for 12 | // a "per-request" dataloader (i.e. one that only lives 13 | // for the life of an http request) but it's not well suited 14 | // for long lived cached items. 15 | type InMemoryCache struct { 16 | items map[string]Thunk 17 | mu sync.RWMutex 18 | } 19 | 20 | // NewCache constructs a new InMemoryCache 21 | func NewCache() *InMemoryCache { 22 | items := make(map[string]Thunk) 23 | return &InMemoryCache{ 24 | items: items, 25 | } 26 | } 27 | 28 | // Set sets the `value` at `key` in the cache 29 | func (c *InMemoryCache) Set(_ context.Context, key Key, value Thunk) { 30 | c.mu.Lock() 31 | c.items[key.String()] = value 32 | c.mu.Unlock() 33 | } 34 | 35 | // Get gets the value at `key` if it exsits, returns value (or nil) and bool 36 | // indicating of value was found 37 | func (c *InMemoryCache) Get(_ context.Context, key Key) (Thunk, bool) { 38 | c.mu.RLock() 39 | defer c.mu.RUnlock() 40 | 41 | item, found := c.items[key.String()] 42 | if !found { 43 | return nil, false 44 | } 45 | 46 | return item, true 47 | } 48 | 49 | // Delete deletes item at `key` from cache 50 | func (c *InMemoryCache) Delete(ctx context.Context, key Key) bool { 51 | if _, found := c.Get(ctx, key); found { 52 | c.mu.Lock() 53 | defer c.mu.Unlock() 54 | delete(c.items, key.String()) 55 | return true 56 | } 57 | return false 58 | } 59 | 60 | // Clear clears the entire cache 61 | func (c *InMemoryCache) Clear() { 62 | c.mu.Lock() 63 | c.items = map[string]Thunk{} 64 | c.mu.Unlock() 65 | } 66 | -------------------------------------------------------------------------------- /example/lru-cache/golang-lru.go: -------------------------------------------------------------------------------- 1 | // This is an exmaple of using go-cache as a long term cache solution for 2 | // dataloader. 3 | package main 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | lru "github.com/hashicorp/golang-lru" 10 | "github.com/nicksrandall/dataloader" 11 | ) 12 | 13 | // Cache implements the dataloader.Cache interface 14 | type Cache struct { 15 | *lru.ARCCache 16 | } 17 | 18 | // Get gets an item from the cache 19 | func (c *Cache) Get(_ context.Context, key dataloader.Key) (dataloader.Thunk, bool) { 20 | v, ok := c.ARCCache.Get(key) 21 | if ok { 22 | return v.(dataloader.Thunk), ok 23 | } 24 | return nil, ok 25 | } 26 | 27 | // Set sets an item in the cache 28 | func (c *Cache) Set(_ context.Context, key dataloader.Key, value dataloader.Thunk) { 29 | c.ARCCache.Add(key, value) 30 | } 31 | 32 | // Delete deletes an item in the cache 33 | func (c *Cache) Delete(_ context.Context, key dataloader.Key) bool { 34 | if c.ARCCache.Contains(key) { 35 | c.ARCCache.Remove(key) 36 | return true 37 | } 38 | return false 39 | } 40 | 41 | // Clear cleasrs the cache 42 | func (c *Cache) Clear() { 43 | c.ARCCache.Purge() 44 | } 45 | 46 | func main() { 47 | // go-cache will automaticlly cleanup expired items on given diration 48 | c, _ := lru.NewARC(100) 49 | cache := &Cache{c} 50 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache(cache)) 51 | 52 | // immediately call the future function from loader 53 | result, err := loader.Load(context.TODO(), dataloader.StringKey("some key"))() 54 | if err != nil { 55 | // handle error 56 | } 57 | 58 | fmt.Printf("identity: %s\n", result) 59 | } 60 | 61 | func batchFunc(_ context.Context, keys dataloader.Keys) []*dataloader.Result { 62 | var results []*dataloader.Result 63 | // do some pretend work to resolve keys 64 | for _, key := range keys { 65 | results = append(results, &dataloader.Result{key.String(), nil}) 66 | } 67 | return results 68 | } 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DataLoader 2 | [![GoDoc](https://godoc.org/gopkg.in/graph-gophers/dataloader.v3?status.svg)](https://godoc.org/github.com/graph-gophers/dataloader) 3 | [![Build Status](https://travis-ci.org/graph-gophers/dataloader.svg?branch=master)](https://travis-ci.org/graph-gophers/dataloader) 4 | 5 | This is an implementation of [Facebook's DataLoader](https://github.com/facebook/dataloader) in Golang. 6 | 7 | ## Install 8 | `go get -u github.com/graph-gophers/dataloader` 9 | 10 | ## Usage 11 | ```go 12 | // setup batch function 13 | batchFn := func(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 14 | var results []*dataloader.Result 15 | // do some async work to get data for specified keys 16 | // append to this list resolved values 17 | return results 18 | } 19 | 20 | // create Loader with an in-memory cache 21 | loader := dataloader.NewBatchedLoader(batchFn) 22 | 23 | /** 24 | * Use loader 25 | * 26 | * A thunk is a function returned from a function that is a 27 | * closure over a value (in this case an interface value and error). 28 | * When called, it will block until the value is resolved. 29 | */ 30 | thunk := loader.Load(ctx.TODO(), dataloader.StringKey("key1")) // StringKey is a convenience method that make wraps string to implement `Key` interface 31 | result, err := thunk() 32 | if err != nil { 33 | // handle data error 34 | } 35 | 36 | log.Printf("value: %#v", result) 37 | ``` 38 | 39 | ### Don't need/want to use context? 40 | You're welcome to install the v1 version of this library. 41 | 42 | ## Cache 43 | This implementation contains a very basic cache that is intended only to be used for short lived DataLoaders (i.e. DataLoaders that ony exsist for the life of an http request). You may use your own implementation if you want. 44 | 45 | > it also has a `NoCache` type that implements the cache interface but all methods are noop. If you do not wish to cache anything. 46 | 47 | ## Examples 48 | There are a few basic examples in the example folder. 49 | -------------------------------------------------------------------------------- /example/ttl-cache/go-cache.go: -------------------------------------------------------------------------------- 1 | // This is an exmaple of using go-cache as a long term cache solution for 2 | // dataloader. 3 | package main 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "time" 9 | 10 | "github.com/nicksrandall/dataloader" 11 | cache "github.com/patrickmn/go-cache" 12 | ) 13 | 14 | // Cache implements the dataloader.Cache interface 15 | type Cache struct { 16 | c *cache.Cache 17 | } 18 | 19 | // Get gets a value from the cache 20 | func (c *Cache) Get(_ context.Context, key dataloader.Key) (dataloader.Thunk, bool) { 21 | v, ok := c.c.Get(key.String()) 22 | if ok { 23 | return v.(dataloader.Thunk), ok 24 | } 25 | return nil, ok 26 | } 27 | 28 | // Set sets a value in the cache 29 | func (c *Cache) Set(_ context.Context, key dataloader.Key, value dataloader.Thunk) { 30 | c.c.Set(key.String(), value, 0) 31 | } 32 | 33 | // Delete deletes and item in the cache 34 | func (c *Cache) Delete(_ context.Context, key dataloader.Key) bool { 35 | if _, found := c.c.Get(key.String()); found { 36 | c.c.Delete(key.String()) 37 | return true 38 | } 39 | return false 40 | } 41 | 42 | // Clear clears the cache 43 | func (c *Cache) Clear() { 44 | c.c.Flush() 45 | } 46 | 47 | func main() { 48 | // go-cache will automaticlly cleanup expired items on given diration 49 | c := cache.New(15*time.Minute, 15*time.Minute) 50 | cache := &Cache{c} 51 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache(cache)) 52 | 53 | // immediately call the future function from loader 54 | result, err := loader.Load(context.TODO(), dataloader.StringKey("some key"))() 55 | if err != nil { 56 | // handle error 57 | } 58 | 59 | fmt.Printf("identity: %s\n", result) 60 | } 61 | 62 | func batchFunc(_ context.Context, keys dataloader.Keys) []*dataloader.Result { 63 | var results []*dataloader.Result 64 | // do some pretend work to resolve keys 65 | for _, key := range keys { 66 | results = append(results, &dataloader.Result{key.String(), nil}) 67 | } 68 | return results 69 | } 70 | -------------------------------------------------------------------------------- /trace.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import ( 4 | "context" 5 | 6 | opentracing "github.com/opentracing/opentracing-go" 7 | ) 8 | 9 | type TraceLoadFinishFunc func(Thunk) 10 | type TraceLoadManyFinishFunc func(ThunkMany) 11 | type TraceBatchFinishFunc func([]*Result) 12 | 13 | // Tracer is an interface that may be used to implement tracing. 14 | type Tracer interface { 15 | // TraceLoad will trace the calls to Load 16 | TraceLoad(ctx context.Context, key Key) (context.Context, TraceLoadFinishFunc) 17 | // TraceLoadMany will trace the calls to LoadMany 18 | TraceLoadMany(ctx context.Context, keys Keys) (context.Context, TraceLoadManyFinishFunc) 19 | // TraceBatch will trace data loader batches 20 | TraceBatch(ctx context.Context, keys Keys) (context.Context, TraceBatchFinishFunc) 21 | } 22 | 23 | // OpenTracing Tracer implements a tracer that can be used with the Open Tracing standard. 24 | type OpenTracingTracer struct{} 25 | 26 | // TraceLoad will trace a call to dataloader.LoadMany with Open Tracing 27 | func (OpenTracingTracer) TraceLoad(ctx context.Context, key Key) (context.Context, TraceLoadFinishFunc) { 28 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: load") 29 | 30 | span.SetTag("dataloader.key", key.String()) 31 | 32 | return spanCtx, func(thunk Thunk) { 33 | // TODO: is there anything we should do with the results? 34 | span.Finish() 35 | } 36 | } 37 | 38 | // TraceLoadMany will trace a call to dataloader.LoadMany with Open Tracing 39 | func (OpenTracingTracer) TraceLoadMany(ctx context.Context, keys Keys) (context.Context, TraceLoadManyFinishFunc) { 40 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: loadmany") 41 | 42 | span.SetTag("dataloader.keys", keys.Keys()) 43 | 44 | return spanCtx, func(thunk ThunkMany) { 45 | // TODO: is there anything we should do with the results? 46 | span.Finish() 47 | } 48 | } 49 | 50 | // TraceBatch will trace a call to dataloader.LoadMany with Open Tracing 51 | func (OpenTracingTracer) TraceBatch(ctx context.Context, keys Keys) (context.Context, TraceBatchFinishFunc) { 52 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: batch") 53 | 54 | span.SetTag("dataloader.keys", keys.Keys()) 55 | 56 | return spanCtx, func(results []*Result) { 57 | // TODO: is there anything we should do with the results? 58 | span.Finish() 59 | } 60 | } 61 | 62 | // NoopTracer is the default (noop) tracer 63 | type NoopTracer struct{} 64 | 65 | // TraceLoad is a noop function 66 | func (NoopTracer) TraceLoad(ctx context.Context, key Key) (context.Context, TraceLoadFinishFunc) { 67 | return ctx, func(Thunk) {} 68 | } 69 | 70 | // TraceLoadMany is a noop function 71 | func (NoopTracer) TraceLoadMany(ctx context.Context, keys Keys) (context.Context, TraceLoadManyFinishFunc) { 72 | return ctx, func(ThunkMany) {} 73 | } 74 | 75 | // TraceBatch is a noop function 76 | func (NoopTracer) TraceBatch(ctx context.Context, keys Keys) (context.Context, TraceBatchFinishFunc) { 77 | return ctx, func(result []*Result) {} 78 | } 79 | -------------------------------------------------------------------------------- /TRACE.md: -------------------------------------------------------------------------------- 1 | # Adding a new trace backend. 2 | 3 | If you whant to add a new tracing backend all you need to do is implement the 4 | `Tracer` interface and pass it as an option to the dataloader on initialization. 5 | 6 | As an example, this is how you could implement it to an OpenCensus backend. 7 | 8 | ```go 9 | package main 10 | 11 | import ( 12 | "context" 13 | "strings" 14 | 15 | exp "go.opencensus.io/examples/exporter" 16 | "github.com/nicksrandall/dataloader" 17 | "go.opencensus.io/trace" 18 | ) 19 | 20 | // OpenCensusTracer Tracer implements a tracer that can be used with the Open Tracing standard. 21 | type OpenCensusTracer struct{} 22 | 23 | // TraceLoad will trace a call to dataloader.LoadMany with Open Tracing 24 | func (OpenCensusTracer) TraceLoad(ctx context.Context, key dataloader.Key) (context.Context, dataloader.TraceLoadFinishFunc) { 25 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: load") 26 | cSpan.AddAttributes( 27 | trace.StringAttribute("dataloader.key", key.String()), 28 | ) 29 | return cCtx, func(thunk dataloader.Thunk) { 30 | // TODO: is there anything we should do with the results? 31 | cSpan.End() 32 | } 33 | } 34 | 35 | // TraceLoadMany will trace a call to dataloader.LoadMany with Open Tracing 36 | func (OpenCensusTracer) TraceLoadMany(ctx context.Context, keys dataloader.Keys) (context.Context, dataloader.TraceLoadManyFinishFunc) { 37 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: loadmany") 38 | cSpan.AddAttributes( 39 | trace.StringAttribute("dataloader.keys", strings.Join(keys.Keys(), ",")), 40 | ) 41 | return cCtx, func(thunk dataloader.ThunkMany) { 42 | // TODO: is there anything we should do with the results? 43 | cSpan.End() 44 | } 45 | } 46 | 47 | // TraceBatch will trace a call to dataloader.LoadMany with Open Tracing 48 | func (OpenCensusTracer) TraceBatch(ctx context.Context, keys dataloader.Keys) (context.Context, dataloader.TraceBatchFinishFunc) { 49 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: batch") 50 | cSpan.AddAttributes( 51 | trace.StringAttribute("dataloader.keys", strings.Join(keys.Keys(), ",")), 52 | ) 53 | return cCtx, func(results []*dataloader.Result) { 54 | // TODO: is there anything we should do with the results? 55 | cSpan.End() 56 | } 57 | } 58 | 59 | func batchFunc(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { 60 | // ...loader logic goes here 61 | } 62 | 63 | func main(){ 64 | //initialize an example exporter that just logs to the console 65 | trace.ApplyConfig(trace.Config{ 66 | DefaultSampler: trace.AlwaysSample(), 67 | }) 68 | trace.RegisterExporter(&exp.PrintExporter{}) 69 | // initialize the dataloader with your new tracer backend 70 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithTracer(OpenCensusTracer{})) 71 | // initialize a context since it's not receiving one from anywhere else. 72 | ctx, span := trace.StartSpan(context.TODO(), "Span Name") 73 | defer span.End() 74 | // request from the dataloader as usual 75 | value, err := loader.Load(ctx, dataloader.StringKey(SomeID))() 76 | // ... 77 | } 78 | ``` 79 | 80 | Don't forget to initialize the exporters of your choice and register it with `trace.RegisterExporter(&exporterInstance)`. 81 | -------------------------------------------------------------------------------- /MIGRATE.md: -------------------------------------------------------------------------------- 1 | ## Upgrade from v1 to v2 2 | The only difference between v1 and v2 is that we added use of [context](https://golang.org/pkg/context). 3 | 4 | ```diff 5 | - loader.Load(key string) Thunk 6 | + loader.Load(ctx context.Context, key string) Thunk 7 | - loader.LoadMany(keys []string) ThunkMany 8 | + loader.LoadMany(ctx context.Context, keys []string) ThunkMany 9 | ``` 10 | 11 | ```diff 12 | - type BatchFunc func([]string) []*Result 13 | + type BatchFunc func(context.Context, []string) []*Result 14 | ``` 15 | 16 | ## Upgrade from v2 to v3 17 | ```diff 18 | // dataloader.Interface as added context.Context to methods 19 | - loader.Prime(key string, value interface{}) Interface 20 | + loader.Prime(ctx context.Context, key string, value interface{}) Interface 21 | - loader.Clear(key string) Interface 22 | + loader.Clear(ctx context.Context, key string) Interface 23 | ``` 24 | 25 | ```diff 26 | // cache interface as added context.Context to methods 27 | type Cache interface { 28 | - Get(string) (Thunk, bool) 29 | + Get(context.Context, string) (Thunk, bool) 30 | - Set(string, Thunk) 31 | + Set(context.Context, string, Thunk) 32 | - Delete(string) bool 33 | + Delete(context.Context, string) bool 34 | Clear() 35 | } 36 | ``` 37 | 38 | ## Upgrade from v3 to v4 39 | ```diff 40 | // dataloader.Interface as now allows interace{} as key rather than string 41 | - loader.Load(context.Context, key string) Thunk 42 | + loader.Load(ctx context.Context, key interface{}) Thunk 43 | - loader.LoadMany(context.Context, key []string) ThunkMany 44 | + loader.LoadMany(ctx context.Context, keys []interface{}) ThunkMany 45 | - loader.Prime(context.Context, key string, value interface{}) Interface 46 | + loader.Prime(ctx context.Context, key interface{}, value interface{}) Interface 47 | - loader.Clear(context.Context, key string) Interface 48 | + loader.Clear(ctx context.Context, key interface{}) Interface 49 | ``` 50 | 51 | ```diff 52 | // cache interface now allows interface{} as key instead of string 53 | type Cache interface { 54 | - Get(context.Context, string) (Thunk, bool) 55 | + Get(context.Context, interface{}) (Thunk, bool) 56 | - Set(context.Context, string, Thunk) 57 | + Set(context.Context, interface{}, Thunk) 58 | - Delete(context.Context, string) bool 59 | + Delete(context.Context, interface{}) bool 60 | Clear() 61 | } 62 | ``` 63 | 64 | ## Upgrade from v4 to v5 65 | ```diff 66 | // dataloader.Interface as now allows interace{} as key rather than string 67 | - loader.Load(context.Context, key interface{}) Thunk 68 | + loader.Load(ctx context.Context, key Key) Thunk 69 | - loader.LoadMany(context.Context, key []interface{}) ThunkMany 70 | + loader.LoadMany(ctx context.Context, keys Keys) ThunkMany 71 | - loader.Prime(context.Context, key interface{}, value interface{}) Interface 72 | + loader.Prime(ctx context.Context, key Key, value interface{}) Interface 73 | - loader.Clear(context.Context, key interface{}) Interface 74 | + loader.Clear(ctx context.Context, key Key) Interface 75 | ``` 76 | 77 | ```diff 78 | // cache interface now allows interface{} as key instead of string 79 | type Cache interface { 80 | - Get(context.Context, interface{}) (Thunk, bool) 81 | + Get(context.Context, Key) (Thunk, bool) 82 | - Set(context.Context, interface{}, Thunk) 83 | + Set(context.Context, Key, Thunk) 84 | - Delete(context.Context, interface{}) bool 85 | + Delete(context.Context, Key) bool 86 | Clear() 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /dataloader.go: -------------------------------------------------------------------------------- 1 | // Package dataloader is an implimentation of facebook's dataloader in go. 2 | // See https://github.com/facebook/dataloader for more information 3 | package dataloader 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log" 9 | "runtime" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // Interface is a `DataLoader` Interface which defines a public API for loading data from a particular 15 | // data back-end with unique keys such as the `id` column of a SQL table or 16 | // document name in a MongoDB database, given a batch loading function. 17 | // 18 | // Each `DataLoader` instance should contain a unique memoized cache. Use caution when 19 | // used in long-lived applications or those which serve many users with 20 | // different access permissions and consider creating a new instance per 21 | // web request. 22 | type Interface interface { 23 | Load(context.Context, Key) Thunk 24 | LoadMany(context.Context, Keys) ThunkMany 25 | Clear(context.Context, Key) Interface 26 | ClearAll() Interface 27 | Prime(ctx context.Context, key Key, value interface{}) Interface 28 | } 29 | 30 | // BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`. 31 | // It's important that the length of the input keys matches the length of the output results. 32 | // 33 | // The keys passed to this function are guaranteed to be unique 34 | type BatchFunc func(context.Context, Keys) []*Result 35 | 36 | // Result is the data structure that a BatchFunc returns. 37 | // It contains the resolved data, and any errors that may have occurred while fetching the data. 38 | type Result struct { 39 | Data interface{} 40 | Error error 41 | } 42 | 43 | // ResultMany is used by the LoadMany method. 44 | // It contains a list of resolved data and a list of errors. 45 | // The lengths of the data list and error list will match, and elements at each index correspond to each other. 46 | type ResultMany struct { 47 | Data []interface{} 48 | Error []error 49 | } 50 | 51 | // Loader implements the dataloader.Interface. 52 | type Loader struct { 53 | // the batch function to be used by this loader 54 | batchFn BatchFunc 55 | 56 | // the maximum batch size. Set to 0 if you want it to be unbounded. 57 | batchCap int 58 | 59 | // the internal cache. This packages contains a basic cache implementation but any custom cache 60 | // implementation could be used as long as it implements the `Cache` interface. 61 | cacheLock sync.Mutex 62 | cache Cache 63 | // should we clear the cache on each batch? 64 | // this would allow batching but no long term caching 65 | clearCacheOnBatch bool 66 | 67 | // count of queued up items 68 | count int 69 | 70 | // the maximum input queue size. Set to 0 if you want it to be unbounded. 71 | inputCap int 72 | 73 | // the amount of time to wait before triggering a batch 74 | wait time.Duration 75 | 76 | // lock to protect the batching operations 77 | batchLock sync.Mutex 78 | 79 | // current batcher 80 | curBatcher *batcher 81 | 82 | // used to close the sleeper of the current batcher 83 | endSleeper chan bool 84 | 85 | // used by tests to prevent logs 86 | silent bool 87 | 88 | // can be set to trace calls to dataloader 89 | tracer Tracer 90 | } 91 | 92 | // Thunk is a function that will block until the value (*Result) it contains is resolved. 93 | // After the value it contains is resolved, this function will return the result. 94 | // This function can be called many times, much like a Promise is other languages. 95 | // The value will only need to be resolved once so subsequent calls will return immediately. 96 | type Thunk func() (interface{}, error) 97 | 98 | // ThunkMany is much like the Thunk func type but it contains a list of results. 99 | type ThunkMany func() ([]interface{}, []error) 100 | 101 | // type used to on input channel 102 | type batchRequest struct { 103 | key Key 104 | channel chan *Result 105 | } 106 | 107 | // Option allows for configuration of Loader fields. 108 | type Option func(*Loader) 109 | 110 | // WithCache sets the BatchedLoader cache. Defaults to InMemoryCache if a Cache is not set. 111 | func WithCache(c Cache) Option { 112 | return func(l *Loader) { 113 | l.cache = c 114 | } 115 | } 116 | 117 | // WithBatchCapacity sets the batch capacity. Default is 0 (unbounded). 118 | func WithBatchCapacity(c int) Option { 119 | return func(l *Loader) { 120 | l.batchCap = c 121 | } 122 | } 123 | 124 | // WithInputCapacity sets the input capacity. Default is 1000. 125 | func WithInputCapacity(c int) Option { 126 | return func(l *Loader) { 127 | l.inputCap = c 128 | } 129 | } 130 | 131 | // WithWait sets the amount of time to wait before triggering a batch. 132 | // Default duration is 16 milliseconds. 133 | func WithWait(d time.Duration) Option { 134 | return func(l *Loader) { 135 | l.wait = d 136 | } 137 | } 138 | 139 | // WithClearCacheOnBatch allows batching of items but no long term caching. 140 | // It accomplishes this by clearing the cache after each batch operation. 141 | func WithClearCacheOnBatch() Option { 142 | return func(l *Loader) { 143 | l.cacheLock.Lock() 144 | l.clearCacheOnBatch = true 145 | l.cacheLock.Unlock() 146 | } 147 | } 148 | 149 | // withSilentLogger turns of log messages. It's used by the tests 150 | func withSilentLogger() Option { 151 | return func(l *Loader) { 152 | l.silent = true 153 | } 154 | } 155 | 156 | // WithTracer allows tracing of calls to Load and LoadMany 157 | func WithTracer(tracer Tracer) Option { 158 | return func(l *Loader) { 159 | l.tracer = tracer 160 | } 161 | } 162 | 163 | // WithOpenTracingTracer allows tracing of calls to Load and LoadMany 164 | func WithOpenTracingTracer() Option { 165 | return WithTracer(&OpenTracingTracer{}) 166 | } 167 | 168 | // NewBatchedLoader constructs a new Loader with given options. 169 | func NewBatchedLoader(batchFn BatchFunc, opts ...Option) *Loader { 170 | loader := &Loader{ 171 | batchFn: batchFn, 172 | inputCap: 1000, 173 | wait: 16 * time.Millisecond, 174 | } 175 | 176 | // Apply options 177 | for _, apply := range opts { 178 | apply(loader) 179 | } 180 | 181 | // Set defaults 182 | if loader.cache == nil { 183 | loader.cache = NewCache() 184 | } 185 | 186 | if loader.tracer == nil { 187 | loader.tracer = &NoopTracer{} 188 | } 189 | 190 | return loader 191 | } 192 | 193 | // Load load/resolves the given key, returning a channel that will contain the value and error 194 | func (l *Loader) Load(originalContext context.Context, key Key) Thunk { 195 | ctx, finish := l.tracer.TraceLoad(originalContext, key) 196 | 197 | c := make(chan *Result, 1) 198 | var result struct { 199 | mu sync.RWMutex 200 | value *Result 201 | } 202 | 203 | // lock to prevent duplicate keys coming in before item has been added to cache. 204 | l.cacheLock.Lock() 205 | if v, ok := l.cache.Get(ctx, key); ok { 206 | defer finish(v) 207 | defer l.cacheLock.Unlock() 208 | return v 209 | } 210 | 211 | thunk := func() (interface{}, error) { 212 | result.mu.RLock() 213 | resultNotSet := result.value == nil 214 | result.mu.RUnlock() 215 | 216 | if resultNotSet { 217 | result.mu.Lock() 218 | if v, ok := <-c; ok { 219 | result.value = v 220 | } 221 | result.mu.Unlock() 222 | } 223 | result.mu.RLock() 224 | defer result.mu.RUnlock() 225 | return result.value.Data, result.value.Error 226 | } 227 | defer finish(thunk) 228 | 229 | l.cache.Set(ctx, key, thunk) 230 | l.cacheLock.Unlock() 231 | 232 | // this is sent to batch fn. It contains the key and the channel to return the 233 | // the result on 234 | req := &batchRequest{key, c} 235 | 236 | l.batchLock.Lock() 237 | // start the batch window if it hasn't already started. 238 | if l.curBatcher == nil { 239 | l.curBatcher = l.newBatcher(l.silent, l.tracer) 240 | // start the current batcher batch function 241 | go l.curBatcher.batch(originalContext) 242 | // start a sleeper for the current batcher 243 | l.endSleeper = make(chan bool) 244 | go l.sleeper(l.curBatcher, l.endSleeper) 245 | } 246 | 247 | l.curBatcher.input <- req 248 | 249 | // if we need to keep track of the count (max batch), then do so. 250 | if l.batchCap > 0 { 251 | l.count++ 252 | // if we hit our limit, force the batch to start 253 | if l.count == l.batchCap { 254 | // end the batcher synchronously here because another call to Load 255 | // may concurrently happen and needs to go to a new batcher. 256 | l.curBatcher.end() 257 | // end the sleeper for the current batcher. 258 | // this is to stop the goroutine without waiting for the 259 | // sleeper timeout. 260 | close(l.endSleeper) 261 | l.reset() 262 | } 263 | } 264 | l.batchLock.Unlock() 265 | 266 | return thunk 267 | } 268 | 269 | // LoadMany loads mulitiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in. 270 | func (l *Loader) LoadMany(originalContext context.Context, keys Keys) ThunkMany { 271 | ctx, finish := l.tracer.TraceLoadMany(originalContext, keys) 272 | 273 | var ( 274 | length = len(keys) 275 | data = make([]interface{}, length) 276 | errors = make([]error, length) 277 | c = make(chan *ResultMany, 1) 278 | wg sync.WaitGroup 279 | ) 280 | 281 | resolve := func(ctx context.Context, i int) { 282 | defer wg.Done() 283 | thunk := l.Load(ctx, keys[i]) 284 | result, err := thunk() 285 | data[i] = result 286 | errors[i] = err 287 | } 288 | 289 | wg.Add(length) 290 | for i := range keys { 291 | go resolve(ctx, i) 292 | } 293 | 294 | go func() { 295 | wg.Wait() 296 | 297 | // errs is nil unless there exists a non-nil error. 298 | // This prevents dataloader from returning a slice of all-nil errors. 299 | var errs []error 300 | for _, e := range errors { 301 | if e != nil { 302 | errs = errors 303 | break 304 | } 305 | } 306 | 307 | c <- &ResultMany{Data: data, Error: errs} 308 | close(c) 309 | }() 310 | 311 | var result struct { 312 | mu sync.RWMutex 313 | value *ResultMany 314 | } 315 | 316 | thunkMany := func() ([]interface{}, []error) { 317 | result.mu.RLock() 318 | resultNotSet := result.value == nil 319 | result.mu.RUnlock() 320 | 321 | if resultNotSet { 322 | result.mu.Lock() 323 | if v, ok := <-c; ok { 324 | result.value = v 325 | } 326 | result.mu.Unlock() 327 | } 328 | result.mu.RLock() 329 | defer result.mu.RUnlock() 330 | return result.value.Data, result.value.Error 331 | } 332 | 333 | defer finish(thunkMany) 334 | return thunkMany 335 | } 336 | 337 | // Clear clears the value at `key` from the cache, it it exsits. Returs self for method chaining 338 | func (l *Loader) Clear(ctx context.Context, key Key) Interface { 339 | l.cacheLock.Lock() 340 | l.cache.Delete(ctx, key) 341 | l.cacheLock.Unlock() 342 | return l 343 | } 344 | 345 | // ClearAll clears the entire cache. To be used when some event results in unknown invalidations. 346 | // Returns self for method chaining. 347 | func (l *Loader) ClearAll() Interface { 348 | l.cacheLock.Lock() 349 | l.cache.Clear() 350 | l.cacheLock.Unlock() 351 | return l 352 | } 353 | 354 | // Prime adds the provided key and value to the cache. If the key already exists, no change is made. 355 | // Returns self for method chaining 356 | func (l *Loader) Prime(ctx context.Context, key Key, value interface{}) Interface { 357 | if _, ok := l.cache.Get(ctx, key); !ok { 358 | thunk := func() (interface{}, error) { 359 | return value, nil 360 | } 361 | l.cache.Set(ctx, key, thunk) 362 | } 363 | return l 364 | } 365 | 366 | func (l *Loader) reset() { 367 | l.count = 0 368 | l.curBatcher = nil 369 | 370 | if l.clearCacheOnBatch { 371 | l.cache.Clear() 372 | } 373 | } 374 | 375 | type batcher struct { 376 | input chan *batchRequest 377 | batchFn BatchFunc 378 | finished bool 379 | silent bool 380 | tracer Tracer 381 | } 382 | 383 | // newBatcher returns a batcher for the current requests 384 | // all the batcher methods must be protected by a global batchLock 385 | func (l *Loader) newBatcher(silent bool, tracer Tracer) *batcher { 386 | return &batcher{ 387 | input: make(chan *batchRequest, l.inputCap), 388 | batchFn: l.batchFn, 389 | silent: silent, 390 | tracer: tracer, 391 | } 392 | } 393 | 394 | // stop receiving input and process batch function 395 | func (b *batcher) end() { 396 | if !b.finished { 397 | close(b.input) 398 | b.finished = true 399 | } 400 | } 401 | 402 | // execute the batch of all items in queue 403 | func (b *batcher) batch(originalContext context.Context) { 404 | var ( 405 | keys = make(Keys, 0) 406 | reqs = make([]*batchRequest, 0) 407 | items = make([]*Result, 0) 408 | panicErr interface{} 409 | ) 410 | 411 | for item := range b.input { 412 | keys = append(keys, item.key) 413 | reqs = append(reqs, item) 414 | } 415 | 416 | ctx, finish := b.tracer.TraceBatch(originalContext, keys) 417 | defer finish(items) 418 | 419 | func() { 420 | defer func() { 421 | if r := recover(); r != nil { 422 | panicErr = r 423 | if b.silent { 424 | return 425 | } 426 | const size = 64 << 10 427 | buf := make([]byte, size) 428 | buf = buf[:runtime.Stack(buf, false)] 429 | log.Printf("Dataloader: Panic received in batch function:: %v\n%s", panicErr, buf) 430 | } 431 | }() 432 | items = b.batchFn(ctx, keys) 433 | }() 434 | 435 | if panicErr != nil { 436 | for _, req := range reqs { 437 | req.channel <- &Result{Error: fmt.Errorf("Panic received in batch function: %v", panicErr)} 438 | close(req.channel) 439 | } 440 | return 441 | } 442 | 443 | if len(items) != len(keys) { 444 | err := &Result{Error: fmt.Errorf(` 445 | The batch function supplied did not return an array of responses 446 | the same length as the array of keys. 447 | 448 | Keys: 449 | %v 450 | 451 | Values: 452 | %v 453 | `, keys, items)} 454 | 455 | for _, req := range reqs { 456 | req.channel <- err 457 | close(req.channel) 458 | } 459 | 460 | return 461 | } 462 | 463 | for i, req := range reqs { 464 | req.channel <- items[i] 465 | close(req.channel) 466 | } 467 | } 468 | 469 | // wait the appropriate amount of time for the provided batcher 470 | func (l *Loader) sleeper(b *batcher, close chan bool) { 471 | select { 472 | // used by batch to close early. usually triggered by max batch size 473 | case <-close: 474 | return 475 | // this will move this goroutine to the back of the callstack? 476 | case <-time.After(l.wait): 477 | } 478 | 479 | // reset 480 | // this is protected by the batchLock to avoid closing the batcher input 481 | // channel while Load is inserting a request 482 | l.batchLock.Lock() 483 | b.end() 484 | 485 | // We can end here also if the batcher has already been closed and a 486 | // new one has been created. So reset the loader state only if the batcher 487 | // is the current one 488 | if l.curBatcher == b { 489 | l.reset() 490 | } 491 | l.batchLock.Unlock() 492 | } 493 | -------------------------------------------------------------------------------- /dataloader_test.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "reflect" 9 | "strconv" 10 | "sync" 11 | "testing" 12 | ) 13 | 14 | /////////////////////////////////////////////////// 15 | // Tests 16 | /////////////////////////////////////////////////// 17 | func TestLoader(t *testing.T) { 18 | t.Run("test Load method", func(t *testing.T) { 19 | t.Parallel() 20 | identityLoader, _ := IDLoader(0) 21 | ctx := context.Background() 22 | future := identityLoader.Load(ctx, StringKey("1")) 23 | value, err := future() 24 | if err != nil { 25 | t.Error(err.Error()) 26 | } 27 | if value != "1" { 28 | t.Error("load didn't return the right value") 29 | } 30 | }) 31 | 32 | t.Run("test thunk does not contain race conditions", func(t *testing.T) { 33 | t.Parallel() 34 | identityLoader, _ := IDLoader(0) 35 | ctx := context.Background() 36 | future := identityLoader.Load(ctx, StringKey("1")) 37 | go future() 38 | go future() 39 | }) 40 | 41 | t.Run("test Load Method Panic Safety", func(t *testing.T) { 42 | t.Parallel() 43 | defer func() { 44 | r := recover() 45 | if r != nil { 46 | t.Error("Panic Loader's panic should have been handled'") 47 | } 48 | }() 49 | panicLoader, _ := PanicLoader(0) 50 | ctx := context.Background() 51 | future := panicLoader.Load(ctx, StringKey("1")) 52 | _, err := future() 53 | if err == nil || err.Error() != "Panic received in batch function: Programming error" { 54 | t.Error("Panic was not propagated as an error.") 55 | } 56 | }) 57 | 58 | t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) { 59 | t.Parallel() 60 | defer func() { 61 | r := recover() 62 | if r != nil { 63 | t.Error("Panic Loader's panic should have been handled'") 64 | } 65 | }() 66 | panicLoader, _ := PanicLoader(0) 67 | futures := []Thunk{} 68 | ctx := context.Background() 69 | for i := 0; i < 3; i++ { 70 | futures = append(futures, panicLoader.Load(ctx, StringKey(strconv.Itoa(i)))) 71 | } 72 | for _, f := range futures { 73 | _, err := f() 74 | if err == nil || err.Error() != "Panic received in batch function: Programming error" { 75 | t.Error("Panic was not propagated as an error.") 76 | } 77 | } 78 | }) 79 | 80 | t.Run("test LoadMany returns errors", func(t *testing.T) { 81 | t.Parallel() 82 | errorLoader, _ := ErrorLoader(0) 83 | ctx := context.Background() 84 | future := errorLoader.LoadMany(ctx, Keys{StringKey("1"), StringKey("2"), StringKey("3")}) 85 | _, err := future() 86 | if len(err) != 3 { 87 | t.Error("LoadMany didn't return right number of errors") 88 | } 89 | }) 90 | 91 | t.Run("test LoadMany returns len(errors) == len(keys)", func(t *testing.T) { 92 | t.Parallel() 93 | loader, _ := OneErrorLoader(3) 94 | ctx := context.Background() 95 | future := loader.LoadMany(ctx, Keys{StringKey("1"), StringKey("2"), StringKey("3")}) 96 | _, errs := future() 97 | if len(errs) != 3 { 98 | t.Errorf("LoadMany didn't return right number of errors (should match size of input)") 99 | } 100 | 101 | var errCount int = 0 102 | var nilCount int = 0 103 | for _, err := range errs { 104 | if err == nil { 105 | nilCount++ 106 | } else { 107 | errCount++ 108 | } 109 | } 110 | if errCount != 1 { 111 | t.Error("Expected an error on only one of the items loaded") 112 | } 113 | 114 | if nilCount != 2 { 115 | t.Error("Expected second and third errors to be nil") 116 | } 117 | }) 118 | 119 | t.Run("test LoadMany returns nil []error when no errors occurred", func(t *testing.T) { 120 | t.Parallel() 121 | loader, _ := IDLoader(0) 122 | ctx := context.Background() 123 | _, err := loader.LoadMany(ctx, Keys{StringKey("1"), StringKey("2"), StringKey("3")})() 124 | if err != nil { 125 | t.Errorf("Expected LoadMany() to return nil error slice when no errors occurred") 126 | } 127 | }) 128 | 129 | t.Run("test thunkmany does not contain race conditions", func(t *testing.T) { 130 | t.Parallel() 131 | identityLoader, _ := IDLoader(0) 132 | ctx := context.Background() 133 | future := identityLoader.LoadMany(ctx, Keys{StringKey("1"), StringKey("2"), StringKey("3")}) 134 | go future() 135 | go future() 136 | }) 137 | 138 | t.Run("test Load Many Method Panic Safety", func(t *testing.T) { 139 | t.Parallel() 140 | defer func() { 141 | r := recover() 142 | if r != nil { 143 | t.Error("Panic Loader's panic should have been handled'") 144 | } 145 | }() 146 | panicLoader, _ := PanicLoader(0) 147 | ctx := context.Background() 148 | future := panicLoader.LoadMany(ctx, Keys{StringKey("1")}) 149 | _, errs := future() 150 | if len(errs) < 1 || errs[0].Error() != "Panic received in batch function: Programming error" { 151 | t.Error("Panic was not propagated as an error.") 152 | } 153 | }) 154 | 155 | t.Run("test LoadMany method", func(t *testing.T) { 156 | t.Parallel() 157 | identityLoader, _ := IDLoader(0) 158 | ctx := context.Background() 159 | future := identityLoader.LoadMany(ctx, Keys{StringKey("1"), StringKey("2"), StringKey("3")}) 160 | results, _ := future() 161 | if results[0].(string) != "1" || results[1].(string) != "2" || results[2].(string) != "3" { 162 | t.Error("loadmany didn't return the right value") 163 | } 164 | }) 165 | 166 | t.Run("batches many requests", func(t *testing.T) { 167 | t.Parallel() 168 | identityLoader, loadCalls := IDLoader(0) 169 | ctx := context.Background() 170 | future1 := identityLoader.Load(ctx, StringKey("1")) 171 | future2 := identityLoader.Load(ctx, StringKey("2")) 172 | 173 | _, err := future1() 174 | if err != nil { 175 | t.Error(err.Error()) 176 | } 177 | _, err = future2() 178 | if err != nil { 179 | t.Error(err.Error()) 180 | } 181 | 182 | calls := *loadCalls 183 | inner := []string{"1", "2"} 184 | expected := [][]string{inner} 185 | if !reflect.DeepEqual(calls, expected) { 186 | t.Errorf("did not call batchFn in right order. Expected %#v, got %#v", expected, calls) 187 | } 188 | }) 189 | 190 | t.Run("number of results matches number of keys", func(t *testing.T) { 191 | t.Parallel() 192 | faultyLoader, _ := FaultyLoader() 193 | ctx := context.Background() 194 | 195 | n := 10 196 | reqs := []Thunk{} 197 | keys := Keys{} 198 | for i := 0; i < n; i++ { 199 | key := StringKey(strconv.Itoa(i)) 200 | reqs = append(reqs, faultyLoader.Load(ctx, key)) 201 | keys = append(keys, key) 202 | } 203 | 204 | for _, future := range reqs { 205 | _, err := future() 206 | if err == nil { 207 | t.Error("if number of results doesn't match keys, all keys should contain error") 208 | } 209 | } 210 | 211 | // TODO: expect to get some kind of warning 212 | }) 213 | 214 | t.Run("responds to max batch size", func(t *testing.T) { 215 | t.Parallel() 216 | identityLoader, loadCalls := IDLoader(2) 217 | ctx := context.Background() 218 | future1 := identityLoader.Load(ctx, StringKey("1")) 219 | future2 := identityLoader.Load(ctx, StringKey("2")) 220 | future3 := identityLoader.Load(ctx, StringKey("3")) 221 | 222 | _, err := future1() 223 | if err != nil { 224 | t.Error(err.Error()) 225 | } 226 | _, err = future2() 227 | if err != nil { 228 | t.Error(err.Error()) 229 | } 230 | _, err = future3() 231 | if err != nil { 232 | t.Error(err.Error()) 233 | } 234 | 235 | calls := *loadCalls 236 | inner1 := []string{"1", "2"} 237 | inner2 := []string{"3"} 238 | expected := [][]string{inner1, inner2} 239 | if !reflect.DeepEqual(calls, expected) { 240 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 241 | } 242 | }) 243 | 244 | t.Run("caches repeated requests", func(t *testing.T) { 245 | t.Parallel() 246 | identityLoader, loadCalls := IDLoader(0) 247 | ctx := context.Background() 248 | future1 := identityLoader.Load(ctx, StringKey("1")) 249 | future2 := identityLoader.Load(ctx, StringKey("1")) 250 | 251 | _, err := future1() 252 | if err != nil { 253 | t.Error(err.Error()) 254 | } 255 | _, err = future2() 256 | if err != nil { 257 | t.Error(err.Error()) 258 | } 259 | 260 | calls := *loadCalls 261 | inner := []string{"1"} 262 | expected := [][]string{inner} 263 | if !reflect.DeepEqual(calls, expected) { 264 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 265 | } 266 | }) 267 | 268 | t.Run("allows primed cache", func(t *testing.T) { 269 | t.Parallel() 270 | identityLoader, loadCalls := IDLoader(0) 271 | ctx := context.Background() 272 | identityLoader.Prime(ctx, StringKey("A"), "Cached") 273 | future1 := identityLoader.Load(ctx, StringKey("1")) 274 | future2 := identityLoader.Load(ctx, StringKey("A")) 275 | 276 | _, err := future1() 277 | if err != nil { 278 | t.Error(err.Error()) 279 | } 280 | value, err := future2() 281 | if err != nil { 282 | t.Error(err.Error()) 283 | } 284 | 285 | calls := *loadCalls 286 | inner := []string{"1"} 287 | expected := [][]string{inner} 288 | if !reflect.DeepEqual(calls, expected) { 289 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 290 | } 291 | 292 | if value.(string) != "Cached" { 293 | t.Errorf("did not use primed cache value. Expected '%#v', got '%#v'", "Cached", value) 294 | } 295 | }) 296 | 297 | t.Run("allows clear value in cache", func(t *testing.T) { 298 | t.Parallel() 299 | identityLoader, loadCalls := IDLoader(0) 300 | ctx := context.Background() 301 | identityLoader.Prime(ctx, StringKey("A"), "Cached") 302 | identityLoader.Prime(ctx, StringKey("B"), "B") 303 | future1 := identityLoader.Load(ctx, StringKey("1")) 304 | future2 := identityLoader.Clear(ctx, StringKey("A")).Load(ctx, StringKey("A")) 305 | future3 := identityLoader.Load(ctx, StringKey("B")) 306 | 307 | _, err := future1() 308 | if err != nil { 309 | t.Error(err.Error()) 310 | } 311 | value, err := future2() 312 | if err != nil { 313 | t.Error(err.Error()) 314 | } 315 | _, err = future3() 316 | if err != nil { 317 | t.Error(err.Error()) 318 | } 319 | 320 | calls := *loadCalls 321 | inner := []string{"1", "A"} 322 | expected := [][]string{inner} 323 | if !reflect.DeepEqual(calls, expected) { 324 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 325 | } 326 | 327 | if value != "A" { 328 | t.Errorf("did not use primed cache value. Expected '%#v', got '%#v'", "Cached", value) 329 | } 330 | }) 331 | 332 | t.Run("allows clearAll values in cache", func(t *testing.T) { 333 | t.Parallel() 334 | batchOnlyLoader, loadCalls := BatchOnlyLoader(0) 335 | ctx := context.Background() 336 | future1 := batchOnlyLoader.Load(ctx, StringKey("1")) 337 | future2 := batchOnlyLoader.Load(ctx, StringKey("1")) 338 | 339 | _, err := future1() 340 | if err != nil { 341 | t.Error(err.Error()) 342 | } 343 | _, err = future2() 344 | if err != nil { 345 | t.Error(err.Error()) 346 | } 347 | 348 | calls := *loadCalls 349 | inner := []string{"1"} 350 | expected := [][]string{inner} 351 | if !reflect.DeepEqual(calls, expected) { 352 | t.Errorf("did not batch queries. Expected %#v, got %#v", expected, calls) 353 | } 354 | 355 | if _, found := batchOnlyLoader.cache.Get(ctx, StringKey("1")); found { 356 | t.Errorf("did not clear cache after batch. Expected %#v, got %#v", false, found) 357 | } 358 | }) 359 | 360 | t.Run("allows clearAll values in cache", func(t *testing.T) { 361 | t.Parallel() 362 | identityLoader, loadCalls := IDLoader(0) 363 | ctx := context.Background() 364 | identityLoader.Prime(ctx, StringKey("A"), "Cached") 365 | identityLoader.Prime(ctx, StringKey("B"), "B") 366 | 367 | identityLoader.ClearAll() 368 | 369 | future1 := identityLoader.Load(ctx, StringKey("1")) 370 | future2 := identityLoader.Load(ctx, StringKey("A")) 371 | future3 := identityLoader.Load(ctx, StringKey("B")) 372 | 373 | _, err := future1() 374 | if err != nil { 375 | t.Error(err.Error()) 376 | } 377 | _, err = future2() 378 | if err != nil { 379 | t.Error(err.Error()) 380 | } 381 | _, err = future3() 382 | if err != nil { 383 | t.Error(err.Error()) 384 | } 385 | 386 | calls := *loadCalls 387 | inner := []string{"1", "A", "B"} 388 | expected := [][]string{inner} 389 | if !reflect.DeepEqual(calls, expected) { 390 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 391 | } 392 | }) 393 | 394 | t.Run("all methods on NoCache are Noops", func(t *testing.T) { 395 | t.Parallel() 396 | identityLoader, loadCalls := NoCacheLoader(0) 397 | ctx := context.Background() 398 | identityLoader.Prime(ctx, StringKey("A"), "Cached") 399 | identityLoader.Prime(ctx, StringKey("B"), "B") 400 | 401 | identityLoader.ClearAll() 402 | 403 | future1 := identityLoader.Clear(ctx, StringKey("1")).Load(ctx, StringKey("1")) 404 | future2 := identityLoader.Load(ctx, StringKey("A")) 405 | future3 := identityLoader.Load(ctx, StringKey("B")) 406 | 407 | _, err := future1() 408 | if err != nil { 409 | t.Error(err.Error()) 410 | } 411 | _, err = future2() 412 | if err != nil { 413 | t.Error(err.Error()) 414 | } 415 | _, err = future3() 416 | if err != nil { 417 | t.Error(err.Error()) 418 | } 419 | 420 | calls := *loadCalls 421 | inner := []string{"1", "A", "B"} 422 | expected := [][]string{inner} 423 | if !reflect.DeepEqual(calls, expected) { 424 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 425 | } 426 | }) 427 | 428 | t.Run("no cache does not cache anything", func(t *testing.T) { 429 | t.Parallel() 430 | identityLoader, loadCalls := NoCacheLoader(0) 431 | ctx := context.Background() 432 | identityLoader.Prime(ctx, StringKey("A"), "Cached") 433 | identityLoader.Prime(ctx, StringKey("B"), "B") 434 | 435 | future1 := identityLoader.Load(ctx, StringKey("1")) 436 | future2 := identityLoader.Load(ctx, StringKey("A")) 437 | future3 := identityLoader.Load(ctx, StringKey("B")) 438 | 439 | _, err := future1() 440 | if err != nil { 441 | t.Error(err.Error()) 442 | } 443 | _, err = future2() 444 | if err != nil { 445 | t.Error(err.Error()) 446 | } 447 | _, err = future3() 448 | if err != nil { 449 | t.Error(err.Error()) 450 | } 451 | 452 | calls := *loadCalls 453 | inner := []string{"1", "A", "B"} 454 | expected := [][]string{inner} 455 | if !reflect.DeepEqual(calls, expected) { 456 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 457 | } 458 | }) 459 | 460 | } 461 | 462 | // test helpers 463 | func IDLoader(max int) (*Loader, *[][]string) { 464 | var mu sync.Mutex 465 | var loadCalls [][]string 466 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 467 | var results []*Result 468 | mu.Lock() 469 | loadCalls = append(loadCalls, keys.Keys()) 470 | mu.Unlock() 471 | for _, key := range keys { 472 | results = append(results, &Result{key.String(), nil}) 473 | } 474 | return results 475 | }, WithBatchCapacity(max)) 476 | return identityLoader, &loadCalls 477 | } 478 | func BatchOnlyLoader(max int) (*Loader, *[][]string) { 479 | var mu sync.Mutex 480 | var loadCalls [][]string 481 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 482 | var results []*Result 483 | mu.Lock() 484 | loadCalls = append(loadCalls, keys.Keys()) 485 | mu.Unlock() 486 | for _, key := range keys { 487 | results = append(results, &Result{key, nil}) 488 | } 489 | return results 490 | }, WithBatchCapacity(max), WithClearCacheOnBatch()) 491 | return identityLoader, &loadCalls 492 | } 493 | func ErrorLoader(max int) (*Loader, *[][]string) { 494 | var mu sync.Mutex 495 | var loadCalls [][]string 496 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 497 | var results []*Result 498 | mu.Lock() 499 | loadCalls = append(loadCalls, keys.Keys()) 500 | mu.Unlock() 501 | for _, key := range keys { 502 | results = append(results, &Result{key, fmt.Errorf("this is a test error")}) 503 | } 504 | return results 505 | }, WithBatchCapacity(max)) 506 | return identityLoader, &loadCalls 507 | } 508 | func OneErrorLoader(max int) (*Loader, *[][]string) { 509 | var mu sync.Mutex 510 | var loadCalls [][]string 511 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 512 | results := make([]*Result, max) 513 | mu.Lock() 514 | loadCalls = append(loadCalls, keys.Keys()) 515 | mu.Unlock() 516 | for i := range keys { 517 | var err error 518 | if i == 0 { 519 | err = errors.New("always error on the first key") 520 | } 521 | results[i] = &Result{keys[i], err} 522 | } 523 | return results 524 | }, WithBatchCapacity(max)) 525 | return identityLoader, &loadCalls 526 | } 527 | func PanicLoader(max int) (*Loader, *[][]string) { 528 | var loadCalls [][]string 529 | panicLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 530 | panic("Programming error") 531 | }, WithBatchCapacity(max), withSilentLogger()) 532 | return panicLoader, &loadCalls 533 | } 534 | func BadLoader(max int) (*Loader, *[][]string) { 535 | var mu sync.Mutex 536 | var loadCalls [][]string 537 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 538 | var results []*Result 539 | mu.Lock() 540 | loadCalls = append(loadCalls, keys.Keys()) 541 | mu.Unlock() 542 | results = append(results, &Result{keys[0], nil}) 543 | return results 544 | }, WithBatchCapacity(max)) 545 | return identityLoader, &loadCalls 546 | } 547 | func NoCacheLoader(max int) (*Loader, *[][]string) { 548 | var mu sync.Mutex 549 | var loadCalls [][]string 550 | cache := &NoCache{} 551 | identityLoader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 552 | var results []*Result 553 | mu.Lock() 554 | loadCalls = append(loadCalls, keys.Keys()) 555 | mu.Unlock() 556 | for _, key := range keys { 557 | results = append(results, &Result{key, nil}) 558 | } 559 | return results 560 | }, WithCache(cache), WithBatchCapacity(max)) 561 | return identityLoader, &loadCalls 562 | } 563 | 564 | // FaultyLoader gives len(keys)-1 results. 565 | func FaultyLoader() (*Loader, *[][]string) { 566 | var mu sync.Mutex 567 | var loadCalls [][]string 568 | 569 | loader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result { 570 | var results []*Result 571 | mu.Lock() 572 | loadCalls = append(loadCalls, keys.Keys()) 573 | mu.Unlock() 574 | 575 | lastKeyIndex := len(keys) - 1 576 | for i, key := range keys { 577 | if i == lastKeyIndex { 578 | break 579 | } 580 | 581 | results = append(results, &Result{key, nil}) 582 | } 583 | return results 584 | }) 585 | 586 | return loader, &loadCalls 587 | } 588 | 589 | /////////////////////////////////////////////////// 590 | // Benchmarks 591 | /////////////////////////////////////////////////// 592 | var a = &Avg{} 593 | 594 | func batchIdentity(_ context.Context, keys Keys) (results []*Result) { 595 | a.Add(len(keys)) 596 | for _, key := range keys { 597 | results = append(results, &Result{key, nil}) 598 | } 599 | return 600 | } 601 | 602 | var _ctx = context.Background() 603 | 604 | func BenchmarkLoader(b *testing.B) { 605 | UserLoader := NewBatchedLoader(batchIdentity) 606 | b.ResetTimer() 607 | for i := 0; i < b.N; i++ { 608 | UserLoader.Load(_ctx, StringKey(strconv.Itoa(i))) 609 | } 610 | log.Printf("avg: %f", a.Avg()) 611 | } 612 | 613 | type Avg struct { 614 | total float64 615 | length float64 616 | lock sync.RWMutex 617 | } 618 | 619 | func (a *Avg) Add(v int) { 620 | a.lock.Lock() 621 | a.total += float64(v) 622 | a.length++ 623 | a.lock.Unlock() 624 | } 625 | 626 | func (a *Avg) Avg() float64 { 627 | a.lock.RLock() 628 | defer a.lock.RUnlock() 629 | if a.total == 0 { 630 | return 0 631 | } else if a.length == 0 { 632 | return 0 633 | } 634 | return a.total / a.length 635 | } 636 | --------------------------------------------------------------------------------