├── .circleci └── config.yml ├── .gitignore ├── README.md ├── appveyor.yml ├── dataloaden.go ├── example ├── benchmark_test.go ├── pkgname │ ├── user.go │ └── userloader_gen.go ├── slice │ ├── user.go │ ├── usersliceloader_gen.go │ └── usersliceloader_test.go ├── user.go ├── user_test.go └── userloader_gen.go ├── go.mod ├── go.sum ├── licence.md └── pkg └── generator ├── generator.go ├── generator_test.go ├── template.go └── testdata └── mismatch └── mismatch.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: golang:1.11 6 | working_directory: /projects/dataloaden 7 | steps: &steps 8 | - checkout 9 | - run: go generate ./example/... && if [[ $(git diff) ]] ; then echo "you need to run go generate" ; git diff ; exit 1 ; fi 10 | - run: go test -bench=. -benchmem -v ./example/... 11 | - run: go test -bench=. -benchmem -v ./example/... -race 12 | - run: go test -coverprofile=coverage.txt -covermode=atomic ./example && bash <(curl -s https://codecov.io/bash) 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /vendor 2 | /.idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### The DATALOADer gENerator [![CircleCI](https://circleci.com/gh/Vektah/dataloaden.svg?style=svg)](https://circleci.com/gh/vektah/dataloaden) [![Go Report Card](https://goreportcard.com/badge/github.com/vektah/dataloaden)](https://goreportcard.com/report/github.com/vektah/dataloaden) [![codecov](https://codecov.io/gh/vektah/dataloaden/branch/master/graph/badge.svg)](https://codecov.io/gh/vektah/dataloaden) 2 | 3 | Requires golang 1.11+ for modules support. 4 | 5 | This is a tool for generating type safe data loaders for go, inspired by https://github.com/facebook/dataloader. 6 | 7 | The intended use is in graphql servers, to reduce the number of queries being sent to the database. These dataloader 8 | objects should be request scoped and short lived. They should be cheap to create in every request even if they dont 9 | get used. 10 | 11 | #### Getting started 12 | 13 | From inside the package you want to have the dataloader in: 14 | ```bash 15 | go run github.com/vektah/dataloaden UserLoader string *github.com/dataloaden/example.User 16 | ``` 17 | 18 | This will generate a dataloader called `UserLoader` that looks up `*github.com/dataloaden/example.User`'s objects 19 | based on a `string` key. 20 | 21 | In another file in the same package, create the constructor method: 22 | ```go 23 | func NewUserLoader() *UserLoader { 24 | return &UserLoader{ 25 | wait: 2 * time.Millisecond, 26 | maxBatch: 100, 27 | fetch: func(keys []string) ([]*User, []error) { 28 | users := make([]*User, len(keys)) 29 | errors := make([]error, len(keys)) 30 | 31 | for i, key := range keys { 32 | users[i] = &User{ID: key, Name: "user " + key} 33 | } 34 | return users, errors 35 | }, 36 | } 37 | } 38 | ``` 39 | 40 | Then wherever you want to call the dataloader 41 | ```go 42 | loader := NewUserLoader() 43 | 44 | user, err := loader.Load("123") 45 | ``` 46 | 47 | This method will block for a short amount of time, waiting for any other similar requests to come in, call your fetch 48 | function once. It also caches values and wont request duplicates in a batch. 49 | 50 | #### Returning Slices 51 | 52 | You may want to generate a dataloader that returns slices instead of single values. Both key and value types can be a 53 | simple go type expression: 54 | 55 | ```bash 56 | go run github.com/vektah/dataloaden UserSliceLoader string []*github.com/dataloaden/example.User 57 | ``` 58 | 59 | Now each key is expected to return a slice of values and the `fetch` function has the return type `[][]*User`. 60 | 61 | #### Using with go modules 62 | 63 | Create a tools.go that looks like this: 64 | ```go 65 | // +build tools 66 | 67 | package main 68 | 69 | import _ "github.com/vektah/dataloaden" 70 | ``` 71 | 72 | This will allow go modules to see the dependency. 73 | 74 | You can invoke it from anywhere within your module now using `go run github.com/vektah/dataloaden` and 75 | always get the pinned version. 76 | 77 | #### Wait, how do I use context with this? 78 | 79 | I don't think context makes sense to be passed through a data loader. Consider a few scenarios: 80 | 1. a dataloader shared between requests: request A and B both get batched together, which context should be passed to the DB? context.Background is probably more suitable. 81 | 2. a dataloader per request for graphql: two different nodes in the graph get batched together, they have different context for tracing purposes, which should be passed to the db? neither, you should just use the root request context. 82 | 83 | 84 | So be explicit about your context: 85 | ```go 86 | func NewLoader(ctx context.Context) *UserLoader { 87 | return &UserLoader{ 88 | wait: 2 * time.Millisecond, 89 | maxBatch: 100, 90 | fetch: func(keys []string) ([]*User, []error) { 91 | // you now have a ctx to work with 92 | }, 93 | } 94 | } 95 | ``` 96 | 97 | If you feel like I'm wrong please raise an issue. 98 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | version: "{build}" 2 | 3 | # Source Config 4 | 5 | skip_branch_with_pr: true 6 | clone_folder: c:\projects\dataloaden 7 | 8 | # Build host 9 | 10 | environment: 11 | GOPATH: c:\gopath 12 | GOVERSION: 1.11.5 13 | PATH: '%PATH%;c:\gopath\bin' 14 | 15 | init: 16 | - git config --global core.autocrlf input 17 | 18 | # Build 19 | 20 | install: 21 | # Install the specific Go version. 22 | - rmdir c:\go /s /q 23 | - appveyor DownloadFile https://storage.googleapis.com/golang/go%GOVERSION%.windows-amd64.msi 24 | - msiexec /i go%GOVERSION%.windows-amd64.msi /q 25 | - go version 26 | 27 | build: false 28 | deploy: false 29 | 30 | test_script: 31 | - go generate ./... 32 | - go test -parallel 8 ./... 33 | -------------------------------------------------------------------------------- /dataloaden.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/vektah/dataloaden/pkg/generator" 8 | ) 9 | 10 | func main() { 11 | if len(os.Args) != 4 { 12 | fmt.Println("usage: name keyType valueType") 13 | fmt.Println(" example:") 14 | fmt.Println(" dataloaden 'UserLoader int []*github.com/my/package.User'") 15 | os.Exit(1) 16 | } 17 | 18 | wd, err := os.Getwd() 19 | if err != nil { 20 | fmt.Fprintln(os.Stderr, err.Error()) 21 | os.Exit(2) 22 | } 23 | 24 | if err := generator.Generate(os.Args[1], os.Args[2], os.Args[3], wd); err != nil { 25 | fmt.Fprintln(os.Stderr, err.Error()) 26 | os.Exit(2) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /example/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strconv" 7 | "sync" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func BenchmarkLoader(b *testing.B) { 13 | dl := &UserLoader{ 14 | wait: 500 * time.Nanosecond, 15 | maxBatch: 100, 16 | fetch: func(keys []string) ([]*User, []error) { 17 | users := make([]*User, len(keys)) 18 | errors := make([]error, len(keys)) 19 | 20 | for i, key := range keys { 21 | if rand.Int()%100 == 1 { 22 | errors[i] = fmt.Errorf("user not found") 23 | } else if rand.Int()%100 == 1 { 24 | users[i] = nil 25 | } else { 26 | users[i] = &User{ID: key, Name: "user " + key} 27 | } 28 | } 29 | return users, errors 30 | }, 31 | } 32 | 33 | b.Run("caches", func(b *testing.B) { 34 | thunks := make([]func() (*User, error), b.N) 35 | for i := 0; i < b.N; i++ { 36 | thunks[i] = dl.LoadThunk(strconv.Itoa(rand.Int() % 300)) 37 | } 38 | 39 | for i := 0; i < b.N; i++ { 40 | thunks[i]() 41 | } 42 | }) 43 | 44 | b.Run("random spread", func(b *testing.B) { 45 | thunks := make([]func() (*User, error), b.N) 46 | for i := 0; i < b.N; i++ { 47 | thunks[i] = dl.LoadThunk(strconv.Itoa(rand.Int())) 48 | } 49 | 50 | for i := 0; i < b.N; i++ { 51 | thunks[i]() 52 | } 53 | }) 54 | 55 | b.Run("concurently", func(b *testing.B) { 56 | var wg sync.WaitGroup 57 | for i := 0; i < 10; i++ { 58 | wg.Add(1) 59 | go func() { 60 | for j := 0; j < b.N; j++ { 61 | dl.Load(strconv.Itoa(rand.Int())) 62 | } 63 | wg.Done() 64 | }() 65 | } 66 | wg.Wait() 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /example/pkgname/user.go: -------------------------------------------------------------------------------- 1 | package differentpkg 2 | 3 | //go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User 4 | -------------------------------------------------------------------------------- /example/pkgname/userloader_gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by github.com/vektah/dataloaden, DO NOT EDIT. 2 | 3 | package differentpkg 4 | 5 | import ( 6 | "sync" 7 | "time" 8 | 9 | "github.com/vektah/dataloaden/example" 10 | ) 11 | 12 | // UserLoaderConfig captures the config to create a new UserLoader 13 | type UserLoaderConfig struct { 14 | // Fetch is a method that provides the data for the loader 15 | Fetch func(keys []string) ([]*example.User, []error) 16 | 17 | // Wait is how long wait before sending a batch 18 | Wait time.Duration 19 | 20 | // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit 21 | MaxBatch int 22 | } 23 | 24 | // NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch 25 | func NewUserLoader(config UserLoaderConfig) *UserLoader { 26 | return &UserLoader{ 27 | fetch: config.Fetch, 28 | wait: config.Wait, 29 | maxBatch: config.MaxBatch, 30 | } 31 | } 32 | 33 | // UserLoader batches and caches requests 34 | type UserLoader struct { 35 | // this method provides the data for the loader 36 | fetch func(keys []string) ([]*example.User, []error) 37 | 38 | // how long to done before sending a batch 39 | wait time.Duration 40 | 41 | // this will limit the maximum number of keys to send in one batch, 0 = no limit 42 | maxBatch int 43 | 44 | // INTERNAL 45 | 46 | // lazily created cache 47 | cache map[string]*example.User 48 | 49 | // the current batch. keys will continue to be collected until timeout is hit, 50 | // then everything will be sent to the fetch method and out to the listeners 51 | batch *userLoaderBatch 52 | 53 | // mutex to prevent races 54 | mu sync.Mutex 55 | } 56 | 57 | type userLoaderBatch struct { 58 | keys []string 59 | data []*example.User 60 | error []error 61 | closing bool 62 | done chan struct{} 63 | } 64 | 65 | // Load a User by key, batching and caching will be applied automatically 66 | func (l *UserLoader) Load(key string) (*example.User, error) { 67 | return l.LoadThunk(key)() 68 | } 69 | 70 | // LoadThunk returns a function that when called will block waiting for a User. 71 | // This method should be used if you want one goroutine to make requests to many 72 | // different data loaders without blocking until the thunk is called. 73 | func (l *UserLoader) LoadThunk(key string) func() (*example.User, error) { 74 | l.mu.Lock() 75 | if it, ok := l.cache[key]; ok { 76 | l.mu.Unlock() 77 | return func() (*example.User, error) { 78 | return it, nil 79 | } 80 | } 81 | if l.batch == nil { 82 | l.batch = &userLoaderBatch{done: make(chan struct{})} 83 | } 84 | batch := l.batch 85 | pos := batch.keyIndex(l, key) 86 | l.mu.Unlock() 87 | 88 | return func() (*example.User, error) { 89 | <-batch.done 90 | 91 | var data *example.User 92 | if pos < len(batch.data) { 93 | data = batch.data[pos] 94 | } 95 | 96 | var err error 97 | // its convenient to be able to return a single error for everything 98 | if len(batch.error) == 1 { 99 | err = batch.error[0] 100 | } else if batch.error != nil { 101 | err = batch.error[pos] 102 | } 103 | 104 | if err == nil { 105 | l.mu.Lock() 106 | l.unsafeSet(key, data) 107 | l.mu.Unlock() 108 | } 109 | 110 | return data, err 111 | } 112 | } 113 | 114 | // LoadAll fetches many keys at once. It will be broken into appropriate sized 115 | // sub batches depending on how the loader is configured 116 | func (l *UserLoader) LoadAll(keys []string) ([]*example.User, []error) { 117 | results := make([]func() (*example.User, error), len(keys)) 118 | 119 | for i, key := range keys { 120 | results[i] = l.LoadThunk(key) 121 | } 122 | 123 | users := make([]*example.User, len(keys)) 124 | errors := make([]error, len(keys)) 125 | for i, thunk := range results { 126 | users[i], errors[i] = thunk() 127 | } 128 | return users, errors 129 | } 130 | 131 | // LoadAllThunk returns a function that when called will block waiting for a Users. 132 | // This method should be used if you want one goroutine to make requests to many 133 | // different data loaders without blocking until the thunk is called. 134 | func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*example.User, []error) { 135 | results := make([]func() (*example.User, error), len(keys)) 136 | for i, key := range keys { 137 | results[i] = l.LoadThunk(key) 138 | } 139 | return func() ([]*example.User, []error) { 140 | users := make([]*example.User, len(keys)) 141 | errors := make([]error, len(keys)) 142 | for i, thunk := range results { 143 | users[i], errors[i] = thunk() 144 | } 145 | return users, errors 146 | } 147 | } 148 | 149 | // Prime the cache with the provided key and value. If the key already exists, no change is made 150 | // and false is returned. 151 | // (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) 152 | func (l *UserLoader) Prime(key string, value *example.User) bool { 153 | l.mu.Lock() 154 | var found bool 155 | if _, found = l.cache[key]; !found { 156 | // make a copy when writing to the cache, its easy to pass a pointer in from a loop var 157 | // and end up with the whole cache pointing to the same value. 158 | cpy := *value 159 | l.unsafeSet(key, &cpy) 160 | } 161 | l.mu.Unlock() 162 | return !found 163 | } 164 | 165 | // Clear the value at key from the cache, if it exists 166 | func (l *UserLoader) Clear(key string) { 167 | l.mu.Lock() 168 | delete(l.cache, key) 169 | l.mu.Unlock() 170 | } 171 | 172 | func (l *UserLoader) unsafeSet(key string, value *example.User) { 173 | if l.cache == nil { 174 | l.cache = map[string]*example.User{} 175 | } 176 | l.cache[key] = value 177 | } 178 | 179 | // keyIndex will return the location of the key in the batch, if its not found 180 | // it will add the key to the batch 181 | func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { 182 | for i, existingKey := range b.keys { 183 | if key == existingKey { 184 | return i 185 | } 186 | } 187 | 188 | pos := len(b.keys) 189 | b.keys = append(b.keys, key) 190 | if pos == 0 { 191 | go b.startTimer(l) 192 | } 193 | 194 | if l.maxBatch != 0 && pos >= l.maxBatch-1 { 195 | if !b.closing { 196 | b.closing = true 197 | l.batch = nil 198 | go b.end(l) 199 | } 200 | } 201 | 202 | return pos 203 | } 204 | 205 | func (b *userLoaderBatch) startTimer(l *UserLoader) { 206 | time.Sleep(l.wait) 207 | l.mu.Lock() 208 | 209 | // we must have hit a batch limit and are already finalizing this batch 210 | if b.closing { 211 | l.mu.Unlock() 212 | return 213 | } 214 | 215 | l.batch = nil 216 | l.mu.Unlock() 217 | 218 | b.end(l) 219 | } 220 | 221 | func (b *userLoaderBatch) end(l *UserLoader) { 222 | b.data, b.error = l.fetch(b.keys) 223 | close(b.done) 224 | } 225 | -------------------------------------------------------------------------------- /example/slice/user.go: -------------------------------------------------------------------------------- 1 | //go:generate go run github.com/vektah/dataloaden UserSliceLoader int []github.com/vektah/dataloaden/example.User 2 | 3 | package slice 4 | 5 | import ( 6 | "strconv" 7 | "time" 8 | 9 | "github.com/vektah/dataloaden/example" 10 | ) 11 | 12 | func NewLoader() *UserSliceLoader { 13 | return &UserSliceLoader{ 14 | wait: 2 * time.Millisecond, 15 | maxBatch: 100, 16 | fetch: func(keys []int) ([][]example.User, []error) { 17 | users := make([][]example.User, len(keys)) 18 | errors := make([]error, len(keys)) 19 | 20 | for i, key := range keys { 21 | users[i] = []example.User{{ID: strconv.Itoa(key), Name: "user " + strconv.Itoa(key)}} 22 | } 23 | return users, errors 24 | }, 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /example/slice/usersliceloader_gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by github.com/vektah/dataloaden, DO NOT EDIT. 2 | 3 | package slice 4 | 5 | import ( 6 | "sync" 7 | "time" 8 | 9 | "github.com/vektah/dataloaden/example" 10 | ) 11 | 12 | // UserSliceLoaderConfig captures the config to create a new UserSliceLoader 13 | type UserSliceLoaderConfig struct { 14 | // Fetch is a method that provides the data for the loader 15 | Fetch func(keys []int) ([][]example.User, []error) 16 | 17 | // Wait is how long wait before sending a batch 18 | Wait time.Duration 19 | 20 | // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit 21 | MaxBatch int 22 | } 23 | 24 | // NewUserSliceLoader creates a new UserSliceLoader given a fetch, wait, and maxBatch 25 | func NewUserSliceLoader(config UserSliceLoaderConfig) *UserSliceLoader { 26 | return &UserSliceLoader{ 27 | fetch: config.Fetch, 28 | wait: config.Wait, 29 | maxBatch: config.MaxBatch, 30 | } 31 | } 32 | 33 | // UserSliceLoader batches and caches requests 34 | type UserSliceLoader struct { 35 | // this method provides the data for the loader 36 | fetch func(keys []int) ([][]example.User, []error) 37 | 38 | // how long to done before sending a batch 39 | wait time.Duration 40 | 41 | // this will limit the maximum number of keys to send in one batch, 0 = no limit 42 | maxBatch int 43 | 44 | // INTERNAL 45 | 46 | // lazily created cache 47 | cache map[int][]example.User 48 | 49 | // the current batch. keys will continue to be collected until timeout is hit, 50 | // then everything will be sent to the fetch method and out to the listeners 51 | batch *userSliceLoaderBatch 52 | 53 | // mutex to prevent races 54 | mu sync.Mutex 55 | } 56 | 57 | type userSliceLoaderBatch struct { 58 | keys []int 59 | data [][]example.User 60 | error []error 61 | closing bool 62 | done chan struct{} 63 | } 64 | 65 | // Load a User by key, batching and caching will be applied automatically 66 | func (l *UserSliceLoader) Load(key int) ([]example.User, error) { 67 | return l.LoadThunk(key)() 68 | } 69 | 70 | // LoadThunk returns a function that when called will block waiting for a User. 71 | // This method should be used if you want one goroutine to make requests to many 72 | // different data loaders without blocking until the thunk is called. 73 | func (l *UserSliceLoader) LoadThunk(key int) func() ([]example.User, error) { 74 | l.mu.Lock() 75 | if it, ok := l.cache[key]; ok { 76 | l.mu.Unlock() 77 | return func() ([]example.User, error) { 78 | return it, nil 79 | } 80 | } 81 | if l.batch == nil { 82 | l.batch = &userSliceLoaderBatch{done: make(chan struct{})} 83 | } 84 | batch := l.batch 85 | pos := batch.keyIndex(l, key) 86 | l.mu.Unlock() 87 | 88 | return func() ([]example.User, error) { 89 | <-batch.done 90 | 91 | var data []example.User 92 | if pos < len(batch.data) { 93 | data = batch.data[pos] 94 | } 95 | 96 | var err error 97 | // its convenient to be able to return a single error for everything 98 | if len(batch.error) == 1 { 99 | err = batch.error[0] 100 | } else if batch.error != nil { 101 | err = batch.error[pos] 102 | } 103 | 104 | if err == nil { 105 | l.mu.Lock() 106 | l.unsafeSet(key, data) 107 | l.mu.Unlock() 108 | } 109 | 110 | return data, err 111 | } 112 | } 113 | 114 | // LoadAll fetches many keys at once. It will be broken into appropriate sized 115 | // sub batches depending on how the loader is configured 116 | func (l *UserSliceLoader) LoadAll(keys []int) ([][]example.User, []error) { 117 | results := make([]func() ([]example.User, error), len(keys)) 118 | 119 | for i, key := range keys { 120 | results[i] = l.LoadThunk(key) 121 | } 122 | 123 | users := make([][]example.User, len(keys)) 124 | errors := make([]error, len(keys)) 125 | for i, thunk := range results { 126 | users[i], errors[i] = thunk() 127 | } 128 | return users, errors 129 | } 130 | 131 | // LoadAllThunk returns a function that when called will block waiting for a Users. 132 | // This method should be used if you want one goroutine to make requests to many 133 | // different data loaders without blocking until the thunk is called. 134 | func (l *UserSliceLoader) LoadAllThunk(keys []int) func() ([][]example.User, []error) { 135 | results := make([]func() ([]example.User, error), len(keys)) 136 | for i, key := range keys { 137 | results[i] = l.LoadThunk(key) 138 | } 139 | return func() ([][]example.User, []error) { 140 | users := make([][]example.User, len(keys)) 141 | errors := make([]error, len(keys)) 142 | for i, thunk := range results { 143 | users[i], errors[i] = thunk() 144 | } 145 | return users, errors 146 | } 147 | } 148 | 149 | // Prime the cache with the provided key and value. If the key already exists, no change is made 150 | // and false is returned. 151 | // (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) 152 | func (l *UserSliceLoader) Prime(key int, value []example.User) bool { 153 | l.mu.Lock() 154 | var found bool 155 | if _, found = l.cache[key]; !found { 156 | // make a copy when writing to the cache, its easy to pass a pointer in from a loop var 157 | // and end up with the whole cache pointing to the same value. 158 | cpy := make([]example.User, len(value)) 159 | copy(cpy, value) 160 | l.unsafeSet(key, cpy) 161 | } 162 | l.mu.Unlock() 163 | return !found 164 | } 165 | 166 | // Clear the value at key from the cache, if it exists 167 | func (l *UserSliceLoader) Clear(key int) { 168 | l.mu.Lock() 169 | delete(l.cache, key) 170 | l.mu.Unlock() 171 | } 172 | 173 | func (l *UserSliceLoader) unsafeSet(key int, value []example.User) { 174 | if l.cache == nil { 175 | l.cache = map[int][]example.User{} 176 | } 177 | l.cache[key] = value 178 | } 179 | 180 | // keyIndex will return the location of the key in the batch, if its not found 181 | // it will add the key to the batch 182 | func (b *userSliceLoaderBatch) keyIndex(l *UserSliceLoader, key int) int { 183 | for i, existingKey := range b.keys { 184 | if key == existingKey { 185 | return i 186 | } 187 | } 188 | 189 | pos := len(b.keys) 190 | b.keys = append(b.keys, key) 191 | if pos == 0 { 192 | go b.startTimer(l) 193 | } 194 | 195 | if l.maxBatch != 0 && pos >= l.maxBatch-1 { 196 | if !b.closing { 197 | b.closing = true 198 | l.batch = nil 199 | go b.end(l) 200 | } 201 | } 202 | 203 | return pos 204 | } 205 | 206 | func (b *userSliceLoaderBatch) startTimer(l *UserSliceLoader) { 207 | time.Sleep(l.wait) 208 | l.mu.Lock() 209 | 210 | // we must have hit a batch limit and are already finalizing this batch 211 | if b.closing { 212 | l.mu.Unlock() 213 | return 214 | } 215 | 216 | l.batch = nil 217 | l.mu.Unlock() 218 | 219 | b.end(l) 220 | } 221 | 222 | func (b *userSliceLoaderBatch) end(l *UserSliceLoader) { 223 | b.data, b.error = l.fetch(b.keys) 224 | close(b.done) 225 | } 226 | -------------------------------------------------------------------------------- /example/slice/usersliceloader_test.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | "github.com/vektah/dataloaden/example" 13 | ) 14 | 15 | func TestUserLoader(t *testing.T) { 16 | var fetches [][]int 17 | var mu sync.Mutex 18 | 19 | dl := &UserSliceLoader{ 20 | wait: 10 * time.Millisecond, 21 | maxBatch: 5, 22 | fetch: func(keys []int) (users [][]example.User, errors []error) { 23 | mu.Lock() 24 | fetches = append(fetches, keys) 25 | mu.Unlock() 26 | 27 | users = make([][]example.User, len(keys)) 28 | errors = make([]error, len(keys)) 29 | 30 | for i, key := range keys { 31 | if key%10 == 0 { // anything ending in zero is bad 32 | errors[i] = fmt.Errorf("users not found") 33 | } else { 34 | users[i] = []example.User{ 35 | {ID: strconv.Itoa(key), Name: "user " + strconv.Itoa(key)}, 36 | {ID: strconv.Itoa(key), Name: "user " + strconv.Itoa(key)}, 37 | } 38 | } 39 | } 40 | return users, errors 41 | }, 42 | } 43 | 44 | t.Run("fetch concurrent data", func(t *testing.T) { 45 | t.Run("load user successfully", func(t *testing.T) { 46 | t.Parallel() 47 | u, err := dl.Load(1) 48 | require.NoError(t, err) 49 | require.Equal(t, u[0].ID, "1") 50 | require.Equal(t, u[1].ID, "1") 51 | }) 52 | 53 | t.Run("load failed user", func(t *testing.T) { 54 | t.Parallel() 55 | u, err := dl.Load(10) 56 | require.Error(t, err) 57 | require.Nil(t, u) 58 | }) 59 | 60 | t.Run("load many users", func(t *testing.T) { 61 | t.Parallel() 62 | u, err := dl.LoadAll([]int{2, 10, 20, 4}) 63 | require.Equal(t, u[0][0].Name, "user 2") 64 | require.Error(t, err[1]) 65 | require.Error(t, err[2]) 66 | require.Equal(t, u[3][0].Name, "user 4") 67 | }) 68 | 69 | t.Run("load thunk", func(t *testing.T) { 70 | t.Parallel() 71 | thunk1 := dl.LoadThunk(5) 72 | thunk2 := dl.LoadThunk(50) 73 | 74 | u1, err1 := thunk1() 75 | require.NoError(t, err1) 76 | require.Equal(t, "user 5", u1[0].Name) 77 | 78 | u2, err2 := thunk2() 79 | require.Error(t, err2) 80 | require.Nil(t, u2) 81 | }) 82 | }) 83 | 84 | t.Run("it sent two batches", func(t *testing.T) { 85 | mu.Lock() 86 | defer mu.Unlock() 87 | 88 | require.Len(t, fetches, 2) 89 | assert.Len(t, fetches[0], 5) 90 | assert.Len(t, fetches[1], 3) 91 | }) 92 | 93 | t.Run("fetch more", func(t *testing.T) { 94 | 95 | t.Run("previously cached", func(t *testing.T) { 96 | t.Parallel() 97 | u, err := dl.Load(1) 98 | require.NoError(t, err) 99 | require.Equal(t, u[0].ID, "1") 100 | }) 101 | 102 | t.Run("load many users", func(t *testing.T) { 103 | t.Parallel() 104 | u, err := dl.LoadAll([]int{2, 4}) 105 | require.NoError(t, err[0]) 106 | require.NoError(t, err[1]) 107 | require.Equal(t, u[0][0].Name, "user 2") 108 | require.Equal(t, u[1][0].Name, "user 4") 109 | }) 110 | }) 111 | 112 | t.Run("no round trips", func(t *testing.T) { 113 | mu.Lock() 114 | defer mu.Unlock() 115 | 116 | require.Len(t, fetches, 2) 117 | }) 118 | 119 | t.Run("fetch partial", func(t *testing.T) { 120 | t.Run("errors not in cache cache value", func(t *testing.T) { 121 | t.Parallel() 122 | u, err := dl.Load(20) 123 | require.Nil(t, u) 124 | require.Error(t, err) 125 | }) 126 | 127 | t.Run("load all", func(t *testing.T) { 128 | t.Parallel() 129 | u, err := dl.LoadAll([]int{1, 4, 10, 9, 5}) 130 | require.Equal(t, u[0][0].ID, "1") 131 | require.Equal(t, u[1][0].ID, "4") 132 | require.Error(t, err[2]) 133 | require.Equal(t, u[3][0].ID, "9") 134 | require.Equal(t, u[4][0].ID, "5") 135 | }) 136 | }) 137 | 138 | t.Run("one partial trip", func(t *testing.T) { 139 | mu.Lock() 140 | defer mu.Unlock() 141 | 142 | require.Len(t, fetches, 3) 143 | require.Len(t, fetches[2], 3) // E1 U9 E2 in some random order 144 | }) 145 | 146 | t.Run("primed reads dont hit the fetcher", func(t *testing.T) { 147 | dl.Prime(99, []example.User{ 148 | {ID: "U99", Name: "Primed user"}, 149 | {ID: "U99", Name: "Primed user"}, 150 | }) 151 | u, err := dl.Load(99) 152 | require.NoError(t, err) 153 | require.Equal(t, "Primed user", u[0].Name) 154 | 155 | require.Len(t, fetches, 3) 156 | }) 157 | 158 | t.Run("priming in a loop is safe", func(t *testing.T) { 159 | users := [][]example.User{ 160 | {{ID: "123", Name: "Alpha"}, {ID: "123", Name: "Alpha"}}, 161 | {{ID: "124", Name: "Omega"}, {ID: "124", Name: "Omega"}}, 162 | } 163 | for _, user := range users { 164 | id, _ := strconv.Atoi(user[0].ID) 165 | dl.Prime(id, user) 166 | } 167 | 168 | u, err := dl.Load(123) 169 | require.NoError(t, err) 170 | require.Equal(t, "Alpha", u[0].Name) 171 | 172 | u, err = dl.Load(124) 173 | require.NoError(t, err) 174 | require.Equal(t, "Omega", u[0].Name) 175 | 176 | require.Len(t, fetches, 3) 177 | }) 178 | 179 | t.Run("cleared results will go back to the fetcher", func(t *testing.T) { 180 | dl.Clear(99) 181 | u, err := dl.Load(99) 182 | require.NoError(t, err) 183 | require.Equal(t, "user 99", u[0].Name) 184 | 185 | require.Len(t, fetches, 4) 186 | }) 187 | 188 | t.Run("load all thunk", func(t *testing.T) { 189 | thunk1 := dl.LoadAllThunk([]int{5, 6}) 190 | thunk2 := dl.LoadAllThunk([]int{6, 60}) 191 | 192 | users1, err1 := thunk1() 193 | 194 | require.NoError(t, err1[0]) 195 | require.NoError(t, err1[1]) 196 | require.Equal(t, "user 5", users1[0][0].Name) 197 | require.Equal(t, "user 6", users1[1][0].Name) 198 | 199 | users2, err2 := thunk2() 200 | 201 | require.NoError(t, err2[0]) 202 | require.Error(t, err2[1]) 203 | require.Equal(t, "user 6", users2[0][0].Name) 204 | }) 205 | } 206 | -------------------------------------------------------------------------------- /example/user.go: -------------------------------------------------------------------------------- 1 | //go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User 2 | 3 | package example 4 | 5 | import ( 6 | "time" 7 | ) 8 | 9 | // User is some kind of database backed model 10 | type User struct { 11 | ID string 12 | Name string 13 | } 14 | 15 | // NewLoader will collect user requests for 2 milliseconds and send them as a single batch to the fetch func 16 | // normally fetch would be a database call. 17 | func NewLoader() *UserLoader { 18 | return &UserLoader{ 19 | wait: 2 * time.Millisecond, 20 | maxBatch: 100, 21 | fetch: func(keys []string) ([]*User, []error) { 22 | users := make([]*User, len(keys)) 23 | errors := make([]error, len(keys)) 24 | 25 | for i, key := range keys { 26 | users[i] = &User{ID: key, Name: "user " + key} 27 | } 28 | return users, errors 29 | }, 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /example/user_test.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestUserLoader(t *testing.T) { 15 | var fetches [][]string 16 | var mu sync.Mutex 17 | 18 | dl := &UserLoader{ 19 | wait: 10 * time.Millisecond, 20 | maxBatch: 5, 21 | fetch: func(keys []string) ([]*User, []error) { 22 | mu.Lock() 23 | fetches = append(fetches, keys) 24 | mu.Unlock() 25 | 26 | users := make([]*User, len(keys)) 27 | errors := make([]error, len(keys)) 28 | 29 | for i, key := range keys { 30 | if strings.HasPrefix(key, "E") { 31 | errors[i] = fmt.Errorf("user not found") 32 | } else { 33 | users[i] = &User{ID: key, Name: "user " + key} 34 | } 35 | } 36 | return users, errors 37 | }, 38 | } 39 | 40 | t.Run("fetch concurrent data", func(t *testing.T) { 41 | t.Run("load user successfully", func(t *testing.T) { 42 | t.Parallel() 43 | u, err := dl.Load("U1") 44 | require.NoError(t, err) 45 | require.Equal(t, u.ID, "U1") 46 | }) 47 | 48 | t.Run("load failed user", func(t *testing.T) { 49 | t.Parallel() 50 | u, err := dl.Load("E1") 51 | require.Error(t, err) 52 | require.Nil(t, u) 53 | }) 54 | 55 | t.Run("load many users", func(t *testing.T) { 56 | t.Parallel() 57 | u, err := dl.LoadAll([]string{"U2", "E2", "E3", "U4"}) 58 | require.Equal(t, u[0].Name, "user U2") 59 | require.Equal(t, u[3].Name, "user U4") 60 | require.Error(t, err[1]) 61 | require.Error(t, err[2]) 62 | }) 63 | 64 | t.Run("load thunk", func(t *testing.T) { 65 | t.Parallel() 66 | thunk1 := dl.LoadThunk("U5") 67 | thunk2 := dl.LoadThunk("E5") 68 | 69 | u1, err1 := thunk1() 70 | require.NoError(t, err1) 71 | require.Equal(t, "user U5", u1.Name) 72 | 73 | u2, err2 := thunk2() 74 | require.Error(t, err2) 75 | require.Nil(t, u2) 76 | }) 77 | }) 78 | 79 | t.Run("it sent two batches", func(t *testing.T) { 80 | mu.Lock() 81 | defer mu.Unlock() 82 | 83 | require.Len(t, fetches, 2) 84 | assert.Len(t, fetches[0], 5) 85 | assert.Len(t, fetches[1], 3) 86 | }) 87 | 88 | t.Run("fetch more", func(t *testing.T) { 89 | 90 | t.Run("previously cached", func(t *testing.T) { 91 | t.Parallel() 92 | u, err := dl.Load("U1") 93 | require.NoError(t, err) 94 | require.Equal(t, u.ID, "U1") 95 | }) 96 | 97 | t.Run("load many users", func(t *testing.T) { 98 | t.Parallel() 99 | u, err := dl.LoadAll([]string{"U2", "U4"}) 100 | require.NoError(t, err[0]) 101 | require.NoError(t, err[1]) 102 | require.Equal(t, u[0].Name, "user U2") 103 | require.Equal(t, u[1].Name, "user U4") 104 | }) 105 | }) 106 | 107 | t.Run("no round trips", func(t *testing.T) { 108 | mu.Lock() 109 | defer mu.Unlock() 110 | 111 | require.Len(t, fetches, 2) 112 | }) 113 | 114 | t.Run("fetch partial", func(t *testing.T) { 115 | t.Run("errors not in cache cache value", func(t *testing.T) { 116 | t.Parallel() 117 | u, err := dl.Load("E2") 118 | require.Nil(t, u) 119 | require.Error(t, err) 120 | }) 121 | 122 | t.Run("load all", func(t *testing.T) { 123 | t.Parallel() 124 | u, err := dl.LoadAll([]string{"U1", "U4", "E1", "U9", "U5"}) 125 | require.Equal(t, u[0].ID, "U1") 126 | require.Equal(t, u[1].ID, "U4") 127 | require.Error(t, err[2]) 128 | require.Equal(t, u[3].ID, "U9") 129 | require.Equal(t, u[4].ID, "U5") 130 | }) 131 | }) 132 | 133 | t.Run("one partial trip", func(t *testing.T) { 134 | mu.Lock() 135 | defer mu.Unlock() 136 | 137 | require.Len(t, fetches, 3) 138 | require.Len(t, fetches[2], 3) // E1 U9 E2 in some random order 139 | }) 140 | 141 | t.Run("primed reads dont hit the fetcher", func(t *testing.T) { 142 | dl.Prime("U99", &User{ID: "U99", Name: "Primed user"}) 143 | u, err := dl.Load("U99") 144 | require.NoError(t, err) 145 | require.Equal(t, "Primed user", u.Name) 146 | 147 | require.Len(t, fetches, 3) 148 | }) 149 | 150 | t.Run("priming in a loop is safe", func(t *testing.T) { 151 | users := []User{ 152 | {ID: "Alpha", Name: "Alpha"}, 153 | {ID: "Omega", Name: "Omega"}, 154 | } 155 | for _, user := range users { 156 | dl.Prime(user.ID, &user) 157 | } 158 | 159 | u, err := dl.Load("Alpha") 160 | require.NoError(t, err) 161 | require.Equal(t, "Alpha", u.Name) 162 | 163 | u, err = dl.Load("Omega") 164 | require.NoError(t, err) 165 | require.Equal(t, "Omega", u.Name) 166 | 167 | require.Len(t, fetches, 3) 168 | }) 169 | 170 | t.Run("cleared results will go back to the fetcher", func(t *testing.T) { 171 | dl.Clear("U99") 172 | u, err := dl.Load("U99") 173 | require.NoError(t, err) 174 | require.Equal(t, "user U99", u.Name) 175 | 176 | require.Len(t, fetches, 4) 177 | }) 178 | 179 | t.Run("load all thunk", func(t *testing.T) { 180 | thunk1 := dl.LoadAllThunk([]string{"U5", "U6"}) 181 | thunk2 := dl.LoadAllThunk([]string{"U6", "E6"}) 182 | 183 | users1, err1 := thunk1() 184 | 185 | require.NoError(t, err1[0]) 186 | require.NoError(t, err1[1]) 187 | require.Equal(t, "user U5", users1[0].Name) 188 | require.Equal(t, "user U6", users1[1].Name) 189 | 190 | users2, err2 := thunk2() 191 | 192 | require.NoError(t, err2[0]) 193 | require.Error(t, err2[1]) 194 | require.Equal(t, "user U6", users2[0].Name) 195 | }) 196 | } 197 | -------------------------------------------------------------------------------- /example/userloader_gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by github.com/vektah/dataloaden, DO NOT EDIT. 2 | 3 | package example 4 | 5 | import ( 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // UserLoaderConfig captures the config to create a new UserLoader 11 | type UserLoaderConfig struct { 12 | // Fetch is a method that provides the data for the loader 13 | Fetch func(keys []string) ([]*User, []error) 14 | 15 | // Wait is how long wait before sending a batch 16 | Wait time.Duration 17 | 18 | // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit 19 | MaxBatch int 20 | } 21 | 22 | // NewUserLoader creates a new UserLoader given a fetch, wait, and maxBatch 23 | func NewUserLoader(config UserLoaderConfig) *UserLoader { 24 | return &UserLoader{ 25 | fetch: config.Fetch, 26 | wait: config.Wait, 27 | maxBatch: config.MaxBatch, 28 | } 29 | } 30 | 31 | // UserLoader batches and caches requests 32 | type UserLoader struct { 33 | // this method provides the data for the loader 34 | fetch func(keys []string) ([]*User, []error) 35 | 36 | // how long to done before sending a batch 37 | wait time.Duration 38 | 39 | // this will limit the maximum number of keys to send in one batch, 0 = no limit 40 | maxBatch int 41 | 42 | // INTERNAL 43 | 44 | // lazily created cache 45 | cache map[string]*User 46 | 47 | // the current batch. keys will continue to be collected until timeout is hit, 48 | // then everything will be sent to the fetch method and out to the listeners 49 | batch *userLoaderBatch 50 | 51 | // mutex to prevent races 52 | mu sync.Mutex 53 | } 54 | 55 | type userLoaderBatch struct { 56 | keys []string 57 | data []*User 58 | error []error 59 | closing bool 60 | done chan struct{} 61 | } 62 | 63 | // Load a User by key, batching and caching will be applied automatically 64 | func (l *UserLoader) Load(key string) (*User, error) { 65 | return l.LoadThunk(key)() 66 | } 67 | 68 | // LoadThunk returns a function that when called will block waiting for a User. 69 | // This method should be used if you want one goroutine to make requests to many 70 | // different data loaders without blocking until the thunk is called. 71 | func (l *UserLoader) LoadThunk(key string) func() (*User, error) { 72 | l.mu.Lock() 73 | if it, ok := l.cache[key]; ok { 74 | l.mu.Unlock() 75 | return func() (*User, error) { 76 | return it, nil 77 | } 78 | } 79 | if l.batch == nil { 80 | l.batch = &userLoaderBatch{done: make(chan struct{})} 81 | } 82 | batch := l.batch 83 | pos := batch.keyIndex(l, key) 84 | l.mu.Unlock() 85 | 86 | return func() (*User, error) { 87 | <-batch.done 88 | 89 | var data *User 90 | if pos < len(batch.data) { 91 | data = batch.data[pos] 92 | } 93 | 94 | var err error 95 | // its convenient to be able to return a single error for everything 96 | if len(batch.error) == 1 { 97 | err = batch.error[0] 98 | } else if batch.error != nil { 99 | err = batch.error[pos] 100 | } 101 | 102 | if err == nil { 103 | l.mu.Lock() 104 | l.unsafeSet(key, data) 105 | l.mu.Unlock() 106 | } 107 | 108 | return data, err 109 | } 110 | } 111 | 112 | // LoadAll fetches many keys at once. It will be broken into appropriate sized 113 | // sub batches depending on how the loader is configured 114 | func (l *UserLoader) LoadAll(keys []string) ([]*User, []error) { 115 | results := make([]func() (*User, error), len(keys)) 116 | 117 | for i, key := range keys { 118 | results[i] = l.LoadThunk(key) 119 | } 120 | 121 | users := make([]*User, len(keys)) 122 | errors := make([]error, len(keys)) 123 | for i, thunk := range results { 124 | users[i], errors[i] = thunk() 125 | } 126 | return users, errors 127 | } 128 | 129 | // LoadAllThunk returns a function that when called will block waiting for a Users. 130 | // This method should be used if you want one goroutine to make requests to many 131 | // different data loaders without blocking until the thunk is called. 132 | func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*User, []error) { 133 | results := make([]func() (*User, error), len(keys)) 134 | for i, key := range keys { 135 | results[i] = l.LoadThunk(key) 136 | } 137 | return func() ([]*User, []error) { 138 | users := make([]*User, len(keys)) 139 | errors := make([]error, len(keys)) 140 | for i, thunk := range results { 141 | users[i], errors[i] = thunk() 142 | } 143 | return users, errors 144 | } 145 | } 146 | 147 | // Prime the cache with the provided key and value. If the key already exists, no change is made 148 | // and false is returned. 149 | // (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) 150 | func (l *UserLoader) Prime(key string, value *User) bool { 151 | l.mu.Lock() 152 | var found bool 153 | if _, found = l.cache[key]; !found { 154 | // make a copy when writing to the cache, its easy to pass a pointer in from a loop var 155 | // and end up with the whole cache pointing to the same value. 156 | cpy := *value 157 | l.unsafeSet(key, &cpy) 158 | } 159 | l.mu.Unlock() 160 | return !found 161 | } 162 | 163 | // Clear the value at key from the cache, if it exists 164 | func (l *UserLoader) Clear(key string) { 165 | l.mu.Lock() 166 | delete(l.cache, key) 167 | l.mu.Unlock() 168 | } 169 | 170 | func (l *UserLoader) unsafeSet(key string, value *User) { 171 | if l.cache == nil { 172 | l.cache = map[string]*User{} 173 | } 174 | l.cache[key] = value 175 | } 176 | 177 | // keyIndex will return the location of the key in the batch, if its not found 178 | // it will add the key to the batch 179 | func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { 180 | for i, existingKey := range b.keys { 181 | if key == existingKey { 182 | return i 183 | } 184 | } 185 | 186 | pos := len(b.keys) 187 | b.keys = append(b.keys, key) 188 | if pos == 0 { 189 | go b.startTimer(l) 190 | } 191 | 192 | if l.maxBatch != 0 && pos >= l.maxBatch-1 { 193 | if !b.closing { 194 | b.closing = true 195 | l.batch = nil 196 | go b.end(l) 197 | } 198 | } 199 | 200 | return pos 201 | } 202 | 203 | func (b *userLoaderBatch) startTimer(l *UserLoader) { 204 | time.Sleep(l.wait) 205 | l.mu.Lock() 206 | 207 | // we must have hit a batch limit and are already finalizing this batch 208 | if b.closing { 209 | l.mu.Unlock() 210 | return 211 | } 212 | 213 | l.batch = nil 214 | l.mu.Unlock() 215 | 216 | b.end(l) 217 | } 218 | 219 | func (b *userLoaderBatch) end(l *UserLoader) { 220 | b.data, b.error = l.fetch(b.keys) 221 | close(b.done) 222 | } 223 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/vektah/dataloaden 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/davecgh/go-spew v1.1.1 // indirect 7 | github.com/pkg/errors v0.9.1 8 | github.com/stretchr/testify v1.7.0 9 | golang.org/x/tools v0.1.10 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 5 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 10 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 11 | github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 12 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 13 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 14 | golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= 15 | golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= 16 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 17 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 18 | golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 19 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 20 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 21 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 22 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 23 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 24 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 25 | golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= 26 | golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 27 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 30 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 31 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 32 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 33 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 34 | golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= 35 | golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= 36 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 37 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 38 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 39 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 40 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 42 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 43 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 44 | -------------------------------------------------------------------------------- /licence.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Adam Scarr 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /pkg/generator/generator.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "regexp" 9 | "strings" 10 | "unicode" 11 | 12 | "github.com/pkg/errors" 13 | "golang.org/x/tools/go/packages" 14 | "golang.org/x/tools/imports" 15 | ) 16 | 17 | type templateData struct { 18 | Package string 19 | Name string 20 | KeyType *goType 21 | ValType *goType 22 | } 23 | 24 | type goType struct { 25 | Modifiers string 26 | ImportPath string 27 | ImportName string 28 | Name string 29 | } 30 | 31 | func (t *goType) String() string { 32 | if t.ImportName != "" { 33 | return t.Modifiers + t.ImportName + "." + t.Name 34 | } 35 | 36 | return t.Modifiers + t.Name 37 | } 38 | 39 | func (t *goType) IsPtr() bool { 40 | return strings.HasPrefix(t.Modifiers, "*") 41 | } 42 | 43 | func (t *goType) IsSlice() bool { 44 | return strings.HasPrefix(t.Modifiers, "[]") 45 | } 46 | 47 | var partsRe = regexp.MustCompile(`^([\[\]\*]*)(.*?)(\.\w*)?$`) 48 | 49 | func parseType(str string) (*goType, error) { 50 | parts := partsRe.FindStringSubmatch(str) 51 | if len(parts) != 4 { 52 | return nil, fmt.Errorf("type must be in the form []*github.com/import/path.Name") 53 | } 54 | 55 | t := &goType{ 56 | Modifiers: parts[1], 57 | ImportPath: parts[2], 58 | Name: strings.TrimPrefix(parts[3], "."), 59 | } 60 | 61 | if t.Name == "" { 62 | t.Name = t.ImportPath 63 | t.ImportPath = "" 64 | } 65 | 66 | if t.ImportPath != "" { 67 | p, err := packages.Load(&packages.Config{Mode: packages.NeedName}, t.ImportPath) 68 | if err != nil { 69 | return nil, err 70 | } 71 | if len(p) != 1 { 72 | return nil, fmt.Errorf("not found") 73 | } 74 | 75 | t.ImportName = p[0].Name 76 | } 77 | 78 | return t, nil 79 | } 80 | 81 | func Generate(name string, keyType string, valueType string, wd string) error { 82 | data, err := getData(name, keyType, valueType, wd) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | filename := strings.ToLower(data.Name) + "_gen.go" 88 | 89 | if err := writeTemplate(filepath.Join(wd, filename), data); err != nil { 90 | return err 91 | } 92 | 93 | return nil 94 | } 95 | 96 | func getData(name string, keyType string, valueType string, wd string) (templateData, error) { 97 | var data templateData 98 | 99 | genPkg := getPackage(wd) 100 | if genPkg == nil { 101 | return templateData{}, fmt.Errorf("unable to find package info for " + wd) 102 | } 103 | 104 | var err error 105 | data.Name = name 106 | data.Package = genPkg.Name 107 | data.KeyType, err = parseType(keyType) 108 | if err != nil { 109 | return templateData{}, fmt.Errorf("key type: %s", err.Error()) 110 | } 111 | data.ValType, err = parseType(valueType) 112 | if err != nil { 113 | return templateData{}, fmt.Errorf("key type: %s", err.Error()) 114 | } 115 | 116 | // if we are inside the same package as the type we don't need an import and can refer directly to the type 117 | if genPkg.PkgPath == data.ValType.ImportPath { 118 | data.ValType.ImportName = "" 119 | data.ValType.ImportPath = "" 120 | } 121 | if genPkg.PkgPath == data.KeyType.ImportPath { 122 | data.KeyType.ImportName = "" 123 | data.KeyType.ImportPath = "" 124 | } 125 | 126 | return data, nil 127 | } 128 | 129 | func getPackage(dir string) *packages.Package { 130 | p, _ := packages.Load(&packages.Config{ 131 | Dir: dir, 132 | }, ".") 133 | 134 | if len(p) != 1 { 135 | return nil 136 | } 137 | 138 | return p[0] 139 | } 140 | 141 | func writeTemplate(filepath string, data templateData) error { 142 | var buf bytes.Buffer 143 | if err := tpl.Execute(&buf, data); err != nil { 144 | return errors.Wrap(err, "generating code") 145 | } 146 | 147 | src, err := imports.Process(filepath, buf.Bytes(), nil) 148 | if err != nil { 149 | return errors.Wrap(err, "unable to gofmt") 150 | } 151 | 152 | if err := os.WriteFile(filepath, src, 0644); err != nil { 153 | return errors.Wrap(err, "writing output") 154 | } 155 | 156 | return nil 157 | } 158 | 159 | func lcFirst(s string) string { 160 | r := []rune(s) 161 | r[0] = unicode.ToLower(r[0]) 162 | return string(r) 163 | } 164 | -------------------------------------------------------------------------------- /pkg/generator/generator_test.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestParseType(t *testing.T) { 10 | require.Equal(t, &goType{Name: "string"}, parse("string")) 11 | require.Equal(t, &goType{Name: "Time", ImportPath: "time", ImportName: "time"}, parse("time.Time")) 12 | require.Equal(t, &goType{ 13 | Name: "Foo", 14 | ImportPath: "github.com/vektah/dataloaden/pkg/generator/testdata/mismatch", 15 | ImportName: "mismatched", 16 | }, parse("github.com/vektah/dataloaden/pkg/generator/testdata/mismatch.Foo")) 17 | } 18 | 19 | func parse(s string) *goType { 20 | t, err := parseType(s) 21 | if err != nil { 22 | panic(err) 23 | } 24 | 25 | return t 26 | } 27 | -------------------------------------------------------------------------------- /pkg/generator/template.go: -------------------------------------------------------------------------------- 1 | package generator 2 | 3 | import "text/template" 4 | 5 | var tpl = template.Must(template.New("generated"). 6 | Funcs(template.FuncMap{ 7 | "lcFirst": lcFirst, 8 | }). 9 | Parse(` 10 | // Code generated by github.com/vektah/dataloaden, DO NOT EDIT. 11 | 12 | package {{.Package}} 13 | 14 | import ( 15 | "sync" 16 | "time" 17 | 18 | {{if .KeyType.ImportPath}}"{{.KeyType.ImportPath}}"{{end}} 19 | {{if .ValType.ImportPath}}"{{.ValType.ImportPath}}"{{end}} 20 | ) 21 | 22 | // {{.Name}}Config captures the config to create a new {{.Name}} 23 | type {{.Name}}Config struct { 24 | // Fetch is a method that provides the data for the loader 25 | Fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) 26 | 27 | // Wait is how long wait before sending a batch 28 | Wait time.Duration 29 | 30 | // MaxBatch will limit the maximum number of keys to send in one batch, 0 = not limit 31 | MaxBatch int 32 | } 33 | 34 | // New{{.Name}} creates a new {{.Name}} given a fetch, wait, and maxBatch 35 | func New{{.Name}}(config {{.Name}}Config) *{{.Name}} { 36 | return &{{.Name}}{ 37 | fetch: config.Fetch, 38 | wait: config.Wait, 39 | maxBatch: config.MaxBatch, 40 | } 41 | } 42 | 43 | // {{.Name}} batches and caches requests 44 | type {{.Name}} struct { 45 | // this method provides the data for the loader 46 | fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) 47 | 48 | // how long to done before sending a batch 49 | wait time.Duration 50 | 51 | // this will limit the maximum number of keys to send in one batch, 0 = no limit 52 | maxBatch int 53 | 54 | // INTERNAL 55 | 56 | // lazily created cache 57 | cache map[{{.KeyType.String}}]{{.ValType.String}} 58 | 59 | // the current batch. keys will continue to be collected until timeout is hit, 60 | // then everything will be sent to the fetch method and out to the listeners 61 | batch *{{.Name|lcFirst}}Batch 62 | 63 | // mutex to prevent races 64 | mu sync.Mutex 65 | } 66 | 67 | type {{.Name|lcFirst}}Batch struct { 68 | keys []{{.KeyType}} 69 | data []{{.ValType.String}} 70 | error []error 71 | closing bool 72 | done chan struct{} 73 | } 74 | 75 | // Load a {{.ValType.Name}} by key, batching and caching will be applied automatically 76 | func (l *{{.Name}}) Load(key {{.KeyType.String}}) ({{.ValType.String}}, error) { 77 | return l.LoadThunk(key)() 78 | } 79 | 80 | // LoadThunk returns a function that when called will block waiting for a {{.ValType.Name}}. 81 | // This method should be used if you want one goroutine to make requests to many 82 | // different data loaders without blocking until the thunk is called. 83 | func (l *{{.Name}}) LoadThunk(key {{.KeyType.String}}) func() ({{.ValType.String}}, error) { 84 | l.mu.Lock() 85 | if it, ok := l.cache[key]; ok { 86 | l.mu.Unlock() 87 | return func() ({{.ValType.String}}, error) { 88 | return it, nil 89 | } 90 | } 91 | if l.batch == nil { 92 | l.batch = &{{.Name|lcFirst}}Batch{done: make(chan struct{})} 93 | } 94 | batch := l.batch 95 | pos := batch.keyIndex(l, key) 96 | l.mu.Unlock() 97 | 98 | return func() ({{.ValType.String}}, error) { 99 | <-batch.done 100 | 101 | var data {{.ValType.String}} 102 | if pos < len(batch.data) { 103 | data = batch.data[pos] 104 | } 105 | 106 | var err error 107 | // its convenient to be able to return a single error for everything 108 | if len(batch.error) == 1 { 109 | err = batch.error[0] 110 | } else if batch.error != nil { 111 | err = batch.error[pos] 112 | } 113 | 114 | if err == nil { 115 | l.mu.Lock() 116 | l.unsafeSet(key, data) 117 | l.mu.Unlock() 118 | } 119 | 120 | return data, err 121 | } 122 | } 123 | 124 | // LoadAll fetches many keys at once. It will be broken into appropriate sized 125 | // sub batches depending on how the loader is configured 126 | func (l *{{.Name}}) LoadAll(keys []{{.KeyType}}) ([]{{.ValType.String}}, []error) { 127 | results := make([]func() ({{.ValType.String}}, error), len(keys)) 128 | 129 | for i, key := range keys { 130 | results[i] = l.LoadThunk(key) 131 | } 132 | 133 | {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) 134 | errors := make([]error, len(keys)) 135 | for i, thunk := range results { 136 | {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() 137 | } 138 | return {{.ValType.Name|lcFirst}}s, errors 139 | } 140 | 141 | // LoadAllThunk returns a function that when called will block waiting for a {{.ValType.Name}}s. 142 | // This method should be used if you want one goroutine to make requests to many 143 | // different data loaders without blocking until the thunk is called. 144 | func (l *{{.Name}}) LoadAllThunk(keys []{{.KeyType}}) (func() ([]{{.ValType.String}}, []error)) { 145 | results := make([]func() ({{.ValType.String}}, error), len(keys)) 146 | for i, key := range keys { 147 | results[i] = l.LoadThunk(key) 148 | } 149 | return func() ([]{{.ValType.String}}, []error) { 150 | {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) 151 | errors := make([]error, len(keys)) 152 | for i, thunk := range results { 153 | {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() 154 | } 155 | return {{.ValType.Name|lcFirst}}s, errors 156 | } 157 | } 158 | 159 | // Prime the cache with the provided key and value. If the key already exists, no change is made 160 | // and false is returned. 161 | // (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) 162 | func (l *{{.Name}}) Prime(key {{.KeyType}}, value {{.ValType.String}}) bool { 163 | l.mu.Lock() 164 | var found bool 165 | if _, found = l.cache[key]; !found { 166 | {{- if .ValType.IsPtr }} 167 | // make a copy when writing to the cache, its easy to pass a pointer in from a loop var 168 | // and end up with the whole cache pointing to the same value. 169 | cpy := *value 170 | l.unsafeSet(key, &cpy) 171 | {{- else if .ValType.IsSlice }} 172 | // make a copy when writing to the cache, its easy to pass a pointer in from a loop var 173 | // and end up with the whole cache pointing to the same value. 174 | cpy := make({{.ValType.String}}, len(value)) 175 | copy(cpy, value) 176 | l.unsafeSet(key, cpy) 177 | {{- else }} 178 | l.unsafeSet(key, value) 179 | {{- end }} 180 | } 181 | l.mu.Unlock() 182 | return !found 183 | } 184 | 185 | // Clear the value at key from the cache, if it exists 186 | func (l *{{.Name}}) Clear(key {{.KeyType}}) { 187 | l.mu.Lock() 188 | delete(l.cache, key) 189 | l.mu.Unlock() 190 | } 191 | 192 | func (l *{{.Name}}) unsafeSet(key {{.KeyType}}, value {{.ValType.String}}) { 193 | if l.cache == nil { 194 | l.cache = map[{{.KeyType}}]{{.ValType.String}}{} 195 | } 196 | l.cache[key] = value 197 | } 198 | 199 | // keyIndex will return the location of the key in the batch, if its not found 200 | // it will add the key to the batch 201 | func (b *{{.Name|lcFirst}}Batch) keyIndex(l *{{.Name}}, key {{.KeyType}}) int { 202 | for i, existingKey := range b.keys { 203 | if key == existingKey { 204 | return i 205 | } 206 | } 207 | 208 | pos := len(b.keys) 209 | b.keys = append(b.keys, key) 210 | if pos == 0 { 211 | go b.startTimer(l) 212 | } 213 | 214 | if l.maxBatch != 0 && pos >= l.maxBatch-1 { 215 | if !b.closing { 216 | b.closing = true 217 | l.batch = nil 218 | go b.end(l) 219 | } 220 | } 221 | 222 | return pos 223 | } 224 | 225 | func (b *{{.Name|lcFirst}}Batch) startTimer(l *{{.Name}}) { 226 | time.Sleep(l.wait) 227 | l.mu.Lock() 228 | 229 | // we must have hit a batch limit and are already finalizing this batch 230 | if b.closing { 231 | l.mu.Unlock() 232 | return 233 | } 234 | 235 | l.batch = nil 236 | l.mu.Unlock() 237 | 238 | b.end(l) 239 | } 240 | 241 | func (b *{{.Name|lcFirst}}Batch) end(l *{{.Name}}) { 242 | b.data, b.error = l.fetch(b.keys) 243 | close(b.done) 244 | } 245 | `)) 246 | -------------------------------------------------------------------------------- /pkg/generator/testdata/mismatch/mismatch.go: -------------------------------------------------------------------------------- 1 | package mismatched 2 | 3 | type Foo struct { 4 | Name string 5 | } 6 | --------------------------------------------------------------------------------