├── .gitignore ├── .travis.yml ├── go.mod ├── codecov.yml ├── .github └── workflows │ └── go.yml ├── trace ├── otel │ ├── trace_test.go │ └── trace.go └── opentracing │ ├── trace_test.go │ └── trace.go ├── cache.go ├── LICENSE ├── example ├── no_cache │ └── no_cache_test.go ├── lru_cache │ └── golang_lru_test.go └── ttl_cache │ └── go_cache_test.go ├── trace.go ├── in_memory_cache.go ├── README.md ├── go.sum ├── TRACE.md ├── MIGRATE.md ├── dataloader.go └── dataloader_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor/ 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.18 5 | 6 | env: 7 | - GO111MODULE=on 8 | 9 | script: 10 | - go test -v -race -coverprofile=coverage.txt -covermode=atomic 11 | 12 | after_success: 13 | - bash <(curl -s https://codecov.io/bash) 14 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/graph-gophers/dataloader/v7 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/hashicorp/golang-lru v0.5.4 7 | github.com/opentracing/opentracing-go v1.2.0 8 | github.com/patrickmn/go-cache v2.1.0+incompatible 9 | go.opentelemetry.io/otel v1.6.3 10 | go.opentelemetry.io/otel/trace v1.6.3 11 | ) 12 | 13 | require ( 14 | github.com/go-logr/logr v1.2.3 // indirect 15 | github.com/go-logr/stdr v1.2.2 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | notify: 3 | require_ci_to_pass: true 4 | comment: 5 | behavior: default 6 | layout: header, diff 7 | require_changes: false 8 | coverage: 9 | precision: 2 10 | range: 11 | - 70.0 12 | - 100.0 13 | round: down 14 | status: 15 | changes: false 16 | patch: true 17 | project: true 18 | parsers: 19 | gcov: 20 | branch_detection: 21 | conditional: true 22 | loop: true 23 | macro: false 24 | method: false 25 | javascript: 26 | enable_partials: false 27 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v3 21 | with: 22 | go-version: 1.19 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /trace/otel/trace_test.go: -------------------------------------------------------------------------------- 1 | package otel_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/graph-gophers/dataloader/v7" 7 | "github.com/graph-gophers/dataloader/v7/trace/otel" 8 | ) 9 | 10 | func TestInterfaceImplementation(t *testing.T) { 11 | type User struct { 12 | ID uint 13 | FirstName string 14 | LastName string 15 | Email string 16 | } 17 | var _ dataloader.Tracer[string, int] = otel.Tracer[string, int]{} 18 | var _ dataloader.Tracer[string, string] = otel.Tracer[string, string]{} 19 | var _ dataloader.Tracer[uint, User] = otel.Tracer[uint, User]{} 20 | // check compatibility with loader options 21 | dataloader.WithTracer[uint, User](&otel.Tracer[uint, User]{}) 22 | } 23 | -------------------------------------------------------------------------------- /trace/opentracing/trace_test.go: -------------------------------------------------------------------------------- 1 | package opentracing_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/graph-gophers/dataloader/v7" 7 | "github.com/graph-gophers/dataloader/v7/trace/opentracing" 8 | ) 9 | 10 | func TestInterfaceImplementation(t *testing.T) { 11 | type User struct { 12 | ID uint 13 | FirstName string 14 | LastName string 15 | Email string 16 | } 17 | var _ dataloader.Tracer[string, int] = opentracing.Tracer[string, int]{} 18 | var _ dataloader.Tracer[string, string] = opentracing.Tracer[string, string]{} 19 | var _ dataloader.Tracer[uint, User] = opentracing.Tracer[uint, User]{} 20 | // check compatibility with loader options 21 | dataloader.WithTracer[uint, User](&opentracing.Tracer[uint, User]{}) 22 | } 23 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import "context" 4 | 5 | // The Cache interface. If a custom cache is provided, it must implement this interface. 6 | type Cache[K comparable, V any] interface { 7 | Get(context.Context, K) (Thunk[V], bool) 8 | Set(context.Context, K, Thunk[V]) 9 | Delete(context.Context, K) bool 10 | Clear() 11 | } 12 | 13 | // NoCache implements Cache interface where all methods are noops. 14 | // This is useful for when you don't want to cache items but still 15 | // want to use a data loader 16 | type NoCache[K comparable, V any] struct{} 17 | 18 | // Get is a NOOP 19 | func (c *NoCache[K, V]) Get(context.Context, K) (Thunk[V], bool) { return nil, false } 20 | 21 | // Set is a NOOP 22 | func (c *NoCache[K, V]) Set(context.Context, K, Thunk[V]) { return } 23 | 24 | // Delete is a NOOP 25 | func (c *NoCache[K, V]) Delete(context.Context, K) bool { return false } 26 | 27 | // Clear is a NOOP 28 | func (c *NoCache[K, V]) Clear() { return } 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nick Randall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /example/no_cache/no_cache_test.go: -------------------------------------------------------------------------------- 1 | package no_cache_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | dataloader "github.com/graph-gophers/dataloader/v7" 8 | ) 9 | 10 | func ExampleNoCache() { 11 | type User struct { 12 | ID int 13 | Email string 14 | FirstName string 15 | LastName string 16 | } 17 | 18 | m := map[int]*User{ 19 | 5: {ID: 5, FirstName: "John", LastName: "Smith", Email: "john@example.com"}, 20 | } 21 | 22 | batchFunc := func(_ context.Context, keys []int) []*dataloader.Result[*User] { 23 | var results []*dataloader.Result[*User] 24 | // do some pretend work to resolve keys 25 | for _, k := range keys { 26 | results = append(results, &dataloader.Result[*User]{Data: m[k]}) 27 | } 28 | return results 29 | } 30 | 31 | // go-cache will automatically cleanup expired items on given duration 32 | cache := &dataloader.NoCache[int, *User]{} 33 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache[int, *User](cache)) 34 | 35 | result, err := loader.Load(context.Background(), 5)() 36 | if err != nil { 37 | // handle error 38 | } 39 | 40 | fmt.Printf("result: %+v", result) 41 | // Output: result: &{ID:5 Email:john@example.com FirstName:John LastName:Smith} 42 | } 43 | -------------------------------------------------------------------------------- /trace.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type TraceLoadFinishFunc[V any] func(Thunk[V]) 8 | type TraceLoadManyFinishFunc[V any] func(ThunkMany[V]) 9 | type TraceBatchFinishFunc[V any] func([]*Result[V]) 10 | 11 | // Tracer is an interface that may be used to implement tracing. 12 | type Tracer[K comparable, V any] interface { 13 | // TraceLoad will trace the calls to Load. 14 | TraceLoad(ctx context.Context, key K) (context.Context, TraceLoadFinishFunc[V]) 15 | // TraceLoadMany will trace the calls to LoadMany. 16 | TraceLoadMany(ctx context.Context, keys []K) (context.Context, TraceLoadManyFinishFunc[V]) 17 | // TraceBatch will trace data loader batches. 18 | TraceBatch(ctx context.Context, keys []K) (context.Context, TraceBatchFinishFunc[V]) 19 | } 20 | 21 | // NoopTracer is the default (noop) tracer 22 | type NoopTracer[K comparable, V any] struct{} 23 | 24 | // TraceLoad is a noop function 25 | func (NoopTracer[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, TraceLoadFinishFunc[V]) { 26 | return ctx, func(Thunk[V]) {} 27 | } 28 | 29 | // TraceLoadMany is a noop function 30 | func (NoopTracer[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, TraceLoadManyFinishFunc[V]) { 31 | return ctx, func(ThunkMany[V]) {} 32 | } 33 | 34 | // TraceBatch is a noop function 35 | func (NoopTracer[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, TraceBatchFinishFunc[V]) { 36 | return ctx, func(result []*Result[V]) {} 37 | } 38 | -------------------------------------------------------------------------------- /trace/opentracing/trace.go: -------------------------------------------------------------------------------- 1 | package opentracing 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/graph-gophers/dataloader/v7" 8 | 9 | "github.com/opentracing/opentracing-go" 10 | ) 11 | 12 | // Tracer implements a tracer that can be used with the Open Tracing standard. 13 | type Tracer[K comparable, V any] struct{} 14 | 15 | // TraceLoad will trace a call to dataloader.LoadMany with Open Tracing. 16 | func (Tracer[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, dataloader.TraceLoadFinishFunc[V]) { 17 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: load") 18 | 19 | span.SetTag("dataloader.key", fmt.Sprintf("%v", key)) 20 | 21 | return spanCtx, func(thunk dataloader.Thunk[V]) { 22 | span.Finish() 23 | } 24 | } 25 | 26 | // TraceLoadMany will trace a call to dataloader.LoadMany with Open Tracing. 27 | func (Tracer[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, dataloader.TraceLoadManyFinishFunc[V]) { 28 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: loadmany") 29 | 30 | span.SetTag("dataloader.keys", fmt.Sprintf("%v", keys)) 31 | 32 | return spanCtx, func(thunk dataloader.ThunkMany[V]) { 33 | span.Finish() 34 | } 35 | } 36 | 37 | // TraceBatch will trace a call to dataloader.LoadMany with Open Tracing. 38 | func (Tracer[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, dataloader.TraceBatchFinishFunc[V]) { 39 | span, spanCtx := opentracing.StartSpanFromContext(ctx, "Dataloader: batch") 40 | 41 | span.SetTag("dataloader.keys", fmt.Sprintf("%v", keys)) 42 | 43 | return spanCtx, func(results []*dataloader.Result[V]) { 44 | span.Finish() 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /in_memory_cache.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // InMemoryCache is an in memory implementation of Cache interface. 9 | // This simple implementation is well suited for 10 | // a "per-request" dataloader (i.e. one that only lives 11 | // for the life of an http request) but it's not well suited 12 | // for long lived cached items. 13 | type InMemoryCache[K comparable, V any] struct { 14 | items map[K]Thunk[V] 15 | mu sync.RWMutex 16 | } 17 | 18 | // NewCache constructs a new InMemoryCache 19 | func NewCache[K comparable, V any]() *InMemoryCache[K, V] { 20 | items := make(map[K]Thunk[V]) 21 | return &InMemoryCache[K, V]{ 22 | items: items, 23 | } 24 | } 25 | 26 | // Set sets the `value` at `key` in the cache 27 | func (c *InMemoryCache[K, V]) Set(_ context.Context, key K, value Thunk[V]) { 28 | c.mu.Lock() 29 | c.items[key] = value 30 | c.mu.Unlock() 31 | } 32 | 33 | // Get gets the value at `key` if it exists, returns value (or nil) and bool 34 | // indicating of value was found 35 | func (c *InMemoryCache[K, V]) Get(_ context.Context, key K) (Thunk[V], bool) { 36 | c.mu.RLock() 37 | defer c.mu.RUnlock() 38 | 39 | item, found := c.items[key] 40 | if !found { 41 | return nil, false 42 | } 43 | 44 | return item, true 45 | } 46 | 47 | // Delete deletes item at `key` from cache 48 | func (c *InMemoryCache[K, V]) Delete(ctx context.Context, key K) bool { 49 | if _, found := c.Get(ctx, key); found { 50 | c.mu.Lock() 51 | defer c.mu.Unlock() 52 | delete(c.items, key) 53 | return true 54 | } 55 | return false 56 | } 57 | 58 | // Clear clears the entire cache 59 | func (c *InMemoryCache[K, V]) Clear() { 60 | c.mu.Lock() 61 | c.items = map[K]Thunk[V]{} 62 | c.mu.Unlock() 63 | } 64 | -------------------------------------------------------------------------------- /trace/otel/trace.go: -------------------------------------------------------------------------------- 1 | package otel 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/graph-gophers/dataloader/v7" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/trace" 12 | ) 13 | 14 | // Tracer implements a tracer that can be used with the Open Tracing standard. 15 | type Tracer[K comparable, V any] struct { 16 | tr trace.Tracer 17 | } 18 | 19 | func NewTracer[K comparable, V any](tr trace.Tracer) *Tracer[K, V] { 20 | return &Tracer[K, V]{tr: tr} 21 | } 22 | 23 | func (t *Tracer[K, V]) Tracer() trace.Tracer { 24 | if t.tr != nil { 25 | return t.tr 26 | } 27 | return otel.Tracer("graph-gophers/dataloader") 28 | } 29 | 30 | // TraceLoad will trace a call to dataloader.LoadMany with Open Tracing. 31 | func (t Tracer[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, dataloader.TraceLoadFinishFunc[V]) { 32 | spanCtx, span := t.Tracer().Start(ctx, "Dataloader: load") 33 | 34 | span.SetAttributes(attribute.String("dataloader.key", fmt.Sprintf("%v", key))) 35 | 36 | return spanCtx, func(thunk dataloader.Thunk[V]) { 37 | span.End() 38 | } 39 | } 40 | 41 | // TraceLoadMany will trace a call to dataloader.LoadMany with Open Tracing. 42 | func (t Tracer[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, dataloader.TraceLoadManyFinishFunc[V]) { 43 | spanCtx, span := t.Tracer().Start(ctx, "Dataloader: loadmany") 44 | 45 | span.SetAttributes(attribute.String("dataloader.keys", fmt.Sprintf("%v", keys))) 46 | 47 | return spanCtx, func(thunk dataloader.ThunkMany[V]) { 48 | span.End() 49 | } 50 | } 51 | 52 | // TraceBatch will trace a call to dataloader.LoadMany with Open Tracing. 53 | func (t Tracer[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, dataloader.TraceBatchFinishFunc[V]) { 54 | spanCtx, span := t.Tracer().Start(ctx, "Dataloader: batch") 55 | 56 | span.SetAttributes(attribute.String("dataloader.keys", fmt.Sprintf("%v", keys))) 57 | 58 | return spanCtx, func(results []*dataloader.Result[V]) { 59 | span.End() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DataLoader 2 | [![GoDoc](https://godoc.org/gopkg.in/graph-gophers/dataloader.v7?status.svg)](https://pkg.go.dev/github.com/graph-gophers/dataloader/v7) 3 | [![Build Status](https://travis-ci.org/graph-gophers/dataloader.svg?branch=master)](https://travis-ci.org/graph-gophers/dataloader) 4 | 5 | This is an implementation of [Facebook's DataLoader](https://github.com/facebook/dataloader) in Golang. 6 | 7 | ## Install 8 | `go get -u github.com/graph-gophers/dataloader/v7` 9 | 10 | ## Usage 11 | ```go 12 | // setup batch function - the first Context passed to the Loader's Load 13 | // function will be provided when the batch function is called. 14 | // this function is registered with the Loader, and the key and value are fixed using generics. 15 | batchFn := func(ctx context.Context, keys []int) []*dataloader.Result[*User] { 16 | var results []*dataloader.Result[*User] 17 | // do some async work to get data for specified keys 18 | // append to this list resolved values 19 | return results 20 | } 21 | 22 | // create Loader with an in-memory cache 23 | loader := dataloader.NewBatchedLoader(batchFn) 24 | 25 | /** 26 | * Use loader 27 | * 28 | * A thunk is a function returned from a function that is a 29 | * closure over a value (in this case an interface value and error). 30 | * When called, it will block until the value is resolved. 31 | * 32 | * loader.Load() may be called multiple times for a given batch window. 33 | * The first context passed to Load is the object that will be passed 34 | * to the batch function. 35 | */ 36 | thunk := loader.Load(context.TODO(), 5) 37 | result, err := thunk() 38 | if err != nil { 39 | // handle data error 40 | } 41 | 42 | log.Printf("value: %#v", result) 43 | ``` 44 | 45 | ### Don't need/want to use context? 46 | You're welcome to install the v1 version of this library. 47 | 48 | ## Cache 49 | This implementation contains a very basic cache that is intended only to be used for short lived DataLoaders (i.e. DataLoaders that only exist for the life of an http request). You may use your own implementation if you want. 50 | 51 | > it also has a `NoCache` type that implements the cache interface but all methods are noop. If you do not wish to cache anything. 52 | 53 | ## Examples 54 | There are a few basic examples in the example folder. 55 | 56 | ## See also 57 | - [TRACE](TRACE.md) 58 | - [MIGRATE](MIGRATE.md) 59 | -------------------------------------------------------------------------------- /example/lru_cache/golang_lru_test.go: -------------------------------------------------------------------------------- 1 | // package lru_cache_test contains an exmaple of using go-cache as a long term cache solution for dataloader. 2 | package lru_cache_test 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | 8 | dataloader "github.com/graph-gophers/dataloader/v7" 9 | 10 | lru "github.com/hashicorp/golang-lru" 11 | ) 12 | 13 | // Cache implements the dataloader.Cache interface 14 | type cache[K comparable, V any] struct { 15 | *lru.ARCCache 16 | } 17 | 18 | // Get gets an item from the cache 19 | func (c *cache[K, V]) Get(_ context.Context, key K) (dataloader.Thunk[V], bool) { 20 | v, ok := c.ARCCache.Get(key) 21 | if ok { 22 | return v.(dataloader.Thunk[V]), ok 23 | } 24 | return nil, ok 25 | } 26 | 27 | // Set sets an item in the cache 28 | func (c *cache[K, V]) Set(_ context.Context, key K, value dataloader.Thunk[V]) { 29 | c.ARCCache.Add(key, value) 30 | } 31 | 32 | // Delete deletes an item in the cache 33 | func (c *cache[K, V]) Delete(_ context.Context, key K) bool { 34 | if c.ARCCache.Contains(key) { 35 | c.ARCCache.Remove(key) 36 | return true 37 | } 38 | return false 39 | } 40 | 41 | // Clear clears the cache 42 | func (c *cache[K, V]) Clear() { 43 | c.ARCCache.Purge() 44 | } 45 | 46 | func ExampleGolangLRU() { 47 | type User struct { 48 | ID int 49 | Email string 50 | FirstName string 51 | LastName string 52 | } 53 | 54 | m := map[int]*User{ 55 | 5: {ID: 5, FirstName: "John", LastName: "Smith", Email: "john@example.com"}, 56 | } 57 | 58 | batchFunc := func(_ context.Context, keys []int) []*dataloader.Result[*User] { 59 | var results []*dataloader.Result[*User] 60 | // do some pretend work to resolve keys 61 | for _, k := range keys { 62 | results = append(results, &dataloader.Result[*User]{Data: m[k]}) 63 | } 64 | return results 65 | } 66 | 67 | // go-cache will automatically cleanup expired items on given duration. 68 | c, _ := lru.NewARC(100) 69 | cache := &cache[int, *User]{ARCCache: c} 70 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache[int, *User](cache)) 71 | 72 | // immediately call the future function from loader 73 | result, err := loader.Load(context.TODO(), 5)() 74 | if err != nil { 75 | // handle error 76 | } 77 | 78 | fmt.Printf("result: %+v", result) 79 | // Output: result: &{ID:5 Email:john@example.com FirstName:John LastName:Smith} 80 | } 81 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 4 | github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= 5 | github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 6 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 7 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 8 | github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= 9 | github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= 10 | github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= 11 | github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 12 | github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= 13 | github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= 14 | github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= 15 | github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 19 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 20 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 21 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 22 | go.opentelemetry.io/otel v1.6.3 h1:FLOfo8f9JzFVFVyU+MSRJc2HdEAXQgm7pIv2uFKRSZE= 23 | go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= 24 | go.opentelemetry.io/otel/trace v1.6.3 h1:IqN4L+5b0mPNjdXIiZ90Ni4Bl5BRkDQywePLWemd9bc= 25 | go.opentelemetry.io/otel/trace v1.6.3/go.mod h1:GNJQusJlUgZl9/TQBPKU/Y/ty+0iVB5fjhKeJGZPGFs= 26 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 29 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 30 | -------------------------------------------------------------------------------- /example/ttl_cache/go_cache_test.go: -------------------------------------------------------------------------------- 1 | // package ttl_cache_test contains an example of using go-cache as a long term cache solution for dataloader. 2 | package ttl_cache_test 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "time" 8 | 9 | dataloader "github.com/graph-gophers/dataloader/v7" 10 | 11 | cache "github.com/patrickmn/go-cache" 12 | ) 13 | 14 | // Cache implements the dataloader.Cache interface 15 | type Cache[K comparable, V any] struct { 16 | c *cache.Cache 17 | } 18 | 19 | // Get gets a value from the cache 20 | func (c *Cache[K, V]) Get(_ context.Context, key K) (dataloader.Thunk[V], bool) { 21 | k := fmt.Sprintf("%v", key) // convert the key to string because the underlying library doesn't support Generics yet 22 | v, ok := c.c.Get(k) 23 | if ok { 24 | return v.(dataloader.Thunk[V]), ok 25 | } 26 | return nil, ok 27 | } 28 | 29 | // Set sets a value in the cache 30 | func (c *Cache[K, V]) Set(_ context.Context, key K, value dataloader.Thunk[V]) { 31 | k := fmt.Sprintf("%v", key) // convert the key to string because the underlying library doesn't support Generics yet 32 | c.c.Set(k, value, 0) 33 | } 34 | 35 | // Delete deletes and item in the cache 36 | func (c *Cache[K, V]) Delete(_ context.Context, key K) bool { 37 | k := fmt.Sprintf("%v", key) // convert the key to string because the underlying library doesn't support Generics yet 38 | if _, found := c.c.Get(k); found { 39 | c.c.Delete(k) 40 | return true 41 | } 42 | return false 43 | } 44 | 45 | // Clear clears the cache 46 | func (c *Cache[K, V]) Clear() { 47 | c.c.Flush() 48 | } 49 | 50 | func ExampleTTLCache() { 51 | type User struct { 52 | ID int 53 | Email string 54 | FirstName string 55 | LastName string 56 | } 57 | 58 | m := map[int]*User{ 59 | 5: {ID: 5, FirstName: "John", LastName: "Smith", Email: "john@example.com"}, 60 | } 61 | 62 | batchFunc := func(_ context.Context, keys []int) []*dataloader.Result[*User] { 63 | var results []*dataloader.Result[*User] 64 | // do some pretend work to resolve keys 65 | for _, k := range keys { 66 | results = append(results, &dataloader.Result[*User]{Data: m[k]}) 67 | } 68 | return results 69 | } 70 | 71 | // go-cache will automatically cleanup expired items on given duration 72 | c := cache.New(15*time.Minute, 15*time.Minute) 73 | cache := &Cache[int, *User]{c} 74 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithCache[int, *User](cache)) 75 | 76 | // immediately call the future function from loader 77 | result, err := loader.Load(context.Background(), 5)() 78 | if err != nil { 79 | // handle error 80 | } 81 | 82 | fmt.Printf("result: %+v", result) 83 | // Output: result: &{ID:5 Email:john@example.com FirstName:John LastName:Smith} 84 | } 85 | -------------------------------------------------------------------------------- /TRACE.md: -------------------------------------------------------------------------------- 1 | # Adding a new trace backend. 2 | 3 | If you want to add a new tracing backend all you need to do is implement the 4 | `Tracer` interface and pass it as an option to the dataloader on initialization. 5 | 6 | As an example, this is how you could implement it to an OpenCensus backend. 7 | 8 | ```go 9 | package main 10 | 11 | import ( 12 | "context" 13 | "strings" 14 | 15 | "github.com/graph-gophers/dataloader/v7" 16 | exp "go.opencensus.io/examples/exporter" 17 | "go.opencensus.io/trace" 18 | ) 19 | 20 | type User struct { 21 | ID string 22 | } 23 | 24 | // OpenCensusTracer Tracer implements a tracer that can be used with the Open Tracing standard. 25 | type OpenCensusTracer struct{} 26 | 27 | // TraceLoad will trace a call to dataloader.LoadMany with Open Tracing 28 | func (OpenCensusTracer) TraceLoad(ctx context.Context, key string) (context.Context, dataloader.TraceLoadFinishFunc[*User]) { 29 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: load") 30 | cSpan.AddAttributes( 31 | trace.StringAttribute("dataloader.key", key), 32 | ) 33 | return cCtx, func(thunk dataloader.Thunk[*User]) { 34 | // TODO: is there anything we should do with the results? 35 | cSpan.End() 36 | } 37 | } 38 | 39 | // TraceLoadMany will trace a call to dataloader.LoadMany with Open Tracing 40 | func (OpenCensusTracer) TraceLoadMany(ctx context.Context, keys []string) (context.Context, dataloader.TraceLoadManyFinishFunc[*User]) { 41 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: loadmany") 42 | cSpan.AddAttributes( 43 | trace.StringAttribute("dataloader.keys", strings.Join(keys, ",")), 44 | ) 45 | return cCtx, func(thunk dataloader.ThunkMany[*User]) { 46 | // TODO: is there anything we should do with the results? 47 | cSpan.End() 48 | } 49 | } 50 | 51 | // TraceBatch will trace a call to dataloader.LoadMany with Open Tracing 52 | func (OpenCensusTracer) TraceBatch(ctx context.Context, keys []string) (context.Context, dataloader.TraceBatchFinishFunc[*User]) { 53 | cCtx, cSpan := trace.StartSpan(ctx, "Dataloader: batch") 54 | cSpan.AddAttributes( 55 | trace.StringAttribute("dataloader.keys", strings.Join(keys, ",")), 56 | ) 57 | return cCtx, func(results []*dataloader.Result[*User]) { 58 | // TODO: is there anything we should do with the results? 59 | cSpan.End() 60 | } 61 | } 62 | 63 | func batchFunc(ctx context.Context, keys []string) []*dataloader.Result[*User] { 64 | // ...loader logic goes here 65 | } 66 | 67 | func main() { 68 | //initialize an example exporter that just logs to the console 69 | trace.ApplyConfig(trace.Config{ 70 | DefaultSampler: trace.AlwaysSample(), 71 | }) 72 | trace.RegisterExporter(&exp.PrintExporter{}) 73 | // initialize the dataloader with your new tracer backend 74 | loader := dataloader.NewBatchedLoader(batchFunc, dataloader.WithTracer[string, *User](OpenCensusTracer{})) 75 | // initialize a context since it's not receiving one from anywhere else. 76 | ctx, span := trace.StartSpan(context.TODO(), "Span Name") 77 | defer span.End() 78 | // request from the dataloader as usual 79 | value, err := loader.Load(ctx, SomeID)() 80 | // ... 81 | } 82 | ``` 83 | 84 | Don't forget to initialize the exporters of your choice and register it with `trace.RegisterExporter(&exporterInstance)`. 85 | -------------------------------------------------------------------------------- /MIGRATE.md: -------------------------------------------------------------------------------- 1 | ## Upgrade from v1 to v2 2 | The only difference between v1 and v2 is that we added use of [context](https://golang.org/pkg/context). 3 | 4 | ```diff 5 | - loader.Load(key string) Thunk 6 | + loader.Load(ctx context.Context, key string) Thunk 7 | - loader.LoadMany(keys []string) ThunkMany 8 | + loader.LoadMany(ctx context.Context, keys []string) ThunkMany 9 | ``` 10 | 11 | ```diff 12 | - type BatchFunc func([]string) []*Result 13 | + type BatchFunc func(context.Context, []string) []*Result 14 | ``` 15 | 16 | ## Upgrade from v2 to v3 17 | ```diff 18 | // dataloader.Interface as added context.Context to methods 19 | - loader.Prime(key string, value interface{}) Interface 20 | + loader.Prime(ctx context.Context, key string, value interface{}) Interface 21 | - loader.Clear(key string) Interface 22 | + loader.Clear(ctx context.Context, key string) Interface 23 | ``` 24 | 25 | ```diff 26 | // cache interface as added context.Context to methods 27 | type Cache interface { 28 | - Get(string) (Thunk, bool) 29 | + Get(context.Context, string) (Thunk, bool) 30 | - Set(string, Thunk) 31 | + Set(context.Context, string, Thunk) 32 | - Delete(string) bool 33 | + Delete(context.Context, string) bool 34 | Clear() 35 | } 36 | ``` 37 | 38 | ## Upgrade from v3 to v4 39 | ```diff 40 | // dataloader.Interface as now allows interace{} as key rather than string 41 | - loader.Load(context.Context, key string) Thunk 42 | + loader.Load(ctx context.Context, key interface{}) Thunk 43 | - loader.LoadMany(context.Context, key []string) ThunkMany 44 | + loader.LoadMany(ctx context.Context, keys []interface{}) ThunkMany 45 | - loader.Prime(context.Context, key string, value interface{}) Interface 46 | + loader.Prime(ctx context.Context, key interface{}, value interface{}) Interface 47 | - loader.Clear(context.Context, key string) Interface 48 | + loader.Clear(ctx context.Context, key interface{}) Interface 49 | ``` 50 | 51 | ```diff 52 | // cache interface now allows interface{} as key instead of string 53 | type Cache interface { 54 | - Get(context.Context, string) (Thunk, bool) 55 | + Get(context.Context, interface{}) (Thunk, bool) 56 | - Set(context.Context, string, Thunk) 57 | + Set(context.Context, interface{}, Thunk) 58 | - Delete(context.Context, string) bool 59 | + Delete(context.Context, interface{}) bool 60 | Clear() 61 | } 62 | ``` 63 | 64 | ## Upgrade from v4 to v5 65 | ```diff 66 | // dataloader.Interface as now allows interace{} as key rather than string 67 | - loader.Load(context.Context, key interface{}) Thunk 68 | + loader.Load(ctx context.Context, key Key) Thunk 69 | - loader.LoadMany(context.Context, key []interface{}) ThunkMany 70 | + loader.LoadMany(ctx context.Context, keys Keys) ThunkMany 71 | - loader.Prime(context.Context, key interface{}, value interface{}) Interface 72 | + loader.Prime(ctx context.Context, key Key, value interface{}) Interface 73 | - loader.Clear(context.Context, key interface{}) Interface 74 | + loader.Clear(ctx context.Context, key Key) Interface 75 | ``` 76 | 77 | ```diff 78 | // cache interface now allows interface{} as key instead of string 79 | type Cache interface { 80 | - Get(context.Context, interface{}) (Thunk, bool) 81 | + Get(context.Context, Key) (Thunk, bool) 82 | - Set(context.Context, interface{}, Thunk) 83 | + Set(context.Context, Key, Thunk) 84 | - Delete(context.Context, interface{}) bool 85 | + Delete(context.Context, Key) bool 86 | Clear() 87 | } 88 | ``` 89 | 90 | ## Upgrade from v5 to v6 91 | 92 | We add major version release because we switched to using Go Modules from dep, 93 | and drop build tags for older versions of Go (1.9). 94 | 95 | The preferred import method includes the major version tag. 96 | 97 | ```go 98 | import "github.com/graph-gophers/dataloader/v6" 99 | ``` 100 | 101 | ## Upgrade from v6 to v7 102 | 103 | [Generics](https://go.dev/doc/tutorial/generics) support has been added. 104 | With this update, you can now write more type-safe code. 105 | 106 | Use the major version tag in the import path. 107 | 108 | ```go 109 | import "github.com/graph-gophers/dataloader/v7" 110 | ``` 111 | -------------------------------------------------------------------------------- /dataloader.go: -------------------------------------------------------------------------------- 1 | // Package dataloader is an implementation of facebook's dataloader in go. 2 | // See https://github.com/facebook/dataloader for more information 3 | package dataloader 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "log" 10 | "runtime" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | ) 15 | 16 | // Interface is a `DataLoader` Interface which defines a public API for loading data from a particular 17 | // data back-end with unique keys such as the `id` column of a SQL table or 18 | // document name in a MongoDB database, given a batch loading function. 19 | // 20 | // Each `DataLoader` instance should contain a unique memoized cache. Use caution when 21 | // used in long-lived applications or those which serve many users with 22 | // different access permissions and consider creating a new instance per 23 | // web request. 24 | type Interface[K comparable, V any] interface { 25 | Load(context.Context, K) Thunk[V] 26 | LoadMany(context.Context, []K) ThunkMany[V] 27 | Clear(context.Context, K) Interface[K, V] 28 | ClearAll() Interface[K, V] 29 | Prime(ctx context.Context, key K, value V) Interface[K, V] 30 | Flush() 31 | } 32 | 33 | var ErrNoResultProvided = errors.New("no result provided") 34 | 35 | // BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`. 36 | // It's important that the length of the input keys matches the length of the output results. 37 | // Should the batch function return nil for a result, it will be treated as return an error 38 | // of `ErrNoResultProvided` for that key. 39 | // 40 | // The keys passed to this function are guaranteed to be unique 41 | type BatchFunc[K comparable, V any] func(context.Context, []K) []*Result[V] 42 | 43 | // Result is the data structure that a BatchFunc returns. 44 | // It contains the resolved data, and any errors that may have occurred while fetching the data. 45 | type Result[V any] struct { 46 | Data V 47 | Error error 48 | } 49 | 50 | // ResultMany is used by the LoadMany method. 51 | // It contains a list of resolved data and a list of errors. 52 | // The lengths of the data list and error list will match, and elements at each index correspond to each other. 53 | type ResultMany[V any] struct { 54 | Data []V 55 | Error []error 56 | } 57 | 58 | // PanicErrorWrapper wraps the error interface. 59 | // This is used to check if the error is a panic error. 60 | // We should not cache panic errors. 61 | type PanicErrorWrapper struct { 62 | panicError error 63 | } 64 | 65 | func (p *PanicErrorWrapper) Error() string { 66 | return p.panicError.Error() 67 | } 68 | 69 | // SkipCacheError wraps the error interface. 70 | // The cache should not store SkipCacheErrors. 71 | type SkipCacheError struct { 72 | err error 73 | } 74 | 75 | func (s *SkipCacheError) Error() string { 76 | return s.err.Error() 77 | } 78 | 79 | func (s *SkipCacheError) Unwrap() error { 80 | return s.err 81 | } 82 | 83 | func NewSkipCacheError(err error) *SkipCacheError { 84 | return &SkipCacheError{err: err} 85 | } 86 | 87 | // Loader implements the dataloader.Interface. 88 | type Loader[K comparable, V any] struct { 89 | // the batch function to be used by this loader 90 | batchFn BatchFunc[K, V] 91 | 92 | // the maximum batch size. Set to 0 if you want it to be unbounded. 93 | batchCap int 94 | 95 | // the internal cache. This packages contains a basic cache implementation but any custom cache 96 | // implementation could be used as long as it implements the `Cache` interface. 97 | cacheLock sync.Mutex 98 | cache Cache[K, V] 99 | // should we clear the cache on each batch? 100 | // this would allow batching but no long term caching 101 | clearCacheOnBatch bool 102 | 103 | // count of queued up items 104 | count int 105 | 106 | // the maximum input queue size. Set to 0 if you want it to be unbounded. 107 | inputCap int 108 | 109 | // the amount of time to wait before triggering a batch 110 | wait time.Duration 111 | 112 | // lock to protect the batching operations 113 | batchLock sync.Mutex 114 | 115 | // current batcher 116 | curBatcher *batcher[K, V] 117 | 118 | // used to close the sleeper of the current batcher 119 | endSleeper chan bool 120 | 121 | // used by tests to prevent logs 122 | silent bool 123 | 124 | // can be set to trace calls to dataloader 125 | tracer Tracer[K, V] 126 | } 127 | 128 | // Thunk is a function that will block until the value (*Result) it contains is resolved. 129 | // After the value it contains is resolved, this function will return the result. 130 | // This function can be called many times, much like a Promise is other languages. 131 | // The value will only need to be resolved once so subsequent calls will return immediately. 132 | type Thunk[V any] func() (V, error) 133 | 134 | // ThunkMany is much like the Thunk func type but it contains a list of results. 135 | type ThunkMany[V any] func() ([]V, []error) 136 | 137 | // type used to on input channel 138 | type batchRequest[K comparable, V any] struct { 139 | key K 140 | result atomic.Pointer[Result[V]] 141 | done chan struct{} 142 | } 143 | 144 | // Option allows for configuration of Loader fields. 145 | type Option[K comparable, V any] func(*Loader[K, V]) 146 | 147 | // WithCache sets the BatchedLoader cache. Defaults to InMemoryCache if a Cache is not set. 148 | func WithCache[K comparable, V any](c Cache[K, V]) Option[K, V] { 149 | return func(l *Loader[K, V]) { 150 | l.cache = c 151 | } 152 | } 153 | 154 | // WithBatchCapacity sets the batch capacity. Default is 0 (unbounded). 155 | func WithBatchCapacity[K comparable, V any](c int) Option[K, V] { 156 | return func(l *Loader[K, V]) { 157 | l.batchCap = c 158 | } 159 | } 160 | 161 | // WithInputCapacity sets the input capacity. Default is 1000. 162 | func WithInputCapacity[K comparable, V any](c int) Option[K, V] { 163 | return func(l *Loader[K, V]) { 164 | l.inputCap = c 165 | } 166 | } 167 | 168 | // WithWait sets the amount of time to wait before triggering a batch. 169 | // Default duration is 16 milliseconds. 170 | func WithWait[K comparable, V any](d time.Duration) Option[K, V] { 171 | return func(l *Loader[K, V]) { 172 | l.wait = d 173 | } 174 | } 175 | 176 | // WithClearCacheOnBatch allows batching of items but no long term caching. 177 | // It accomplishes this by clearing the cache after each batch operation. 178 | func WithClearCacheOnBatch[K comparable, V any]() Option[K, V] { 179 | return func(l *Loader[K, V]) { 180 | l.cacheLock.Lock() 181 | l.clearCacheOnBatch = true 182 | l.cacheLock.Unlock() 183 | } 184 | } 185 | 186 | // withSilentLogger turns of log messages. It's used by the tests 187 | func withSilentLogger[K comparable, V any]() Option[K, V] { 188 | return func(l *Loader[K, V]) { 189 | l.silent = true 190 | } 191 | } 192 | 193 | // WithTracer allows tracing of calls to Load and LoadMany 194 | func WithTracer[K comparable, V any](tracer Tracer[K, V]) Option[K, V] { 195 | return func(l *Loader[K, V]) { 196 | l.tracer = tracer 197 | } 198 | } 199 | 200 | // NewBatchedLoader constructs a new Loader with given options. 201 | func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts ...Option[K, V]) *Loader[K, V] { 202 | loader := &Loader[K, V]{ 203 | batchFn: batchFn, 204 | inputCap: 1000, 205 | wait: 16 * time.Millisecond, 206 | } 207 | 208 | // Apply options 209 | for _, apply := range opts { 210 | apply(loader) 211 | } 212 | 213 | // Set defaults 214 | if loader.cache == nil { 215 | loader.cache = NewCache[K, V]() 216 | } 217 | 218 | if loader.tracer == nil { 219 | loader.tracer = NoopTracer[K, V]{} 220 | } 221 | 222 | return loader 223 | } 224 | 225 | // Load load/resolves the given key, returning a channel that will contain the value and error. 226 | // The first context passed to this function within a given batch window will be provided to 227 | // the registered BatchFunc. 228 | func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { 229 | ctx, finish := l.tracer.TraceLoad(originalContext, key) 230 | req := &batchRequest[K, V]{ 231 | key: key, 232 | done: make(chan struct{}), 233 | } 234 | 235 | // We need to lock both the batchLock and cacheLock because the batcher can 236 | // reset the cache when either the batchCap or the wait time is reached. 237 | // 238 | // When we would only lock the cacheLock while doing l.cache.Get and/or 239 | // l.cache.Set, it could be that the batcher resets the cache after those 240 | // operations have finished but before the new request (if any) is send to the 241 | // batcher. 242 | // 243 | // In that case it is no longer guaranteed that the keys passed to the BatchFunc 244 | // function are unique as the cache has been reset so if the same key is 245 | // requested again before the new batcher is started, the same key will be 246 | // send to the batcher again causing unexpected behavior in the BatchFunc. 247 | l.batchLock.Lock() 248 | l.cacheLock.Lock() 249 | 250 | if v, ok := l.cache.Get(ctx, key); ok { 251 | l.cacheLock.Unlock() 252 | l.batchLock.Unlock() 253 | defer finish(v) 254 | return v 255 | } 256 | 257 | defer l.batchLock.Unlock() 258 | defer l.cacheLock.Unlock() 259 | 260 | thunk := func() (V, error) { 261 | <-req.done 262 | result := req.result.Load() 263 | var ev *PanicErrorWrapper 264 | var es *SkipCacheError 265 | if result.Error != nil && (errors.As(result.Error, &ev) || errors.As(result.Error, &es)) { 266 | l.Clear(ctx, key) 267 | } 268 | return result.Data, result.Error 269 | } 270 | defer finish(thunk) 271 | 272 | l.cache.Set(ctx, key, thunk) 273 | 274 | // start the batch window if it hasn't already started. 275 | if l.curBatcher == nil { 276 | l.curBatcher = l.newBatcher(l.silent, l.tracer) 277 | // start the current batcher batch function 278 | go l.curBatcher.batch(originalContext) 279 | // start a sleeper for the current batcher 280 | l.endSleeper = make(chan bool) 281 | go l.sleeper(l.curBatcher, l.endSleeper) 282 | } 283 | 284 | l.curBatcher.input <- req 285 | 286 | // if we need to keep track of the count (max batch), then do so. 287 | if l.batchCap > 0 { 288 | l.count++ 289 | // if we hit our limit, force the batch to start 290 | if l.count == l.batchCap { 291 | // end/flush the batcher synchronously here because another call to Load 292 | // may concurrently happen and needs to go to a new batcher. 293 | l.flush() 294 | } 295 | } 296 | 297 | return thunk 298 | } 299 | 300 | // flush() is a helper that runs whatever batched items there are immediately. 301 | // it must be called by code protected by a l.batchLock.Lock() 302 | func (l *Loader[K, V]) flush() { 303 | l.curBatcher.end() 304 | 305 | // end the sleeper for the current batcher. 306 | // this is to stop the goroutine without waiting for the 307 | // sleeper timeout. 308 | close(l.endSleeper) 309 | l.reset() 310 | } 311 | 312 | // Flush will load the items in the current batch immediately without waiting for the timer. 313 | func (l *Loader[K, V]) Flush() { 314 | l.batchLock.Lock() 315 | defer l.batchLock.Unlock() 316 | if l.curBatcher == nil { 317 | return 318 | } 319 | l.flush() 320 | } 321 | 322 | // LoadMany loads multiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in. 323 | func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) ThunkMany[V] { 324 | ctx, finish := l.tracer.TraceLoadMany(originalContext, keys) 325 | 326 | var ( 327 | length = len(keys) 328 | data = make([]V, length) 329 | errors = make([]error, length) 330 | result atomic.Pointer[ResultMany[V]] 331 | wg sync.WaitGroup 332 | done = make(chan struct{}) 333 | ) 334 | 335 | resolve := func(ctx context.Context, i int) { 336 | defer wg.Done() 337 | thunk := l.Load(ctx, keys[i]) 338 | result, err := thunk() 339 | data[i] = result 340 | errors[i] = err 341 | } 342 | 343 | wg.Add(length) 344 | for i := range keys { 345 | go resolve(ctx, i) 346 | } 347 | 348 | go func() { 349 | defer close(done) 350 | wg.Wait() 351 | 352 | // errs is nil unless there exists a non-nil error. 353 | // This prevents dataloader from returning a slice of all-nil errors. 354 | var errs []error 355 | for _, e := range errors { 356 | if e != nil { 357 | errs = errors 358 | break 359 | } 360 | } 361 | 362 | result.Store(&ResultMany[V]{Data: data, Error: errs}) 363 | }() 364 | 365 | thunkMany := func() ([]V, []error) { 366 | <-done 367 | r := result.Load() 368 | return r.Data, r.Error 369 | } 370 | 371 | defer finish(thunkMany) 372 | return thunkMany 373 | } 374 | 375 | // Clear clears the value at `key` from the cache, it it exists. Returns self for method chaining 376 | func (l *Loader[K, V]) Clear(ctx context.Context, key K) Interface[K, V] { 377 | l.cacheLock.Lock() 378 | l.cache.Delete(ctx, key) 379 | l.cacheLock.Unlock() 380 | return l 381 | } 382 | 383 | // ClearAll clears the entire cache. To be used when some event results in unknown invalidations. 384 | // Returns self for method chaining. 385 | func (l *Loader[K, V]) ClearAll() Interface[K, V] { 386 | l.cacheLock.Lock() 387 | l.cache.Clear() 388 | l.cacheLock.Unlock() 389 | return l 390 | } 391 | 392 | // Prime adds the provided key and value to the cache. If the key already exists, no change is made. 393 | // Returns self for method chaining 394 | func (l *Loader[K, V]) Prime(ctx context.Context, key K, value V) Interface[K, V] { 395 | if _, ok := l.cache.Get(ctx, key); !ok { 396 | thunk := func() (V, error) { 397 | return value, nil 398 | } 399 | l.cache.Set(ctx, key, thunk) 400 | } 401 | return l 402 | } 403 | 404 | func (l *Loader[K, V]) reset() { 405 | l.count = 0 406 | l.curBatcher = nil 407 | 408 | if l.clearCacheOnBatch { 409 | l.cache.Clear() 410 | } 411 | } 412 | 413 | type batcher[K comparable, V any] struct { 414 | input chan *batchRequest[K, V] 415 | batchFn BatchFunc[K, V] 416 | finished bool 417 | silent bool 418 | tracer Tracer[K, V] 419 | } 420 | 421 | // newBatcher returns a batcher for the current requests 422 | // all the batcher methods must be protected by a global batchLock 423 | func (l *Loader[K, V]) newBatcher(silent bool, tracer Tracer[K, V]) *batcher[K, V] { 424 | return &batcher[K, V]{ 425 | input: make(chan *batchRequest[K, V], l.inputCap), 426 | batchFn: l.batchFn, 427 | silent: silent, 428 | tracer: tracer, 429 | } 430 | } 431 | 432 | // stop receiving input and process batch function 433 | func (b *batcher[K, V]) end() { 434 | if !b.finished { 435 | close(b.input) 436 | b.finished = true 437 | } 438 | } 439 | 440 | // execute the batch of all items in queue 441 | func (b *batcher[K, V]) batch(originalContext context.Context) { 442 | var ( 443 | keys = make([]K, 0) 444 | reqs = make([]*batchRequest[K, V], 0) 445 | items = make([]*Result[V], 0) 446 | panicErr interface{} 447 | ) 448 | 449 | for item := range b.input { 450 | keys = append(keys, item.key) 451 | reqs = append(reqs, item) 452 | } 453 | 454 | ctx, finish := b.tracer.TraceBatch(originalContext, keys) 455 | defer finish(items) 456 | 457 | func() { 458 | defer func() { 459 | if r := recover(); r != nil { 460 | panicErr = r 461 | if b.silent { 462 | return 463 | } 464 | const size = 64 << 10 465 | buf := make([]byte, size) 466 | buf = buf[:runtime.Stack(buf, false)] 467 | log.Printf("Dataloader: Panic received in batch function: %v\n%s", panicErr, buf) 468 | } 469 | }() 470 | items = b.batchFn(ctx, keys) 471 | }() 472 | 473 | if panicErr != nil { 474 | for _, req := range reqs { 475 | req.result.Store(&Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}) 476 | close(req.done) 477 | } 478 | return 479 | } 480 | 481 | if len(items) != len(keys) { 482 | err := &Result[V]{Error: fmt.Errorf(` 483 | The batch function supplied did not return an array of responses 484 | the same length as the array of keys. 485 | 486 | Keys: 487 | %v 488 | 489 | Values: 490 | %v 491 | `, keys, items)} 492 | 493 | for _, req := range reqs { 494 | req.result.Store(err) 495 | close(req.done) 496 | } 497 | 498 | return 499 | } 500 | 501 | var notSetResult *Result[V] // don't allocate unless we need it 502 | for i, req := range reqs { 503 | if items[i] == nil { 504 | if notSetResult == nil { 505 | notSetResult = &Result[V]{Error: ErrNoResultProvided} 506 | } 507 | req.result.Store(notSetResult) 508 | } else { 509 | req.result.Store(items[i]) 510 | } 511 | close(req.done) 512 | } 513 | } 514 | 515 | // wait the appropriate amount of time for the provided batcher 516 | func (l *Loader[K, V]) sleeper(b *batcher[K, V], close chan bool) { 517 | select { 518 | // used by batch to close early. usually triggered by max batch size 519 | case <-close: 520 | return 521 | // this will move this goroutine to the back of the callstack? 522 | case <-time.After(l.wait): 523 | } 524 | 525 | // reset 526 | // this is protected by the batchLock to avoid closing the batcher input 527 | // channel while Load is inserting a request 528 | l.batchLock.Lock() 529 | b.end() 530 | 531 | // We can end here also if the batcher has already been closed and a 532 | // new one has been created. So reset the loader state only if the batcher 533 | // is the current one 534 | if l.curBatcher == b { 535 | l.reset() 536 | } 537 | l.batchLock.Unlock() 538 | } 539 | -------------------------------------------------------------------------------- /dataloader_test.go: -------------------------------------------------------------------------------- 1 | package dataloader 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "reflect" 9 | "strconv" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | /* 16 | Tests 17 | */ 18 | func TestLoader(t *testing.T) { 19 | t.Run("test Load method", func(t *testing.T) { 20 | t.Parallel() 21 | identityLoader, _ := IDLoader[string](0) 22 | ctx := context.Background() 23 | future := identityLoader.Load(ctx, "1") 24 | value, err := future() 25 | if err != nil { 26 | t.Error(err.Error()) 27 | } 28 | if value != "1" { 29 | t.Error("load didn't return the right value") 30 | } 31 | }) 32 | 33 | t.Run("test thunk does not contain race conditions", func(t *testing.T) { 34 | t.Parallel() 35 | identityLoader, _ := IDLoader[string](0) 36 | ctx := context.Background() 37 | future := identityLoader.Load(ctx, "1") 38 | go future() 39 | go future() 40 | }) 41 | 42 | t.Run("test Load Method Panic Safety", func(t *testing.T) { 43 | t.Parallel() 44 | defer func() { 45 | r := recover() 46 | if r != nil { 47 | t.Error("Panic Loader's panic should have been handled'") 48 | } 49 | }() 50 | panicLoader, _ := PanicLoader[string](0) 51 | ctx := context.Background() 52 | future := panicLoader.Load(ctx, "1") 53 | _, err := future() 54 | if err == nil || err.Error() != "Panic received in batch function: Programming error" { 55 | t.Error("Panic was not propagated as an error.") 56 | } 57 | }) 58 | 59 | t.Run("test Load Method cache error", func(t *testing.T) { 60 | t.Parallel() 61 | errorCacheLoader, _ := ErrorCacheLoader[string](0) 62 | ctx := context.Background() 63 | futures := []Thunk[string]{} 64 | for i := 0; i < 2; i++ { 65 | futures = append(futures, errorCacheLoader.Load(ctx, strconv.Itoa(i))) 66 | } 67 | 68 | for _, f := range futures { 69 | _, err := f() 70 | if err == nil { 71 | t.Error("Error was not propagated") 72 | } 73 | } 74 | nextFuture := errorCacheLoader.Load(ctx, "1") 75 | _, err := nextFuture() 76 | 77 | // Normal errors should be cached. 78 | if err == nil { 79 | t.Error("Error from batch function was not cached") 80 | } 81 | }) 82 | 83 | t.Run("test Load Method not caching results with errors of type SkipCacheError", func(t *testing.T) { 84 | t.Parallel() 85 | skipCacheLoader, loadCalls := SkipCacheErrorLoader(3, "1") 86 | ctx := context.Background() 87 | futures1 := skipCacheLoader.LoadMany(ctx, []string{"1", "2", "3"}) 88 | _, errs1 := futures1() 89 | var errCount int = 0 90 | var nilCount int = 0 91 | for _, err := range errs1 { 92 | if err == nil { 93 | nilCount++ 94 | } else { 95 | errCount++ 96 | } 97 | } 98 | if errCount != 1 { 99 | t.Error("Expected an error on only key \"1\"") 100 | } 101 | 102 | if nilCount != 2 { 103 | t.Error("Expected the other errors to be nil") 104 | } 105 | 106 | futures2 := skipCacheLoader.LoadMany(ctx, []string{"2", "3", "1"}) 107 | _, errs2 := futures2() 108 | // There should be no errors in the second batch, as the only key that was not cached 109 | // this time around will not throw an error 110 | if errs2 != nil { 111 | t.Error("Expected LoadMany() to return nil error slice when no errors occurred") 112 | } 113 | 114 | calls := (*loadCalls)[1] 115 | expected := []string{"1"} 116 | 117 | if !reflect.DeepEqual(calls, expected) { 118 | t.Errorf("Expected load calls %#v, got %#v", expected, calls) 119 | } 120 | }) 121 | 122 | t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) { 123 | t.Parallel() 124 | defer func() { 125 | r := recover() 126 | if r != nil { 127 | t.Error("Panic Loader's panic should have been handled'") 128 | } 129 | }() 130 | panicLoader, _ := PanicCacheLoader[string](0) 131 | futures := []Thunk[string]{} 132 | ctx := context.Background() 133 | for i := 0; i < 3; i++ { 134 | futures = append(futures, panicLoader.Load(ctx, strconv.Itoa(i))) 135 | } 136 | for _, f := range futures { 137 | _, err := f() 138 | if err == nil || err.Error() != "Panic received in batch function: Programming error" { 139 | t.Error("Panic was not propagated as an error.") 140 | } 141 | } 142 | 143 | futures = []Thunk[string]{} 144 | for i := 0; i < 3; i++ { 145 | futures = append(futures, panicLoader.Load(ctx, strconv.Itoa(1))) 146 | } 147 | 148 | for _, f := range futures { 149 | _, err := f() 150 | if err != nil { 151 | t.Error("Panic error from batch function was cached") 152 | } 153 | } 154 | }) 155 | 156 | t.Run("test LoadMany returns errors", func(t *testing.T) { 157 | t.Parallel() 158 | errorLoader, _ := ErrorLoader[string](0) 159 | ctx := context.Background() 160 | future := errorLoader.LoadMany(ctx, []string{"1", "2", "3"}) 161 | _, err := future() 162 | if len(err) != 3 { 163 | t.Error("LoadMany didn't return right number of errors") 164 | } 165 | }) 166 | 167 | t.Run("test LoadMany returns len(errors) == len(keys)", func(t *testing.T) { 168 | t.Parallel() 169 | loader, _ := OneErrorLoader[string](3) 170 | ctx := context.Background() 171 | future := loader.LoadMany(ctx, []string{"1", "2", "3"}) 172 | _, errs := future() 173 | if len(errs) != 3 { 174 | t.Errorf("LoadMany didn't return right number of errors (should match size of input)") 175 | } 176 | 177 | var errCount int = 0 178 | var nilCount int = 0 179 | for _, err := range errs { 180 | if err == nil { 181 | nilCount++ 182 | } else { 183 | errCount++ 184 | } 185 | } 186 | if errCount != 1 { 187 | t.Error("Expected an error on only one of the items loaded") 188 | } 189 | 190 | if nilCount != 2 { 191 | t.Error("Expected second and third errors to be nil") 192 | } 193 | }) 194 | 195 | t.Run("test LoadMany returns nil []error when no errors occurred", func(t *testing.T) { 196 | t.Parallel() 197 | loader, _ := IDLoader[string](0) 198 | ctx := context.Background() 199 | _, err := loader.LoadMany(ctx, []string{"1", "2", "3"})() 200 | if err != nil { 201 | t.Errorf("Expected LoadMany() to return nil error slice when no errors occurred") 202 | } 203 | }) 204 | 205 | t.Run("test thunkmany does not contain race conditions", func(t *testing.T) { 206 | t.Parallel() 207 | identityLoader, _ := IDLoader[string](0) 208 | ctx := context.Background() 209 | future := identityLoader.LoadMany(ctx, []string{"1", "2", "3"}) 210 | go future() 211 | go future() 212 | }) 213 | 214 | t.Run("test Load Many Method Panic Safety", func(t *testing.T) { 215 | t.Parallel() 216 | defer func() { 217 | r := recover() 218 | if r != nil { 219 | t.Error("Panic Loader's panic should have been handled'") 220 | } 221 | }() 222 | panicLoader, _ := PanicCacheLoader[string](0) 223 | ctx := context.Background() 224 | future := panicLoader.LoadMany(ctx, []string{"1", "2"}) 225 | _, errs := future() 226 | if len(errs) < 2 || errs[0].Error() != "Panic received in batch function: Programming error" { 227 | t.Error("Panic was not propagated as an error.") 228 | } 229 | 230 | future = panicLoader.LoadMany(ctx, []string{"1"}) 231 | _, errs = future() 232 | 233 | if len(errs) > 0 { 234 | t.Error("Panic error from batch function was cached") 235 | } 236 | 237 | }) 238 | 239 | t.Run("test LoadMany method", func(t *testing.T) { 240 | t.Parallel() 241 | identityLoader, _ := IDLoader[string](0) 242 | ctx := context.Background() 243 | future := identityLoader.LoadMany(ctx, []string{"1", "2", "3"}) 244 | results, _ := future() 245 | if results[0] != "1" || results[1] != "2" || results[2] != "3" { 246 | t.Error("loadmany didn't return the right value") 247 | } 248 | }) 249 | 250 | t.Run("batches many requests", func(t *testing.T) { 251 | t.Parallel() 252 | identityLoader, loadCalls := IDLoader[string](0) 253 | ctx := context.Background() 254 | future1 := identityLoader.Load(ctx, "1") 255 | future2 := identityLoader.Load(ctx, "2") 256 | 257 | _, err := future1() 258 | if err != nil { 259 | t.Error(err.Error()) 260 | } 261 | _, err = future2() 262 | if err != nil { 263 | t.Error(err.Error()) 264 | } 265 | 266 | calls := *loadCalls 267 | inner := []string{"1", "2"} 268 | expected := [][]string{inner} 269 | if !reflect.DeepEqual(calls, expected) { 270 | t.Errorf("did not call batchFn in right order. Expected %#v, got %#v", expected, calls) 271 | } 272 | }) 273 | 274 | t.Run("number of results matches number of keys", func(t *testing.T) { 275 | t.Parallel() 276 | faultyLoader, _ := FaultyLoader[string]() 277 | ctx := context.Background() 278 | 279 | n := 10 280 | reqs := []Thunk[string]{} 281 | var keys []string 282 | for i := 0; i < n; i++ { 283 | key := strconv.Itoa(i) 284 | reqs = append(reqs, faultyLoader.Load(ctx, key)) 285 | keys = append(keys, key) 286 | } 287 | 288 | for _, future := range reqs { 289 | _, err := future() 290 | if err == nil { 291 | t.Error("if number of results doesn't match keys, all keys should contain error") 292 | } 293 | } 294 | 295 | // TODO: expect to get some kind of warning 296 | }) 297 | 298 | t.Run("responds to max batch size", func(t *testing.T) { 299 | t.Parallel() 300 | identityLoader, loadCalls := IDLoader[string](2) 301 | ctx := context.Background() 302 | future1 := identityLoader.Load(ctx, "1") 303 | future2 := identityLoader.Load(ctx, "2") 304 | future3 := identityLoader.Load(ctx, "3") 305 | 306 | _, err := future1() 307 | if err != nil { 308 | t.Error(err.Error()) 309 | } 310 | _, err = future2() 311 | if err != nil { 312 | t.Error(err.Error()) 313 | } 314 | _, err = future3() 315 | if err != nil { 316 | t.Error(err.Error()) 317 | } 318 | 319 | calls := *loadCalls 320 | inner1 := []string{"1", "2"} 321 | inner2 := []string{"3"} 322 | expected := [][]string{inner1, inner2} 323 | if !reflect.DeepEqual(calls, expected) { 324 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 325 | } 326 | }) 327 | 328 | t.Run("caches repeated requests", func(t *testing.T) { 329 | t.Parallel() 330 | identityLoader, loadCalls := IDLoader[string](0) 331 | ctx := context.Background() 332 | start := time.Now() 333 | future1 := identityLoader.Load(ctx, "1") 334 | future2 := identityLoader.Load(ctx, "1") 335 | 336 | _, err := future1() 337 | if err != nil { 338 | t.Error(err.Error()) 339 | } 340 | _, err = future2() 341 | if err != nil { 342 | t.Error(err.Error()) 343 | } 344 | 345 | // also check that it took the full timeout to return 346 | var duration = time.Since(start) 347 | if duration < 16*time.Millisecond { 348 | t.Errorf("took %v when expected it to take more than 16 ms because of wait", duration) 349 | } 350 | 351 | calls := *loadCalls 352 | inner := []string{"1"} 353 | expected := [][]string{inner} 354 | if !reflect.DeepEqual(calls, expected) { 355 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 356 | } 357 | }) 358 | 359 | t.Run("doesn't wait for timeout if Flush() is called", func(t *testing.T) { 360 | t.Parallel() 361 | identityLoader, loadCalls := IDLoader[string](0) 362 | ctx := context.Background() 363 | start := time.Now() 364 | future1 := identityLoader.Load(ctx, "1") 365 | future2 := identityLoader.Load(ctx, "2") 366 | 367 | // trigger them to be fetched immediately vs waiting for the 16 ms timer 368 | identityLoader.Flush() 369 | 370 | _, err := future1() 371 | if err != nil { 372 | t.Error(err.Error()) 373 | } 374 | _, err = future2() 375 | if err != nil { 376 | t.Error(err.Error()) 377 | } 378 | 379 | var duration = time.Since(start) 380 | if duration > 2*time.Millisecond { 381 | t.Errorf("took %v when expected it to take less than 2 ms b/c we called Flush()", duration) 382 | } 383 | 384 | calls := *loadCalls 385 | inner := []string{"1", "2"} 386 | expected := [][]string{inner} 387 | if !reflect.DeepEqual(calls, expected) { 388 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 389 | } 390 | }) 391 | 392 | t.Run("Nothing for Flush() to do on empty loader with current batch", func(t *testing.T) { 393 | t.Parallel() 394 | identityLoader, _ := IDLoader[string](0) 395 | identityLoader.Flush() 396 | }) 397 | 398 | t.Run("allows primed cache", func(t *testing.T) { 399 | t.Parallel() 400 | identityLoader, loadCalls := IDLoader[string](0) 401 | ctx := context.Background() 402 | identityLoader.Prime(ctx, "A", "Cached") 403 | future1 := identityLoader.Load(ctx, "1") 404 | future2 := identityLoader.Load(ctx, "A") 405 | 406 | _, err := future1() 407 | if err != nil { 408 | t.Error(err.Error()) 409 | } 410 | value, err := future2() 411 | if err != nil { 412 | t.Error(err.Error()) 413 | } 414 | 415 | calls := *loadCalls 416 | inner := []string{"1"} 417 | expected := [][]string{inner} 418 | if !reflect.DeepEqual(calls, expected) { 419 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 420 | } 421 | 422 | if value != "Cached" { 423 | t.Errorf("did not use primed cache value. Expected '%#v', got '%#v'", "Cached", value) 424 | } 425 | }) 426 | 427 | t.Run("allows clear value in cache", func(t *testing.T) { 428 | t.Parallel() 429 | identityLoader, loadCalls := IDLoader[string](0) 430 | ctx := context.Background() 431 | identityLoader.Prime(ctx, "A", "Cached") 432 | identityLoader.Prime(ctx, "B", "B") 433 | future1 := identityLoader.Load(ctx, "1") 434 | future2 := identityLoader.Clear(ctx, "A").Load(ctx, "A") 435 | future3 := identityLoader.Load(ctx, "B") 436 | 437 | _, err := future1() 438 | if err != nil { 439 | t.Error(err.Error()) 440 | } 441 | value, err := future2() 442 | if err != nil { 443 | t.Error(err.Error()) 444 | } 445 | _, err = future3() 446 | if err != nil { 447 | t.Error(err.Error()) 448 | } 449 | 450 | calls := *loadCalls 451 | inner := []string{"1", "A"} 452 | expected := [][]string{inner} 453 | if !reflect.DeepEqual(calls, expected) { 454 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 455 | } 456 | 457 | if value != "A" { 458 | t.Errorf("did not use primed cache value. Expected '%#v', got '%#v'", "Cached", value) 459 | } 460 | }) 461 | 462 | t.Run("clears cache on batch with WithClearCacheOnBatch", func(t *testing.T) { 463 | t.Parallel() 464 | batchOnlyLoader, loadCalls := BatchOnlyLoader[string](0) 465 | ctx := context.Background() 466 | future1 := batchOnlyLoader.Load(ctx, "1") 467 | future2 := batchOnlyLoader.Load(ctx, "1") 468 | 469 | _, err := future1() 470 | if err != nil { 471 | t.Error(err.Error()) 472 | } 473 | _, err = future2() 474 | if err != nil { 475 | t.Error(err.Error()) 476 | } 477 | 478 | calls := *loadCalls 479 | inner := []string{"1"} 480 | expected := [][]string{inner} 481 | if !reflect.DeepEqual(calls, expected) { 482 | t.Errorf("did not batch queries. Expected %#v, got %#v", expected, calls) 483 | } 484 | 485 | if _, found := batchOnlyLoader.cache.Get(ctx, "1"); found { 486 | t.Errorf("did not clear cache after batch. Expected %#v, got %#v", false, found) 487 | } 488 | }) 489 | 490 | t.Run("allows clearAll values in cache", func(t *testing.T) { 491 | t.Parallel() 492 | identityLoader, loadCalls := IDLoader[string](0) 493 | ctx := context.Background() 494 | identityLoader.Prime(ctx, "A", "Cached") 495 | identityLoader.Prime(ctx, "B", "B") 496 | 497 | identityLoader.ClearAll() 498 | 499 | future1 := identityLoader.Load(ctx, "1") 500 | future2 := identityLoader.Load(ctx, "A") 501 | future3 := identityLoader.Load(ctx, "B") 502 | 503 | _, err := future1() 504 | if err != nil { 505 | t.Error(err.Error()) 506 | } 507 | _, err = future2() 508 | if err != nil { 509 | t.Error(err.Error()) 510 | } 511 | _, err = future3() 512 | if err != nil { 513 | t.Error(err.Error()) 514 | } 515 | 516 | calls := *loadCalls 517 | inner := []string{"1", "A", "B"} 518 | expected := [][]string{inner} 519 | if !reflect.DeepEqual(calls, expected) { 520 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 521 | } 522 | }) 523 | 524 | t.Run("all methods on NoCache are Noops", func(t *testing.T) { 525 | t.Parallel() 526 | identityLoader, loadCalls := NoCacheLoader[string](0) 527 | ctx := context.Background() 528 | identityLoader.Prime(ctx, "A", "Cached") 529 | identityLoader.Prime(ctx, "B", "B") 530 | 531 | identityLoader.ClearAll() 532 | 533 | future1 := identityLoader.Clear(ctx, "1").Load(ctx, "1") 534 | future2 := identityLoader.Load(ctx, "A") 535 | future3 := identityLoader.Load(ctx, "B") 536 | 537 | _, err := future1() 538 | if err != nil { 539 | t.Error(err.Error()) 540 | } 541 | _, err = future2() 542 | if err != nil { 543 | t.Error(err.Error()) 544 | } 545 | _, err = future3() 546 | if err != nil { 547 | t.Error(err.Error()) 548 | } 549 | 550 | calls := *loadCalls 551 | inner := []string{"1", "A", "B"} 552 | expected := [][]string{inner} 553 | if !reflect.DeepEqual(calls, expected) { 554 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 555 | } 556 | }) 557 | 558 | t.Run("no cache does not cache anything", func(t *testing.T) { 559 | t.Parallel() 560 | identityLoader, loadCalls := NoCacheLoader[string](0) 561 | ctx := context.Background() 562 | identityLoader.Prime(ctx, "A", "Cached") 563 | identityLoader.Prime(ctx, "B", "B") 564 | 565 | future1 := identityLoader.Load(ctx, "1") 566 | future2 := identityLoader.Load(ctx, "A") 567 | future3 := identityLoader.Load(ctx, "B") 568 | 569 | _, err := future1() 570 | if err != nil { 571 | t.Error(err.Error()) 572 | } 573 | _, err = future2() 574 | if err != nil { 575 | t.Error(err.Error()) 576 | } 577 | _, err = future3() 578 | if err != nil { 579 | t.Error(err.Error()) 580 | } 581 | 582 | calls := *loadCalls 583 | inner := []string{"1", "A", "B"} 584 | expected := [][]string{inner} 585 | if !reflect.DeepEqual(calls, expected) { 586 | t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls) 587 | } 588 | }) 589 | 590 | } 591 | 592 | // test helpers 593 | func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 594 | var mu sync.Mutex 595 | var loadCalls [][]K 596 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 597 | var results []*Result[K] 598 | mu.Lock() 599 | loadCalls = append(loadCalls, keys) 600 | mu.Unlock() 601 | for _, key := range keys { 602 | results = append(results, &Result[K]{key, nil}) 603 | } 604 | return results 605 | }, WithBatchCapacity[K, K](max)) 606 | return identityLoader, &loadCalls 607 | } 608 | func BatchOnlyLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 609 | var mu sync.Mutex 610 | var loadCalls [][]K 611 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 612 | var results []*Result[K] 613 | mu.Lock() 614 | loadCalls = append(loadCalls, keys) 615 | mu.Unlock() 616 | for _, key := range keys { 617 | results = append(results, &Result[K]{key, nil}) 618 | } 619 | return results 620 | }, WithBatchCapacity[K, K](max), WithClearCacheOnBatch[K, K]()) 621 | return identityLoader, &loadCalls 622 | } 623 | func ErrorLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 624 | var mu sync.Mutex 625 | var loadCalls [][]K 626 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 627 | var results []*Result[K] 628 | mu.Lock() 629 | loadCalls = append(loadCalls, keys) 630 | mu.Unlock() 631 | for _, key := range keys { 632 | results = append(results, &Result[K]{key, fmt.Errorf("this is a test error")}) 633 | } 634 | return results 635 | }, WithBatchCapacity[K, K](max)) 636 | return identityLoader, &loadCalls 637 | } 638 | func OneErrorLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 639 | var mu sync.Mutex 640 | var loadCalls [][]K 641 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 642 | results := make([]*Result[K], max) 643 | mu.Lock() 644 | loadCalls = append(loadCalls, keys) 645 | mu.Unlock() 646 | for i := range keys { 647 | var err error 648 | if i == 0 { 649 | err = errors.New("always error on the first key") 650 | } 651 | results[i] = &Result[K]{keys[i], err} 652 | } 653 | return results 654 | }, WithBatchCapacity[K, K](max)) 655 | return identityLoader, &loadCalls 656 | } 657 | func PanicLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 658 | var loadCalls [][]K 659 | panicLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 660 | panic("Programming error") 661 | }, WithBatchCapacity[K, K](max), withSilentLogger[K, K]()) 662 | return panicLoader, &loadCalls 663 | } 664 | 665 | func PanicCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 666 | var loadCalls [][]K 667 | panicCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 668 | if len(keys) > 1 { 669 | panic("Programming error") 670 | } 671 | 672 | returnResult := make([]*Result[K], len(keys)) 673 | for idx := range returnResult { 674 | returnResult[idx] = &Result[K]{ 675 | keys[0], 676 | nil, 677 | } 678 | } 679 | 680 | return returnResult 681 | 682 | }, WithBatchCapacity[K, K](max), withSilentLogger[K, K]()) 683 | return panicCacheLoader, &loadCalls 684 | } 685 | 686 | func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 687 | var loadCalls [][]K 688 | errorCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 689 | if len(keys) > 1 { 690 | var results []*Result[K] 691 | for _, key := range keys { 692 | results = append(results, &Result[K]{key, fmt.Errorf("this is a test error")}) 693 | } 694 | return results 695 | } 696 | 697 | returnResult := make([]*Result[K], len(keys)) 698 | for idx := range returnResult { 699 | returnResult[idx] = &Result[K]{ 700 | keys[0], 701 | nil, 702 | } 703 | } 704 | 705 | return returnResult 706 | 707 | }, WithBatchCapacity[K, K](max), withSilentLogger[K, K]()) 708 | return errorCacheLoader, &loadCalls 709 | } 710 | 711 | func SkipCacheErrorLoader[K comparable](max int, onceErrorKey K) (*Loader[K, K], *[][]K) { 712 | var mu sync.Mutex 713 | var loadCalls [][]K 714 | errorThrown := false 715 | skipCacheErrorLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 716 | var results []*Result[K] 717 | mu.Lock() 718 | loadCalls = append(loadCalls, keys) 719 | mu.Unlock() 720 | // return a non cacheable error for the first occurence of onceErrorKey 721 | for _, k := range keys { 722 | if !errorThrown && k == onceErrorKey { 723 | results = append(results, &Result[K]{k, NewSkipCacheError(fmt.Errorf("non cacheable error"))}) 724 | errorThrown = true 725 | } else { 726 | results = append(results, &Result[K]{k, nil}) 727 | } 728 | } 729 | 730 | return results 731 | }, WithBatchCapacity[K, K](max)) 732 | return skipCacheErrorLoader, &loadCalls 733 | } 734 | 735 | func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 736 | var mu sync.Mutex 737 | var loadCalls [][]K 738 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 739 | var results []*Result[K] 740 | mu.Lock() 741 | loadCalls = append(loadCalls, keys) 742 | mu.Unlock() 743 | results = append(results, &Result[K]{keys[0], nil}) 744 | return results 745 | }, WithBatchCapacity[K, K](max)) 746 | return identityLoader, &loadCalls 747 | } 748 | 749 | func NoCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) { 750 | var mu sync.Mutex 751 | var loadCalls [][]K 752 | cache := &NoCache[K, K]{} 753 | identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 754 | var results []*Result[K] 755 | mu.Lock() 756 | loadCalls = append(loadCalls, keys) 757 | mu.Unlock() 758 | for _, key := range keys { 759 | results = append(results, &Result[K]{key, nil}) 760 | } 761 | return results 762 | }, WithCache[K, K](cache), WithBatchCapacity[K, K](max)) 763 | return identityLoader, &loadCalls 764 | } 765 | 766 | // FaultyLoader gives len(keys)-1 results. 767 | func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) { 768 | var mu sync.Mutex 769 | var loadCalls [][]K 770 | 771 | loader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { 772 | var results []*Result[K] 773 | mu.Lock() 774 | loadCalls = append(loadCalls, keys) 775 | mu.Unlock() 776 | 777 | lastKeyIndex := len(keys) - 1 778 | for i, key := range keys { 779 | if i == lastKeyIndex { 780 | break 781 | } 782 | 783 | results = append(results, &Result[K]{key, nil}) 784 | } 785 | return results 786 | }) 787 | 788 | return loader, &loadCalls 789 | } 790 | 791 | /* 792 | Benchmarks 793 | */ 794 | var a = &Avg{} 795 | 796 | func batchIdentity[K comparable](_ context.Context, keys []K) (results []*Result[K]) { 797 | a.Add(len(keys)) 798 | for _, key := range keys { 799 | results = append(results, &Result[K]{key, nil}) 800 | } 801 | return 802 | } 803 | 804 | var _ctx = context.Background() 805 | 806 | func BenchmarkLoader(b *testing.B) { 807 | UserLoader := NewBatchedLoader(batchIdentity[string]) 808 | b.ResetTimer() 809 | for i := 0; i < b.N; i++ { 810 | UserLoader.Load(_ctx, (strconv.Itoa(i))) 811 | } 812 | log.Printf("avg: %f", a.Avg()) 813 | } 814 | 815 | type Avg struct { 816 | total float64 817 | length float64 818 | lock sync.RWMutex 819 | } 820 | 821 | func (a *Avg) Add(v int) { 822 | a.lock.Lock() 823 | a.total += float64(v) 824 | a.length++ 825 | a.lock.Unlock() 826 | } 827 | 828 | func (a *Avg) Avg() float64 { 829 | a.lock.RLock() 830 | defer a.lock.RUnlock() 831 | if a.total == 0 { 832 | return 0 833 | } else if a.length == 0 { 834 | return 0 835 | } 836 | return a.total / a.length 837 | } 838 | --------------------------------------------------------------------------------