├── .gitignore ├── zdns.service ├── .github ├── workflows │ └── ci.yml └── dependabot.yml ├── http ├── prometheus.go ├── router.go ├── http_test.go └── http.go ├── Makefile ├── cmd └── zdns │ ├── main_test.go │ └── main.go ├── go.mod ├── signal ├── signal_test.go └── signal.go ├── sql ├── cache_test.go ├── cache.go ├── logger.go ├── logger_test.go ├── sql_test.go └── sql.go ├── dns ├── http │ ├── client.go │ └── client_test.go ├── dnsutil │ ├── dnsutil.go │ └── dnsutil_test.go ├── proxy.go └── proxy_test.go ├── hosts ├── hosts.go └── hosts_test.go ├── zdnsrc ├── server.go ├── config_test.go ├── server_test.go ├── go.sum ├── README.md ├── config.go ├── cache ├── cache.go └── cache_test.go └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | new.txt 2 | old.txt 3 | -------------------------------------------------------------------------------- /zdns.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | After=network-online.target 3 | 4 | [Service] 5 | ExecStart=/usr/local/bin/zdns -f /etc/zdnsrc 6 | ExecReload=/bin/kill -HUP $MAINPID 7 | Restart=always 8 | 9 | [Install] 10 | WantedBy=multi-user.target 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: install go 15 | uses: actions/setup-go@v2 16 | with: 17 | go-version: 1.24 18 | - name: build and test 19 | run: make 20 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "gomod" 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | open-pull-requests-limit: 5 13 | -------------------------------------------------------------------------------- /http/prometheus.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "github.com/prometheus/client_golang/prometheus/promauto" 6 | "github.com/prometheus/client_golang/prometheus/promhttp" 7 | ) 8 | 9 | var ( 10 | totalRequestsGauge = promauto.NewGauge(prometheus.GaugeOpts{ 11 | Name: "zdns_requests_total", 12 | Help: "The total number of DNS requests.", 13 | }) 14 | hijackedRequestsGauge = promauto.NewGauge(prometheus.GaugeOpts{ 15 | Name: "zdns_requests_hijacked", 16 | Help: "The number of hijacked DNS requests.", 17 | }) 18 | prometheusHandler = promhttp.Handler() 19 | ) 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | XGOARCH := amd64 2 | XGOOS := linux 3 | XBIN := $(XGOOS)_$(XGOARCH)/zdns 4 | 5 | all: lint test-race install 6 | 7 | test: 8 | go test ./... 9 | 10 | test-race: 11 | go test -race ./... 12 | 13 | vet: 14 | go vet ./... 15 | 16 | fmt: 17 | @sh -c "test -z $$(gofmt -l .)" || { echo "one or more files need to be formatted: try make fmt to fix this automatically"; exit 1; } 18 | 19 | lint: fmt vet 20 | 21 | install: 22 | go install ./... 23 | 24 | xinstall: 25 | # TODO: Switch to -static flag once 1.14 is released. 26 | # https://github.com/golang/go/issues/26492 27 | env GOOS=$(XGOOS) GOARCH=$(XGOARCH) CGO_ENABLED=1 \ 28 | CC=x86_64-linux-musl-gcc go install -ldflags '-extldflags "-static"' ./... 29 | 30 | publish: 31 | ifndef DEST_PATH 32 | $(error DEST_PATH must be set when publishing) 33 | endif 34 | rsync -az $(GOPATH)/bin/$(XBIN) $(DEST_PATH)/$(XBIN) 35 | @sha256sum $(GOPATH)/bin/$(XBIN) 36 | -------------------------------------------------------------------------------- /cmd/zdns/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "syscall" 7 | "testing" 8 | ) 9 | 10 | func tempFile(t *testing.T, s string) (string, error) { 11 | f, err := ioutil.TempFile("", "zdns") 12 | if err != nil { 13 | return "", err 14 | } 15 | defer f.Close() 16 | if err := ioutil.WriteFile(f.Name(), []byte(s), 0644); err != nil { 17 | return "", err 18 | } 19 | return f.Name(), nil 20 | } 21 | 22 | func TestMain(t *testing.T) { 23 | conf := ` 24 | [dns] 25 | listen = "127.0.0.1:0" 26 | listen_http = "127.0.0.1:0" 27 | 28 | [resolver] 29 | protocol = "udp" 30 | timeout = "1s" 31 | 32 | [filter] 33 | hijack_mode = "zero" 34 | ` 35 | f, err := tempFile(t, conf) 36 | if err != nil { 37 | t.Fatal(err) 38 | } 39 | defer os.Remove(f) 40 | 41 | sig := make(chan os.Signal, 1) 42 | cli := newCli(ioutil.Discard, []string{"-f", f}, f, sig) 43 | sig <- syscall.SIGTERM 44 | cli.sh.Close() 45 | } 46 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mpolden/zdns 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/BurntSushi/toml v1.5.0 9 | github.com/cenkalti/backoff/v4 v4.3.0 10 | github.com/jmoiron/sqlx v1.4.0 11 | github.com/mattn/go-sqlite3 v1.14.30 12 | github.com/miekg/dns v1.1.67 13 | github.com/prometheus/client_golang v1.23.2 14 | ) 15 | 16 | require ( 17 | github.com/beorn7/perks v1.0.1 // indirect 18 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 19 | github.com/kr/text v0.2.0 // indirect 20 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 21 | github.com/prometheus/client_model v0.6.2 // indirect 22 | github.com/prometheus/common v0.66.1 // indirect 23 | github.com/prometheus/procfs v0.16.1 // indirect 24 | go.yaml.in/yaml/v2 v2.4.2 // indirect 25 | golang.org/x/mod v0.24.0 // indirect 26 | golang.org/x/net v0.43.0 // indirect 27 | golang.org/x/sync v0.14.0 // indirect 28 | golang.org/x/sys v0.35.0 // indirect 29 | golang.org/x/tools v0.33.0 // indirect 30 | google.golang.org/protobuf v1.36.8 // indirect 31 | ) 32 | -------------------------------------------------------------------------------- /signal/signal_test.go: -------------------------------------------------------------------------------- 1 | package signal 2 | 3 | import ( 4 | "os" 5 | "sync" 6 | "syscall" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | type reloaderCloser struct { 12 | mu sync.RWMutex 13 | reloaded bool 14 | closed bool 15 | } 16 | 17 | func (rc *reloaderCloser) Reload() { 18 | rc.mu.Lock() 19 | defer rc.mu.Unlock() 20 | rc.reloaded = true 21 | } 22 | func (rc *reloaderCloser) Close() error { 23 | rc.mu.Lock() 24 | defer rc.mu.Unlock() 25 | rc.closed = true 26 | return nil 27 | } 28 | func (rc *reloaderCloser) isReloaded() bool { 29 | rc.mu.RLock() 30 | defer rc.mu.RUnlock() 31 | return rc.reloaded 32 | } 33 | func (rc *reloaderCloser) isClosed() bool { 34 | rc.mu.RLock() 35 | defer rc.mu.RUnlock() 36 | return rc.closed 37 | } 38 | 39 | func (rc *reloaderCloser) reset() { 40 | rc.mu.Lock() 41 | defer rc.mu.Unlock() 42 | rc.reloaded = false 43 | rc.closed = false 44 | } 45 | 46 | func TestHandler(t *testing.T) { 47 | h := NewHandler(make(chan os.Signal, 1)) 48 | 49 | rc := &reloaderCloser{} 50 | h.OnReload(rc) 51 | h.OnClose(rc) 52 | 53 | var tests = []struct { 54 | signal syscall.Signal 55 | value func() bool 56 | }{ 57 | {syscall.SIGHUP, rc.isReloaded}, 58 | {syscall.SIGTERM, rc.isClosed}, 59 | {syscall.SIGINT, rc.isClosed}, 60 | } 61 | 62 | for _, tt := range tests { 63 | rc.reset() 64 | h.signal <- tt.signal 65 | ts := time.Now() 66 | for !tt.value() { 67 | time.Sleep(10 * time.Millisecond) 68 | if time.Since(ts) > 2*time.Second { 69 | t.Fatalf("timed out waiting for handler of signal %s", tt.signal) 70 | } 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /sql/cache_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/mpolden/zdns/cache" 8 | ) 9 | 10 | func TestCache(t *testing.T) { 11 | data1 := "1 1578680472 00000100000100000000000003777777076578616d706c6503636f6d0000010001" 12 | v1, err := cache.Unpack(data1) 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | data2 := "2 1578680472 00000100000100000000000003777777076578616d706c6503636f6d0000010001" 17 | v2, err := cache.Unpack(data2) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | client, err := New(":memory:") 22 | if err != nil { 23 | panic(err) 24 | } 25 | c := NewCache(client) 26 | 27 | // Set and read 28 | c.Set(v1.Key, v1) 29 | values := c.Read() 30 | if got, want := len(values), 1; got != want { 31 | t.Fatalf("len(values) = %d, want %d", got, want) 32 | } 33 | if got, want := values[0], v1; !reflect.DeepEqual(got, want) { 34 | t.Errorf("got %+v, want %+v", got, want) 35 | } 36 | 37 | // Reset and read 38 | c.Reset() 39 | values = c.Read() 40 | if got, want := len(values), 0; got != want { 41 | t.Fatalf("len(values) = %d, want %d", got, want) 42 | } 43 | 44 | // Insert, remove and read 45 | c.Set(v1.Key, v1) 46 | c.Set(v2.Key, v2) 47 | c.Evict(v1.Key) 48 | values = c.Read() 49 | if got, want := len(values), 1; got != want { 50 | t.Fatalf("len(values) = %d, want %d", got, want) 51 | } 52 | 53 | // Replacing existing value changes order 54 | c.Reset() 55 | c.Set(v1.Key, v1) 56 | c.Set(v2.Key, v2) 57 | c.Set(v1.Key, v1) 58 | values = c.Read() 59 | if got, want := values[len(values)-1].Key, v1.Key; got != want { 60 | t.Fatalf("last Key = %d, want %d", got, want) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /http/router.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | ) 7 | 8 | type router struct { 9 | routes []*route 10 | } 11 | 12 | type route struct { 13 | method string 14 | path string 15 | handler appHandler 16 | } 17 | 18 | type appHandler func(http.ResponseWriter, *http.Request) *httpError 19 | 20 | func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 21 | if e := fn(w, r); e != nil { // e is *httpError, not os.Error. 22 | if e.Message == "" { 23 | e.Message = e.err.Error() 24 | } 25 | w.WriteHeader(e.Status) 26 | if w.Header().Get("Content-Type") == jsonMediaType { 27 | out, err := json.Marshal(e) 28 | if err != nil { 29 | panic(err) 30 | } 31 | w.Write(out) 32 | } else { 33 | w.Write([]byte(e.Message)) 34 | } 35 | } 36 | } 37 | 38 | func notFoundHandler(w http.ResponseWriter, r *http.Request) *httpError { 39 | writeJSONHeader(w) 40 | return &httpError{ 41 | Status: http.StatusNotFound, 42 | Message: "Resource not found", 43 | } 44 | } 45 | 46 | func (r *router) route(method, path string, handler appHandler) *route { 47 | route := route{ 48 | method: method, 49 | path: path, 50 | handler: handler, 51 | } 52 | r.routes = append(r.routes, &route) 53 | return &route 54 | } 55 | 56 | func (r *router) handler() http.Handler { 57 | return appHandler(func(w http.ResponseWriter, req *http.Request) *httpError { 58 | for _, route := range r.routes { 59 | if route.match(req) { 60 | return route.handler(w, req) 61 | } 62 | } 63 | return notFoundHandler(w, req) 64 | }) 65 | } 66 | 67 | func (r *route) match(req *http.Request) bool { 68 | if req.Method != r.method { 69 | return false 70 | } 71 | if r.path != req.URL.Path { 72 | return false 73 | } 74 | return true 75 | } 76 | -------------------------------------------------------------------------------- /signal/signal.go: -------------------------------------------------------------------------------- 1 | package signal 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "os" 7 | "os/signal" 8 | "sync" 9 | "syscall" 10 | ) 11 | 12 | // Reloader is the interface for types that need to act on a reload signal. 13 | type Reloader interface { 14 | Reload() 15 | } 16 | 17 | // Handler represents a signal handler and holds references to types that should act on operating system signals. 18 | type Handler struct { 19 | signal chan os.Signal 20 | reloaders []Reloader 21 | closers []io.Closer 22 | wg sync.WaitGroup 23 | } 24 | 25 | // NewHandler creates a new handler for handling operating system signals. 26 | func NewHandler(c chan os.Signal) *Handler { 27 | h := &Handler{signal: c} 28 | signal.Notify(h.signal) 29 | h.wg.Add(1) 30 | go h.readSignal() 31 | return h 32 | } 33 | 34 | // OnReload registers a reloader to call for the signal SIGHUP. 35 | func (h *Handler) OnReload(r Reloader) { h.reloaders = append(h.reloaders, r) } 36 | 37 | // OnClose registers a closer to call for signals SIGTERM and SIGINT. 38 | func (h *Handler) OnClose(c io.Closer) { h.closers = append(h.closers, c) } 39 | 40 | // Close stops handling any new signals and completes processing of pending signals before returning. 41 | func (h *Handler) Close() error { 42 | signal.Stop(h.signal) 43 | close(h.signal) 44 | h.wg.Wait() 45 | return nil 46 | } 47 | 48 | func (h *Handler) readSignal() { 49 | defer h.wg.Done() 50 | for sig := range h.signal { 51 | switch sig { 52 | case syscall.SIGHUP: 53 | log.Printf("received signal %s: reloading", sig) 54 | for _, r := range h.reloaders { 55 | r.Reload() 56 | } 57 | case syscall.SIGTERM, syscall.SIGINT: 58 | log.Printf("received signal %s: shutting down", sig) 59 | for _, c := range h.closers { 60 | if err := c.Close(); err != nil { 61 | log.Printf("close of %T failed: %s", c, err) 62 | } 63 | } 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /dns/http/client.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net/http" 8 | "net/url" 9 | "time" 10 | 11 | "github.com/miekg/dns" 12 | ) 13 | 14 | // RFC8484 (https://tools.ietf.org/html/rfc8484) claims that application/dns-message should be used, as does 15 | // https://developers.cloudflare.com/1.1.1.1/dns-over-https/wireformat/. 16 | // 17 | // However, Cloudflare's service only accept this media type from one of the older RFC drafts 18 | // (https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-05). 19 | const mimeType = "application/dns-udpwireformat" 20 | 21 | // Client is a DNS-over-HTTPS client. 22 | type Client struct { 23 | httpClient *http.Client 24 | } 25 | 26 | // NewClient creates a new DNS-over-HTTPS client. 27 | func NewClient(timeout time.Duration) *Client { 28 | return &Client{httpClient: &http.Client{Timeout: timeout}} 29 | } 30 | 31 | // Exchange sends the DNS message msg to the DNS-over-HTTPS endpoint addr and returns the response. 32 | func (c *Client) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { 33 | u, err := url.Parse(addr) 34 | if err != nil { 35 | return nil, 0, fmt.Errorf("invalid url: %w", err) 36 | } 37 | 38 | p, err := msg.Pack() 39 | if err != nil { 40 | return nil, 0, err 41 | } 42 | 43 | r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(p)) 44 | if err != nil { 45 | return nil, 0, err 46 | } 47 | r.Header.Set("Content-Type", mimeType) 48 | r.Header.Set("Accept", mimeType) 49 | 50 | t := time.Now() 51 | resp, err := c.httpClient.Do(r) 52 | if err != nil { 53 | return nil, 0, err 54 | } 55 | defer resp.Body.Close() 56 | 57 | if resp.StatusCode != http.StatusOK { 58 | return nil, 0, fmt.Errorf("server returned HTTP %d error: %q", resp.StatusCode, resp.Status) 59 | } 60 | if contentType := resp.Header.Get("Content-Type"); contentType != mimeType { 61 | return nil, 0, fmt.Errorf("server returned unexpected ContentType %q, want %q", contentType, mimeType) 62 | } 63 | 64 | p, err = ioutil.ReadAll(resp.Body) 65 | if err != nil { 66 | return nil, 0, err 67 | } 68 | rtt := time.Since(t) 69 | reply := dns.Msg{} 70 | if err := reply.Unpack(p); err != nil { 71 | return nil, 0, err 72 | } 73 | return &reply, rtt, nil 74 | } 75 | -------------------------------------------------------------------------------- /hosts/hosts.go: -------------------------------------------------------------------------------- 1 | package hosts 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "net" 8 | "strings" 9 | ) 10 | 11 | // LocalNames represent host names that are considered local. 12 | var LocalNames = []string{ 13 | "localhost", 14 | "localhost.localdomain", 15 | "local", 16 | "broadcasthost", 17 | "ip6-localhost", 18 | "ip6-loopback", 19 | "ip6-localnet", 20 | "ip6-mcastprefix", 21 | "ip6-allnodes", 22 | "ip6-allrouters", 23 | "ip6-allhosts", 24 | "0.0.0.0", 25 | } 26 | 27 | // DefaultParser is the default parser 28 | var DefaultParser = &Parser{IgnoredHosts: LocalNames} 29 | 30 | // Parser represents a hosts parser. 31 | type Parser struct { 32 | IgnoredHosts []string 33 | } 34 | 35 | // Hosts represents a hosts file. 36 | type Hosts map[string][]net.IPAddr 37 | 38 | // Parse uses DefaultParser to parse hosts from reader r. 39 | func Parse(r io.Reader) (Hosts, error) { 40 | return DefaultParser.Parse(r) 41 | } 42 | 43 | // Get returns the IP addresses of name. 44 | func (h Hosts) Get(name string) ([]net.IPAddr, bool) { 45 | ipAddrs, ok := h[name] 46 | return ipAddrs, ok 47 | } 48 | 49 | // Del deletes the hosts entry of name. 50 | func (h Hosts) Del(name string) { 51 | delete(h, name) 52 | } 53 | 54 | func (p *Parser) ignore(name string) bool { 55 | for _, ignored := range p.IgnoredHosts { 56 | if ignored == name { 57 | return true 58 | } 59 | } 60 | return false 61 | } 62 | 63 | // Parse parses hosts from reader r. 64 | func (p *Parser) Parse(r io.Reader) (Hosts, error) { 65 | entries := make(map[string][]net.IPAddr) 66 | scanner := bufio.NewScanner(r) 67 | n := 0 68 | for scanner.Scan() { 69 | n++ 70 | line := scanner.Text() 71 | fields := strings.Fields(line) 72 | if len(fields) < 2 { 73 | continue 74 | } 75 | ip := fields[0] 76 | if strings.HasPrefix(ip, "#") { 77 | continue 78 | } 79 | ipAddr, err := net.ResolveIPAddr("", ip) 80 | if err != nil { 81 | return nil, fmt.Errorf("line %d: invalid ip address: %s - %s", n, fields[0], line) 82 | } 83 | for _, name := range fields[1:] { 84 | if strings.HasPrefix(name, "#") { 85 | break 86 | } 87 | if p.ignore(name) { 88 | continue 89 | } 90 | entries[name] = append(entries[name], *ipAddr) 91 | } 92 | } 93 | return entries, nil 94 | } 95 | -------------------------------------------------------------------------------- /dns/http/client_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/hex" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/miekg/dns" 13 | ) 14 | 15 | // Reponse and request data from https://developers.cloudflare.com/1.1.1.1/dns-over-https/wireformat/. 16 | const response = ` 17 | 00 00 81 80 00 01 00 01 00 00 00 00 03 77 77 77 18 | 07 65 78 61 6d 70 6c 65 03 63 6f 6d 00 00 01 00 19 | 01 03 77 77 77 07 65 78 61 6d 70 6c 65 03 63 6f 20 | 6d 00 00 01 00 01 00 00 00 80 00 04 C0 00 02 01 21 | ` 22 | 23 | const request = ` 24 | 00 00 01 00 00 01 00 00 00 00 00 00 03 77 77 77 25 | 07 65 78 61 6d 70 6c 65 03 63 6f 6d 00 00 01 00 26 | 01 27 | ` 28 | 29 | func hexDecode(s string) []byte { 30 | replacer := strings.NewReplacer(" ", "", "\n", "") 31 | b, err := hex.DecodeString(replacer.Replace(s)) 32 | if err != nil { 33 | panic(err) 34 | } 35 | return b 36 | } 37 | 38 | func handler(w http.ResponseWriter, r *http.Request) { 39 | contentType := r.Header.Get("Content-Type") 40 | accept := r.Header.Get("Accept") 41 | const mimeType = "application/dns-udpwireformat" 42 | 43 | if contentType != mimeType { 44 | w.WriteHeader(http.StatusUnsupportedMediaType) 45 | io.WriteString(w, "invalid value for header \"Content-Type\"") 46 | return 47 | } 48 | if accept != mimeType { 49 | w.WriteHeader(http.StatusUnsupportedMediaType) 50 | io.WriteString(w, "invalid value for header \"Accept\"") 51 | return 52 | } 53 | w.Header().Set("Content-Type", mimeType) 54 | w.Write(hexDecode(response)) 55 | } 56 | 57 | func TestExchange(t *testing.T) { 58 | srv := httptest.NewServer(http.HandlerFunc(handler)) 59 | defer srv.Close() 60 | 61 | msg := dns.Msg{} 62 | if err := msg.Unpack(hexDecode(request)); err != nil { 63 | t.Fatal(err) 64 | } 65 | 66 | client := NewClient(10 * time.Second) 67 | reply, _, err := client.Exchange(&msg, srv.URL) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | 72 | want := `;; opcode: QUERY, status: NOERROR, id: 0 73 | ;; flags: qr rd ra; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 74 | 75 | ;; QUESTION SECTION: 76 | ;www.example.com. IN A 77 | 78 | ;; ANSWER SECTION: 79 | www.example.com. 128 IN A 192.0.2.1 80 | ` 81 | if got := reply.String(); got != want { 82 | t.Errorf("got %s, want %s", got, want) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /hosts/hosts_test.go: -------------------------------------------------------------------------------- 1 | package hosts 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | type test struct { 10 | in string 11 | out []string 12 | ok bool 13 | } 14 | 15 | func testParser(p *Parser, hosts string, tests []test, t *testing.T) { 16 | h, err := p.Parse(strings.NewReader(hosts)) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | for i, tt := range tests { 21 | ipAddrs, ok := h.Get(tt.in) 22 | var got []string 23 | for _, ipAddr := range ipAddrs { 24 | got = append(got, ipAddr.String()) 25 | } 26 | if ok != tt.ok || !reflect.DeepEqual(got, tt.out) { 27 | t.Errorf("#%d: Get(%q) = (%v, %t), want (%v, %t)", i, tt.in, got, ok, tt.out, tt.ok) 28 | } 29 | } 30 | } 31 | 32 | func TestParse(t *testing.T) { 33 | in := ` 34 | # comment 35 | 36 | # comment with leading whitespace 37 | 38 | incomplete-line 39 | 40 | 127.0.0.1 localhost 41 | 127.0.0.1 localhost.localdomain 42 | 127.0.0.1 local 43 | 255.255.255.255 broadcasthost 44 | ::1 localhost 45 | ::1 ip6-localhost #comment 46 | ::1 ip6-loopback # comment 47 | fe80::1%lo0 localhost 48 | ff00::0 ip6-localnet 49 | ff00::0 ip6-mcastprefix 50 | ff02::1 ip6-allnodes 51 | ff02::2 ip6-allrouters 52 | ff02::3 ip6-allhosts 53 | 0.0.0.0 0.0.0.0 54 | 192.0.2.1 test1 test2 test3 55 | 192.0.2.2 test4 56 | 192.0.2.3 test5 57 | 192.0.2.1 test6 58 | ` 59 | 60 | tests1 := []test{ 61 | {"test1", []string{"192.0.2.1"}, true}, 62 | {"test2", []string{"192.0.2.1"}, true}, 63 | {"test3", []string{"192.0.2.1"}, true}, 64 | {"test4", []string{"192.0.2.2"}, true}, 65 | {"test5", []string{"192.0.2.3"}, true}, 66 | {"#comment", nil, false}, 67 | {"#", nil, false}, 68 | {"nonexistent", nil, false}, 69 | {"localhost", nil, false}, 70 | {"localhost.localdomain", nil, false}, 71 | {"local", nil, false}, 72 | {"broadcasthost", nil, false}, 73 | {"ip6-localhost", nil, false}, 74 | {"ip6-loopback", nil, false}, 75 | {"ip6-localnet", nil, false}, 76 | {"ip6-mcastprefix", nil, false}, 77 | {"ip6-allnodes", nil, false}, 78 | {"ip6-allrouters", nil, false}, 79 | {"ip6-allhosts", nil, false}, 80 | {"0.0.0.0", nil, false}, 81 | } 82 | 83 | testParser(DefaultParser, in, tests1, t) 84 | 85 | tests2 := []test{ 86 | {"localhost", []string{"127.0.0.1", "::1", "fe80::1%lo0"}, true}, 87 | {"localhost.localdomain", []string{"127.0.0.1"}, true}, 88 | {"local", []string{"127.0.0.1"}, true}, 89 | {"broadcasthost", []string{"255.255.255.255"}, true}, 90 | {"ip6-localhost", []string{"::1"}, true}, 91 | {"ip6-loopback", []string{"::1"}, true}, 92 | {"ip6-localnet", []string{"ff00::"}, true}, 93 | {"ip6-mcastprefix", []string{"ff00::"}, true}, 94 | {"ip6-allnodes", []string{"ff02::1"}, true}, 95 | {"ip6-allrouters", []string{"ff02::2"}, true}, 96 | {"ip6-allhosts", []string{"ff02::3"}, true}, 97 | {"0.0.0.0", []string{"0.0.0.0"}, true}, 98 | } 99 | testParser(&Parser{}, in, tests2, t) 100 | } 101 | -------------------------------------------------------------------------------- /sql/cache.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | 7 | "github.com/mpolden/zdns/cache" 8 | ) 9 | 10 | const ( 11 | setOp = iota 12 | removeOp 13 | resetOp 14 | ) 15 | 16 | type query struct { 17 | op int 18 | key uint32 19 | value cache.Value 20 | } 21 | 22 | // Cache is a persistent DNS cache. Values added to the cache are written to a SQL database. 23 | type Cache struct { 24 | wg sync.WaitGroup 25 | queue chan query 26 | client *Client 27 | } 28 | 29 | // CacheStats containts cache statistics. 30 | type CacheStats struct { 31 | PendingTasks int 32 | } 33 | 34 | // NewCache creates a new cache using client for persistence. 35 | func NewCache(client *Client) *Cache { 36 | c := &Cache{ 37 | queue: make(chan query, 1024), 38 | client: client, 39 | } 40 | go c.readQueue() 41 | return c 42 | } 43 | 44 | // Close consumes any outstanding writes and closes the cache. 45 | func (c *Cache) Close() error { 46 | c.wg.Wait() 47 | return nil 48 | } 49 | 50 | // Set queues a write associating value with key. Set is non-blocking, but read operations wait for any pending writes 51 | // to complete before reading. 52 | func (c *Cache) Set(key uint32, value cache.Value) { 53 | c.enqueue(query{op: setOp, key: key, value: value}) 54 | } 55 | 56 | // Evict queues a removal of key. As Set, Evict is non-blocking. 57 | func (c *Cache) Evict(key uint32) { c.enqueue(query{op: removeOp, key: key}) } 58 | 59 | // Reset queues removal of all entries. As Set, Reset is non-blocking. 60 | func (c *Cache) Reset() { c.enqueue(query{op: resetOp}) } 61 | 62 | // Read returns all entries in the cache. 63 | func (c *Cache) Read() []cache.Value { 64 | c.wg.Wait() 65 | entries, err := c.client.readCache() 66 | if err != nil { 67 | log.Print(err) 68 | return nil 69 | } 70 | values := make([]cache.Value, 0, len(entries)) 71 | for _, entry := range entries { 72 | unpacked, err := cache.Unpack(entry.Data) 73 | if err != nil { 74 | panic(err) // Should never happen 75 | } 76 | values = append(values, unpacked) 77 | } 78 | return values 79 | } 80 | 81 | // Stats returns cache statistics. 82 | func (c *Cache) Stats() CacheStats { return CacheStats{PendingTasks: len(c.queue)} } 83 | 84 | func (c *Cache) enqueue(q query) { 85 | c.wg.Add(1) 86 | c.queue <- q 87 | } 88 | 89 | func (c *Cache) readQueue() { 90 | for q := range c.queue { 91 | switch q.op { 92 | case setOp: 93 | packed, err := q.value.Pack() 94 | if err != nil { 95 | log.Fatalf("failed to pack value: %s", err) 96 | } 97 | if err := c.client.writeCacheValue(q.key, packed); err != nil { 98 | log.Printf("failed to write key=%d data=%q: %s", q.key, packed, err) 99 | } 100 | case removeOp: 101 | if err := c.client.removeCacheValue(q.key); err != nil { 102 | log.Printf("failed to remove key=%d: %s", q.key, err) 103 | } 104 | case resetOp: 105 | if err := c.client.truncateCache(); err != nil { 106 | log.Printf("failed to truncate cache: %s", err) 107 | } 108 | default: 109 | log.Printf("unhandled operation %d", q.op) 110 | } 111 | c.wg.Done() 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /dns/dnsutil/dnsutil.go: -------------------------------------------------------------------------------- 1 | package dnsutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | "github.com/miekg/dns" 11 | "github.com/mpolden/zdns/dns/http" 12 | ) 13 | 14 | var ( 15 | // TypeToString contains a mapping of DNS request type to string. 16 | TypeToString = dns.TypeToString 17 | 18 | // RcodeToString contains a mapping of Mapping DNS response code to string. 19 | RcodeToString = dns.RcodeToString 20 | ) 21 | 22 | // Client is the interface of a DNS client. 23 | type Client interface { 24 | Exchange(*dns.Msg) (*dns.Msg, error) 25 | } 26 | 27 | // Config is a structure used to configure a DNS client. 28 | type Config struct { 29 | Network string 30 | Timeout time.Duration 31 | } 32 | 33 | type resolver interface { 34 | Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) 35 | } 36 | 37 | type client struct { 38 | resolver resolver 39 | address string 40 | } 41 | 42 | type mux struct{ clients []Client } 43 | 44 | // NewMux creates a new multiplexed client which queries all clients in parallel and returns the first successful 45 | // response. 46 | func NewMux(client ...Client) Client { return &mux{clients: client} } 47 | 48 | func (m *mux) Exchange(msg *dns.Msg) (*dns.Msg, error) { 49 | if len(m.clients) == 0 { 50 | return nil, fmt.Errorf("no clients to query") 51 | } 52 | responses := make(chan *dns.Msg, len(m.clients)) 53 | errs := make(chan error, len(m.clients)) 54 | var wg sync.WaitGroup 55 | for _, c := range m.clients { 56 | wg.Add(1) 57 | go func(client Client) { 58 | defer wg.Done() 59 | r, err := client.Exchange(msg) 60 | if err != nil { 61 | errs <- err 62 | return 63 | } 64 | responses <- r 65 | }(c) 66 | } 67 | go func() { 68 | wg.Wait() 69 | close(errs) 70 | close(responses) 71 | }() 72 | for rr := range responses { 73 | return rr, nil 74 | } 75 | return nil, <-errs 76 | } 77 | 78 | // NewClient creates a new Client for addr using config. 79 | func NewClient(addr string, config Config) Client { 80 | var r resolver 81 | if config.Network == "https" { 82 | r = http.NewClient(config.Timeout) 83 | } else { 84 | var tlsConfig *tls.Config 85 | parts := strings.SplitN(addr, "=", 2) 86 | if len(parts) == 2 { 87 | addr = parts[0] 88 | tlsConfig = &tls.Config{ServerName: parts[1]} 89 | } 90 | r = &dns.Client{Net: config.Network, Timeout: config.Timeout, TLSConfig: tlsConfig} 91 | } 92 | return &client{resolver: r, address: addr} 93 | } 94 | 95 | func (c *client) Exchange(msg *dns.Msg) (*dns.Msg, error) { 96 | r, _, err := c.resolver.Exchange(msg, c.address) 97 | if err != nil { 98 | return nil, fmt.Errorf("resolver %s failed: %w", c.address, err) 99 | } 100 | return r, err 101 | } 102 | 103 | // Answers returns all values in the answer section of DNS message msg. 104 | func Answers(msg *dns.Msg) []string { 105 | var answers []string 106 | for _, answer := range msg.Answer { 107 | for i := 1; i <= dns.NumField(answer); i++ { 108 | answers = append(answers, dns.Field(answer, i)) 109 | } 110 | } 111 | return answers 112 | } 113 | 114 | // MinTTL returns the lowest TTL of of answer, authority and additional sections. 115 | func MinTTL(msg *dns.Msg) time.Duration { 116 | var ttl uint32 = (1 << 31) - 1 // Maximum TTL from RFC 2181 117 | for _, answer := range msg.Answer { 118 | ttl = min(answer.Header().Ttl, ttl) 119 | } 120 | for _, ns := range msg.Ns { 121 | ttl = min(ns.Header().Ttl, ttl) 122 | } 123 | for _, extra := range msg.Extra { 124 | // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags 125 | if extra.Header().Rrtype == dns.TypeOPT { 126 | continue 127 | } 128 | ttl = min(extra.Header().Ttl, ttl) 129 | } 130 | return time.Duration(ttl) * time.Second 131 | } 132 | 133 | func min(x, y uint32) uint32 { 134 | if x < y { 135 | return x 136 | } 137 | return y 138 | } 139 | -------------------------------------------------------------------------------- /dns/dnsutil/dnsutil_test.go: -------------------------------------------------------------------------------- 1 | package dnsutil 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "reflect" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/miekg/dns" 12 | ) 13 | 14 | type response struct { 15 | answer *dns.Msg 16 | fail bool 17 | mu sync.Mutex 18 | } 19 | 20 | type testResolver struct { 21 | mu sync.RWMutex 22 | response *response 23 | } 24 | 25 | func (e *testResolver) setResponse(r *response) { 26 | e.mu.Lock() 27 | defer e.mu.Unlock() 28 | e.response = r 29 | } 30 | 31 | func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { 32 | e.mu.RLock() 33 | defer e.mu.RUnlock() 34 | r := e.response 35 | if r == nil { 36 | panic("no response set") 37 | } 38 | if r.fail { 39 | return nil, errors.New("error") 40 | } 41 | r.mu.Lock() 42 | defer r.mu.Unlock() 43 | return r.answer, nil 44 | } 45 | 46 | func newA(name string, ttl uint32, ipAddr ...string) *dns.Msg { 47 | m := dns.Msg{} 48 | m.Id = dns.Id() 49 | m.SetQuestion(dns.Fqdn(name), dns.TypeA) 50 | rr := make([]dns.RR, 0, len(ipAddr)) 51 | for _, ip := range ipAddr { 52 | rr = append(rr, &dns.A{ 53 | A: net.ParseIP(ip), 54 | Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, 55 | }) 56 | } 57 | m.Answer = rr 58 | return &m 59 | } 60 | 61 | func TestMinTTL(t *testing.T) { 62 | var tests = []struct { 63 | answer []dns.RR 64 | ns []dns.RR 65 | extra []dns.RR 66 | ttl time.Duration 67 | }{ 68 | { 69 | []dns.RR{ 70 | &dns.A{Hdr: dns.RR_Header{Ttl: 3600}}, 71 | &dns.A{Hdr: dns.RR_Header{Ttl: 60}}, 72 | }, 73 | nil, 74 | nil, 75 | time.Minute, 76 | }, 77 | { 78 | []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, 79 | []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}}, 80 | nil, 81 | 30 * time.Second, 82 | }, 83 | { 84 | []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, 85 | []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}}, 86 | []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 10}}}, 87 | 10 * time.Second, 88 | }, 89 | { 90 | []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, 91 | nil, 92 | []dns.RR{ 93 | &dns.OPT{Hdr: dns.RR_Header{Ttl: 10, Rrtype: dns.TypeOPT}}, // Ignored 94 | &dns.A{Hdr: dns.RR_Header{Ttl: 30}}, 95 | }, 96 | 30 * time.Second, 97 | }, 98 | } 99 | for i, tt := range tests { 100 | msg := dns.Msg{} 101 | msg.Answer = tt.answer 102 | msg.Ns = tt.ns 103 | msg.Extra = tt.extra 104 | if got := MinTTL(&msg); got != tt.ttl { 105 | t.Errorf("#%d: MinTTL(\n%s) = %s, want %s", i, msg.String(), got, tt.ttl) 106 | } 107 | } 108 | } 109 | 110 | func TestAnswers(t *testing.T) { 111 | var tests = []struct { 112 | rr []dns.RR 113 | out []string 114 | }{ 115 | {[]dns.RR{&dns.A{A: net.ParseIP("192.0.2.1")}}, []string{"192.0.2.1"}}, 116 | {[]dns.RR{ 117 | &dns.A{A: net.ParseIP("192.0.2.1")}, 118 | &dns.A{A: net.ParseIP("192.0.2.2")}, 119 | }, []string{"192.0.2.1", "192.0.2.2"}}, 120 | {[]dns.RR{&dns.AAAA{AAAA: net.ParseIP("2001:db8::1")}}, []string{"2001:db8::1"}}, 121 | } 122 | for i, tt := range tests { 123 | msg := dns.Msg{Answer: tt.rr} 124 | if got, want := Answers(&msg), tt.out; !reflect.DeepEqual(got, want) { 125 | t.Errorf("#%d: Answers(%+v) = %+v, want %+v", i, tt.rr, got, want) 126 | } 127 | } 128 | } 129 | 130 | func TestExchange(t *testing.T) { 131 | resolver1 := &testResolver{} 132 | resolver2 := &testResolver{} 133 | 134 | // First responding resolver returns answer 135 | answer1 := newA("example.com.", 60, "192.0.2.1") 136 | answer2 := newA("example.com.", 60, "192.0.2.2") 137 | r1 := response{answer: answer1} 138 | r1.mu.Lock() // Locking first resolver so that second wins 139 | resolver1.setResponse(&r1) 140 | resolver2.setResponse(&response{answer: answer2}) 141 | 142 | mux := NewMux(resolver1, resolver2) 143 | r, err := mux.Exchange(&dns.Msg{}) 144 | if err != nil { 145 | t.Fatal(err) 146 | } 147 | if got, want := r.Answer[0].(*dns.A), answer2.Answer[0].(*dns.A); got != want { 148 | t.Errorf("got Answer[0] = %s, want %s", got, want) 149 | } 150 | r1.mu.Unlock() 151 | 152 | // All resolvers fail 153 | resolver1.setResponse(&response{fail: true}) 154 | resolver2.setResponse(&response{fail: true}) 155 | _, err = mux.Exchange(&dns.Msg{}) 156 | if err == nil { 157 | t.Errorf("got %s, want error", err) 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /cmd/zdns/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "os" 7 | "path/filepath" 8 | "sync" 9 | 10 | "flag" 11 | 12 | "github.com/mpolden/zdns" 13 | "github.com/mpolden/zdns/cache" 14 | "github.com/mpolden/zdns/dns" 15 | "github.com/mpolden/zdns/dns/dnsutil" 16 | "github.com/mpolden/zdns/http" 17 | "github.com/mpolden/zdns/signal" 18 | "github.com/mpolden/zdns/sql" 19 | ) 20 | 21 | const ( 22 | name = "zdns" 23 | logPrefix = name + ": " 24 | configName = "." + name + "rc" 25 | ) 26 | 27 | func init() { 28 | log.SetPrefix(logPrefix) 29 | log.SetFlags(log.Lshortfile) 30 | } 31 | 32 | type server interface{ ListenAndServe() error } 33 | 34 | type cli struct { 35 | servers []server 36 | sh *signal.Handler 37 | wg sync.WaitGroup 38 | } 39 | 40 | func configPath() string { return filepath.Join(os.Getenv("HOME"), configName) } 41 | 42 | func readConfig(file string) (zdns.Config, error) { 43 | f, err := os.Open(file) 44 | if err != nil { 45 | return zdns.Config{}, err 46 | } 47 | defer f.Close() 48 | return zdns.ReadConfig(f) 49 | } 50 | 51 | func fatal(err error) { 52 | if err == nil { 53 | return 54 | } 55 | log.Fatal(err) 56 | } 57 | 58 | func (c *cli) runServer(server server) { 59 | c.wg.Add(1) 60 | go func() { 61 | defer c.wg.Done() 62 | if err := server.ListenAndServe(); err != nil { 63 | fatal(err) 64 | } 65 | }() 66 | } 67 | 68 | func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) *cli { 69 | cl := flag.CommandLine 70 | cl.SetOutput(out) 71 | log.SetOutput(out) 72 | confFile := cl.String("f", configFile, "config file `path`") 73 | cl.Parse(args) 74 | 75 | // Config 76 | config, err := readConfig(*confFile) 77 | fatal(err) 78 | 79 | // Signal handler 80 | sigHandler := signal.NewHandler(sig) 81 | 82 | // SQL backends 83 | var ( 84 | sqlClient *sql.Client 85 | sqlLogger *sql.Logger 86 | sqlCache *sql.Cache 87 | ) 88 | if config.DNS.Database != "" { 89 | sqlClient, err = sql.New(config.DNS.Database) 90 | fatal(err) 91 | 92 | // Logger 93 | sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL) 94 | 95 | // Cache 96 | sqlCache = sql.NewCache(sqlClient) 97 | } 98 | 99 | // DNS client 100 | dnsConfig := dnsutil.Config{ 101 | Network: config.Resolver.Protocol, 102 | Timeout: config.Resolver.Timeout, 103 | } 104 | dnsClients := make([]dnsutil.Client, 0, len(config.DNS.Resolvers)) 105 | for _, addr := range config.DNS.Resolvers { 106 | dnsClients = append(dnsClients, dnsutil.NewClient(addr, dnsConfig)) 107 | } 108 | dnsClient := dnsutil.NewMux(dnsClients...) 109 | 110 | // Cache 111 | var dnsCache *cache.Cache 112 | var cacheDNS dnsutil.Client 113 | if config.DNS.CachePrefetch { 114 | cacheDNS = dnsClient 115 | } 116 | if sqlCache != nil && config.DNS.CachePersist { 117 | dnsCache = cache.NewWithBackend(config.DNS.CacheSize, cacheDNS, sqlCache) 118 | 119 | } else { 120 | dnsCache = cache.New(config.DNS.CacheSize, cacheDNS) 121 | } 122 | 123 | // DNS server 124 | proxy, err := dns.NewProxy(dnsCache, dnsClient, sqlLogger) 125 | fatal(err) 126 | 127 | dnsSrv, err := zdns.NewServer(proxy, config) 128 | fatal(err) 129 | sigHandler.OnReload(dnsSrv) 130 | servers := []server{dnsSrv} 131 | 132 | // HTTP server 133 | var httpSrv *http.Server 134 | if config.DNS.ListenHTTP != "" { 135 | httpSrv = http.NewServer(dnsCache, sqlLogger, sqlCache, config.DNS.ListenHTTP) 136 | servers = append(servers, httpSrv) 137 | } 138 | 139 | // Close proxy first 140 | sigHandler.OnClose(proxy) 141 | 142 | // ... then HTTP server 143 | if httpSrv != nil { 144 | sigHandler.OnClose(httpSrv) 145 | } 146 | 147 | // ... then cache 148 | sigHandler.OnClose(dnsCache) 149 | 150 | // ... then database components 151 | if config.DNS.Database != "" { 152 | sigHandler.OnClose(sqlLogger) 153 | sigHandler.OnClose(sqlCache) 154 | sigHandler.OnClose(sqlClient) 155 | } 156 | 157 | // ... and finally the server itself 158 | sigHandler.OnClose(dnsSrv) 159 | return &cli{servers: servers, sh: sigHandler} 160 | } 161 | 162 | func (c *cli) run() { 163 | for _, s := range c.servers { 164 | c.runServer(s) 165 | } 166 | c.wg.Wait() 167 | c.sh.Close() 168 | } 169 | 170 | func main() { 171 | sig := make(chan os.Signal, 1) 172 | c := newCli(os.Stderr, os.Args[1:], configPath(), sig) 173 | c.run() 174 | } 175 | -------------------------------------------------------------------------------- /dns/proxy.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | "strings" 8 | "sync" 9 | 10 | "github.com/miekg/dns" 11 | "github.com/mpolden/zdns/cache" 12 | "github.com/mpolden/zdns/dns/dnsutil" 13 | "github.com/mpolden/zdns/sql" 14 | ) 15 | 16 | const ( 17 | // TypeA represents th resource record type A, an IPv4 address. 18 | TypeA = dns.TypeA 19 | // TypeAAAA represents the resource record type AAAA, an IPv6 address. 20 | TypeAAAA = dns.TypeAAAA 21 | ) 22 | 23 | // Request represents a simplified DNS request. 24 | type Request struct { 25 | Type uint16 26 | Name string 27 | } 28 | 29 | // Reply represents a simplifed DNS reply. 30 | type Reply struct{ rr []dns.RR } 31 | 32 | // Handler represents the handler for a DNS request. 33 | type Handler func(*Request) *Reply 34 | 35 | // Proxy represents a DNS proxy. 36 | type Proxy struct { 37 | Handler Handler 38 | cache *cache.Cache 39 | logger *sql.Logger 40 | server *dns.Server 41 | client dnsutil.Client 42 | mu sync.RWMutex 43 | } 44 | 45 | // NewProxy creates a new DNS proxy. 46 | func NewProxy(cache *cache.Cache, client dnsutil.Client, logger *sql.Logger) (*Proxy, error) { 47 | return &Proxy{ 48 | logger: logger, 49 | cache: cache, 50 | client: client, 51 | }, nil 52 | } 53 | 54 | // ReplyA creates a resource record of type A. 55 | func ReplyA(name string, ipAddr ...net.IP) *Reply { 56 | rr := make([]dns.RR, 0, len(ipAddr)) 57 | for _, ip := range ipAddr { 58 | rr = append(rr, &dns.A{ 59 | A: ip, 60 | Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600}, 61 | }) 62 | } 63 | return &Reply{rr} 64 | } 65 | 66 | // ReplyAAAA creates a resource record of type AAAA. 67 | func ReplyAAAA(name string, ipAddr ...net.IP) *Reply { 68 | rr := make([]dns.RR, 0, len(ipAddr)) 69 | for _, ip := range ipAddr { 70 | rr = append(rr, &dns.AAAA{ 71 | AAAA: ip, 72 | Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 3600}, 73 | }) 74 | } 75 | return &Reply{rr} 76 | } 77 | 78 | func (r *Reply) String() string { 79 | b := strings.Builder{} 80 | for i, rr := range r.rr { 81 | b.WriteString(rr.String()) 82 | if i < len(r.rr)-1 { 83 | b.WriteRune('\n') 84 | } 85 | } 86 | return b.String() 87 | } 88 | 89 | func (p *Proxy) reply(r *dns.Msg) *dns.Msg { 90 | if p.Handler == nil || len(r.Question) != 1 { 91 | return nil 92 | } 93 | reply := p.Handler(&Request{ 94 | Name: r.Question[0].Name, 95 | Type: r.Question[0].Qtype, 96 | }) 97 | if reply == nil { 98 | return nil 99 | } 100 | m := dns.Msg{Answer: reply.rr} 101 | // Pretend this is an recursive answer 102 | m.RecursionAvailable = true 103 | m.SetReply(r) 104 | return &m 105 | } 106 | 107 | // Close closes the proxy. 108 | func (p *Proxy) Close() error { 109 | p.mu.RLock() 110 | defer p.mu.RUnlock() 111 | if p.server != nil { 112 | return p.server.Shutdown() 113 | } 114 | return nil 115 | } 116 | 117 | func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { 118 | var ip net.IP 119 | switch v := w.RemoteAddr().(type) { 120 | case *net.UDPAddr: 121 | ip = v.IP 122 | case *net.TCPAddr: 123 | ip = v.IP 124 | default: 125 | panic(fmt.Sprintf("unexpected remote address type %T", v)) 126 | } 127 | if p.logger != nil { 128 | p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) 129 | } 130 | w.WriteMsg(msg) 131 | } 132 | 133 | // ServeDNS implements the dns.Handler interface. 134 | func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 135 | if reply := p.reply(r); reply != nil { 136 | p.writeMsg(w, reply, true) 137 | return 138 | } 139 | q := r.Question[0] 140 | key := cache.NewKey(q.Name, q.Qtype, q.Qclass) 141 | if msg, ok := p.cache.Get(key); ok { 142 | msg.SetReply(r) 143 | p.writeMsg(w, msg, false) 144 | return 145 | } 146 | rr, err := p.client.Exchange(r) 147 | if err == nil { 148 | p.writeMsg(w, rr, false) 149 | p.cache.Set(key, rr) 150 | } else { 151 | log.Print(err) 152 | dns.HandleFailed(w, r) 153 | } 154 | } 155 | 156 | // ListenAndServe listens on the network address addr and uses the server to process requests. 157 | func (p *Proxy) ListenAndServe(addr string, network string) error { 158 | p.mu.Lock() 159 | p.server = &dns.Server{Addr: addr, Net: network, Handler: p} 160 | p.mu.Unlock() 161 | return p.server.ListenAndServe() 162 | } 163 | -------------------------------------------------------------------------------- /sql/logger.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "log" 5 | "net" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | const ( 11 | // LogDiscard disables logging of DNS requests. 12 | LogDiscard = iota 13 | // LogAll logs all DNS requests. 14 | LogAll 15 | // LogHijacked only logs hijacked DNS requests. 16 | LogHijacked 17 | ) 18 | 19 | // Logger is a logger that logs DNS requests to a SQL database. 20 | type Logger struct { 21 | mode int 22 | queue chan LogEntry 23 | client *Client 24 | wg sync.WaitGroup 25 | now func() time.Time 26 | } 27 | 28 | // LogEntry represents a log entry for a DNS request. 29 | type LogEntry struct { 30 | Time time.Time 31 | RemoteAddr net.IP 32 | Hijacked bool 33 | Qtype uint16 34 | Question string 35 | Answers []string 36 | } 37 | 38 | // LogStats contains log statistics. 39 | type LogStats struct { 40 | Since time.Time 41 | Total int64 42 | Hijacked int64 43 | PendingTasks int 44 | Events []LogEvent 45 | } 46 | 47 | // LogEvent contains the number of requests at a point in time. 48 | type LogEvent struct { 49 | Time time.Time 50 | Count int64 51 | } 52 | 53 | // NewLogger creates a new logger. Persisted entries are kept according to ttl. 54 | func NewLogger(client *Client, mode int, ttl time.Duration) *Logger { 55 | l := &Logger{ 56 | client: client, 57 | queue: make(chan LogEntry, 1024), 58 | now: time.Now, 59 | mode: mode, 60 | } 61 | if mode != LogDiscard { 62 | go l.readQueue(ttl) 63 | } 64 | return l 65 | } 66 | 67 | // Close consumes any outstanding log requests and closes the logger. 68 | func (l *Logger) Close() error { 69 | l.wg.Wait() 70 | return nil 71 | } 72 | 73 | // Record records the given DNS request to the log database. 74 | func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) { 75 | if l.mode == LogDiscard { 76 | return 77 | } 78 | if l.mode == LogHijacked && !hijacked { 79 | return 80 | } 81 | l.wg.Add(1) 82 | l.queue <- LogEntry{ 83 | Time: l.now(), 84 | RemoteAddr: remoteAddr, 85 | Hijacked: hijacked, 86 | Qtype: qtype, 87 | Question: question, 88 | Answers: answers, 89 | } 90 | } 91 | 92 | // Read returns the n most recent log entries. 93 | func (l *Logger) Read(n int) ([]LogEntry, error) { 94 | entries, err := l.client.readLog(n) 95 | if err != nil { 96 | return nil, err 97 | } 98 | ids := make(map[int64]*LogEntry) 99 | logEntries := make([]LogEntry, 0, len(entries)) 100 | for _, le := range entries { 101 | entry, ok := ids[le.ID] 102 | if !ok { 103 | newEntry := LogEntry{ 104 | Time: time.Unix(le.Time, 0).UTC(), 105 | RemoteAddr: le.RemoteAddr, 106 | Hijacked: le.Hijacked, 107 | Qtype: le.Qtype, 108 | Question: le.Question, 109 | } 110 | logEntries = append(logEntries, newEntry) 111 | entry = &logEntries[len(logEntries)-1] 112 | ids[le.ID] = entry 113 | } 114 | if le.Answer != "" { 115 | entry.Answers = append(entry.Answers, le.Answer) 116 | } 117 | } 118 | return logEntries, nil 119 | } 120 | 121 | // Stats returns logger statistics. Events will be merged together according to resolution. A zero duration disables 122 | // merging. 123 | func (l *Logger) Stats(resolution time.Duration) (LogStats, error) { 124 | stats, err := l.client.readLogStats() 125 | if err != nil { 126 | return LogStats{}, err 127 | } 128 | events := make([]LogEvent, 0, len(stats.Events)) 129 | var last *LogEvent 130 | for _, le := range stats.Events { 131 | next := LogEvent{ 132 | Time: time.Unix(le.Time, 0).UTC(), 133 | Count: le.Count, 134 | } 135 | if last != nil && next.Time.Before(last.Time.Add(resolution)) { 136 | last.Count += next.Count 137 | } else { 138 | events = append(events, next) 139 | last = &events[len(events)-1] 140 | } 141 | } 142 | return LogStats{ 143 | Since: time.Unix(stats.Since, 0).UTC(), 144 | Total: stats.Total, 145 | Hijacked: stats.Hijacked, 146 | PendingTasks: len(l.queue), 147 | Events: events, 148 | }, nil 149 | } 150 | 151 | func (l *Logger) readQueue(ttl time.Duration) { 152 | for e := range l.queue { 153 | if err := l.client.writeLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { 154 | log.Printf("write failed: %+v: %s", e, err) 155 | } 156 | if ttl > 0 { 157 | t := l.now().Add(-ttl) 158 | if err := l.client.deleteLogBefore(t); err != nil { 159 | log.Printf("deleting log entries before %v failed: %s", t, err) 160 | } 161 | } 162 | l.wg.Done() 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /zdnsrc: -------------------------------------------------------------------------------- 1 | # -*- mode: conf-toml -*- 2 | 3 | # Each commented option contains the default value. 4 | 5 | [dns] 6 | # Listening address of the resolver. 7 | # 8 | # listen = "127.0.0.1:53000" 9 | 10 | # Listening protocol. The only supported one is "udp". 11 | # 12 | # protocol = "udp" 13 | 14 | # Maximum number of entries to keep in the DNS cache. The cache discards older 15 | # entries once the number of entries exceeds this size. 16 | # 17 | # cache_size = 4096 18 | 19 | # Cache pre-fetching. 20 | # 21 | # If enabled, cached entries will be re-resolved asynchronously. Note that this 22 | # may lead to slightly stale entries, but cached requests will never block 23 | # waiting for the upstream resolver. 24 | # 25 | # cache_prefetch = true 26 | 27 | # Cache persistence. 28 | # 29 | # If enabled, cache contents is periodically written to disk. The persisted 30 | # content will then be used to pre-populate the cache on startup. 31 | # 32 | # cache_persist = false 33 | 34 | # Upstream DNS servers to use when answering queries. 35 | # 36 | # Each entry has the following format: 37 | # 38 | # addr:port[=tls-name] 39 | # 40 | # The tls-name part is optional. Some DNS servers only have FQDNs in their 41 | # certificate SAN field. This causes certificate validation to fail when 42 | # connecting using an IP address. 43 | # 44 | # When tls-name is set it's used to verify the hostname of the returned 45 | # certificate. This only makes sense in combination with the tcp-tls protocol. 46 | # 47 | # The default is Cloudflare DNS servers, which support DNS-over-TLS. 48 | # https://www.cloudflare.com/learning/dns/what-is-1.1.1.1/ 49 | # 50 | # resolvers = [ 51 | # "1.1.1.1:853", 52 | # "1.0.0.1:853", 53 | # ] 54 | # 55 | # Or using DNS-over-HTTPS: 56 | # 57 | # resolvers = [ 58 | # "https://cloudflare-dns.com/dns-query", 59 | # ] 60 | # 61 | # Or using a specific TLS server name, for example with a UncensoredDNS servers 62 | # (https://blog.uncensoreddns.org): 63 | # 64 | # resolvers = [ 65 | # "89.233.43.71:853=unicast.censurfridns.dk", 66 | # "91.239.100.100:853=anycast.censurfridns.dk", 67 | # ] 68 | 69 | # Configure how to answer hijacked DNS requests. 70 | # 71 | # zero: Respond with the IPv4 zero address (0.0.0.0) to type A requests. 72 | # Respond with the IPv6 zero address (::) to type AAAA requests. 73 | # empty: Respond with an empty answer to all hijacked requests. 74 | # hosts: Respond with the corresponding inline host, if any. 75 | # 76 | # hijack_mode = "zero" 77 | 78 | # Configures the interval when each remote hosts list should be refreshed. 79 | # 80 | # hosts_refresh_interval = "48h" 81 | 82 | # Path to the database. This is used for persistence, such as logging of DNS requests. 83 | # 84 | # database = "" 85 | 86 | # Set logging mode. The option log_database must be set when setting this to 87 | # non-empty. 88 | # 89 | # all: Logs all requests. 90 | # hijacked: Logs only hijacked requests 91 | # empty string: Log nothing (default). 92 | # 93 | # log_mode = "" 94 | 95 | # Configure the duration of logged requests. Log entries older than this will be 96 | # removed. 97 | # 98 | # log_ttl = "168h" 99 | 100 | # HTTP server for inspecting logs and cache. Setting a listening address on the 101 | # form addr:port will enable the server. Set to empty string to disable. 102 | # 103 | # listen_http = "127.0.0.1:8053" 104 | 105 | [resolver] 106 | # Set the protocol to use when sending requests to upstream resolvers. Supported protocols: 107 | # 108 | # tcp-tls: DNS over TLS (encrypted). Note that the upstream resolver must 109 | # support this protocol. 110 | # https: DNS over HTTPS (encrypted). Only recommended for networks where tcp-tls 111 | # does not work, due to e.g. aggressive firewalls. Note that the upstream 112 | # resolver must support this protocol. 113 | # udp: DNS over UDP (plaintext). 114 | # tcp: DNS over TCP (plaintext). 115 | # 116 | # protocol = "tcp-tls" 117 | 118 | # Set the maximum timeout of a DNS request. 119 | # 120 | # timeout = "2s" 121 | 122 | # Answer queries from static hosts files. There are no default values for the 123 | # following examples. 124 | # 125 | # Load hosts from an URL. The hijack option can be one of: 126 | # 127 | # true: Matching requests will be answered according to hijack_mode. 128 | # false: Matching requests will never be hijacked. This can be used to 129 | # whitelist particular hosts as shown in the example below. 130 | # 131 | # [[hosts]] 132 | # url = "https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts" 133 | # hijack = true 134 | # timeout = "5s" 135 | 136 | # Load hosts from a local file. 137 | # 138 | # [[hosts]] 139 | # url = "file:///home/foo/myhosts.txt" 140 | # hijack = true 141 | 142 | # Inline hosts list. Useful for blocking or whitelisting a small set of hosts. 143 | # 144 | # [[hosts]] 145 | # entries = [ 146 | # # Unblock the following to avoid breaking video watching history 147 | # "0.0.0.0 s.youtube.com", 148 | # ] 149 | # hijack = false 150 | -------------------------------------------------------------------------------- /sql/logger_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "net" 5 | "reflect" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestRecord(t *testing.T) { 11 | client := testClient() 12 | logger := NewLogger(client, LogAll, 0) 13 | logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") 14 | // Flush queue 15 | if err := logger.Close(); err != nil { 16 | t.Fatal(err) 17 | } 18 | logEntries, err := logger.client.readLog(1) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | if want, got := 2, len(logEntries); want != got { 23 | t.Errorf("len(entries) = %d, want %d", got, want) 24 | } 25 | } 26 | 27 | func TestMode(t *testing.T) { 28 | badHost := "badhost1." 29 | goodHost := "goodhost1." 30 | var tests = []struct { 31 | question string 32 | remoteAddr net.IP 33 | hijacked bool 34 | mode int 35 | log bool 36 | }{ 37 | {badHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, 38 | {goodHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, 39 | {badHost, net.IPv4(192, 0, 2, 100), true, LogHijacked, true}, 40 | {goodHost, net.IPv4(192, 0, 2, 100), false, LogHijacked, false}, 41 | {badHost, net.IPv4(192, 0, 2, 100), true, LogDiscard, false}, 42 | {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false}, 43 | } 44 | for i, tt := range tests { 45 | logger := NewLogger(testClient(), tt.mode, 0) 46 | logger.mode = tt.mode 47 | logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) 48 | if err := logger.Close(); err != nil { // Flush 49 | t.Fatal(err) 50 | } 51 | entries, err := logger.Read(1) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | if len(entries) > 0 != tt.log { 56 | t.Errorf("#%d: question %q (hijacked=%t) should be logged in mode %d", i, tt.question, tt.hijacked, tt.mode) 57 | } 58 | } 59 | } 60 | 61 | func TestAnswerMerging(t *testing.T) { 62 | logger := NewLogger(testClient(), LogAll, 0) 63 | now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) 64 | logger.now = func() time.Time { return now } 65 | logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") 66 | logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "2.example.com.") 67 | // Flush queue 68 | if err := logger.Close(); err != nil { 69 | t.Fatal(err) 70 | } 71 | // Multi-answer log entries are merged 72 | got, err := logger.Read(2) 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | want := []LogEntry{ 77 | { 78 | Time: now, 79 | RemoteAddr: net.IPv4(192, 0, 2, 100), 80 | Hijacked: true, 81 | Qtype: 1, 82 | Question: "example.com.", 83 | Answers: []string{"192.0.2.2", "192.0.2.1"}, 84 | }, 85 | { 86 | Time: now, 87 | RemoteAddr: net.IPv4(192, 0, 2, 100), 88 | Hijacked: true, 89 | Qtype: 1, 90 | Question: "2.example.com.", 91 | }} 92 | if !reflect.DeepEqual(want, got) { 93 | t.Errorf("Get(1) = %+v, want %+v", got, want) 94 | } 95 | } 96 | 97 | func TestLogPruning(t *testing.T) { 98 | logger := NewLogger(testClient(), LogAll, time.Hour) 99 | defer logger.Close() 100 | tt := time.Now() 101 | logger.now = func() time.Time { return tt } 102 | logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1") 103 | 104 | // Wait until queue is flushed 105 | ts := time.Now() 106 | var entries []LogEntry 107 | var err error 108 | for len(entries) == 0 { 109 | entries, err = logger.Read(1) 110 | if err != nil { 111 | t.Fatal(err) 112 | } 113 | time.Sleep(10 * time.Millisecond) 114 | if time.Since(ts) > 2*time.Second { 115 | t.Fatal("timed out waiting for log entry to be written") 116 | } 117 | } 118 | 119 | // Advance time beyond log TTL 120 | tt = tt.Add(time.Hour).Add(time.Second) 121 | // Trigger pruning by recording another entry 122 | logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "2.example.com.", "192.0.2.2") 123 | for len(entries) > 1 { 124 | entries, err = logger.Read(2) 125 | if err != nil { 126 | t.Fatal(err) 127 | } 128 | time.Sleep(10 * time.Millisecond) 129 | if time.Since(ts) > 2*time.Second { 130 | t.Fatal("timed out waiting for log entry to be removed") 131 | } 132 | } 133 | } 134 | 135 | func TestStats(t *testing.T) { 136 | var tests = []struct { 137 | interval time.Duration 138 | resolution time.Duration 139 | eventCount int 140 | }{ 141 | {time.Minute, 0, 3}, 142 | {time.Minute, time.Second, 3}, 143 | {time.Minute, time.Minute, 3}, 144 | {time.Minute, time.Minute * 2, 2}, 145 | {time.Minute, time.Minute * 3, 1}, 146 | {time.Minute, time.Minute * 5, 1}, 147 | } 148 | for i, tt := range tests { 149 | logger := NewLogger(testClient(), LogAll, time.Hour) 150 | now := time.Now() 151 | for i := 0; i < 3; i++ { 152 | logger.now = func() time.Time { return now.Add(time.Duration(i) * tt.interval) } 153 | logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1") 154 | logger.Close() 155 | } 156 | stats, err := logger.Stats(tt.resolution) 157 | if err != nil { 158 | t.Fatal(err) 159 | } 160 | if got, want := len(stats.Events), tt.eventCount; got != want { 161 | t.Errorf("#%d: len(Events) = %d, want %d", i, got, want) 162 | } 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package zdns 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "sync" 12 | "time" 13 | 14 | "github.com/cenkalti/backoff/v4" 15 | "github.com/mpolden/zdns/dns" 16 | "github.com/mpolden/zdns/hosts" 17 | ) 18 | 19 | const ( 20 | // HijackZero returns the zero IP address to matching requests. 21 | HijackZero = iota 22 | // HijackEmpty returns an empty answer to matching requests. 23 | HijackEmpty 24 | // HijackHosts returns the value of the hoss entry to matching request. 25 | HijackHosts 26 | ) 27 | 28 | // A Server defines parameters for running a DNS server. 29 | type Server struct { 30 | Config Config 31 | hosts hosts.Hosts 32 | proxy *dns.Proxy 33 | done chan bool 34 | mu sync.RWMutex 35 | httpClient *http.Client 36 | } 37 | 38 | // NewServer returns a new server configured according to config. 39 | func NewServer(proxy *dns.Proxy, config Config) (*Server, error) { 40 | server := &Server{ 41 | Config: config, 42 | done: make(chan bool, 1), 43 | proxy: proxy, 44 | httpClient: &http.Client{Timeout: 10 * time.Second}, 45 | } 46 | proxy.Handler = server.hijack 47 | 48 | // Periodically refresh hosts 49 | if interval := config.DNS.refreshInterval; interval > 0 { 50 | go server.reloadHosts(interval) 51 | } 52 | 53 | // Load initial hosts 54 | go server.loadHosts() 55 | return server, nil 56 | } 57 | 58 | func (s *Server) httpGet(url string) (io.ReadCloser, error) { 59 | var body io.ReadCloser 60 | policy := backoff.NewExponentialBackOff() 61 | policy.MaxInterval = 2 * time.Second 62 | policy.MaxElapsedTime = 30 * time.Second 63 | err := backoff.Retry(func() error { 64 | res, err := s.httpClient.Get(url) 65 | if err == nil { 66 | body = res.Body 67 | } 68 | return err 69 | }, policy) 70 | if err != nil { 71 | return nil, err 72 | } 73 | return body, nil 74 | } 75 | 76 | func (s *Server) readHosts(name string) (hosts.Hosts, error) { 77 | url, err := url.Parse(name) 78 | if err != nil { 79 | return nil, err 80 | } 81 | var rc io.ReadCloser 82 | switch url.Scheme { 83 | case "file": 84 | f, err := os.Open(url.Path) 85 | if err != nil { 86 | return nil, err 87 | } 88 | rc = f 89 | case "http", "https": 90 | rc, err = s.httpGet(url.String()) 91 | if err != nil { 92 | return nil, err 93 | } 94 | default: 95 | return nil, fmt.Errorf("%s: invalid scheme: %s", url, url.Scheme) 96 | } 97 | hosts, err := hosts.Parse(rc) 98 | if err1 := rc.Close(); err == nil { 99 | err = err1 100 | } 101 | return hosts, err 102 | } 103 | 104 | func nonFqdn(s string) string { 105 | sz := len(s) 106 | if sz > 0 && s[sz-1:] == "." { 107 | return s[:sz-1] 108 | } 109 | return s 110 | } 111 | 112 | func (s *Server) reloadHosts(interval time.Duration) { 113 | for { 114 | select { 115 | case <-s.done: 116 | return 117 | case <-time.After(interval): 118 | s.loadHosts() 119 | } 120 | } 121 | } 122 | 123 | func (s *Server) loadHosts() { 124 | hs := make(hosts.Hosts) 125 | for _, h := range s.Config.Hosts { 126 | src := "inline hosts" 127 | hs1 := h.hosts 128 | if h.URL != "" { 129 | src = h.URL 130 | var err error 131 | hs1, err = s.readHosts(h.URL) 132 | if err != nil { 133 | log.Printf("failed to read hosts from %s: %s", h.URL, err) 134 | continue 135 | } 136 | } 137 | if h.Hijack { 138 | for name, ipAddrs := range hs1 { 139 | hs[name] = ipAddrs 140 | } 141 | log.Printf("loaded %d hosts from %s", len(hs1), src) 142 | } else { 143 | removed := 0 144 | for hostToRemove := range hs1 { 145 | if _, ok := hs.Get(hostToRemove); ok { 146 | removed++ 147 | hs.Del(hostToRemove) 148 | } 149 | } 150 | if removed > 0 { 151 | log.Printf("removed %d hosts from %s", removed, src) 152 | } 153 | } 154 | } 155 | s.mu.Lock() 156 | s.hosts = hs 157 | s.mu.Unlock() 158 | log.Printf("loaded %d hosts in total", len(hs)) 159 | } 160 | 161 | // Reload updates hosts entries of Server s. 162 | func (s *Server) Reload() { s.loadHosts() } 163 | 164 | // Close terminates all active operations and shuts down the DNS server. 165 | func (s *Server) Close() error { 166 | s.done <- true 167 | return nil 168 | } 169 | 170 | func (s *Server) hijack(r *dns.Request) *dns.Reply { 171 | if r.Type != dns.TypeA && r.Type != dns.TypeAAAA { 172 | return nil // Type not applicable 173 | } 174 | s.mu.RLock() 175 | ipAddrs, ok := s.hosts.Get(nonFqdn(r.Name)) 176 | s.mu.RUnlock() 177 | if !ok { 178 | return nil // No match 179 | } 180 | switch s.Config.DNS.hijackMode { 181 | case HijackZero: 182 | switch r.Type { 183 | case dns.TypeA: 184 | return dns.ReplyA(r.Name, net.IPv4zero) 185 | case dns.TypeAAAA: 186 | return dns.ReplyAAAA(r.Name, net.IPv6zero) 187 | } 188 | case HijackEmpty: 189 | return &dns.Reply{} 190 | case HijackHosts: 191 | var ipv4Addr []net.IP 192 | var ipv6Addr []net.IP 193 | for _, ipAddr := range ipAddrs { 194 | if ipAddr.IP.To4() == nil { 195 | ipv6Addr = append(ipv6Addr, ipAddr.IP) 196 | } else { 197 | ipv4Addr = append(ipv4Addr, ipAddr.IP) 198 | } 199 | } 200 | switch r.Type { 201 | case dns.TypeA: 202 | return dns.ReplyA(r.Name, ipv4Addr...) 203 | case dns.TypeAAAA: 204 | return dns.ReplyAAAA(r.Name, ipv6Addr...) 205 | } 206 | } 207 | return nil 208 | } 209 | 210 | // ListenAndServe starts a server on configured address and protocol. 211 | func (s *Server) ListenAndServe() error { 212 | log.Printf("dns server listening on %s [%s]", s.Config.DNS.Listen, s.Config.DNS.Protocol) 213 | return s.proxy.ListenAndServe(s.Config.DNS.Listen, s.Config.DNS.Protocol) 214 | } 215 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | package zdns 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestConfig(t *testing.T) { 11 | text := ` 12 | [dns] 13 | listen = "0.0.0.0:53" 14 | protocol = "udp" 15 | cache_size = 2048 16 | resolvers = [ 17 | "192.0.2.1:53", 18 | "192.0.2.2:53=example.com", 19 | ] 20 | hijack_mode = "zero" # or: empty, hosts 21 | hosts_refresh_interval = "48h" 22 | database = "/tmp/log.db" 23 | log_mode = "all" 24 | log_ttl = "72h" 25 | 26 | [resolver] 27 | protocol = "tcp-tls" # or: "", "udp", "tcp" 28 | timeout = "1s" 29 | 30 | [[hosts]] 31 | url = "file:///home/foo/hosts-good" 32 | hijack = false 33 | 34 | [[hosts]] 35 | url = "https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts" 36 | timeout = "10s" 37 | hijack = true 38 | 39 | [[hosts]] 40 | entries = [ 41 | "0.0.0.0 goodhost1", 42 | "0.0.0.0 goodhost2", 43 | ] 44 | hijack = false 45 | ` 46 | r := strings.NewReader(text) 47 | conf, err := ReadConfig(r) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | var intTests = []struct { 53 | field string 54 | got int 55 | want int 56 | }{ 57 | {"DNS.CacheSize", conf.DNS.CacheSize, 2048}, 58 | {"len(DNS.Resolvers)", len(conf.DNS.Resolvers), 2}, 59 | {"Resolver.Timeout", int(conf.Resolver.Timeout), int(time.Second)}, 60 | {"DNS.RefreshInterval", int(conf.DNS.refreshInterval), int(48 * time.Hour)}, 61 | {"len(Hosts)", len(conf.Hosts), 3}, 62 | {"DNS.LogTTL", int(conf.DNS.LogTTL), int(72 * time.Hour)}, 63 | } 64 | for i, tt := range intTests { 65 | if tt.got != tt.want { 66 | t.Errorf("#%d: %s = %d, want %d", i, tt.field, tt.got, tt.want) 67 | } 68 | } 69 | 70 | var stringTests = []struct { 71 | field string 72 | got string 73 | want string 74 | }{ 75 | {"DNS.Listen", conf.DNS.Listen, "0.0.0.0:53"}, 76 | {"DNS.Protocol", conf.DNS.Protocol, "udp"}, 77 | {"DNS.Resolvers[0]", conf.DNS.Resolvers[0], "192.0.2.1:53"}, 78 | {"DNS.Resolvers[1]", conf.DNS.Resolvers[1], "192.0.2.2:53=example.com"}, 79 | {"DNS.HijackMode", conf.DNS.HijackMode, "zero"}, 80 | {"DNS.Database", conf.DNS.Database, "/tmp/log.db"}, 81 | {"DNS.LogMode", conf.DNS.LogModeString, "all"}, 82 | {"DNS.LogTTL", conf.DNS.LogTTLString, "72h"}, 83 | {"Resolver.Protocol", conf.Resolver.Protocol, "tcp-tls"}, 84 | {"Hosts[0].Source", conf.Hosts[0].URL, "file:///home/foo/hosts-good"}, 85 | {"Hosts[1].Source", conf.Hosts[1].URL, "https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts"}, 86 | {"Hosts[1].Timeout", conf.Hosts[1].Timeout, "10s"}, 87 | {"Hosts[2].hosts", fmt.Sprintf("%+v", conf.Hosts[2].hosts), "map[goodhost1:[{IP:0.0.0.0 Zone:}] goodhost2:[{IP:0.0.0.0 Zone:}]]"}, 88 | } 89 | for i, tt := range stringTests { 90 | if tt.got != tt.want { 91 | t.Errorf("#%d: %s = %q, want %q", i, tt.field, tt.got, tt.want) 92 | } 93 | } 94 | 95 | var boolTests = []struct { 96 | field string 97 | got bool 98 | want bool 99 | }{ 100 | {"Hosts[0].Hijack", conf.Hosts[0].Hijack, false}, 101 | {"Hosts[1].Hijack", conf.Hosts[1].Hijack, true}, 102 | } 103 | for i, tt := range boolTests { 104 | if tt.got != tt.want { 105 | t.Errorf("#%d: %s = %t, want %t", i, tt.field, tt.got, tt.want) 106 | } 107 | } 108 | } 109 | 110 | func TestConfigErrors(t *testing.T) { 111 | baseConf := "[dns]\nlisten = \"0.0.0.0:53\"\n" 112 | conf0 := baseConf + "cache_size = -1" 113 | conf1 := baseConf + ` 114 | hijack_mode = "foo" 115 | ` 116 | conf2 := baseConf + ` 117 | hosts_refresh_interval = "foo" 118 | ` 119 | conf3 := baseConf + ` 120 | hosts_refresh_interval = "-1h" 121 | ` 122 | conf4 := baseConf + ` 123 | resolvers = ["foo"] 124 | ` 125 | conf5 := baseConf + ` 126 | [resolver] 127 | protocol = "foo" 128 | ` 129 | conf6 := baseConf + ` 130 | [resolver] 131 | timeout = "foo" 132 | ` 133 | conf7 := baseConf + ` 134 | [resolver] 135 | timeout = "-1s" 136 | ` 137 | conf8 := baseConf + ` 138 | [[hosts]] 139 | url = ":foo" 140 | ` 141 | conf9 := baseConf + ` 142 | [[hosts]] 143 | url = "foo://bar" 144 | ` 145 | conf10 := baseConf + ` 146 | [[hosts]] 147 | url = "file:///tmp/foo" 148 | timeout = "1s" 149 | ` 150 | conf11 := baseConf + ` 151 | [[hosts]] 152 | entries = ["0.0.0.0 host1"] 153 | timeout = "1s" 154 | ` 155 | conf12 := baseConf + ` 156 | log_mode = "foo" 157 | 158 | [resolver] 159 | timeout = "1s" 160 | ` 161 | conf13 := baseConf + ` 162 | log_mode = "hijacked" 163 | 164 | [resolver] 165 | timeout = "1s" 166 | ` 167 | conf14 := baseConf + ` 168 | resolvers = ["http://example.com"] 169 | [resolver] 170 | protocol = "https" 171 | ` 172 | conf15 := baseConf + ` 173 | cache_persist = true 174 | ` 175 | var tests = []struct { 176 | in string 177 | err string 178 | }{ 179 | 180 | {conf0, "cache size must be >= 0"}, 181 | {conf1, "invalid hijack mode: foo"}, 182 | {conf2, "invalid refresh interval: time: invalid duration \"foo\""}, 183 | {conf3, "refresh interval must be >= 0"}, 184 | {conf4, "invalid resolver: address foo: missing port in address"}, 185 | {conf5, "invalid resolver protocol: foo"}, 186 | {conf6, "invalid resolver timeout: foo"}, 187 | {conf7, "resolver timeout must be >= 0"}, 188 | {conf8, ":foo: invalid url: parse \":foo\": missing protocol scheme"}, 189 | {conf9, "foo://bar: unsupported scheme: foo"}, 190 | {conf10, "file:///tmp/foo: timeout cannot be set for file url"}, 191 | {conf11, "[0.0.0.0 host1]: timeout cannot be set for inline hosts"}, 192 | {conf12, "invalid log mode: foo"}, 193 | {conf13, `log_mode = "hijacked" requires 'database' to be set`}, 194 | {conf14, "protocol https requires https scheme for resolver http://example.com"}, 195 | {conf15, "cache_persist = true requires 'database' to be set"}, 196 | } 197 | for i, tt := range tests { 198 | var got string 199 | _, err := ReadConfig(strings.NewReader(tt.in)) 200 | if err != nil { 201 | got = err.Error() 202 | } 203 | if got != tt.err { 204 | t.Errorf("#%d: want %q, got %q", i, tt.err, got) 205 | } 206 | } 207 | 208 | } 209 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package zdns 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "net" 7 | "net/http" 8 | "net/http/httptest" 9 | "os" 10 | "reflect" 11 | "testing" 12 | "time" 13 | 14 | "github.com/mpolden/zdns/cache" 15 | "github.com/mpolden/zdns/dns" 16 | "github.com/mpolden/zdns/hosts" 17 | ) 18 | 19 | func init() { 20 | log.SetOutput(ioutil.Discard) 21 | } 22 | 23 | const hostsFile1 = ` 24 | 192.0.2.1 badhost1 25 | 2001:db8::1 badhost1 26 | 192.0.2.2 badhost2 27 | 192.0.2.3 badhost3 28 | ` 29 | 30 | const hostsFile2 = ` 31 | 192.0.2.4 badhost4 32 | 192.0.2.5 badhost5 33 | 192.0.2.6 badhost6 34 | ` 35 | 36 | func httpHandler(t *testing.T, response string) http.Handler { 37 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | if _, err := w.Write([]byte(response)); err != nil { 39 | t.Fatal(err) 40 | } 41 | }) 42 | } 43 | 44 | func httpServer(t *testing.T, s string) *httptest.Server { 45 | return httptest.NewServer(httpHandler(t, s)) 46 | } 47 | 48 | func tempFile(t *testing.T, s string) (string, error) { 49 | f, err := ioutil.TempFile("", "zdns") 50 | if err != nil { 51 | return "", err 52 | } 53 | defer f.Close() 54 | if err := ioutil.WriteFile(f.Name(), []byte(s), 0644); err != nil { 55 | return "", err 56 | } 57 | return f.Name(), nil 58 | } 59 | 60 | func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { 61 | var ( 62 | httpSrv *httptest.Server 63 | srv *Server 64 | file string 65 | err error 66 | ) 67 | cleanup := func() { 68 | if httpSrv != nil { 69 | httpSrv.Close() 70 | } 71 | if file != "" { 72 | if err := os.Remove(file); err != nil { 73 | t.Error(err) 74 | } 75 | } 76 | if srv != nil { 77 | if err := srv.Close(); err != nil { 78 | t.Error(err) 79 | } 80 | } 81 | } 82 | httpSrv = httpServer(t, hostsFile1) 83 | file, err = tempFile(t, hostsFile2) 84 | if err != nil { 85 | defer cleanup() 86 | t.Fatal(err) 87 | } 88 | config := Config{ 89 | DNS: DNSOptions{Listen: "0.0.0.0:53", 90 | hijackMode: HijackZero, 91 | refreshInterval: refreshInterval, 92 | }, 93 | Resolver: ResolverOptions{TimeoutString: "0"}, 94 | Hosts: []Hosts{ 95 | {URL: httpSrv.URL, Hijack: true}, 96 | {URL: "file://" + file, Hijack: true}, 97 | {Hosts: []string{"192.0.2.5 badhost5"}}, 98 | }, 99 | } 100 | if err := config.load(); err != nil { 101 | t.Fatal(err) 102 | } 103 | proxy, err := dns.NewProxy(cache.New(0, nil), nil, nil) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | srv, err = NewServer(proxy, config) 108 | if err != nil { 109 | defer cleanup() 110 | t.Fatal(err) 111 | } 112 | ts := time.Now() 113 | for { 114 | srv.mu.RLock() 115 | hostsLoaded := srv.hosts != nil 116 | srv.mu.RUnlock() 117 | if hostsLoaded { 118 | break 119 | } 120 | time.Sleep(10 * time.Millisecond) 121 | if time.Since(ts) > 2*time.Second { 122 | t.Fatal("timed out waiting initial hosts to load") 123 | } 124 | } 125 | return srv, cleanup 126 | } 127 | 128 | func TestLoadHosts(t *testing.T) { 129 | s, cleanup := testServer(t, 10*time.Millisecond) 130 | defer cleanup() 131 | want := hosts.Hosts{ 132 | "badhost1": []net.IPAddr{{IP: net.ParseIP("192.0.2.1")}, {IP: net.ParseIP("2001:db8::1")}}, 133 | "badhost2": []net.IPAddr{{IP: net.ParseIP("192.0.2.2")}}, 134 | "badhost3": []net.IPAddr{{IP: net.ParseIP("192.0.2.3")}}, 135 | "badhost4": []net.IPAddr{{IP: net.ParseIP("192.0.2.4")}}, 136 | "badhost6": []net.IPAddr{{IP: net.ParseIP("192.0.2.6")}}, 137 | } 138 | got := s.hosts 139 | if !reflect.DeepEqual(want, got) { 140 | t.Errorf("got %+v, want %+v", got, want) 141 | } 142 | } 143 | 144 | func TestReloadHostsOnTick(t *testing.T) { 145 | s, cleanup := testServer(t, 10*time.Millisecond) 146 | defer cleanup() 147 | oldHosts := s.hosts 148 | if oldHosts == nil { 149 | t.Fatal("expected matcher to be initialized") 150 | } 151 | ts := time.Now() 152 | for &s.hosts == &oldHosts { 153 | time.Sleep(10 * time.Millisecond) 154 | if time.Since(ts) > 2*time.Second { 155 | t.Fatal("timed out waiting hosts to load") 156 | } 157 | } 158 | } 159 | 160 | func TestNonFqdn(t *testing.T) { 161 | var tests = []struct { 162 | in, out string 163 | }{ 164 | {"", ""}, 165 | {"foo", "foo"}, 166 | {"foo.", "foo"}, 167 | } 168 | for i, tt := range tests { 169 | got := nonFqdn(tt.in) 170 | if got != tt.out { 171 | t.Errorf("#%d: nonFqdn(%q) = %q, want %q", i, tt.in, got, tt.out) 172 | } 173 | } 174 | } 175 | 176 | func TestHijack(t *testing.T) { 177 | s := &Server{ 178 | Config: Config{}, 179 | hosts: hosts.Hosts{ 180 | "badhost1": []net.IPAddr{ 181 | {IP: net.ParseIP("192.0.2.1")}, 182 | {IP: net.ParseIP("2001:db8::1")}, 183 | }, 184 | }, 185 | } 186 | 187 | var tests = []struct { 188 | rtype uint16 189 | rname string 190 | mode int 191 | out string 192 | }{ 193 | {dns.TypeA, "goodhost1", HijackZero, ""}, // Unmatched host 194 | {dns.TypeAAAA, "goodhost1", HijackZero, ""}, // Unmatched host 195 | {15 /* MX */, "badhost1", HijackZero, ""}, // Unmatched type 196 | {dns.TypeA, "badhost1", HijackZero, "badhost1\t3600\tIN\tA\t0.0.0.0"}, 197 | {dns.TypeA, "badhost1", HijackEmpty, ""}, 198 | {dns.TypeA, "badhost1", HijackHosts, "badhost1\t3600\tIN\tA\t192.0.2.1"}, 199 | {dns.TypeAAAA, "badhost1", HijackZero, "badhost1\t3600\tIN\tAAAA\t::"}, 200 | {dns.TypeAAAA, "badhost1", HijackEmpty, ""}, 201 | {dns.TypeAAAA, "badhost1", HijackHosts, "badhost1\t3600\tIN\tAAAA\t2001:db8::1"}, 202 | } 203 | for i, tt := range tests { 204 | s.Config.DNS.hijackMode = tt.mode 205 | req := &dns.Request{Type: tt.rtype, Name: tt.rname} 206 | reply := s.hijack(&dns.Request{Type: tt.rtype, Name: tt.rname}) 207 | if reply == nil && tt.out == "" { 208 | reply = &dns.Reply{} 209 | } 210 | if reply.String() != tt.out { 211 | t.Errorf("#%d: hijack(%+v) = %q, want %q", i, req, reply.String(), tt.out) 212 | } 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= 2 | filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= 3 | github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= 4 | github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= 5 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 6 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 7 | github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= 8 | github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= 9 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 10 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 11 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 12 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 13 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 14 | github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= 15 | github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= 16 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 17 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 18 | github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= 19 | github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= 20 | github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= 21 | github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= 22 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 23 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 24 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 25 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 26 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 27 | github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 28 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 29 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 30 | github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 31 | github.com/mattn/go-sqlite3 v1.14.30 h1:bVreufq3EAIG1Quvws73du3/QgdeZ3myglJlrzSYYCY= 32 | github.com/mattn/go-sqlite3 v1.14.30/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 33 | github.com/miekg/dns v1.1.67 h1:kg0EHj0G4bfT5/oOys6HhZw4vmMlnoZ+gDu8tJ/AlI0= 34 | github.com/miekg/dns v1.1.67/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= 35 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= 36 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 37 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 38 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 39 | github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= 40 | github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= 41 | github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= 42 | github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= 43 | github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= 44 | github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= 45 | github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= 46 | github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= 47 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 48 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 49 | github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= 50 | github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= 51 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 52 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 53 | go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= 54 | go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= 55 | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= 56 | golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= 57 | golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= 58 | golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= 59 | golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= 60 | golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 61 | golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= 62 | golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 63 | golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= 64 | golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= 65 | google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= 66 | google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= 67 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 68 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 69 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 70 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 71 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # zdns 2 | 3 | ![Build Status](https://github.com/mpolden/zdns/workflows/ci/badge.svg) 4 | 5 | `zdns` is a privacy-focused [DNS 6 | resolver](https://en.wikipedia.org/wiki/Domain_Name_System#DNS_resolvers) and 7 | [DNS sinkhole](https://en.wikipedia.org/wiki/DNS_sinkhole). 8 | 9 | Its primary focus is to allow easy filtering of unwanted content at the 10 | DNS-level, transport upstream requests securely, be portable and easy to 11 | configure. 12 | 13 | ## Contents 14 | 15 | * [Features](#features) 16 | * [Usage](#usage) 17 | * [Installation](#installation) 18 | * [Configuration](#configuration) 19 | * [Logging](#logging) 20 | * [Port redirection](#port-redirection) 21 | * [REST API](#rest-api) 22 | * [Why not Pi-hole?](#why-not-pi-hole) 23 | 24 | ## Features 25 | 26 | * **Control**: Filter unwanted content at the DNS-level. Similar to 27 | [Pi-hole](https://github.com/pi-hole/pi-hole). 28 | * **Fast**: Parallel resolving over multiple resolvers, efficient filtering and 29 | caching of DNS requests. With pre-fetching enabled, cached requests will never 30 | block waiting for the upstream resolver. Asynchronous persistent caching is 31 | also supported. 32 | * **Reliable**: Built with Go and [miekg/dns](https://github.com/miekg/dns) - a 33 | mature DNS library. 34 | * **Secure**: Protect your DNS requests from snooping and tampering using [DNS 35 | over TLS](https://en.wikipedia.org/wiki/DNS_over_TLS) or [DNS over 36 | HTTPS](https://en.wikipedia.org/wiki/DNS_over_HTTPS) for upstream resolvers. 37 | * **Self-contained**: Zero run-time dependencies makes `zdns` easy to deploy and 38 | maintain. 39 | * **Observable**: `zdns` features DNS logging and metrics which makes it easy to 40 | observe what's going on your network. 41 | * **Portable**: Run it on your VPS, container, laptop, Raspberry Pi or home 42 | router. Runs on all platforms supported by Go. 43 | 44 | ## Usage 45 | 46 | ### Installation 47 | 48 | `zdns` is a standard Go package. Install with: 49 | 50 | ``` shell 51 | $ go install github.com/mpolden/zdns/...@latest 52 | ``` 53 | 54 | ### Configuration 55 | 56 | `zdns` uses the [TOML](https://github.com/toml-lang/toml) configuration format 57 | and expects to find its configuration file in `~/.zdnsrc` by default. 58 | 59 | See [zdnsrc](zdnsrc) for an example configuration file. 60 | [zdns.service](zdns.service) contains an example systemd service file. 61 | 62 | An optional command line option, `-f`, allows specifying a custom configuration 63 | file path. 64 | 65 | ### Logging 66 | 67 | `zdns` supports logging of DNS requests. Logs are written to a SQLite database. 68 | 69 | Logs can be inspected through the built-in REST API or by querying the SQLite 70 | database directly. See `zdnsrc` for more details. 71 | 72 | ### Port redirection 73 | 74 | Most operating systems expect to find their DNS resolver on UDP port 53. 75 | However, as this is a well-known port, any program listening on this port must 76 | have special privileges. 77 | 78 | To work around this problem we can configure the firewall to redirect 79 | connections to port 53 to a non-reserved port. 80 | 81 | The following examples assumes that `zdns` is running on port 53000. See 82 | `zdnsrc` for port configuration. 83 | 84 | #### Linux (iptables) 85 | 86 | ``` shell 87 | # External requests 88 | $ iptables -t nat -A PREROUTING -d -p udp -m udp --dport 53 -j REDIRECT --to-ports 53000 89 | 90 | # Local requests 91 | $ iptables -A OUTPUT -d 127.0.0.1 -p udp -m udp --dport 53 -j REDIRECT --to-ports 53000 92 | ``` 93 | 94 | #### macOS (pf) 95 | 96 | 1. Edit `/etc/pf.conf` 97 | 2. Add `rdr pass inet proto udp from any to 127.0.0.1 port domain -> 127.0.0.1 port 53000` below the last `rdr-anchor` line. 98 | 3. Enable PF and load rules: `pfctl -ef /etc/pf.conf` 99 | 100 | ## REST API 101 | 102 | A basic REST API provides access to request log and cache entries. The API is 103 | served by the built-in web server, which can be enabled in `zdnsrc`. 104 | 105 | ### Examples 106 | 107 | Read the log: 108 | ```shell 109 | $ curl -s 'http://127.0.0.1:8053/log/v1/?n=1' | jq . 110 | [ 111 | { 112 | "time": "2019-12-27T10:43:23Z", 113 | "remote_addr": "127.0.0.1", 114 | "hijacked": false, 115 | "type": "AAAA", 116 | "question": "discovery.syncthing.net.", 117 | "answers": [ 118 | "2400:6180:100:d0::741:a001", 119 | "2a03:b0c0:0:1010::bb:4001" 120 | ] 121 | } 122 | ] 123 | ``` 124 | 125 | Read the cache: 126 | ```shell 127 | $ curl -s 'http://127.0.0.1:8053/cache/v1/?n=1' | jq . 128 | [ 129 | { 130 | "time": "2019-12-27T10:46:11Z", 131 | "ttl": 18, 132 | "type": "A", 133 | "question": "gateway.fe.apple-dns.net.", 134 | "answers": [ 135 | "17.248.150.110", 136 | "17.248.150.113", 137 | "17.248.150.10", 138 | "17.248.150.40", 139 | "17.248.150.42", 140 | "17.248.150.51", 141 | "17.248.150.79", 142 | "17.248.150.108" 143 | ], 144 | "rcode": "NOERROR" 145 | } 146 | ] 147 | ``` 148 | 149 | Clear the cache: 150 | ```shell 151 | $ curl -s -XDELETE 'http://127.0.0.1:8053/cache/v1/' | jq . 152 | { 153 | "message": "Cleared cache." 154 | } 155 | ``` 156 | 157 | Metrics: 158 | 159 | ``` shell 160 | $ curl 'http://127.0.0.1:8053/metric/v1/?resolution=1m' | jq . 161 | { 162 | "summary": { 163 | "log": { 164 | "since": "2020-01-05T00:58:49Z", 165 | "total": 3816, 166 | "hijacked": 874, 167 | "pending_tasks": 0 168 | }, 169 | "cache": { 170 | "size": 845, 171 | "capacity": 4096, 172 | "pending_tasks": 0, 173 | "backend": { 174 | "pending_tasks": 0 175 | } 176 | } 177 | }, 178 | "requests": [ 179 | { 180 | "time": "2020-01-05T00:58:49Z", 181 | "count": 1 182 | } 183 | ] 184 | } 185 | ``` 186 | 187 | Note that `log_mode = "hijacked"` or `log_mode = "all"` is required to make 188 | metrics available. Choosing `hijacked` will only produce metrics for hijacked 189 | requests. 190 | 191 | The query parameter `resolution` controls the resolution of the data points in 192 | `requests`. It accepts the same values as 193 | [time.ParseDuration](https://golang.org/pkg/time/#ParseDuration) and defaults to 194 | `1m`. 195 | 196 | ## Why not Pi-hole? 197 | 198 | _This is my personal opinion and not a objective assessment of Pi-hole._ 199 | 200 | * Pi-hole has lots of dependencies and a large feature scope. 201 | 202 | * Buggy installation script. In my personal experience, the 4.3 installation 203 | script failed silently in both Debian stretch and buster LXC containers. 204 | 205 | * Installation method pipes `curl` to `bash`. Not properly packaged for any 206 | distributions. 207 | -------------------------------------------------------------------------------- /http/http_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "io/ioutil" 5 | "net" 6 | "net/http" 7 | "net/http/httptest" 8 | "regexp" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/miekg/dns" 13 | "github.com/mpolden/zdns/cache" 14 | "github.com/mpolden/zdns/sql" 15 | ) 16 | 17 | func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { 18 | m := dns.Msg{} 19 | m.Id = dns.Id() 20 | m.SetQuestion(dns.Fqdn(name), dns.TypeA) 21 | rr := make([]dns.RR, 0, len(ipAddr)) 22 | for _, ip := range ipAddr { 23 | rr = append(rr, &dns.A{ 24 | A: ip, 25 | Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, 26 | }) 27 | } 28 | m.Answer = rr 29 | return &m 30 | } 31 | 32 | func testServer() (*httptest.Server, *Server) { 33 | sqlClient, err := sql.New(":memory:") 34 | if err != nil { 35 | panic(err) 36 | } 37 | logger := sql.NewLogger(sqlClient, sql.LogAll, 0) 38 | sqlCache := sql.NewCache(sqlClient) 39 | cache := cache.New(10, nil) 40 | server := NewServer(cache, logger, sqlCache, "") 41 | return httptest.NewServer(server.handler()), server 42 | } 43 | 44 | func httpGet(url string) (*http.Response, string, error) { 45 | res, err := http.Get(url) 46 | if err != nil { 47 | return nil, "", err 48 | } 49 | defer res.Body.Close() 50 | data, err := ioutil.ReadAll(res.Body) 51 | if err != nil { 52 | return nil, "", err 53 | } 54 | return res, string(data), nil 55 | } 56 | 57 | func httpRequest(method, url, body string) (*http.Response, string, error) { 58 | r, err := http.NewRequest(method, url, strings.NewReader(body)) 59 | if err != nil { 60 | return nil, "", err 61 | } 62 | res, err := http.DefaultClient.Do(r) 63 | if err != nil { 64 | return nil, "", err 65 | } 66 | defer res.Body.Close() 67 | data, err := ioutil.ReadAll(res.Body) 68 | if err != nil { 69 | return nil, "", err 70 | } 71 | return res, string(data), nil 72 | } 73 | 74 | func httpDelete(url, body string) (*http.Response, string, error) { 75 | return httpRequest(http.MethodDelete, url, body) 76 | } 77 | 78 | func TestRequests(t *testing.T) { 79 | httpSrv, srv := testServer() 80 | defer httpSrv.Close() 81 | srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") 82 | srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") 83 | srv.logger.Close() // Flush 84 | srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200))) 85 | srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201))) 86 | 87 | cr1 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"},` + 88 | `{"time":"RFC3339","ttl":60,"type":"A","question":"1.example.com.","answers":["192.0.2.200"],"rcode":"NOERROR"}]` 89 | cr2 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"}]` 90 | lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` + 91 | `{"time":"RFC3339","remote_addr":"127.0.0.42","hijacked":false,"type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]` 92 | lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]` 93 | mr1 := `{"summary":{"log":{"since":"RFC3339","total":2,"hijacked":1,"pending_tasks":0},"cache":{"size":2,"capacity":10,"pending_tasks":0,"backend":{"pending_tasks":0}}},"requests":[{"time":"RFC3339","count":2}]}` 94 | mr2 := ` 95 | 96 | # HELP zdns_requests_hijacked The number of hijacked DNS requests. 97 | # TYPE zdns_requests_hijacked gauge 98 | zdns_requests_hijacked 1 99 | # HELP zdns_requests_total The total number of DNS requests. 100 | # TYPE zdns_requests_total gauge 101 | zdns_requests_total 2 102 | ` 103 | var tests = []struct { 104 | method string 105 | url string 106 | response string 107 | status int 108 | contentType string 109 | }{ 110 | {http.MethodGet, "/not-found", `{"status":404,"message":"Resource not found"}`, 404, jsonMediaType}, 111 | {http.MethodGet, "/log/v1/", lr1, 200, jsonMediaType}, 112 | {http.MethodGet, "/log/v1/?n=foo", `{"status":400,"message":"invalid value for parameter n: foo"}`, 400, jsonMediaType}, 113 | {http.MethodGet, "/log/v1/?n=1", lr2, 200, jsonMediaType}, 114 | {http.MethodGet, "/cache/v1/", cr1, 200, jsonMediaType}, 115 | {http.MethodGet, "/cache/v1/?n=foo", `{"status":400,"message":"invalid value for parameter n: foo"}`, 400, jsonMediaType}, 116 | {http.MethodGet, "/cache/v1/?n=1", cr2, 200, jsonMediaType}, 117 | {http.MethodGet, "/metric/v1/", mr1, 200, jsonMediaType}, 118 | {http.MethodGet, "/metric/v1/?format=basic", mr1, 200, jsonMediaType}, 119 | {http.MethodGet, "/metric/v1/?format=prometheus", mr2, 200, "text/plain; version=0.0.4; charset=utf-8; escaping=underscores"}, 120 | {http.MethodGet, "/metric/v1/?resolution=1m", mr1, 200, jsonMediaType}, 121 | {http.MethodGet, "/metric/v1/?resolution=0", mr1, 200, jsonMediaType}, 122 | {http.MethodGet, "/metric/v1/?format=foo", `{"status":400,"message":"invalid metric format: foo"}`, 400, jsonMediaType}, 123 | {http.MethodGet, "/metric/v1/?resolution=foo", `{"status":400,"message":"time: invalid duration \"foo\""}`, 400, jsonMediaType}, 124 | {http.MethodDelete, "/cache/v1/", `{"message":"Cleared cache."}`, 200, jsonMediaType}, 125 | } 126 | 127 | for i, tt := range tests { 128 | var ( 129 | res *http.Response 130 | data string 131 | err error 132 | ) 133 | switch tt.method { 134 | case http.MethodGet: 135 | res, data, err = httpGet(httpSrv.URL + tt.url) 136 | case http.MethodDelete: 137 | res, data, err = httpDelete(httpSrv.URL+tt.url, "") 138 | default: 139 | t.Fatalf("#%d: invalid method: %s", i, tt.method) 140 | } 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | if got := res.StatusCode; got != tt.status { 145 | t.Errorf("#%d: %s %s returned status %d, want %d", i, tt.method, tt.url, got, tt.status) 146 | } 147 | 148 | if got, want := res.Header.Get("Content-Type"), tt.contentType; got != want { 149 | t.Errorf("#%d: got Content-Type %q, want %q", i, got, want) 150 | } 151 | 152 | got := string(data) 153 | want := regexp.QuoteMeta(tt.response) 154 | want = strings.ReplaceAll(want, "RFC3339", `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z`) 155 | want = strings.ReplaceAll(want, "", ".*") 156 | matched, err := regexp.MatchString(want, got) 157 | if err != nil { 158 | t.Fatal(err) 159 | } 160 | if !matched { 161 | t.Errorf("#%d: %s %s returned response %s, want %s", i, tt.method, tt.url, got, want) 162 | } 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /dns/proxy_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "net" 8 | "reflect" 9 | "sync" 10 | "testing" 11 | 12 | "github.com/miekg/dns" 13 | "github.com/mpolden/zdns/cache" 14 | ) 15 | 16 | func init() { 17 | log.SetOutput(ioutil.Discard) 18 | } 19 | 20 | type dnsWriter struct{ lastReply *dns.Msg } 21 | 22 | func (w *dnsWriter) LocalAddr() net.Addr { return nil } 23 | func (w *dnsWriter) RemoteAddr() net.Addr { 24 | return &net.UDPAddr{IP: net.IPv4(192, 0, 2, 100), Port: 50000} 25 | } 26 | func (w *dnsWriter) Network() string { return "udp" } 27 | func (w *dnsWriter) Write(b []byte) (int, error) { return 0, nil } 28 | func (w *dnsWriter) Close() error { return nil } 29 | func (w *dnsWriter) TsigStatus() error { return nil } 30 | func (w *dnsWriter) TsigTimersOnly(b bool) {} 31 | func (w *dnsWriter) Hijack() {} 32 | 33 | func (w *dnsWriter) WriteMsg(msg *dns.Msg) error { 34 | w.lastReply = msg 35 | return nil 36 | } 37 | 38 | type response struct { 39 | answer *dns.Msg 40 | fail bool 41 | } 42 | 43 | type testResolver struct { 44 | mu sync.RWMutex 45 | response *response 46 | } 47 | 48 | func (e *testResolver) setResponse(response *response) { 49 | e.mu.Lock() 50 | defer e.mu.Unlock() 51 | e.response = response 52 | } 53 | 54 | func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { 55 | e.mu.RLock() 56 | defer e.mu.RUnlock() 57 | r := e.response 58 | if r == nil || r.fail { 59 | return nil, fmt.Errorf("SERVFAIL") 60 | } 61 | return r.answer, nil 62 | } 63 | 64 | func testProxy(t *testing.T) *Proxy { 65 | proxy, err := NewProxy(cache.New(0, nil), nil, nil) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | return proxy 70 | } 71 | 72 | func assertRR(t *testing.T, p *Proxy, msg *dns.Msg, answer string) { 73 | var ( 74 | qtype = msg.Question[0].Qtype 75 | qname = msg.Question[0].Name 76 | ) 77 | w := &dnsWriter{} 78 | p.ServeDNS(w, msg) 79 | 80 | qtypeString := dns.TypeToString[qtype] 81 | answers := w.lastReply.Answer 82 | if got, want := len(answers), 1; got != want { 83 | t.Fatalf("len(msg.Answer) = %d, want %d for %s %s", got, want, qtypeString, qname) 84 | } 85 | ans := answers[0] 86 | 87 | if got := w.lastReply.Id; got != msg.Id { 88 | t.Errorf("id = %d, want %d for %s %s", got, msg.Id, qtypeString, qname) 89 | } 90 | 91 | want := net.ParseIP(answer) 92 | var got net.IP 93 | switch qtype { 94 | case dns.TypeA: 95 | rr, ok := ans.(*dns.A) 96 | if !ok { 97 | t.Errorf("type = %q, want %q for %s %s", dns.TypeToString[dns.TypeA], dns.TypeToString[rr.Header().Rrtype], qtypeString, qname) 98 | } 99 | got = rr.A 100 | case dns.TypeAAAA: 101 | rr, ok := ans.(*dns.AAAA) 102 | if !ok { 103 | t.Errorf("type = %q, want %q for %s %s", dns.TypeToString[dns.TypeA], dns.TypeToString[rr.Header().Rrtype], qtypeString, qname) 104 | } 105 | got = rr.AAAA 106 | } 107 | if !reflect.DeepEqual(got, want) { 108 | t.Errorf("IP = %s, want %s", got, want) 109 | } 110 | } 111 | 112 | func assertFailure(t *testing.T, p *Proxy, rtype uint16, qname string) { 113 | m := dns.Msg{} 114 | m.Id = dns.Id() 115 | m.RecursionDesired = true 116 | m.SetQuestion(dns.Fqdn(qname), rtype) 117 | 118 | w := &dnsWriter{} 119 | p.ServeDNS(w, &m) 120 | 121 | if got, want := len(w.lastReply.Answer), 0; got != want { 122 | t.Errorf("len(msg.Answer) = %d, want %d for %s %s", got, want, dns.TypeToString[rtype], qname) 123 | } 124 | if got, want := w.lastReply.MsgHdr.Rcode, dns.RcodeServerFailure; got != want { 125 | t.Errorf("MsgHdr.Rcode = %s, want %s for %s %s", dns.RcodeToString[got], dns.RcodeToString[want], dns.TypeToString[rtype], qname) 126 | } 127 | } 128 | 129 | func TestProxy(t *testing.T) { 130 | var h Handler = func(r *Request) *Reply { 131 | switch r.Type { 132 | case TypeA: 133 | return ReplyA(r.Name, net.IPv4zero) 134 | case TypeAAAA: 135 | return ReplyAAAA(r.Name, net.IPv6zero) 136 | } 137 | return nil 138 | } 139 | p := testProxy(t) 140 | p.Handler = h 141 | defer p.Close() 142 | 143 | m := dns.Msg{} 144 | m.Id = dns.Id() 145 | m.RecursionDesired = true 146 | 147 | m.SetQuestion(dns.Fqdn("badhost1"), dns.TypeA) 148 | assertRR(t, p, &m, "0.0.0.0") 149 | 150 | m.SetQuestion(dns.Fqdn("badhost1"), dns.TypeAAAA) 151 | assertRR(t, p, &m, "::") 152 | } 153 | 154 | func TestProxyWithResolver(t *testing.T) { 155 | p := testProxy(t) 156 | r := &testResolver{} 157 | p.client = r 158 | defer p.Close() 159 | // No response 160 | assertFailure(t, p, TypeA, "host1") 161 | 162 | // Responds succesfully 163 | reply := ReplyA("host1", net.ParseIP("192.0.2.1")) 164 | m := dns.Msg{} 165 | m.Id = dns.Id() 166 | m.SetQuestion("host1.", dns.TypeA) 167 | m.Answer = reply.rr 168 | response1 := &response{answer: &m} 169 | r.setResponse(response1) 170 | assertRR(t, p, &m, "192.0.2.1") 171 | 172 | // Resolver fails 173 | response1.fail = true 174 | assertFailure(t, p, TypeA, "host1") 175 | } 176 | 177 | func TestProxyWithCache(t *testing.T) { 178 | p := testProxy(t) 179 | p.cache = cache.New(10, nil) 180 | r := &testResolver{} 181 | p.client = r 182 | defer p.Close() 183 | 184 | reply := ReplyA("host1", net.ParseIP("192.0.2.1")) 185 | m := dns.Msg{} 186 | m.Id = dns.Id() 187 | m.SetQuestion("host1.", dns.TypeA) 188 | m.Answer = reply.rr 189 | r.setResponse(&response{answer: &m}) 190 | assertRR(t, p, &m, "192.0.2.1") 191 | 192 | k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET) 193 | got, ok := p.cache.Get(k) 194 | if !ok { 195 | t.Errorf("cache.Get(%d) = (%+v, %t), want (%+v, %t)", k, got, ok, m, !ok) 196 | } 197 | } 198 | 199 | func TestReplyString(t *testing.T) { 200 | var tests = []struct { 201 | fn func(string, ...net.IP) *Reply 202 | fnName string 203 | name string 204 | ipAddrs []net.IP 205 | out string 206 | }{ 207 | {ReplyA, "ReplyA", "test-host", []net.IP{net.ParseIP("192.0.2.1")}, 208 | "test-host\t3600\tIN\tA\t192.0.2.1"}, 209 | {ReplyA, "ReplyA", "test-host", []net.IP{net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2")}, 210 | "test-host\t3600\tIN\tA\t192.0.2.1\ntest-host\t3600\tIN\tA\t192.0.2.2"}, 211 | {ReplyAAAA, "ReplyAAAA", "test-host", []net.IP{net.ParseIP("2001:db8::1")}, 212 | "test-host\t3600\tIN\tAAAA\t2001:db8::1"}, 213 | {ReplyAAAA, "ReplyAAAA", "test-host", []net.IP{net.ParseIP("2001:db8::1"), net.ParseIP("2001:db8::2")}, 214 | "test-host\t3600\tIN\tAAAA\t2001:db8::1\ntest-host\t3600\tIN\tAAAA\t2001:db8::2"}, 215 | } 216 | for i, tt := range tests { 217 | got := tt.fn(tt.name, tt.ipAddrs...).String() 218 | if got != tt.out { 219 | t.Errorf("#%d: %s(%q, %v) = %q, want %q", i, tt.fnName, tt.name, tt.ipAddrs, got, tt.out) 220 | } 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package zdns 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "net/url" 8 | "strings" 9 | "time" 10 | 11 | "github.com/BurntSushi/toml" 12 | "github.com/mpolden/zdns/hosts" 13 | "github.com/mpolden/zdns/sql" 14 | ) 15 | 16 | // Config specifies is the zdns configuration parameters. 17 | type Config struct { 18 | DNS DNSOptions 19 | Resolver ResolverOptions 20 | Hosts []Hosts 21 | } 22 | 23 | // DNSOptions controlers the behaviour of the DNS server. 24 | type DNSOptions struct { 25 | Listen string 26 | Protocol string `toml:"protocol"` 27 | CacheSize int `toml:"cache_size"` 28 | CachePrefetch bool `toml:"cache_prefetch"` 29 | CachePersist bool `toml:"cache_persist"` 30 | HijackMode string `toml:"hijack_mode"` 31 | hijackMode int 32 | RefreshInterval string `toml:"hosts_refresh_interval"` 33 | refreshInterval time.Duration 34 | Resolvers []string 35 | Database string `toml:"database"` 36 | LogModeString string `toml:"log_mode"` 37 | LogMode int 38 | LogTTLString string `toml:"log_ttl"` 39 | LogTTL time.Duration 40 | ListenHTTP string `toml:"listen_http"` 41 | } 42 | 43 | // ResolverOptions controls the behaviour of resolvers. 44 | type ResolverOptions struct { 45 | Protocol string `toml:"protocol"` 46 | TimeoutString string `toml:"timeout"` 47 | Timeout time.Duration 48 | } 49 | 50 | // Hosts controls how a hosts file should be retrieved. 51 | type Hosts struct { 52 | URL string 53 | Hosts []string `toml:"entries"` 54 | hosts hosts.Hosts 55 | Hijack bool 56 | Timeout string 57 | timeout time.Duration 58 | } 59 | 60 | func newConfig() Config { 61 | c := Config{} 62 | // Default values 63 | c.DNS.Listen = "127.0.0.1:53000" 64 | c.DNS.ListenHTTP = "127.0.0.1:8053" 65 | c.DNS.Protocol = "udp" 66 | c.DNS.CacheSize = 4096 67 | c.DNS.CachePrefetch = true 68 | c.DNS.RefreshInterval = "48h" 69 | c.DNS.Resolvers = []string{ 70 | "1.1.1.1:853", 71 | "1.0.0.1:853", 72 | } 73 | c.DNS.LogTTLString = "168h" 74 | c.Resolver.TimeoutString = "2s" 75 | c.Resolver.Protocol = "tcp-tls" 76 | return c 77 | } 78 | 79 | func (c *Config) load() error { 80 | var err error 81 | if c.DNS.Listen == "" { 82 | return fmt.Errorf("invalid listening address: %s", c.DNS.Listen) 83 | } 84 | if c.DNS.Protocol == "" { 85 | c.DNS.Protocol = "udp" 86 | } 87 | if c.DNS.Protocol != "udp" { 88 | return fmt.Errorf("unsupported protocol: %s", c.DNS.Protocol) 89 | } 90 | if c.DNS.CacheSize < 0 { 91 | return fmt.Errorf("cache size must be >= 0") 92 | } 93 | if c.DNS.CachePersist && c.DNS.Database == "" { 94 | return fmt.Errorf("cache_persist = %t requires 'database' to be set", c.DNS.CachePersist) 95 | } 96 | switch c.DNS.HijackMode { 97 | case "", "zero": 98 | c.DNS.hijackMode = HijackZero 99 | case "empty": 100 | c.DNS.hijackMode = HijackEmpty 101 | case "hosts": 102 | c.DNS.hijackMode = HijackHosts 103 | default: 104 | return fmt.Errorf("invalid hijack mode: %s", c.DNS.HijackMode) 105 | } 106 | if c.DNS.RefreshInterval == "" { 107 | c.DNS.RefreshInterval = "0" 108 | } 109 | c.DNS.refreshInterval, err = time.ParseDuration(c.DNS.RefreshInterval) 110 | if err != nil { 111 | return fmt.Errorf("invalid refresh interval: %w", err) 112 | } 113 | if c.DNS.refreshInterval < 0 { 114 | return fmt.Errorf("refresh interval must be >= 0") 115 | } 116 | for i, hs := range c.Hosts { 117 | if (hs.URL == "") == (hs.Hosts == nil) { 118 | return fmt.Errorf("exactly one of url or hosts must be set") 119 | } 120 | if hs.URL != "" { 121 | url, err := url.Parse(hs.URL) 122 | if err != nil { 123 | return fmt.Errorf("%s: invalid url: %w", hs.URL, err) 124 | } 125 | switch url.Scheme { 126 | case "file", "http", "https": 127 | default: 128 | return fmt.Errorf("%s: unsupported scheme: %s", hs.URL, url.Scheme) 129 | } 130 | if url.Scheme == "file" && hs.Timeout != "" { 131 | return fmt.Errorf("%s: timeout cannot be set for %s url", hs.URL, url.Scheme) 132 | } 133 | if c.Hosts[i].Timeout == "" { 134 | c.Hosts[i].Timeout = "0" 135 | } 136 | c.Hosts[i].timeout, err = time.ParseDuration(c.Hosts[i].Timeout) 137 | if err != nil { 138 | return fmt.Errorf("%s: invalid timeout: %s", hs.URL, hs.Timeout) 139 | } 140 | } 141 | if hs.Hosts != nil { 142 | if hs.Timeout != "" { 143 | return fmt.Errorf("%s: timeout cannot be set for inline hosts", hs.Hosts) 144 | } 145 | var err error 146 | r := strings.NewReader(strings.Join(hs.Hosts, "\n")) 147 | c.Hosts[i].hosts, err = hosts.Parse(r) 148 | if err != nil { 149 | return err 150 | } 151 | } 152 | } 153 | for _, r := range c.DNS.Resolvers { 154 | if c.Resolver.Protocol == "https" { 155 | u, err := url.Parse(r) 156 | if err != nil { 157 | return fmt.Errorf("invalid resolver %s: %w", r, err) 158 | } 159 | if u.Scheme != "https" { 160 | return fmt.Errorf("protocol %s requires https scheme for resolver %s", c.Resolver.Protocol, r) 161 | } 162 | } else { 163 | if _, _, err := net.SplitHostPort(r); err != nil { 164 | return fmt.Errorf("invalid resolver: %w", err) 165 | } 166 | } 167 | } 168 | if c.Resolver.Protocol == "udp" { 169 | c.Resolver.Protocol = "" // Empty means UDP when passed to dns.ListenAndServe 170 | } 171 | switch c.Resolver.Protocol { 172 | case "", "tcp", "tcp-tls", "https": 173 | default: 174 | return fmt.Errorf("invalid resolver protocol: %s", c.Resolver.Protocol) 175 | } 176 | c.Resolver.Timeout, err = time.ParseDuration(c.Resolver.TimeoutString) 177 | if err != nil { 178 | return fmt.Errorf("invalid resolver timeout: %s", c.Resolver.TimeoutString) 179 | } 180 | if c.Resolver.Timeout < 0 { 181 | return fmt.Errorf("resolver timeout must be >= 0") 182 | } 183 | if c.Resolver.Timeout == 0 { 184 | c.Resolver.Timeout = 5 * time.Second 185 | } 186 | switch c.DNS.LogModeString { 187 | case "": 188 | c.DNS.LogMode = sql.LogDiscard 189 | case "all": 190 | c.DNS.LogMode = sql.LogAll 191 | case "hijacked": 192 | c.DNS.LogMode = sql.LogHijacked 193 | default: 194 | return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString) 195 | } 196 | if c.DNS.LogModeString != "" && c.DNS.Database == "" { 197 | return fmt.Errorf("log_mode = %q requires 'database' to be set", c.DNS.LogModeString) 198 | } 199 | if c.DNS.LogTTLString == "" { 200 | c.DNS.LogTTLString = "0" 201 | } 202 | c.DNS.LogTTL, err = time.ParseDuration(c.DNS.LogTTLString) 203 | if err != nil { 204 | return fmt.Errorf("invalid log TTL: %s", c.DNS.LogTTLString) 205 | } 206 | return nil 207 | } 208 | 209 | // ReadConfig reads a zdns configuration from reader r. 210 | func ReadConfig(r io.Reader) (Config, error) { 211 | conf := newConfig() 212 | _, err := toml.NewDecoder(r).Decode(&conf) 213 | if err != nil { 214 | return Config{}, err 215 | } 216 | return conf, conf.load() 217 | } 218 | -------------------------------------------------------------------------------- /http/http.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "log" 8 | "net" 9 | "net/http" 10 | _ "net/http/pprof" // Registers debug handlers as a side effect. 11 | "strconv" 12 | "time" 13 | 14 | "github.com/mpolden/zdns/cache" 15 | "github.com/mpolden/zdns/dns/dnsutil" 16 | "github.com/mpolden/zdns/sql" 17 | ) 18 | 19 | const ( 20 | jsonMediaType = "application/json" 21 | ) 22 | 23 | // A Server defines parameters for running an HTTP server. The HTTP server serves an API for inspecting cache contents 24 | // and request log. 25 | type Server struct { 26 | cache *cache.Cache 27 | logger *sql.Logger 28 | sqlCache *sql.Cache 29 | server *http.Server 30 | } 31 | 32 | type entry struct { 33 | Time string `json:"time"` 34 | TTL int64 `json:"ttl,omitempty"` 35 | RemoteAddr net.IP `json:"remote_addr,omitempty"` 36 | Hijacked *bool `json:"hijacked,omitempty"` 37 | Qtype string `json:"type"` 38 | Question string `json:"question"` 39 | Answers []string `json:"answers,omitempty"` 40 | Rcode string `json:"rcode,omitempty"` 41 | } 42 | 43 | type stats struct { 44 | Summary summary `json:"summary"` 45 | Requests []request `json:"requests"` 46 | } 47 | 48 | type summary struct { 49 | Log logStats `json:"log"` 50 | Cache cacheStats `json:"cache"` 51 | } 52 | 53 | type request struct { 54 | Time string `json:"time"` 55 | Count int64 `json:"count"` 56 | } 57 | 58 | type logStats struct { 59 | Since string `json:"since"` 60 | Total int64 `json:"total"` 61 | Hijacked int64 `json:"hijacked"` 62 | PendingTasks int `json:"pending_tasks"` 63 | } 64 | 65 | type cacheStats struct { 66 | Size int `json:"size"` 67 | Capacity int `json:"capacity"` 68 | PendingTasks int `json:"pending_tasks"` 69 | BackendStats *backendStats `json:"backend,omitempty"` 70 | } 71 | 72 | type backendStats struct { 73 | PendingTasks int `json:"pending_tasks"` 74 | } 75 | 76 | type httpError struct { 77 | err error 78 | Status int `json:"status"` 79 | Message string `json:"message"` 80 | } 81 | 82 | func newHTTPError(err error) *httpError { 83 | return &httpError{ 84 | err: err, 85 | Status: http.StatusInternalServerError, 86 | } 87 | } 88 | 89 | func newHTTPBadRequest(err error) *httpError { 90 | return &httpError{ 91 | err: err, 92 | Status: http.StatusBadRequest, 93 | } 94 | } 95 | 96 | // NewServer creates a new HTTP server, serving logs from the given logger and listening on addr. 97 | func NewServer(cache *cache.Cache, logger *sql.Logger, sqlCache *sql.Cache, addr string) *Server { 98 | server := &http.Server{Addr: addr} 99 | s := &Server{ 100 | server: server, 101 | cache: cache, 102 | logger: logger, 103 | sqlCache: sqlCache, 104 | } 105 | s.server.Handler = s.handler() 106 | return s 107 | } 108 | 109 | func (s *Server) handler() http.Handler { 110 | r := &router{} 111 | r.route(http.MethodGet, "/cache/v1/", s.cacheHandler) 112 | r.route(http.MethodDelete, "/cache/v1/", s.cacheResetHandler) 113 | if s.logger != nil { 114 | r.route(http.MethodGet, "/log/v1/", s.logHandler) 115 | r.route(http.MethodGet, "/metric/v1/", s.metricHandler) 116 | } 117 | return r.handler() 118 | } 119 | 120 | func countFrom(r *http.Request) (int, error) { 121 | param := r.URL.Query().Get("n") 122 | if param == "" { 123 | return 100, nil 124 | } 125 | n, err := strconv.Atoi(param) 126 | if err != nil || n < 0 { 127 | return 0, fmt.Errorf("invalid value for parameter n: %s", param) 128 | } 129 | return n, nil 130 | } 131 | 132 | func resolutionFrom(r *http.Request) (time.Duration, error) { 133 | param := r.URL.Query().Get("resolution") 134 | if param == "" { 135 | return time.Minute, nil 136 | } 137 | return time.ParseDuration(param) 138 | } 139 | 140 | func writeJSONHeader(w http.ResponseWriter) { w.Header().Set("Content-Type", jsonMediaType) } 141 | 142 | func writeJSON(w http.ResponseWriter, data interface{}) { 143 | b, err := json.Marshal(data) 144 | if err != nil { 145 | panic(err) 146 | } 147 | writeJSONHeader(w) 148 | w.Write(b) 149 | } 150 | 151 | func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) *httpError { 152 | count, err := countFrom(r) 153 | if err != nil { 154 | writeJSONHeader(w) 155 | return newHTTPBadRequest(err) 156 | } 157 | cacheValues := s.cache.List(count) 158 | entries := make([]entry, 0, len(cacheValues)) 159 | for _, v := range cacheValues { 160 | entries = append(entries, entry{ 161 | Time: v.CreatedAt.UTC().Format(time.RFC3339), 162 | TTL: int64(v.TTL().Truncate(time.Second).Seconds()), 163 | Qtype: dnsutil.TypeToString[v.Qtype()], 164 | Question: v.Question(), 165 | Answers: v.Answers(), 166 | Rcode: dnsutil.RcodeToString[v.Rcode()], 167 | }) 168 | } 169 | writeJSON(w, entries) 170 | return nil 171 | } 172 | 173 | func (s *Server) cacheResetHandler(w http.ResponseWriter, r *http.Request) *httpError { 174 | s.cache.Reset() 175 | writeJSON(w, struct { 176 | Message string `json:"message"` 177 | }{"Cleared cache."}) 178 | return nil 179 | } 180 | 181 | func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) *httpError { 182 | count, err := countFrom(r) 183 | if err != nil { 184 | writeJSONHeader(w) 185 | return newHTTPBadRequest(err) 186 | } 187 | logEntries, err := s.logger.Read(count) 188 | if err != nil { 189 | writeJSONHeader(w) 190 | return newHTTPError(err) 191 | } 192 | entries := make([]entry, 0, len(logEntries)) 193 | for _, le := range logEntries { 194 | hijacked := le.Hijacked 195 | entries = append(entries, entry{ 196 | Time: le.Time.UTC().Format(time.RFC3339), 197 | RemoteAddr: le.RemoteAddr, 198 | Hijacked: &hijacked, 199 | Qtype: dnsutil.TypeToString[le.Qtype], 200 | Question: le.Question, 201 | Answers: le.Answers, 202 | }) 203 | } 204 | writeJSON(w, entries) 205 | return nil 206 | } 207 | 208 | func (s *Server) basicMetricHandler(w http.ResponseWriter, r *http.Request) *httpError { 209 | resolution, err := resolutionFrom(r) 210 | if err != nil { 211 | writeJSONHeader(w) 212 | return newHTTPBadRequest(err) 213 | } 214 | lstats, err := s.logger.Stats(resolution) 215 | if err != nil { 216 | writeJSONHeader(w) 217 | return newHTTPError(err) 218 | } 219 | requests := make([]request, 0, len(lstats.Events)) 220 | for _, e := range lstats.Events { 221 | requests = append(requests, request{ 222 | Time: e.Time.Format(time.RFC3339), 223 | Count: e.Count, 224 | }) 225 | } 226 | cstats := s.cache.Stats() 227 | var bstats *backendStats 228 | if s.sqlCache != nil { 229 | bstats = &backendStats{PendingTasks: s.sqlCache.Stats().PendingTasks} 230 | } 231 | stats := stats{ 232 | Summary: summary{ 233 | Log: logStats{ 234 | Since: lstats.Since.Format(time.RFC3339), 235 | Total: lstats.Total, 236 | Hijacked: lstats.Hijacked, 237 | }, 238 | Cache: cacheStats{ 239 | Capacity: cstats.Capacity, 240 | Size: cstats.Size, 241 | PendingTasks: cstats.PendingTasks, 242 | BackendStats: bstats, 243 | }, 244 | }, 245 | Requests: requests, 246 | } 247 | writeJSON(w, stats) 248 | return nil 249 | } 250 | 251 | func (s *Server) prometheusMetricHandler(w http.ResponseWriter, r *http.Request) *httpError { 252 | lstats, err := s.logger.Stats(time.Minute) 253 | if err != nil { 254 | return newHTTPError(err) 255 | } 256 | totalRequestsGauge.Set(float64(lstats.Total)) 257 | hijackedRequestsGauge.Set(float64(lstats.Hijacked)) 258 | prometheusHandler.ServeHTTP(w, r) 259 | return nil 260 | } 261 | 262 | func (s *Server) metricHandler(w http.ResponseWriter, r *http.Request) *httpError { 263 | format := "" 264 | if formatParams := r.URL.Query()["format"]; len(formatParams) > 0 { 265 | format = formatParams[0] 266 | } 267 | switch format { 268 | case "", "basic": 269 | return s.basicMetricHandler(w, r) 270 | case "prometheus": 271 | return s.prometheusMetricHandler(w, r) 272 | } 273 | writeJSONHeader(w) 274 | return newHTTPBadRequest(fmt.Errorf("invalid metric format: %s", format)) 275 | } 276 | 277 | // Close shuts down the HTTP server. 278 | func (s *Server) Close() error { return s.server.Shutdown(context.TODO()) } 279 | 280 | // ListenAndServe starts the HTTP server listening on the configured address. 281 | func (s *Server) ListenAndServe() error { 282 | log.Printf("http server listening on http://%s", s.server.Addr) 283 | err := s.server.ListenAndServe() 284 | if err == http.ErrServerClosed { 285 | return nil // Do not treat server closing as an error 286 | } 287 | return err 288 | } 289 | -------------------------------------------------------------------------------- /cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "container/list" 5 | "encoding/binary" 6 | "encoding/hex" 7 | "fmt" 8 | "hash/fnv" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | "time" 13 | 14 | "github.com/miekg/dns" 15 | "github.com/mpolden/zdns/dns/dnsutil" 16 | ) 17 | 18 | // Backend is the interface for a cache backend. All write operations in a Cache are forwarded to a Backend. 19 | type Backend interface { 20 | Set(key uint32, value Value) 21 | Evict(key uint32) 22 | Read() []Value 23 | Reset() 24 | } 25 | 26 | type queue struct { 27 | tasks chan func() 28 | wg sync.WaitGroup 29 | } 30 | 31 | // Cache is a cache of DNS messages. 32 | type Cache struct { 33 | client dnsutil.Client 34 | backend Backend 35 | capacity int 36 | entries map[uint32]*list.Element 37 | values *list.List 38 | mu sync.RWMutex 39 | now func() time.Time 40 | queue *queue 41 | } 42 | 43 | // Value wraps a DNS message stored in the cache. 44 | type Value struct { 45 | Key uint32 46 | CreatedAt time.Time 47 | msg *dns.Msg 48 | } 49 | 50 | // Stats contains cache statistics. 51 | type Stats struct { 52 | Size int 53 | Capacity int 54 | PendingTasks int 55 | } 56 | 57 | // Rcode returns the response code of the cached value v. 58 | func (v *Value) Rcode() int { return v.msg.Rcode } 59 | 60 | // Question returns the first question the cached value v. 61 | func (v *Value) Question() string { return v.msg.Question[0].Name } 62 | 63 | // Qtype returns the query type of the cached value v 64 | func (v *Value) Qtype() uint16 { return v.msg.Question[0].Qtype } 65 | 66 | // Answers returns the answers of the cached value v. 67 | func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) } 68 | 69 | // TTL returns the time to live of the cached value v. 70 | func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } 71 | 72 | // Pack returns a string representation of Value v. 73 | func (v *Value) Pack() (string, error) { 74 | var sb strings.Builder 75 | sb.WriteString(strconv.FormatUint(uint64(v.Key), 10)) 76 | sb.WriteString(" ") 77 | sb.WriteString(strconv.FormatInt(v.CreatedAt.Unix(), 10)) 78 | sb.WriteString(" ") 79 | data, err := v.msg.Pack() 80 | if err != nil { 81 | return "", err 82 | } 83 | sb.WriteString(hex.EncodeToString(data)) 84 | return sb.String(), nil 85 | } 86 | 87 | // Unpack converts a string value into a Value type. 88 | func Unpack(value string) (Value, error) { 89 | fields := strings.Fields(value) 90 | if len(fields) < 3 { 91 | return Value{}, fmt.Errorf("invalid number of fields: %q", value) 92 | } 93 | key, err := strconv.ParseUint(fields[0], 10, 32) 94 | if err != nil { 95 | return Value{}, err 96 | } 97 | secs, err := strconv.ParseInt(fields[1], 10, 64) 98 | if err != nil { 99 | return Value{}, err 100 | } 101 | data, err := hex.DecodeString(fields[2]) 102 | if err != nil { 103 | return Value{}, err 104 | } 105 | msg := &dns.Msg{} 106 | if err := msg.Unpack(data); err != nil { 107 | return Value{}, err 108 | } 109 | return Value{ 110 | Key: uint32(key), 111 | CreatedAt: time.Unix(secs, 0), 112 | msg: msg, 113 | }, nil 114 | } 115 | 116 | // New creates a new cache of given capacity. 117 | // 118 | // If client is non-nil, the cache will prefetch expired entries in an effort to serve results faster. 119 | // 120 | // If backend is non-nil: 121 | // 122 | // - All cache write operations will be forward to the backend. 123 | // - The backed will be used to pre-populate the cache. 124 | func New(capacity int, client dnsutil.Client) *Cache { 125 | return NewWithBackend(capacity, client, nil) 126 | } 127 | 128 | // NewWithBackend creates a new cache that forwards entries to backend. 129 | func NewWithBackend(capacity int, client dnsutil.Client, backend Backend) *Cache { 130 | return newCache(capacity, client, backend, time.Now) 131 | } 132 | 133 | func newQueue(capacity int) *queue { return &queue{tasks: make(chan func(), capacity)} } 134 | 135 | func newCache(capacity int, client dnsutil.Client, backend Backend, now func() time.Time) *Cache { 136 | if capacity < 0 { 137 | capacity = 0 138 | } 139 | c := &Cache{ 140 | client: client, 141 | now: now, 142 | capacity: capacity, 143 | entries: make(map[uint32]*list.Element, capacity), 144 | values: list.New(), 145 | queue: newQueue(1024), 146 | } 147 | if backend != nil { 148 | c.load(backend) 149 | } 150 | go c.queue.consume() 151 | return c 152 | } 153 | 154 | // NewKey creates a new cache key for the DNS name, qtype and qclass 155 | func NewKey(name string, qtype, qclass uint16) uint32 { 156 | h := fnv.New32a() 157 | h.Write([]byte(name)) 158 | binary.Write(h, binary.BigEndian, qtype) 159 | binary.Write(h, binary.BigEndian, qclass) 160 | return h.Sum32() 161 | } 162 | 163 | func (c *Cache) load(backend Backend) { 164 | if c.capacity == 0 { 165 | backend.Reset() 166 | return 167 | } 168 | values := backend.Read() 169 | n := 0 170 | if c.capacity < len(values) { 171 | n = c.capacity 172 | } 173 | // Add the last n values from backend 174 | for _, v := range values[n:] { 175 | c.setValue(v) 176 | } 177 | if c.capacity < len(values) { 178 | // Remove older entries from backend 179 | for _, v := range values[:n] { 180 | backend.Evict(v.Key) 181 | } 182 | } 183 | c.backend = backend 184 | } 185 | 186 | // Close consumes any outstanding cache operations. 187 | func (c *Cache) Close() error { 188 | c.queue.wg.Wait() 189 | return nil 190 | } 191 | 192 | // Get returns the DNS message associated with key. 193 | func (c *Cache) Get(key uint32) (*dns.Msg, bool) { 194 | v, ok := c.getValue(key) 195 | if !ok { 196 | return nil, false 197 | } 198 | return v.msg, true 199 | } 200 | 201 | func (c *Cache) getValue(key uint32) (*Value, bool) { 202 | c.mu.RLock() 203 | defer c.mu.RUnlock() 204 | v, ok := c.entries[key] 205 | if !ok { 206 | return nil, false 207 | } 208 | value := v.Value.(Value) 209 | if c.isExpired(&value) { 210 | if !c.prefetch() { 211 | c.queue.add(func() { c.evictWithLock(key) }) 212 | return nil, false 213 | } 214 | c.queue.add(func() { c.refresh(key, value.msg) }) 215 | } 216 | return &value, true 217 | } 218 | 219 | // List returns the n most recent values in cache c. 220 | func (c *Cache) List(n int) []Value { 221 | values := make([]Value, 0, n) 222 | c.mu.RLock() 223 | defer c.mu.RUnlock() 224 | for el := c.values.Back(); el != nil; el = el.Prev() { 225 | if len(values) == n { 226 | break 227 | } 228 | v := el.Value.(Value) 229 | values = append(values, v) 230 | } 231 | return values 232 | } 233 | 234 | // Set associates key with the DNS message msg. 235 | // 236 | // If prefetching is disabled, the message will be evicted from the cache according to its TTL. 237 | // 238 | // If prefetching is enabled, the message will never be evicted, but it will be refreshed when its TTL passes. 239 | // 240 | // Setting a new key in a cache that has reached its capacity will evict values in a FIFO order. 241 | func (c *Cache) Set(key uint32, msg *dns.Msg) { 242 | c.mu.Lock() 243 | defer c.mu.Unlock() 244 | c.set(key, msg) 245 | } 246 | 247 | // Stats returns cache statistics. 248 | func (c *Cache) Stats() Stats { 249 | c.mu.RLock() 250 | defer c.mu.RUnlock() 251 | return Stats{ 252 | Capacity: c.capacity, 253 | Size: len(c.entries), 254 | PendingTasks: len(c.queue.tasks), 255 | } 256 | } 257 | 258 | func (c *Cache) set(key uint32, msg *dns.Msg) bool { 259 | return c.setValue(Value{Key: key, CreatedAt: c.now(), msg: msg}) 260 | } 261 | 262 | func (c *Cache) setValue(value Value) bool { 263 | if c.capacity == 0 || !canCache(value.msg) { 264 | return false 265 | } 266 | if len(c.entries) == c.capacity { 267 | first := c.values.Front() 268 | key := first.Value.(Value).Key 269 | c.evict(key, first) 270 | } 271 | current, ok := c.entries[value.Key] 272 | if ok { 273 | c.values.Remove(current) 274 | } 275 | c.entries[value.Key] = c.values.PushBack(value) 276 | if c.hasBackend() { 277 | c.backend.Set(value.Key, value) 278 | } 279 | return true 280 | } 281 | 282 | // Reset removes all values contained in cache c. 283 | func (c *Cache) Reset() { 284 | c.mu.Lock() 285 | defer c.mu.Unlock() 286 | c.entries = make(map[uint32]*list.Element, c.capacity) 287 | c.values = c.values.Init() 288 | if c.hasBackend() { 289 | c.backend.Reset() 290 | } 291 | } 292 | 293 | func (c *Cache) prefetch() bool { return c.client != nil } 294 | 295 | func (c *Cache) hasBackend() bool { return c.backend != nil } 296 | 297 | func (c *Cache) refresh(key uint32, old *dns.Msg) { 298 | q := old.Question[0] 299 | msg := dns.Msg{} 300 | msg.SetQuestion(q.Name, q.Qtype) 301 | r, err := c.client.Exchange(&msg) 302 | if err != nil { 303 | return // Retry on next request 304 | } 305 | c.mu.Lock() 306 | defer c.mu.Unlock() 307 | if !c.set(key, r) { 308 | c.evict(key, c.entries[key]) 309 | } 310 | } 311 | 312 | func (c *Cache) evictWithLock(key uint32) { 313 | c.mu.Lock() 314 | defer c.mu.Unlock() 315 | c.evict(key, c.entries[key]) 316 | } 317 | 318 | func (c *Cache) evict(key uint32, element *list.Element) { 319 | if element == nil { 320 | return 321 | } 322 | delete(c.entries, key) 323 | c.values.Remove(element) 324 | if c.hasBackend() { 325 | c.backend.Evict(key) 326 | } 327 | } 328 | 329 | func (c *Cache) isExpired(v *Value) bool { 330 | expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg)) 331 | return c.now().After(expiresAt) 332 | } 333 | 334 | func (q *queue) add(task func()) { 335 | q.wg.Add(1) 336 | q.tasks <- task 337 | } 338 | 339 | func (q *queue) consume() { 340 | for task := range q.tasks { 341 | task() 342 | q.wg.Done() 343 | } 344 | } 345 | 346 | func canCache(msg *dns.Msg) bool { 347 | if dnsutil.MinTTL(msg) == 0 { 348 | return false 349 | } 350 | return msg.Rcode == dns.RcodeSuccess || msg.Rcode == dns.RcodeNameError 351 | } 352 | -------------------------------------------------------------------------------- /sql/sql_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "reflect" 7 | "strings" 8 | "sync" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type rowCount struct { 14 | table string 15 | rows int 16 | } 17 | 18 | var tests = []struct { 19 | question string 20 | qtype uint16 21 | hijacked bool 22 | answers []string 23 | t time.Time 24 | remoteAddr net.IP 25 | rowCounts []rowCount 26 | }{ 27 | {"foo.example.com", 1, false, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 15, 10, 0, time.UTC), net.IPv4(192, 0, 2, 100), 28 | []rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 1}, {"rr_type", 1}, {"remote_addr", 1}}}, 29 | {"foo.example.com", 1, true, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 16, 20, 0, time.UTC), net.IPv4(192, 0, 2, 100), 30 | []rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 2}, {"rr_type", 1}, {"remote_addr", 1}}}, 31 | {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 17, 30, 0, time.UTC), net.IPv4(192, 0, 2, 101), 32 | []rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 3}, {"rr_type", 1}, {"remote_addr", 2}}}, 33 | {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 18, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), 34 | []rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 4}, {"rr_type", 1}, {"remote_addr", 3}}}, 35 | {"bar.example.com", 28, false, []string{"2001:db8::1"}, time.Date(2019, 6, 15, 23, 4, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), 36 | []rowCount{{"rr_question", 2}, {"rr_answer", 3}, {"log", 5}, {"rr_type", 2}, {"remote_addr", 3}}}, 37 | {"bar.example.com", 28, false, []string{"2001:db8::2", "2001:db8::3"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), 38 | []rowCount{{"rr_question", 2}, {"rr_answer", 5}, {"log", 6}, {"rr_type", 2}, {"remote_addr", 3}}}, 39 | {"baz.example.com", 28, false, []string{"2001:db8::4"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), 40 | []rowCount{{"rr_question", 3}, {"rr_answer", 6}, {"log", 7}, {"rr_type", 2}, {"remote_addr", 3}}}, 41 | {"baz.example.com", 28, false, nil, time.Date(2019, 6, 16, 1, 5, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), 42 | []rowCount{{"rr_question", 3}, {"rr_answer", 6}, {"log", 8}, {"rr_type", 2}, {"remote_addr", 3}}}, 43 | } 44 | 45 | func testClient() *Client { 46 | c, err := New(":memory:") 47 | if err != nil { 48 | panic(err) 49 | } 50 | return c 51 | } 52 | 53 | func count(t *testing.T, client *Client, query string, args ...interface{}) int { 54 | rows := 0 55 | if err := client.db.Get(&rows, query, args...); err != nil { 56 | t.Fatalf("query failed: %s: %s", query, err) 57 | } 58 | return rows 59 | } 60 | 61 | func writeTests(c *Client, t *testing.T) { 62 | for i, tt := range tests { 63 | if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { 64 | t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) 65 | } 66 | } 67 | } 68 | 69 | func TestWriteLog(t *testing.T) { 70 | c := testClient() 71 | for i, tt := range tests { 72 | if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { 73 | t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) 74 | } 75 | for _, rowCount := range tt.rowCounts { 76 | rows := count(t, c, "SELECT COUNT(*) FROM "+rowCount.table+" LIMIT 1") 77 | if rows != rowCount.rows { 78 | t.Errorf("#%d: got %d rows in %s, want %d", i, rows, rowCount.table, rowCount.rows) 79 | } 80 | } 81 | } 82 | } 83 | 84 | func TestReadLog(t *testing.T) { 85 | c := testClient() 86 | writeTests(c, t) 87 | allEntries := [][]logEntry{ 88 | {{ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, 89 | {{ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, 90 | { 91 | {ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::3", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 92 | {ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::2", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 93 | }, 94 | {{ID: 5, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::1", Time: 1560639880, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, 95 | {{ID: 4, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637120, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, 96 | {{ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)}}, 97 | {{ID: 2, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636980, RemoteAddr: net.IPv4(192, 0, 2, 100), Hijacked: true}}, 98 | {{ID: 1, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636910, RemoteAddr: net.IPv4(192, 0, 2, 100)}}, 99 | } 100 | for n := 1; n <= len(allEntries); n++ { 101 | var want []logEntry 102 | for _, entries := range allEntries[:n] { 103 | want = append(want, entries...) 104 | } 105 | got, err := c.readLog(n) 106 | if len(got) != len(want) { 107 | t.Errorf("len(got) = %d, want %d", len(got), len(want)) 108 | } 109 | if err != nil || !reflect.DeepEqual(got, want) { 110 | var sb1 strings.Builder 111 | for _, e := range got { 112 | sb1.WriteString(fmt.Sprintf(" %+v\n", e)) 113 | } 114 | var sb2 strings.Builder 115 | for _, e := range want { 116 | sb2.WriteString(fmt.Sprintf(" %+v\n", e)) 117 | } 118 | t.Errorf("ReadLog(%d) = (\n%s, %v),\nwant (\n%s, %v)", n, sb1.String(), err, sb2.String(), nil) 119 | } 120 | } 121 | } 122 | 123 | func TestDeleteLogBefore(t *testing.T) { 124 | c := testClient() 125 | writeTests(c, t) 126 | u := tests[1].t.Add(time.Second) 127 | if err := c.deleteLogBefore(u); err != nil { 128 | t.Fatalf("DeleteBefore(%s) = %v, want %v", u, err, nil) 129 | } 130 | 131 | want := []logEntry{ 132 | {ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 133 | {ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 134 | {ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::3", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 135 | {ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::2", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 136 | {ID: 5, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::1", Time: 1560639880, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 137 | {ID: 4, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637120, RemoteAddr: net.IPv4(192, 0, 2, 102)}, 138 | {ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)}, 139 | } 140 | n := 10 141 | got, err := c.readLog(n) 142 | if err != nil || !reflect.DeepEqual(got, want) { 143 | t.Errorf("ReadLog(%d) = (%+v, %v), want (%+v, %v)", n, got, err, want, nil) 144 | } 145 | 146 | question := "foo.example.com" 147 | if want, got := 0, count(t, c, "SELECT COUNT(*) FROM rr_question WHERE name = $1", question); got != want { 148 | t.Errorf("got %d rows for question %q, want %d", got, question, want) 149 | } 150 | 151 | answer := "192.0.2.1" 152 | if want, got := 0, count(t, c, "SELECT COUNT(*) FROM rr_answer WHERE name = $1", answer); got != want { 153 | t.Errorf("got %d rows for answer %q, want %d", got, question, want) 154 | } 155 | 156 | // Delete logs in the far past which matches 0 entries. 157 | oneYear := time.Hour * 8760 158 | if err := c.deleteLogBefore(u.Add(-oneYear)); err != nil { 159 | t.Fatal(err) 160 | } 161 | } 162 | 163 | func TestInterleavedRW(t *testing.T) { 164 | c := testClient() 165 | var wg sync.WaitGroup 166 | wg.Add(1) 167 | ch := make(chan bool, 10) 168 | var err error 169 | go func() { 170 | defer wg.Done() 171 | for range ch { 172 | err = c.writeLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1") 173 | } 174 | }() 175 | ch <- true 176 | close(ch) 177 | if _, err := c.readLog(1); err != nil { 178 | t.Fatal(err) 179 | } 180 | wg.Wait() 181 | if err != nil { 182 | t.Fatal(err) 183 | } 184 | } 185 | 186 | func TestReadLogStats(t *testing.T) { 187 | c := testClient() 188 | 189 | got, err := c.readLogStats() 190 | if err != nil { 191 | t.Fatal(err) 192 | } 193 | want := logStats{} 194 | if !reflect.DeepEqual(got, want) { 195 | t.Errorf("readLogStats() = (%+v, _), want (%+v, _)", got, want) 196 | } 197 | 198 | writeTests(c, t) 199 | got, err = c.readLogStats() 200 | if err != nil { 201 | t.Fatal(err) 202 | } 203 | want = logStats{ 204 | Since: 1560636910, 205 | Hijacked: 1, 206 | Total: 8, 207 | Events: []logEvent{ 208 | {Time: 1560636910, Count: 1}, 209 | {Time: 1560636980, Count: 1}, 210 | {Time: 1560637050, Count: 1}, 211 | {Time: 1560637120, Count: 1}, 212 | {Time: 1560639880, Count: 1}, 213 | {Time: 1560641700, Count: 2}, 214 | {Time: 1560647100, Count: 1}, 215 | }, 216 | } 217 | if !reflect.DeepEqual(got, want) { 218 | t.Errorf("readLogStats() = (%+v, _), want (%+v, _)", got, want) 219 | } 220 | } 221 | 222 | func BenchmarkReadLog(b *testing.B) { 223 | c := testClient() 224 | for i := 0; i < 1000; i++ { 225 | if err := c.writeLog(time.Now(), net.ParseIP("127.0.0.1"), false, 1, "example.com.", "192.0.2.1"); err != nil { 226 | b.Fatal(err) 227 | } 228 | } 229 | b.ResetTimer() 230 | for n := 0; n < b.N; n++ { 231 | c.readLog(1000) 232 | } 233 | } 234 | 235 | func BenchmarkDeleteLogBefore(b *testing.B) { 236 | c := testClient() 237 | for n := 0; n < b.N; n++ { 238 | b.StopTimer() 239 | // Generate test data with many unique values for each column 240 | for i := 0; i < 16; i++ { 241 | for j := 0; j < 256; j++ { 242 | if err := c.writeLog(time.Now(), net.ParseIP(fmt.Sprintf("127.0.%d.%d", i, j)), false, 1, fmt.Sprintf("%d-%d.example.com.", i, j), fmt.Sprintf("127.1.%d.%d", i, j)); err != nil { 243 | b.Fatal(err) 244 | } 245 | } 246 | } 247 | b.StartTimer() 248 | c.deleteLogBefore(time.Now()) 249 | } 250 | } 251 | -------------------------------------------------------------------------------- /sql/sql.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "database/sql" 5 | "sync" 6 | "time" 7 | 8 | "github.com/jmoiron/sqlx" 9 | _ "github.com/mattn/go-sqlite3" // SQLite database driver 10 | ) 11 | 12 | const schema = ` 13 | CREATE TABLE IF NOT EXISTS rr_question ( 14 | id INTEGER PRIMARY KEY, 15 | name TEXT NOT NULL, 16 | CONSTRAINT name_unique UNIQUE(name) 17 | ); 18 | 19 | CREATE TABLE IF NOT EXISTS rr_answer ( 20 | id INTEGER PRIMARY KEY, 21 | name TEXT NOT NULL, 22 | CONSTRAINT name_unique UNIQUE(name) 23 | ); 24 | 25 | CREATE TABLE IF NOT EXISTS rr_type ( 26 | id INTEGER PRIMARY KEY, 27 | type INTEGER NOT NULL, 28 | CONSTRAINT type_unique UNIQUE(type) 29 | ); 30 | 31 | CREATE TABLE IF NOT EXISTS remote_addr ( 32 | id INTEGER PRIMARY KEY, 33 | addr BLOB NOT NULL, 34 | CONSTRAINT addr_unique UNIQUE(addr) 35 | ); 36 | 37 | CREATE TABLE IF NOT EXISTS log ( 38 | id INTEGER PRIMARY KEY, 39 | time INTEGER NOT NULL, 40 | hijacked INTEGER NOT NULL, 41 | remote_addr_id INTEGER NOT NULL, 42 | rr_type_id INTEGER NOT NULL, 43 | rr_question_id INTEGER NOT NULL, 44 | FOREIGN KEY (remote_addr_id) REFERENCES remote_addr(id), 45 | FOREIGN KEY (rr_question_id) REFERENCES rr_question(id), 46 | FOREIGN KEY (rr_type_id) REFERENCES rr_type(id) 47 | ); 48 | 49 | CREATE INDEX IF NOT EXISTS log_time ON log(time); 50 | CREATE INDEX IF NOT EXISTS log_remote_addr_id ON log(remote_addr_id); 51 | CREATE INDEX IF NOT EXISTS log_rr_question_id ON log(rr_question_id); 52 | CREATE INDEX IF NOT EXISTS log_rr_type_id ON log(rr_type_id); 53 | 54 | CREATE TABLE IF NOT EXISTS log_rr_answer ( 55 | id INTEGER PRIMARY KEY, 56 | log_id INTEGER NOT NULL, 57 | rr_answer_id INTEGER NOT NULL, 58 | FOREIGN KEY (log_id) REFERENCES log(id), 59 | FOREIGN KEY (rr_answer_id) REFERENCES rr_answer(id) 60 | ); 61 | 62 | CREATE INDEX IF NOT EXISTS log_rr_answer_log_id ON log_rr_answer(log_id); 63 | CREATE INDEX IF NOT EXISTS log_rr_answer_rr_answer_id ON log_rr_answer(rr_answer_id); 64 | 65 | CREATE TABLE IF NOT EXISTS cache ( 66 | id INTEGER PRIMARY KEY, 67 | key INTEGER NOT NULL, 68 | data TEXT NOT NULL, 69 | CONSTRAINT key_unique UNIQUE(key) 70 | ); 71 | ` 72 | 73 | // Client implements a client for a SQLite database. 74 | type Client struct { 75 | db *sqlx.DB 76 | mu sync.RWMutex 77 | } 78 | 79 | type logEntry struct { 80 | ID int64 `db:"id"` 81 | Time int64 `db:"time"` 82 | RemoteAddr []byte `db:"remote_addr"` 83 | Hijacked bool `db:"hijacked"` 84 | Qtype uint16 `db:"type"` 85 | Question string `db:"question"` 86 | Answer string `db:"answer"` 87 | } 88 | 89 | type logStats struct { 90 | Since int64 `db:"since"` 91 | Hijacked int64 `db:"hijacked"` 92 | Total int64 `db:"total"` 93 | Events []logEvent 94 | } 95 | 96 | type logEvent struct { 97 | Time int64 `db:"time"` 98 | Count int64 `db:"count"` 99 | } 100 | 101 | type cacheEntry struct { 102 | Key uint32 `db:"key"` 103 | Data string `db:"data"` 104 | } 105 | 106 | // New creates a new database client for given filename. 107 | func New(filename string) (*Client, error) { 108 | db, err := sqlx.Connect("sqlite3", filename) 109 | if err != nil { 110 | return nil, err 111 | } 112 | // Ensure foreign keys are enabled (defaults to off) 113 | if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { 114 | return nil, err 115 | } 116 | if _, err := db.Exec("PRAGMA journal_mode = WAL"); err != nil { 117 | return nil, err 118 | } 119 | if _, err := db.Exec(schema); err != nil { 120 | return nil, err 121 | } 122 | return &Client{db: db}, nil 123 | } 124 | 125 | // Close waits for all queries to complete and then closes the database. 126 | func (c *Client) Close() error { return c.db.Close() } 127 | 128 | func (c *Client) readLog(n int) ([]logEntry, error) { 129 | c.mu.RLock() 130 | defer c.mu.RUnlock() 131 | query := ` 132 | SELECT log.id AS id, 133 | time, 134 | remote_addr.addr AS remote_addr, 135 | hijacked, 136 | type, 137 | rr_question.name AS question, 138 | IFNULL(rr_answer.name, "") AS answer 139 | FROM log 140 | INNER JOIN remote_addr ON remote_addr.id = log.remote_addr_id 141 | INNER JOIN rr_question ON rr_question.id = rr_question_id 142 | INNER JOIN rr_type ON rr_type.id = rr_type_id 143 | LEFT JOIN log_rr_answer ON log_rr_answer.log_id = log.id 144 | LEFT JOIN rr_answer ON rr_answer.id = log_rr_answer.rr_answer_id 145 | WHERE log.id IN (SELECT id FROM log ORDER BY time DESC, id DESC LIMIT $1) 146 | ORDER BY time DESC, rr_answer.id DESC 147 | ` 148 | var entries []logEntry 149 | err := c.db.Select(&entries, query, n) 150 | return entries, err 151 | } 152 | 153 | func getOrInsert(tx *sqlx.Tx, table, column string, value interface{}) (int64, error) { 154 | var id int64 155 | err := tx.Get(&id, "SELECT id FROM "+table+" WHERE "+column+" = ?", value) 156 | if err == sql.ErrNoRows { 157 | res, err := tx.Exec("INSERT INTO "+table+" ("+column+") VALUES (?)", value) 158 | if err != nil { 159 | return 0, err 160 | } 161 | return res.LastInsertId() 162 | } 163 | return id, err 164 | } 165 | 166 | func (c *Client) writeLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error { 167 | c.mu.Lock() 168 | defer c.mu.Unlock() 169 | tx, err := c.db.Beginx() 170 | if err != nil { 171 | return err 172 | } 173 | defer tx.Rollback() 174 | typeID, err := getOrInsert(tx, "rr_type", "type", qtype) 175 | if err != nil { 176 | return err 177 | } 178 | questionID, err := getOrInsert(tx, "rr_question", "name", question) 179 | if err != nil { 180 | return err 181 | } 182 | remoteAddrID, err := getOrInsert(tx, "remote_addr", "addr", remoteAddr) 183 | if err != nil { 184 | return err 185 | } 186 | answerIDs := make([]int64, 0, len(answers)) 187 | for _, answer := range answers { 188 | answerID, err := getOrInsert(tx, "rr_answer", "name", answer) 189 | if err != nil { 190 | return err 191 | } 192 | answerIDs = append(answerIDs, answerID) 193 | } 194 | hijackedInt := 0 195 | if hijacked { 196 | hijackedInt = 1 197 | } 198 | res, err := tx.Exec("INSERT INTO log (time, hijacked, remote_addr_id, rr_type_id, rr_question_id) VALUES ($1, $2, $3, $4, $5)", time.Unix(), hijackedInt, remoteAddrID, typeID, questionID) 199 | if err != nil { 200 | return err 201 | } 202 | logID, err := res.LastInsertId() 203 | if err != nil { 204 | return err 205 | } 206 | for _, answerID := range answerIDs { 207 | if _, err := tx.Exec("INSERT INTO log_rr_answer (log_id, rr_answer_id) VALUES ($1, $2)", logID, answerID); err != nil { 208 | return err 209 | } 210 | } 211 | return tx.Commit() 212 | } 213 | 214 | func (c *Client) deleteLogBefore(t time.Time) (err error) { 215 | c.mu.Lock() 216 | defer c.mu.Unlock() 217 | tx, err := c.db.Beginx() 218 | if err != nil { 219 | return nil 220 | } 221 | defer tx.Rollback() 222 | var ids []int64 223 | // SQLite limits the number of variables to 999 (SQLITE_LIMIT_VARIABLE_NUMBER): 224 | // https://www.sqlite.org/limits.html 225 | if err := tx.Select(&ids, "SELECT id FROM log WHERE time < $1 ORDER BY time ASC LIMIT 999", t.Unix()); err != nil { 226 | return err 227 | } 228 | if len(ids) == 0 { 229 | return nil 230 | } 231 | deleteByIds := []string{ 232 | "DELETE FROM log_rr_answer WHERE log_id IN (?)", 233 | "DELETE FROM log WHERE id IN (?)", 234 | } 235 | for _, q := range deleteByIds { 236 | query, args, err := sqlx.In(q, ids) 237 | if err != nil { 238 | return err 239 | } 240 | if _, err := tx.Exec(query, args...); err != nil { 241 | return err 242 | } 243 | } 244 | deleteBySelection := []string{ 245 | "DELETE FROM rr_type WHERE id NOT IN (SELECT rr_type_id FROM log)", 246 | "DELETE FROM rr_question WHERE id NOT IN (SELECT rr_question_id FROM log)", 247 | "DELETE FROM rr_answer WHERE id NOT IN (SELECT rr_answer_id FROM log_rr_answer)", 248 | "DELETE FROM remote_addr WHERE id NOT IN (SELECT remote_addr_id FROM log)", 249 | } 250 | for _, q := range deleteBySelection { 251 | if _, err := tx.Exec(q); err != nil { 252 | return err 253 | } 254 | } 255 | return tx.Commit() 256 | } 257 | 258 | func (c *Client) readLogStats() (logStats, error) { 259 | c.mu.RLock() 260 | defer c.mu.RUnlock() 261 | var stats logStats 262 | q1 := `SELECT COUNT(*) as total, 263 | COUNT(CASE hijacked WHEN 1 THEN 1 ELSE NULL END) as hijacked, 264 | IFNULL(time, 0) AS since 265 | FROM log 266 | ORDER BY time ASC LIMIT 1` 267 | if err := c.db.Get(&stats, q1); err != nil { 268 | return logStats{}, err 269 | } 270 | var events []logEvent 271 | q2 := `SELECT time, 272 | COUNT(*) AS count 273 | FROM log 274 | GROUP BY time 275 | ORDER BY time ASC` 276 | if err := c.db.Select(&events, q2); err != nil { 277 | return logStats{}, err 278 | } 279 | stats.Events = events 280 | return stats, nil 281 | } 282 | 283 | func (c *Client) writeCacheValue(key uint32, data string) error { 284 | c.mu.Lock() 285 | defer c.mu.Unlock() 286 | tx, err := c.db.Beginx() 287 | if err != nil { 288 | return nil 289 | } 290 | defer tx.Rollback() 291 | if _, err := tx.Exec("DELETE FROM cache WHERE key = $1", key); err != nil { 292 | return err 293 | } 294 | if _, err := tx.Exec("INSERT INTO cache (key, data) VALUES ($1, $2)", key, data); err != nil { 295 | return err 296 | } 297 | return tx.Commit() 298 | } 299 | 300 | func (c *Client) removeCacheValue(key uint32) error { 301 | c.mu.Lock() 302 | defer c.mu.Unlock() 303 | tx, err := c.db.Beginx() 304 | if err != nil { 305 | return nil 306 | } 307 | defer tx.Rollback() 308 | if _, err := tx.Exec("DELETE FROM cache WHERE key = $1", key); err != nil { 309 | return err 310 | } 311 | return tx.Commit() 312 | } 313 | 314 | func (c *Client) truncateCache() error { 315 | c.mu.Lock() 316 | defer c.mu.Unlock() 317 | tx, err := c.db.Beginx() 318 | if err != nil { 319 | return nil 320 | } 321 | defer tx.Rollback() 322 | if _, err := tx.Exec("DELETE FROM cache"); err != nil { 323 | return err 324 | } 325 | return tx.Commit() 326 | } 327 | 328 | func (c *Client) readCache() ([]cacheEntry, error) { 329 | c.mu.RLock() 330 | defer c.mu.RUnlock() 331 | var entries []cacheEntry 332 | err := c.db.Select(&entries, "SELECT key, data FROM cache ORDER BY id ASC") 333 | return entries, err 334 | } 335 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /cache/cache_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "reflect" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/miekg/dns" 12 | "github.com/mpolden/zdns/dns/dnsutil" 13 | ) 14 | 15 | var testMsg *dns.Msg = newA("example.com.", 60, net.ParseIP("192.0.2.1")) 16 | 17 | type testClient struct { 18 | mu sync.RWMutex 19 | answers chan *dns.Msg 20 | } 21 | 22 | func newTestClient() *testClient { return &testClient{answers: make(chan *dns.Msg, 100)} } 23 | 24 | func (e *testClient) setAnswer(answer *dns.Msg) { 25 | e.mu.Lock() 26 | defer e.mu.Unlock() 27 | e.answers <- answer 28 | } 29 | 30 | func (e *testClient) reset() { 31 | e.mu.Lock() 32 | defer e.mu.Unlock() 33 | e.answers = make(chan *dns.Msg, 100) 34 | } 35 | 36 | func (e *testClient) Exchange(msg *dns.Msg) (*dns.Msg, error) { 37 | e.mu.RLock() 38 | defer e.mu.RUnlock() 39 | if len(e.answers) == 0 { 40 | return nil, fmt.Errorf("no answer pending") 41 | } 42 | return <-e.answers, nil 43 | } 44 | 45 | type testBackend struct { 46 | values []Value 47 | } 48 | 49 | func (b *testBackend) Set(key uint32, value Value) { 50 | b.values = append(b.values, value) 51 | } 52 | 53 | func (b *testBackend) Evict(key uint32) { 54 | var values []Value 55 | for _, v := range b.values { 56 | if v.Key == key { 57 | continue 58 | } 59 | values = append(values, v) 60 | } 61 | b.values = values 62 | } 63 | 64 | func (b *testBackend) Reset() { b.values = nil } 65 | 66 | func (b *testBackend) Read() []Value { return b.values } 67 | 68 | func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { 69 | m := dns.Msg{} 70 | m.Id = dns.Id() 71 | m.SetQuestion(dns.Fqdn(name), dns.TypeA) 72 | rr := make([]dns.RR, 0, len(ipAddr)) 73 | for _, ip := range ipAddr { 74 | rr = append(rr, &dns.A{ 75 | A: ip, 76 | Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, 77 | }) 78 | } 79 | m.Answer = rr 80 | return &m 81 | } 82 | 83 | func reverse(msgs []*dns.Msg) []*dns.Msg { 84 | reversed := make([]*dns.Msg, 0, len(msgs)) 85 | for i := len(msgs) - 1; i >= 0; i-- { 86 | reversed = append(reversed, msgs[i]) 87 | } 88 | return reversed 89 | } 90 | 91 | func TestNewKey(t *testing.T) { 92 | var tests = []struct { 93 | name string 94 | qtype, qclass uint16 95 | out uint32 96 | }{ 97 | {"foo.", dns.TypeA, dns.ClassINET, 2839090419}, 98 | {"foo.", dns.TypeAAAA, dns.ClassINET, 3344654668}, 99 | {"foo.", dns.TypeA, dns.ClassANY, 1731870733}, 100 | {"bar.", dns.TypeA, dns.ClassINET, 1951431764}, 101 | } 102 | for i, tt := range tests { 103 | got := NewKey(tt.name, tt.qtype, tt.qclass) 104 | if got != tt.out { 105 | t.Errorf("#%d: NewKey(%q, %d, %d) = %d, want %d", i, tt.name, tt.qtype, tt.qclass, got, tt.out) 106 | } 107 | } 108 | } 109 | 110 | func TestCache(t *testing.T) { 111 | msg := newA("1.example.com.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2")) 112 | msgWithZeroTTL := newA("2.example.com.", 0, net.ParseIP("192.0.2.2")) 113 | msgFailure := newA("3.example.com.", 60, net.ParseIP("192.0.2.2")) 114 | msgFailure.Rcode = dns.RcodeServerFailure 115 | msgNameError := &dns.Msg{} 116 | msgNameError.Id = dns.Id() 117 | msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA) 118 | msgNameError.Rcode = dns.RcodeNameError 119 | 120 | now := time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC) 121 | c := New(100, nil) 122 | var tests = []struct { 123 | msg *dns.Msg 124 | queriedAt time.Time 125 | ok bool 126 | value *Value 127 | }{ 128 | {msg, now, true, &Value{Key: 3517338631, CreatedAt: now, msg: msg}}, // Not expired when query time == create time 129 | {msg, now.Add(30 * time.Second), true, &Value{Key: 3517338631, CreatedAt: now, msg: msg}}, // Not expired when below TTL 130 | {msg, now.Add(60 * time.Second), true, &Value{Key: 3517338631, CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds 131 | {msgNameError, now, true, &Value{Key: 3980405151, CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached 132 | {msg, now.Add(61 * time.Second), false, nil}, // Expired due to TTL exceeded 133 | {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached 134 | {msgFailure, now, false, nil}, // Non-cacheable rcode 135 | } 136 | for i, tt := range tests { 137 | c.now = func() time.Time { return now } 138 | k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass) 139 | c.Set(k, tt.msg) 140 | c.now = func() time.Time { return tt.queriedAt } 141 | if msg, ok := c.Get(k); ok != tt.ok { 142 | t.Errorf("#%d: Get(%d) = (%+v, %t), want (_, %t)", i, k, msg, ok, tt.ok) 143 | } 144 | if v, ok := c.getValue(k); ok != tt.ok || !reflect.DeepEqual(v, tt.value) { 145 | t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok) 146 | } 147 | c.Close() 148 | c.mu.RLock() 149 | if _, ok := c.entries[k]; ok != tt.ok { 150 | t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok) 151 | } 152 | keyIdx := -1 153 | for el := c.values.Front(); el != nil; el = el.Next() { 154 | if el.Value.(Value).Key == k { 155 | keyIdx = i 156 | break 157 | } 158 | } 159 | c.mu.RUnlock() 160 | if (keyIdx != -1) != tt.ok { 161 | t.Errorf("#%d: keys[%d] = %d, found expired key", i, keyIdx, k) 162 | } 163 | } 164 | } 165 | 166 | func TestCacheCapacity(t *testing.T) { 167 | var tests = []struct { 168 | addCount, capacity, size int 169 | }{ 170 | {1, 0, 0}, 171 | {1, 2, 1}, 172 | {2, 2, 2}, 173 | {3, 2, 2}, 174 | } 175 | for i, tt := range tests { 176 | c := New(tt.capacity, nil) 177 | var msgs []*dns.Msg 178 | for i := 0; i < tt.addCount; i++ { 179 | m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) 180 | k := NewKey(m.Question[0].Name, m.Question[0].Qtype, m.Question[0].Qclass) 181 | msgs = append(msgs, m) 182 | c.Set(k, m) 183 | } 184 | if got := len(c.entries); got != tt.size { 185 | t.Errorf("#%d: len(values) = %d, want %d", i, got, tt.size) 186 | } 187 | if tt.capacity > 0 && tt.addCount > tt.capacity && tt.capacity == tt.size { 188 | lastAdded := msgs[tt.addCount-1].Question[0] 189 | lastK := NewKey(lastAdded.Name, lastAdded.Qtype, lastAdded.Qclass) 190 | if _, ok := c.Get(lastK); !ok { 191 | t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, lastAdded.Name, ok, !ok) 192 | } 193 | firstAdded := msgs[0].Question[0] 194 | firstK := NewKey(firstAdded.Name, firstAdded.Qtype, firstAdded.Qclass) 195 | if _, ok := c.Get(firstK); ok { 196 | t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, firstAdded.Name, ok, !ok) 197 | } 198 | } 199 | } 200 | } 201 | 202 | func TestCacheList(t *testing.T) { 203 | var tests = []struct { 204 | addCount, listCount, wantCount int 205 | expire bool 206 | }{ 207 | {0, 0, 0, false}, 208 | {1, 0, 0, false}, 209 | {1, 1, 1, false}, 210 | {2, 1, 1, false}, 211 | {2, 3, 2, false}, 212 | {2, 0, 0, true}, 213 | } 214 | for i, tt := range tests { 215 | c := New(1024, nil) 216 | var msgs []*dns.Msg 217 | for i := 0; i < tt.addCount; i++ { 218 | m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) 219 | k := NewKey(m.Question[0].Name, m.Question[0].Qtype, m.Question[0].Qclass) 220 | msgs = append(msgs, m) 221 | c.Set(k, m) 222 | } 223 | if tt.expire { 224 | c.now = func() time.Time { return time.Now().Add(time.Minute).Add(time.Second) } 225 | } 226 | values := c.List(tt.listCount) 227 | if got := len(values); got != tt.wantCount { 228 | t.Errorf("#%d: len(List(%d)) = %d, want %d", i, tt.listCount, got, tt.wantCount) 229 | } 230 | gotMsgs := make([]*dns.Msg, 0, len(values)) 231 | for _, v := range values { 232 | gotMsgs = append(gotMsgs, v.msg) 233 | } 234 | msgs = reverse(msgs) 235 | want := msgs[:tt.wantCount] 236 | if !reflect.DeepEqual(want, gotMsgs) { 237 | t.Errorf("#%d: got %+v, want %+v", i, gotMsgs, want) 238 | } 239 | } 240 | } 241 | 242 | func TestReset(t *testing.T) { 243 | c := New(10, nil) 244 | c.Set(uint32(1), &dns.Msg{}) 245 | c.Reset() 246 | if got, want := len(c.entries), 0; got != want { 247 | t.Errorf("len(values) = %d, want %d", got, want) 248 | } 249 | if got, want := c.values.Len(), 0; got != want { 250 | t.Errorf("len(keys) = %d, want %d", got, want) 251 | } 252 | } 253 | 254 | func TestCachePrefetch(t *testing.T) { 255 | client := newTestClient() 256 | now := time.Now() 257 | c := newCache(10, client, nil, func() time.Time { return now }) 258 | var tests = []struct { 259 | initialAnswer string 260 | refreshAnswer string 261 | initialTTL time.Duration 262 | refreshTTL time.Duration 263 | readDelay time.Duration 264 | answer string 265 | ok bool 266 | refetch bool 267 | }{ 268 | // Serves cached value before expiry 269 | {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 30 * time.Second, "192.0.2.1", true, true}, 270 | // Serves stale cached value after expiry and before refresh happens 271 | {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.1", true, false}, 272 | // Serves refreshed value after expiry and refresh 273 | {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.42", true, true}, 274 | // Refreshed value can no longer be cached 275 | {"192.0.2.1", "192.0.2.42", time.Minute, 0, 61 * time.Second, "192.0.2.42", false, true}, 276 | } 277 | for i, tt := range tests { 278 | copy := testMsg.Copy() 279 | copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer) 280 | copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds()) 281 | client.reset() 282 | client.setAnswer(copy) 283 | 284 | // Add new value now 285 | c.now = func() time.Time { return now } 286 | var key uint32 = 1 287 | c.Set(key, testMsg) 288 | 289 | // Read value at some point in the future 290 | c.now = func() time.Time { return now.Add(tt.readDelay) } 291 | v, ok := c.getValue(key) 292 | c.Close() // Flush queued operations 293 | 294 | if tt.refetch { 295 | v, ok = c.getValue(key) 296 | } 297 | if ok != tt.ok { 298 | t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok) 299 | } 300 | if tt.ok { 301 | answers := dnsutil.Answers(v.msg) 302 | if answers[0] != tt.answer { 303 | t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer) 304 | } 305 | } 306 | } 307 | } 308 | 309 | func TestCacheEvictAndUpdate(t *testing.T) { 310 | client := newTestClient() 311 | now := time.Now() 312 | c := newCache(10, client, nil, func() time.Time { return now }) 313 | 314 | var key uint32 = 1 315 | c.Set(key, testMsg) 316 | 317 | // Initial prefetched answer can no longer be cached 318 | copy := testMsg.Copy() 319 | copy.Answer[0].(*dns.A).Hdr.Ttl = 0 320 | client.setAnswer(copy) 321 | copy = testMsg.Copy() 322 | copy.Answer[0].(*dns.A).Hdr.Ttl = 30 323 | client.setAnswer(copy) 324 | 325 | // Advance time so that msg is now considered expired. Query to trigger prefetch 326 | c.now = func() time.Time { return now.Add(61 * time.Second) } 327 | c.Get(key) 328 | 329 | // Query again, causing another prefetch with a non-zero TTL 330 | c.Get(key) 331 | 332 | // Last query refreshes key 333 | c.Close() 334 | keyExists := false 335 | for el := c.values.Front(); el != nil; el = el.Next() { 336 | if el.Value.(Value).Key == key { 337 | keyExists = true 338 | } 339 | } 340 | if !keyExists { 341 | t.Errorf("expected cache keys to contain %d", key) 342 | } 343 | } 344 | 345 | func TestPackValue(t *testing.T) { 346 | v := Value{ 347 | Key: 42, 348 | CreatedAt: time.Now().Truncate(time.Second), 349 | msg: testMsg, 350 | } 351 | packed, err := v.Pack() 352 | if err != nil { 353 | t.Fatal(err) 354 | } 355 | unpacked, err := Unpack(packed) 356 | if err != nil { 357 | t.Fatal(err) 358 | } 359 | if got, want := unpacked.Key, v.Key; got != want { 360 | t.Errorf("Key = %d, want %d", got, want) 361 | } 362 | if got, want := unpacked.CreatedAt, v.CreatedAt; !got.Equal(want) { 363 | t.Errorf("CreatedAt = %s, want %s", got, want) 364 | } 365 | if got, want := unpacked.msg.String(), v.msg.String(); got != want { 366 | t.Errorf("msg = %s, want %s", got, want) 367 | } 368 | } 369 | 370 | func TestCacheWithBackend(t *testing.T) { 371 | var tests = []struct { 372 | capacity int 373 | backendSize int 374 | cacheSize int 375 | }{ 376 | {0, 0, 0}, 377 | {0, 1, 0}, 378 | {1, 0, 0}, 379 | {1, 1, 1}, 380 | {1, 2, 1}, 381 | {2, 1, 1}, 382 | {2, 2, 2}, 383 | {3, 2, 2}, 384 | } 385 | for i, tt := range tests { 386 | backend := &testBackend{} 387 | for j := 0; j < tt.backendSize; j++ { 388 | v := Value{ 389 | Key: uint32(j), 390 | CreatedAt: time.Now(), 391 | msg: testMsg, 392 | } 393 | backend.Set(v.Key, v) 394 | } 395 | c := NewWithBackend(tt.capacity, nil, backend) 396 | if got, want := len(c.entries), tt.cacheSize; got != want { 397 | t.Errorf("#%d: len(values) = %d, want %d", i, got, want) 398 | } 399 | if tt.backendSize > tt.capacity { 400 | if got, want := len(backend.Read()), tt.capacity; got != want { 401 | t.Errorf("#%d: len(backend.Read()) = %d, want %d", i, got, want) 402 | } 403 | } 404 | if tt.capacity == tt.backendSize { 405 | // Adding a new entry to a cache at capacity removes the oldest from backend 406 | c.Set(42, testMsg) 407 | if got, want := len(backend.Read()), tt.capacity; got != want { 408 | t.Errorf("#%d: len(backend.Read()) = %d, want %d", i, got, want) 409 | } 410 | } 411 | } 412 | } 413 | 414 | func TestCacheStats(t *testing.T) { 415 | c := New(10, nil) 416 | c.Set(1, testMsg) 417 | c.Set(2, testMsg) 418 | want := Stats{Capacity: 10, Size: 2} 419 | got := c.Stats() 420 | if !reflect.DeepEqual(got, want) { 421 | t.Errorf("Stats() = %+v, want %+v", got, want) 422 | } 423 | } 424 | 425 | func BenchmarkNewKey(b *testing.B) { 426 | for n := 0; n < b.N; n++ { 427 | NewKey("key", 1, 1) 428 | } 429 | } 430 | 431 | func BenchmarkSet(b *testing.B) { 432 | c := New(4096, nil) 433 | b.ResetTimer() 434 | for n := 0; n < b.N; n++ { 435 | c.Set(uint32(n), &dns.Msg{}) 436 | } 437 | } 438 | 439 | func BenchmarkGet(b *testing.B) { 440 | c := New(4096, nil) 441 | c.Set(uint32(1), &dns.Msg{}) 442 | b.ResetTimer() 443 | for n := 0; n < b.N; n++ { 444 | c.Get(uint32(1)) 445 | } 446 | } 447 | 448 | func BenchmarkEviction(b *testing.B) { 449 | c := New(1, nil) 450 | b.ResetTimer() 451 | for n := 0; n < b.N; n++ { 452 | c.Set(uint32(n), &dns.Msg{}) 453 | c.Get(uint32(n)) 454 | } 455 | } 456 | --------------------------------------------------------------------------------