├── go.mod ├── go.sum ├── .travis.yml ├── testutils_test.go ├── example_test.go ├── LICENSE ├── README.md ├── dnscache_test.go └── dnscache.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/rs/dnscache 2 | 3 | go 1.12 4 | 5 | require golang.org/x/sync v0.0.0-20190423024810-112230192c58 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= 2 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - "1.8" 4 | - "1.9" 5 | - "1.10" 6 | - "1.11" 7 | - "1.12" 8 | - tip 9 | matrix: 10 | allow_failures: 11 | - go: tip 12 | script: 13 | go test -v -race -cpu=1,2,4 -bench . -benchmem ./... 14 | -------------------------------------------------------------------------------- /testutils_test.go: -------------------------------------------------------------------------------- 1 | package dnscache 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | type BadResolver struct { 9 | choke bool 10 | } 11 | 12 | func (r BadResolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) { 13 | return 14 | } 15 | 16 | func (r BadResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { 17 | if r.choke { 18 | err = errors.New("Look Up Failed") 19 | } else { 20 | addrs = []string{"216.58.192.238"} 21 | } 22 | return 23 | } 24 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package dnscache 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | func Example() { 11 | r := &Resolver{} 12 | t := &http.Transport{ 13 | DialContext: func(ctx context.Context, network string, addr string) (conn net.Conn, err error) { 14 | host, port, err := net.SplitHostPort(addr) 15 | if err != nil { 16 | return nil, err 17 | } 18 | ips, err := r.LookupHost(ctx, host) 19 | if err != nil { 20 | return nil, err 21 | } 22 | for _, ip := range ips { 23 | var dialer net.Dialer 24 | conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) 25 | if err == nil { 26 | break 27 | } 28 | } 29 | return 30 | }, 31 | } 32 | c := &http.Client{Transport: t} 33 | res, err := c.Get("http://httpbin.org/status/418") 34 | if err == nil { 35 | fmt.Println(res.StatusCode) 36 | } 37 | // Output: 418 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Olivier Poitrey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNS Lookup Cache 2 | 3 | [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/dnscache/master/LICENSE) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/rs/dnscache)](https://goreportcard.com/report/github.com/rs/dnscache) 5 | [![Build Status](https://travis-ci.org/rs/dnscache.svg?branch=master)](https://travis-ci.org/rs/dnscache) 6 | [![Coverage](http://gocover.io/_badge/github.com/rs/dnscache)](http://gocover.io/github.com/rs/dnscache) 7 | [![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/dnscache) 8 | 9 | The dnscache package provides a DNS cache layer to Go's `net.Resolver`. 10 | 11 | # Install 12 | 13 | Install using the "go get" command: 14 | 15 | ``` 16 | go get -u github.com/rs/dnscache 17 | ``` 18 | 19 | # Usage 20 | 21 | Create a new instance and use it in place of `net.Resolver`. New names will be cached. Call the `Refresh` method at regular interval to update cached entries and cleanup unused ones. 22 | 23 | ```go 24 | resolver := &dnscache.Resolver{} 25 | 26 | // First call will cache the result 27 | addrs, err := resolver.LookupHost(context.Background(), "example.com") 28 | 29 | // Subsequent calls will use the cached result 30 | addrs, err = resolver.LookupHost(context.Background(), "example.com") 31 | 32 | // Call to refresh will refresh names in cache. If you pass true, it will also 33 | // remove cached names not looked up since the last call to Refresh. It is a good idea 34 | // to call this method on a regular interval. 35 | go func() { 36 | t := time.NewTicker(5 * time.Minute) 37 | defer t.Stop() 38 | for range t.C { 39 | resolver.Refresh(true) 40 | } 41 | }() 42 | ``` 43 | 44 | If you are using an `http.Transport`, you can use this cache by specifying a `DialContext` function: 45 | 46 | ```go 47 | r := &dnscache.Resolver{} 48 | t := &http.Transport{ 49 | DialContext: func(ctx context.Context, network string, addr string) (conn net.Conn, err error) { 50 | host, port, err := net.SplitHostPort(addr) 51 | if err != nil { 52 | return nil, err 53 | } 54 | ips, err := r.LookupHost(ctx, host) 55 | if err != nil { 56 | return nil, err 57 | } 58 | for _, ip := range ips { 59 | var dialer net.Dialer 60 | conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) 61 | if err == nil { 62 | break 63 | } 64 | } 65 | return 66 | }, 67 | } 68 | ``` 69 | 70 | In addition to the `Refresh` method, you can `RefreshWithOptions`. This method adds an option to persist resource records 71 | on failed lookups 72 | ```go 73 | r := &Resolver{} 74 | options := dnscache.ResolverRefreshOptions{} 75 | options.ClearUnused = true 76 | options.PersistOnFailure = false 77 | resolver.RefreshWithOptions(options) 78 | ``` 79 | -------------------------------------------------------------------------------- /dnscache_test.go: -------------------------------------------------------------------------------- 1 | package dnscache 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http/httptrace" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestResolver_LookupHost(t *testing.T) { 12 | r := &Resolver{} 13 | var cacheMiss bool 14 | r.OnCacheMiss = func() { 15 | cacheMiss = true 16 | } 17 | hosts := []string{"google.com", "google.com.", "netflix.com"} 18 | for _, host := range hosts { 19 | t.Run(host, func(t *testing.T) { 20 | for _, wantMiss := range []bool{true, false, false} { 21 | cacheMiss = false 22 | addrs, err := r.LookupHost(context.Background(), host) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | if len(addrs) == 0 { 27 | t.Error("got no record") 28 | } 29 | for _, addr := range addrs { 30 | if net.ParseIP(addr) == nil { 31 | t.Errorf("got %q; want a literal IP address", addr) 32 | } 33 | } 34 | if wantMiss != cacheMiss { 35 | t.Errorf("got cache miss=%v, want %v", cacheMiss, wantMiss) 36 | } 37 | } 38 | }) 39 | } 40 | } 41 | 42 | func TestClearCache(t *testing.T) { 43 | r := &Resolver{} 44 | _, _ = r.LookupHost(context.Background(), "google.com") 45 | if e := r.cache["hgoogle.com"]; e != nil && !e.used { 46 | t.Error("cache entry used flag is false, want true") 47 | } 48 | r.Refresh(true) 49 | if e := r.cache["hgoogle.com"]; e != nil && e.used { 50 | t.Error("cache entry used flag is true, want false") 51 | } 52 | r.Refresh(true) 53 | if e := r.cache["hgoogle.com"]; e != nil { 54 | t.Error("cache entry is not cleared") 55 | } 56 | 57 | options := ResolverRefreshOptions{} 58 | options.ClearUnused = true 59 | options.PersistOnFailure = false 60 | _, _ = r.LookupHost(context.Background(), "google.com") 61 | if e := r.cache["hgoogle.com"]; e != nil && !e.used { 62 | t.Error("cache entry used flag is false, want true") 63 | } 64 | r.RefreshWithOptions(options) 65 | if e := r.cache["hgoogle.com"]; e != nil && e.used { 66 | t.Error("cache entry used flag is true, want false") 67 | } 68 | r.RefreshWithOptions(options) 69 | if e := r.cache["hgoogle.com"]; e != nil { 70 | t.Error("cache entry is not cleared") 71 | } 72 | 73 | options.ClearUnused = false 74 | options.PersistOnFailure = true 75 | br := &Resolver{} 76 | br.Resolver = BadResolver{} 77 | 78 | _, _ = br.LookupHost(context.Background(), "google.com") 79 | br.Resolver = BadResolver{choke: true} 80 | br.RefreshWithOptions(options) 81 | if len(br.cache["hgoogle.com"].rrs) == 0 { 82 | t.Error("cache entry is cleared") 83 | } 84 | } 85 | 86 | func TestRaceOnDelete(t *testing.T) { 87 | r := &Resolver{} 88 | ls := make(chan bool) 89 | rs := make(chan bool) 90 | 91 | go func() { 92 | for { 93 | select { 94 | case <-ls: 95 | return 96 | default: 97 | r.LookupHost(context.Background(), "google.com") 98 | time.Sleep(2 * time.Millisecond) 99 | } 100 | } 101 | }() 102 | 103 | go func() { 104 | for { 105 | select { 106 | case <-rs: 107 | return 108 | default: 109 | r.Refresh(true) 110 | time.Sleep(time.Millisecond) 111 | } 112 | } 113 | }() 114 | 115 | time.Sleep(1 * time.Second) 116 | 117 | ls <- true 118 | rs <- true 119 | } 120 | 121 | func TestResolver_LookupHost_DNSHooksGetTriggerd(t *testing.T) { 122 | var ( 123 | dnsStartInfo *httptrace.DNSStartInfo 124 | dnsDoneInfo *httptrace.DNSDoneInfo 125 | ) 126 | 127 | trace := &httptrace.ClientTrace{ 128 | DNSStart: func(info httptrace.DNSStartInfo) { 129 | dnsStartInfo = &info 130 | }, 131 | DNSDone: func(info httptrace.DNSDoneInfo) { 132 | dnsDoneInfo = &info 133 | }, 134 | } 135 | 136 | ctx := httptrace.WithClientTrace(context.Background(), trace) 137 | 138 | r := &Resolver{} 139 | 140 | _, err := r.LookupHost(ctx, "example.com") 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | 145 | if dnsStartInfo == nil { 146 | t.Error("dnsStartInfo is nil, indicating that DNSStart callback has not been invoked") 147 | } 148 | 149 | if dnsDoneInfo == nil { 150 | t.Error("dnsDoneInfo is nil, indicating that DNSDone callback has not been invoked") 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /dnscache.go: -------------------------------------------------------------------------------- 1 | package dnscache 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http/httptrace" 7 | "sync" 8 | "time" 9 | 10 | "golang.org/x/sync/singleflight" 11 | ) 12 | 13 | type DNSResolver interface { 14 | LookupHost(ctx context.Context, host string) (addrs []string, err error) 15 | LookupAddr(ctx context.Context, addr string) (names []string, err error) 16 | } 17 | 18 | type Resolver struct { 19 | // Timeout defines the maximum allowed time allowed for a lookup. 20 | Timeout time.Duration 21 | 22 | // Resolver is used to perform actual DNS lookup. If nil, 23 | // net.DefaultResolver is used instead. 24 | Resolver DNSResolver 25 | 26 | once sync.Once 27 | mu sync.RWMutex 28 | cache map[string]*cacheEntry 29 | 30 | // OnCacheMiss is executed if the host or address is not included in 31 | // the cache and the default lookup is executed. 32 | OnCacheMiss func() 33 | } 34 | 35 | type ResolverRefreshOptions struct { 36 | ClearUnused bool 37 | PersistOnFailure bool 38 | } 39 | 40 | type cacheEntry struct { 41 | rrs []string 42 | err error 43 | used bool 44 | } 45 | 46 | // LookupAddr performs a reverse lookup for the given address, returning a list 47 | // of names mapping to that address. 48 | func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) { 49 | r.once.Do(r.init) 50 | return r.lookup(ctx, "r"+addr) 51 | } 52 | 53 | // LookupHost looks up the given host using the local resolver. It returns a 54 | // slice of that host's addresses. 55 | func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { 56 | r.once.Do(r.init) 57 | return r.lookup(ctx, "h"+host) 58 | } 59 | 60 | // refreshRecords refreshes cached entries which have been used at least once since 61 | // the last Refresh. If clearUnused is true, entries which haven't be used since the 62 | // last Refresh are removed from the cache. If persistOnFailure is true, stale 63 | // entries will not be removed on failed lookups 64 | func (r *Resolver) refreshRecords(clearUnused bool, persistOnFailure bool) { 65 | r.once.Do(r.init) 66 | r.mu.RLock() 67 | update := make([]string, 0, len(r.cache)) 68 | del := make([]string, 0, len(r.cache)) 69 | for key, entry := range r.cache { 70 | if entry.used { 71 | update = append(update, key) 72 | } else if clearUnused { 73 | del = append(del, key) 74 | } 75 | } 76 | r.mu.RUnlock() 77 | 78 | if len(del) > 0 { 79 | r.mu.Lock() 80 | for _, key := range del { 81 | delete(r.cache, key) 82 | } 83 | r.mu.Unlock() 84 | } 85 | 86 | for _, key := range update { 87 | r.update(context.Background(), key, false, persistOnFailure) 88 | } 89 | } 90 | 91 | func (r *Resolver) Refresh(clearUnused bool) { 92 | r.refreshRecords(clearUnused, false) 93 | } 94 | 95 | func (r *Resolver) RefreshWithOptions(options ResolverRefreshOptions) { 96 | r.refreshRecords(options.ClearUnused, options.PersistOnFailure) 97 | } 98 | 99 | func (r *Resolver) init() { 100 | r.cache = make(map[string]*cacheEntry) 101 | } 102 | 103 | // lookupGroup merges lookup calls together for lookups for the same host. The 104 | // lookupGroup key is is the LookupIPAddr.host argument. 105 | var lookupGroup singleflight.Group 106 | 107 | func (r *Resolver) lookup(ctx context.Context, key string) (rrs []string, err error) { 108 | var found bool 109 | rrs, err, found = r.load(key) 110 | if !found { 111 | if r.OnCacheMiss != nil { 112 | r.OnCacheMiss() 113 | } 114 | rrs, err = r.update(ctx, key, true, false) 115 | } 116 | return 117 | } 118 | 119 | func (r *Resolver) update(ctx context.Context, key string, used bool, persistOnFailure bool) (rrs []string, err error) { 120 | c := lookupGroup.DoChan(key, r.lookupFunc(ctx, key)) 121 | select { 122 | case <-ctx.Done(): 123 | err = ctx.Err() 124 | if err == context.DeadlineExceeded { 125 | // If DNS request timed out for some reason, force future 126 | // request to start the DNS lookup again rather than waiting 127 | // for the current lookup to complete. 128 | lookupGroup.Forget(key) 129 | } 130 | case res := <-c: 131 | if res.Shared { 132 | // We had concurrent lookups, check if the cache is already updated 133 | // by a friend. 134 | var found bool 135 | rrs, err, found = r.load(key) 136 | if found { 137 | return 138 | } 139 | } 140 | err = res.Err 141 | if err == nil { 142 | rrs, _ = res.Val.([]string) 143 | } 144 | 145 | if err != nil && persistOnFailure { 146 | var found bool 147 | rrs, err, found = r.load(key) 148 | if found { 149 | return 150 | } 151 | } 152 | 153 | r.mu.Lock() 154 | r.storeLocked(key, rrs, used, err) 155 | r.mu.Unlock() 156 | } 157 | return 158 | } 159 | 160 | // lookupFunc returns lookup function for key. The type of the key is stored as 161 | // the first char and the lookup subject is the rest of the key. 162 | func (r *Resolver) lookupFunc(ctx context.Context, key string) func() (interface{}, error) { 163 | if len(key) == 0 { 164 | panic("lookupFunc with empty key") 165 | } 166 | 167 | var resolver DNSResolver = defaultResolver 168 | if r.Resolver != nil { 169 | resolver = r.Resolver 170 | } 171 | 172 | switch key[0] { 173 | case 'h': 174 | return func() (interface{}, error) { 175 | ctx, cancel := r.prepareCtx(ctx) 176 | defer cancel() 177 | 178 | return resolver.LookupHost(ctx, key[1:]) 179 | } 180 | case 'r': 181 | return func() (interface{}, error) { 182 | ctx, cancel := r.prepareCtx(ctx) 183 | defer cancel() 184 | 185 | return resolver.LookupAddr(ctx, key[1:]) 186 | } 187 | default: 188 | panic("lookupFunc invalid key type: " + key) 189 | } 190 | } 191 | 192 | func (r *Resolver) prepareCtx(origContext context.Context) (ctx context.Context, cancel context.CancelFunc) { 193 | ctx = context.Background() 194 | if r.Timeout > 0 { 195 | ctx, cancel = context.WithTimeout(ctx, r.Timeout) 196 | } else { 197 | cancel = func() {} 198 | } 199 | 200 | // If a httptrace has been attached to the given context it will be copied over to the newly created context. We only need to copy pointers 201 | // to DNSStart and DNSDone hooks 202 | if trace := httptrace.ContextClientTrace(origContext); trace != nil { 203 | derivedTrace := &httptrace.ClientTrace{ 204 | DNSStart: trace.DNSStart, 205 | DNSDone: trace.DNSDone, 206 | } 207 | 208 | ctx = httptrace.WithClientTrace(ctx, derivedTrace) 209 | } 210 | 211 | return 212 | } 213 | 214 | func (r *Resolver) load(key string) (rrs []string, err error, found bool) { 215 | r.mu.RLock() 216 | var entry *cacheEntry 217 | entry, found = r.cache[key] 218 | if !found { 219 | r.mu.RUnlock() 220 | return 221 | } 222 | rrs = entry.rrs 223 | err = entry.err 224 | used := entry.used 225 | r.mu.RUnlock() 226 | if !used { 227 | r.mu.Lock() 228 | entry.used = true 229 | r.mu.Unlock() 230 | } 231 | return rrs, err, true 232 | } 233 | 234 | func (r *Resolver) storeLocked(key string, rrs []string, used bool, err error) { 235 | if entry, found := r.cache[key]; found { 236 | // Update existing entry in place 237 | entry.rrs = rrs 238 | entry.err = err 239 | entry.used = used 240 | return 241 | } 242 | r.cache[key] = &cacheEntry{ 243 | rrs: rrs, 244 | err: err, 245 | used: used, 246 | } 247 | } 248 | 249 | var defaultResolver = &defaultResolverWithTrace{ 250 | ipVersion: "ip", 251 | } 252 | 253 | // Create a new resolver that only resolves to IPv4 Addresses when looking up Hosts. 254 | // Example: 255 | // 256 | // resolver := dnscache.Resolver{ 257 | // Resolver: NewResolverOnlyV4(), 258 | // } 259 | func NewResolverOnlyV4() DNSResolver { 260 | return &defaultResolverWithTrace{ 261 | ipVersion: "ip4", 262 | } 263 | } 264 | 265 | // Create a new resolver that only resolves to IPv6 Addresses when looking up Hosts. 266 | // Example: 267 | // 268 | // resolver := dnscache.Resolver{ 269 | // Resolver: NewResolverOnlyV6(), 270 | // } 271 | func NewResolverOnlyV6() DNSResolver { 272 | return &defaultResolverWithTrace{ 273 | ipVersion: "ip6", 274 | } 275 | } 276 | 277 | // defaultResolverWithTrace calls `LookupIP` instead of `LookupHost` on `net.DefaultResolver` in order to cause invocation of the `DNSStart` 278 | // and `DNSDone` hooks. By implementing `DNSResolver`, backward compatibility can be ensured. 279 | type defaultResolverWithTrace struct { 280 | ipVersion string 281 | } 282 | 283 | func (d *defaultResolverWithTrace) LookupHost(ctx context.Context, host string) (addrs []string, err error) { 284 | ipVersion := d.ipVersion 285 | if ipVersion != "ip" && ipVersion != "ip4" && ipVersion != "ip6" { 286 | ipVersion = "ip" 287 | } 288 | 289 | // `net.Resolver#LookupHost` does not cause invocation of `net.Resolver#lookupIPAddr`, therefore the `DNSStart` and `DNSDone` tracing hooks 290 | // built into the stdlib are never called. `LookupIP`, despite it's name, can also be used to lookup a hostname but does cause these hooks to be 291 | // triggered. The format of the reponse is different, therefore it needs this thin wrapper converting it. 292 | rawIPs, err := net.DefaultResolver.LookupIP(ctx, ipVersion, host) 293 | if err != nil { 294 | return nil, err 295 | } 296 | 297 | cookedIPs := make([]string, len(rawIPs)) 298 | 299 | for i, v := range rawIPs { 300 | cookedIPs[i] = v.String() 301 | } 302 | 303 | return cookedIPs, nil 304 | } 305 | 306 | func (d *defaultResolverWithTrace) LookupAddr(ctx context.Context, addr string) (names []string, err error) { 307 | return net.DefaultResolver.LookupAddr(ctx, addr) 308 | } 309 | --------------------------------------------------------------------------------