├── .dockerignore ├── .github ├── FUNDING.yml └── workflows │ ├── codeql.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── VERSION ├── config.yaml ├── go.mod ├── go.sum ├── run.sh └── src ├── blacklist.go ├── blacklist_test.go ├── config.go ├── local.go ├── local_test.go ├── logging.go ├── main.go ├── main_test.go ├── server.go ├── server_test.go ├── upstream.go ├── upstream_test.go ├── upstreamcache.go └── version.go /.dockerignore: -------------------------------------------------------------------------------- 1 | build/ -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: virtualzone 2 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '25 10 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Use only 'java' to analyze code written in Java, Kotlin or both 38 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 39 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 40 | 41 | steps: 42 | - name: Checkout repository 43 | uses: actions/checkout@v3 44 | 45 | # Initializes the CodeQL tools for scanning. 46 | - name: Initialize CodeQL 47 | uses: github/codeql-action/init@v2 48 | with: 49 | languages: ${{ matrix.language }} 50 | # If you wish to specify custom queries, you can do so here or in a config file. 51 | # By default, queries listed here will override any specified in a config file. 52 | # Prefix the list here with "+" to use these queries and those in the config file. 53 | 54 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 55 | # queries: security-extended,security-and-quality 56 | 57 | 58 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 59 | # If this step fails, then you should remove it and run the build manually (see below) 60 | - name: Autobuild 61 | uses: github/codeql-action/autobuild@v2 62 | 63 | # ℹ️ Command-line programs to run using the OS shell. 64 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 65 | 66 | # If the Autobuild fails above, remove it and uncomment the following three lines. 67 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 68 | 69 | # - run: | 70 | # echo "Run, Build Application using script" 71 | # ./location_of_script_within_repo/buildscript.sh 72 | 73 | - name: Perform CodeQL Analysis 74 | uses: github/codeql-action/analyze@v2 75 | with: 76 | category: "/language:${{matrix.language}}" 77 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | workflow_dispatch: 5 | branches: [ main ] 6 | 7 | env: 8 | REGISTRY: ghcr.io 9 | IMAGE_NAME: ${{ github.repository }} 10 | 11 | jobs: 12 | 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-go@v5 20 | with: 21 | go-version: '^1.22' 22 | - name: Set version env 23 | run: echo "CI_VERSION=$(cat VERSION | awk NF)" >> $GITHUB_ENV 24 | - name: Run build 25 | run: make 26 | - name: Set up QEMU 27 | uses: docker/setup-qemu-action@v3 28 | - name: Set up Docker Buildx 29 | uses: docker/setup-buildx-action@v3 30 | - name: Cache Docker layers 31 | uses: actions/cache@v3 32 | with: 33 | path: /tmp/.buildx-cache 34 | key: ${{ runner.os }}-buildx-${{ github.sha }} 35 | restore-keys: | 36 | ${{ runner.os }}-buildx- 37 | - name: Log into registry 38 | if: github.event_name != 'pull_request' 39 | uses: docker/login-action@v3 40 | with: 41 | registry: ${{ env.REGISTRY }} 42 | username: ${{ github.actor }} 43 | password: ${{ secrets.GITHUB_TOKEN }} 44 | - name: Build and push 45 | uses: docker/build-push-action@v5 46 | with: 47 | context: . 48 | platforms: linux/amd64,linux/arm64,linux/arm/v7 49 | push: true 50 | tags: | 51 | ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.CI_VERSION }} 52 | ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest 53 | cache-from: type=local,src=/tmp/.buildx-cache 54 | cache-to: type=local,dest=/tmp/.buildx-cache-new 55 | - name: Create a GitHub release 56 | uses: ncipollo/release-action@v1 57 | with: 58 | token: ${{ secrets.GITHUB_TOKEN }} 59 | tag: ${{ env.CI_VERSION }} 60 | name: ${{ env.CI_VERSION }} 61 | artifacts: "build/go-hole_*" -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | workflow_dispatch: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | container-job: 12 | runs-on: ubuntu-latest 13 | container: golang:1.22-alpine 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Run tests 17 | working-directory: src 18 | run: go test -cover -v 19 | env: 20 | CGO_ENABLED: 0 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | .vscode -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22-bookworm AS builder 2 | RUN export GOBIN=$HOME/work/bin 3 | WORKDIR /go/src/app 4 | ADD . . 5 | RUN echo "package main\n\nconst AppVersion = \"`cat ./VERSION | awk NF`\"" > src/version.go 6 | RUN CGO_ENABLED=0 go build -ldflags="-w -s" -o go-hole src/*.go 7 | RUN apt-get update && apt-get install --yes libcap2-bin 8 | 9 | FROM gcr.io/distroless/base-debian12 10 | COPY --from=builder /go/src/app/go-hole /app/ 11 | COPY --from=builder /sbin/getcap /sbin/ 12 | COPY --from=builder /sbin/setcap /sbin/ 13 | COPY --from=builder /lib/*-linux-*/libcap.so.2 /lib/ 14 | RUN ["/sbin/setcap", "cap_net_bind_service=+ep", "/app/go-hole"] 15 | ADD config.yaml /app/ 16 | WORKDIR /app 17 | EXPOSE 53/udp 18 | USER 65532:65532 19 | CMD ["./go-hole"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Heiner Beck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION=`cat ./VERSION | awk NF` 2 | 3 | all: clean update_version_file linux macos windows 4 | 5 | clean: 6 | rm -f build/* 7 | 8 | update_version_file: 9 | echo "// File autogenerated from VERSION, do not change manually!\npackage main\n\nconst AppVersion = \"${VERSION}\"" > src/version.go 10 | 11 | linux: linux_amd64 linux_arm64 linux_arm 12 | 13 | macos: macos_amd64 macos_arm64 14 | 15 | windows: windows_amd64 windows_arm64 16 | 17 | linux_amd64: 18 | env GOOS=linux GOARCH=amd64 go build -ldflags="-w -s" -o build/go-hole_linux_amd64_${VERSION} src/*.go 19 | 20 | linux_arm64: 21 | env GOOS=linux GOARCH=arm64 go build -ldflags="-w -s" -o build/go-hole_linux_arm64_${VERSION} src/*.go 22 | 23 | linux_arm: 24 | env GOOS=linux GOARCH=arm go build -ldflags="-w -s" -o build/go-hole_linux_arm_${VERSION} src/*.go 25 | 26 | macos_amd64: 27 | env GOOS=darwin GOARCH=amd64 go build -ldflags="-w -s" -o build/go-hole_macos_amd64_${VERSION} src/*.go 28 | 29 | macos_arm64: 30 | env GOOS=darwin GOARCH=arm64 go build -ldflags="-w -s" -o build/go-hole_macos_arm64_${VERSION} src/*.go 31 | 32 | windows_amd64: 33 | env GOOS=windows GOARCH=amd64 go build -ldflags="-w -s" -o build/go-hole_windows_amd64_${VERSION}.exe src/*.go 34 | 35 | windows_arm64: 36 | env GOOS=windows GOARCH=arm64 go build -ldflags="-w -s" -o build/go-hole_windows_arm64_${VERSION}.exe src/*.go -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go-hole 2 | [![](https://img.shields.io/github/v/release/virtualzone/go-hole)](https://github.com/virtualzone/go-hole/releases) 3 | [![](https://img.shields.io/github/release-date/virtualzone/go-hole)](https://github.com/virtualzone/go-hole/releases) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/virtualzone/go-hole)](https://goreportcard.com/report/github.com/virtualzone/go-hole) 5 | [![](https://img.shields.io/github/license/virtualzone/go-hole)](https://github.com/virtualzone/go-hole/blob/master/LICENSE) 6 | 7 | Minimalistic DNS server which serves as an upstream proxy and ad blocker. Written in Go, inspired by [Pi-hole®](https://github.com/pi-hole/pi-hole). 8 | 9 | ## Features 10 | * Minimalistic DNS server, written in Golang, optimized for high performance 11 | * Blacklist DNS names via user-specific source lists 12 | * Whitelist DNS names that are actually blacklisted 13 | * Multiple user-settable upstream DNS servers 14 | * Caching of upstream query results 15 | * Local name resolution 16 | * Pre-built, minimalistic Docker image 17 | 18 | ## How it works 19 | Go-hole serves as DNS server on your (home) network. Instead of having your clients sending DNS queries directly to the internet or to your router, they are resolved by your local Go-hole instance. Go-hole sends these queries to one or more upstream DNS servers and caches the upstream query results for maximum performance. 20 | 21 | Incoming queries from your clients are checked against a list of unwanted domain names ("blacklist"), such as well-known ad serving domains and trackers. If a requested name matches a name on the blacklist, Go-hole responds with error code NXDOMAIN (non-existing domain). This leads to clients not being able to load ads and tracker codes. In case you want to access a blacklisted domain, you can easily add it to a whitelist. 22 | 23 | As an additional feature, you can set a list of custom hostnames/domain names to be resolved to specific IP addresses. This is useful for accessing services on your local network by name instead of their IP addresses. 24 | 25 | ## Usage 26 | 1. Create a ```config.yaml``` file. Use the [config.yaml](https://github.com/virtualzone/go-hole/blob/main/config.yaml) in this repository as a template and customize is according to your needs. 27 | 1. Run Go-hole using Docker and mount your previously created ```config.yaml```: 28 | ```bash 29 | docker run \ 30 | --rm \ 31 | --mount type=bind,source=${PWD}/config.yaml,target=/app/config.yaml \ 32 | -p 53:53/udp \ 33 | ghcr.io/virtualzone/go-hole:latest 34 | ``` 35 | 1. Set Go-hole as your network's DNS server (i.e. in your DHCP server's configuration). -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | v0.5.0 2 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # DNS Server Listen Address 2 | listen: 0.0.0.0:53 3 | 4 | # One or more DNS Upstream Servers (default: Google DNS) 5 | upstream: 6 | - 8.8.8.8:53 7 | - 8.8.4.4:53 8 | 9 | # One or more Blacklist sources 10 | blacklist: 11 | - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts 12 | 13 | # Blacklist renewal interval in minutes - set to 0 to disable 14 | blacklistRenewal: 1440 15 | 16 | # You can also choose to blacklist everything (only whitelist will work then) 17 | blacklistEverything: false 18 | 19 | # Domain names to be resolved upstream, even if they are blacklisted 20 | whitelist: 21 | - googleadservices.com 22 | - iadsdk.apple.com 23 | 24 | # Optional names to be resolved to specific IP addresses 25 | local: 26 | - name: service1.local 27 | target: 28 | - address: 192.168.178.1 29 | type: A 30 | - address: 192.168.179.1 31 | type: A 32 | - address: fe80::9656:d028:8652:1111 33 | type: AAAA 34 | - name: service2.local 35 | target: 36 | - address: 192.168.178.2 37 | type: A 38 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/virtualzone/go-hole 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.23.4 6 | 7 | require ( 8 | github.com/jellydator/ttlcache/v3 v3.3.0 9 | github.com/miekg/dns v1.1.63 10 | gopkg.in/yaml.v3 v3.0.1 11 | ) 12 | 13 | require ( 14 | golang.org/x/mod v0.23.0 // indirect 15 | golang.org/x/net v0.35.0 // indirect 16 | golang.org/x/sync v0.11.0 // indirect 17 | golang.org/x/sys v0.30.0 // indirect 18 | golang.org/x/tools v0.30.0 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 4 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 5 | github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc= 6 | github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw= 7 | github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= 8 | github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 12 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 13 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 14 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 15 | golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= 16 | golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 17 | golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= 18 | golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= 19 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 20 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 21 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 22 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 23 | golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= 24 | golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= 25 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 27 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 28 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | LISTEN_ADDR=0.0.0.0:5300 go run `ls src/*.go | grep -v _test.go` -------------------------------------------------------------------------------- /src/blacklist.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "io" 8 | "log" 9 | "net/http" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "github.com/miekg/dns" 16 | ) 17 | 18 | var blacklistRecords = []string{} 19 | var whitelistRecords = []string{} 20 | 21 | func queryBlacklist(name string, qtype uint16) ([]dns.RR, error) { 22 | if isWhitelisted(name) { 23 | return nil, errors.New("record is whitelisted, not checking against blacklist database") 24 | } 25 | if !isBlacklisted(name) { 26 | return nil, errors.New("record not found in blacklist database") 27 | } 28 | return []dns.RR{}, nil 29 | } 30 | 31 | func isBlacklisted(name string) bool { 32 | if GetConfig().BlacklistEverything { 33 | return true 34 | } 35 | for _, cur := range blacklistRecords { 36 | if cur == name { 37 | return true 38 | } 39 | } 40 | return false 41 | } 42 | 43 | func isWhitelisted(name string) bool { 44 | for _, cur := range whitelistRecords { 45 | if cur == name { 46 | return true 47 | } 48 | } 49 | return false 50 | } 51 | 52 | func updateBlacklistRecords() { 53 | log.Println("Updating blacklist database...") 54 | list := make([]string, 0) 55 | for _, url := range GetConfig().BlacklistSources { 56 | processBlacklistSource(url, &list) 57 | } 58 | blacklistRecords = list 59 | log.Printf("Blacklist database updated, %d records\n", len(blacklistRecords)) 60 | } 61 | 62 | func initBlacklistRenewal() { 63 | if GetConfig().BlacklistRenewal < 1 { 64 | return 65 | } 66 | ticker := time.NewTicker(time.Minute * time.Duration(GetConfig().BlacklistRenewal)) 67 | go func() { 68 | for { 69 | <-ticker.C 70 | updateBlacklistRecords() 71 | } 72 | }() 73 | } 74 | 75 | func updateWhitelistRecords() { 76 | log.Println("Updating whitelist database...") 77 | whitelistRecords = make([]string, 0) 78 | for _, name := range GetConfig().Whitelist { 79 | whitelistRecords = append(whitelistRecords, strings.ToLower(strings.TrimSpace(name))+".") 80 | } 81 | } 82 | 83 | func processBlacklistSource(url string, list *[]string) error { 84 | data, err := getUrlData(url) 85 | if err != nil { 86 | return err 87 | } 88 | reader := bytes.NewReader(data) 89 | fileScanner := bufio.NewScanner(reader) 90 | fileScanner.Split(bufio.ScanLines) 91 | re := regexp.MustCompile(`\s+`) 92 | for fileScanner.Scan() { 93 | line := strings.TrimSpace(fileScanner.Text()) 94 | if line != "" && line[0] != '#' { 95 | split := re.Split(line, -1) 96 | if isValidBlacklistSourceRecord(split) { 97 | if len(split) == 2 { 98 | *list = append(*list, strings.ToLower(split[1])+".") 99 | } else if len(split) == 1 { 100 | *list = append(*list, strings.ToLower(split[0])+".") 101 | } 102 | } 103 | } 104 | } 105 | return nil 106 | } 107 | 108 | func isValidBlacklistSourceRecord(split []string) bool { 109 | if len(split) == 0 { 110 | return false 111 | } 112 | if len(split) > 2 { 113 | return false 114 | } 115 | if len(split) == 2 { 116 | if split[0] != "0.0.0.0" { 117 | return false 118 | } 119 | if split[0] == "0.0.0.0" && split[1] == "0.0.0.0" { 120 | return false 121 | } 122 | } 123 | return true 124 | } 125 | 126 | func getUrlData(url string) ([]byte, error) { 127 | resp, err := http.Get(url) 128 | if err != nil { 129 | return nil, err 130 | } 131 | defer resp.Body.Close() 132 | if resp.StatusCode != http.StatusOK { 133 | return nil, errors.New("received invalid http response code " + strconv.Itoa(resp.StatusCode) + "for url " + url) 134 | } 135 | data, err := io.ReadAll(resp.Body) 136 | if err != nil { 137 | return nil, err 138 | } 139 | return data, nil 140 | } 141 | -------------------------------------------------------------------------------- /src/blacklist_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | func TestBlacklistSuccess(t *testing.T) { 10 | res, err := queryBlacklist("googleads.g.doubleclick.net.", dns.TypeA) 11 | checkTestBool(t, true, err == nil) 12 | checkTestBool(t, false, res == nil) 13 | checkTestInt(t, 0, len(res)) 14 | } 15 | 16 | func TestBlacklistNonExistent(t *testing.T) { 17 | res, err := queryBlacklist("www.apple.com.", dns.TypeA) 18 | checkTestBool(t, false, err == nil) 19 | checkTestBool(t, true, res == nil) 20 | } 21 | -------------------------------------------------------------------------------- /src/config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "path/filepath" 7 | "strconv" 8 | 9 | "gopkg.in/yaml.v3" 10 | ) 11 | 12 | type ConfigLocalAddressTarget struct { 13 | Address string `yaml:"address"` 14 | Type string `yaml:"type"` 15 | } 16 | 17 | type ConfigLocalAddress struct { 18 | Name string `yaml:"name"` 19 | Target []ConfigLocalAddressTarget `yaml:"target"` 20 | } 21 | 22 | type Config struct { 23 | ListenAddr string `yaml:"listen"` 24 | UpstreamDNS []string `yaml:"upstream"` 25 | BlacklistSources []string `yaml:"blacklist"` 26 | BlacklistRenewal int `yaml:"blacklistRenewal"` 27 | BlacklistEverything bool `yaml:"blacklistEverything"` 28 | Whitelist []string `yaml:"whitelist"` 29 | LocalAddresses []ConfigLocalAddress `yaml:"local"` 30 | } 31 | 32 | var ConfigInstance *Config = &Config{} 33 | 34 | func GetConfig() *Config { 35 | return ConfigInstance 36 | } 37 | 38 | func (c *Config) ReadConfig() { 39 | configPath, err := os.Getwd() 40 | if (err != nil) || (configPath == "") { 41 | log.Fatalln("could neither get system config dir nor current working dir") 42 | } 43 | configPath = filepath.Join(configPath, "config.yaml") 44 | data, err := os.ReadFile(configPath) 45 | if err != nil { 46 | log.Fatalf("could not read config yaml from %s\n", configPath) 47 | } 48 | c.ReadConfigData(data) 49 | c.ReadEnv() 50 | } 51 | 52 | func (c *Config) ReadConfigData(data []byte) { 53 | if err := yaml.Unmarshal(data, &c); err != nil { 54 | log.Fatalf("could not parse config yaml: %s\n", err.Error()) 55 | } 56 | } 57 | 58 | func (c *Config) ReadEnv() { 59 | listenAddr := c.getEnv("LISTEN_ADDR", "") 60 | if listenAddr != "" { 61 | c.ListenAddr = listenAddr 62 | } 63 | for i := 1; i <= 10; i++ { 64 | server := c.getEnv("UPSTREAM_DNS_"+strconv.Itoa(i), "") 65 | if server != "" { 66 | c.UpstreamDNS = append(c.UpstreamDNS, server) 67 | } 68 | } 69 | } 70 | 71 | func (c *Config) Print() { 72 | s, _ := yaml.Marshal(c) 73 | log.Println("Using config:\n" + string(s)) 74 | } 75 | 76 | func (c *Config) getEnv(key, defaultValue string) string { 77 | res := os.Getenv(key) 78 | if res == "" { 79 | return defaultValue 80 | } 81 | return res 82 | } 83 | -------------------------------------------------------------------------------- /src/local.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "strings" 8 | 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | type LocalRecordTarget struct { 13 | Target string 14 | Qtype uint16 15 | } 16 | 17 | var localRecords = map[string][]LocalRecordTarget{} 18 | 19 | func queryLocal(name string, qtype uint16) ([]dns.RR, error) { 20 | target, ok := localRecords[name] 21 | if !ok { 22 | return nil, errors.New("record not found in local database") 23 | } 24 | res := make([]dns.RR, 0) 25 | for _, record := range target { 26 | if record.Qtype == qtype { 27 | rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, getQueryTypeText(record.Qtype), record.Target)) 28 | if err != nil { 29 | log.Println(err) 30 | return []dns.RR{}, err 31 | } 32 | res = append(res, rr) 33 | } 34 | } 35 | if len(res) == 0 { 36 | return nil, errors.New("no record for requested query type found in local database") 37 | } 38 | return res, nil 39 | } 40 | 41 | func updateLocalRecords() { 42 | log.Println("Updating local address database...") 43 | localRecords = make(map[string][]LocalRecordTarget, 0) 44 | for _, item := range GetConfig().LocalAddresses { 45 | records := make([]LocalRecordTarget, 0) 46 | for _, target := range item.Target { 47 | qtype, err := getQueryTypeUint(target.Type) 48 | if err != nil { 49 | log.Printf("Ignoring unknown record type %s for %s\n", target.Type, item.Name) 50 | continue 51 | } 52 | record := LocalRecordTarget{ 53 | Target: target.Address, 54 | Qtype: qtype, 55 | } 56 | records = append(records, record) 57 | } 58 | localRecords[strings.ToLower(item.Name)+"."] = records 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/local_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | func TestLocalASuccess(t *testing.T) { 10 | res, err := queryLocal("service1.local.", dns.TypeA) 11 | checkTestBool(t, true, err == nil) 12 | checkTestBool(t, false, res == nil) 13 | checkTestInt(t, 2, len(res)) 14 | aRecord1 := res[0].(*dns.A) 15 | aRecord2 := res[1].(*dns.A) 16 | checkTestString(t, "192.168.178.1", aRecord1.A.String()) 17 | checkTestString(t, "192.168.179.1", aRecord2.A.String()) 18 | } 19 | 20 | func TestLocalAAAASuccess(t *testing.T) { 21 | res, err := queryLocal("service1.local.", dns.TypeAAAA) 22 | checkTestBool(t, true, err == nil) 23 | checkTestBool(t, false, res == nil) 24 | checkTestInt(t, 1, len(res)) 25 | aRecord1 := res[0].(*dns.AAAA) 26 | checkTestString(t, "fe80::9656:d028:8652:1111", aRecord1.AAAA.String()) 27 | } 28 | 29 | func TestLocalNonExistent(t *testing.T) { 30 | res, err := queryLocal("nonexistentrecord.local.", dns.TypeA) 31 | checkTestBool(t, false, err == nil) 32 | checkTestBool(t, true, res == nil) 33 | } 34 | -------------------------------------------------------------------------------- /src/logging.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "net" 7 | "strings" 8 | 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | var queryTypeNames = map[uint16]string{ 13 | dns.TypeNone: "None", 14 | dns.TypeA: "A", 15 | dns.TypeNS: "NS", 16 | dns.TypeMD: "MD", 17 | dns.TypeMF: "MF", 18 | dns.TypeCNAME: "CNAME", 19 | dns.TypeSOA: "SOA", 20 | dns.TypeMB: "MB", 21 | dns.TypeMG: "MG", 22 | dns.TypeMR: "MR", 23 | dns.TypeNULL: "NULL", 24 | dns.TypePTR: "PTR", 25 | dns.TypeHINFO: "HINFO", 26 | dns.TypeMINFO: "MINFO", 27 | dns.TypeMX: "MX", 28 | dns.TypeTXT: "TXT", 29 | dns.TypeRP: "RP", 30 | dns.TypeAFSDB: "AFSDB", 31 | dns.TypeX25: "X25", 32 | dns.TypeISDN: "ISDN", 33 | dns.TypeRT: "RT", 34 | dns.TypeNSAPPTR: "NSAPPTR", 35 | dns.TypeSIG: "SIG", 36 | dns.TypeKEY: "KEY", 37 | dns.TypePX: "PX", 38 | dns.TypeGPOS: "GPOS", 39 | dns.TypeAAAA: "AAAA", 40 | dns.TypeLOC: "LOC", 41 | dns.TypeNXT: "NXT", 42 | dns.TypeEID: "EID", 43 | dns.TypeNIMLOC: "NIMLOC", 44 | dns.TypeSRV: "SRV", 45 | dns.TypeATMA: "ATMA", 46 | dns.TypeNAPTR: "NAPTR", 47 | dns.TypeKX: "KX", 48 | dns.TypeCERT: "CERT", 49 | dns.TypeDNAME: "DNAME", 50 | dns.TypeOPT: "OPT", 51 | dns.TypeAPL: "APL", 52 | dns.TypeDS: "DS", 53 | dns.TypeSSHFP: "SSHFP", 54 | dns.TypeRRSIG: "RRSIG", 55 | dns.TypeNSEC: "NSEC", 56 | dns.TypeDNSKEY: "DNSKEY", 57 | dns.TypeDHCID: "DHCID", 58 | dns.TypeNSEC3: "NSEC3", 59 | dns.TypeNSEC3PARAM: "NSEC3PARAM", 60 | dns.TypeTLSA: "TLSA", 61 | dns.TypeSMIMEA: "SMIMEA", 62 | dns.TypeHIP: "HIP", 63 | dns.TypeNINFO: "NINFO", 64 | dns.TypeRKEY: "RKEY", 65 | dns.TypeTALINK: "TALINK", 66 | dns.TypeCDS: "CDS", 67 | dns.TypeCDNSKEY: "CDNSKEY", 68 | dns.TypeOPENPGPKEY: "OPENPGPKEY", 69 | dns.TypeCSYNC: "CSYNC", 70 | dns.TypeZONEMD: "ZONEMD", 71 | dns.TypeSVCB: "SVCB", 72 | dns.TypeHTTPS: "HTTPS", 73 | dns.TypeSPF: "SPF", 74 | dns.TypeUINFO: "UINFO", 75 | dns.TypeUID: "UID", 76 | dns.TypeGID: "GID", 77 | dns.TypeUNSPEC: "UNSPEC", 78 | dns.TypeNID: "NID", 79 | dns.TypeL32: "L32", 80 | dns.TypeL64: "L64", 81 | dns.TypeLP: "LP", 82 | dns.TypeEUI48: "EUI48", 83 | dns.TypeEUI64: "EUI64", 84 | dns.TypeURI: "URI", 85 | dns.TypeCAA: "CAA", 86 | dns.TypeAVC: "AVC", 87 | dns.TypeTKEY: "TKEY", 88 | dns.TypeTSIG: "TSIG", 89 | dns.TypeIXFR: "IXFR", 90 | dns.TypeAXFR: "AXFR", 91 | dns.TypeMAILB: "MAILB", 92 | dns.TypeMAILA: "MAILA", 93 | dns.TypeANY: "ANY", 94 | dns.TypeTA: "TA", 95 | dns.TypeDLV: "DLV", 96 | dns.TypeReserved: "Reserved", 97 | } 98 | 99 | var queryNameTypes = map[string]uint16{} 100 | 101 | func initLogging() { 102 | queryNameTypes = make(map[string]uint16, 0) 103 | for k, v := range queryTypeNames { 104 | queryNameTypes[v] = k 105 | } 106 | } 107 | 108 | func getQueryTypeText(qtype uint16) string { 109 | res := queryTypeNames[qtype] 110 | if res == "" { 111 | res = "Unknown" 112 | } 113 | return res 114 | } 115 | 116 | func getQueryTypeUint(qtype string) (uint16, error) { 117 | res, ok := queryNameTypes[strings.ToUpper(qtype)] 118 | if !ok { 119 | return 0, errors.New("query type not found") 120 | } 121 | return res, nil 122 | } 123 | 124 | func logQueryResult(source net.Addr, name string, qtype uint16, result string) { 125 | log.Printf("Query from %s for %s type %s %s\n", source.String(), name, getQueryTypeText(qtype), result) 126 | } 127 | -------------------------------------------------------------------------------- /src/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "log" 4 | 5 | func main() { 6 | log.Printf("Starting Go-hole %s...\n", AppVersion) 7 | GetConfig().ReadConfig() 8 | GetConfig().Print() 9 | initServer() 10 | initBlacklistRenewal() 11 | listenAndServe() 12 | } 13 | 14 | func initServer() { 15 | initLogging() 16 | GetUpstreamCache().Init() 17 | updateLocalRecords() 18 | updateBlacklistRecords() 19 | updateWhitelistRecords() 20 | } 21 | -------------------------------------------------------------------------------- /src/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "runtime/debug" 6 | "testing" 7 | ) 8 | 9 | var config = ` 10 | listen: 0.0.0.0:5300 11 | upstream: 12 | - 8.8.8.8:53 13 | - 8.8.4.4:53 14 | blacklist: 15 | - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts 16 | whitelist: 17 | - googleadservices.com 18 | - iadsdk.apple.com 19 | local: 20 | - name: service1.local 21 | target: 22 | - address: 192.168.178.1 23 | type: A 24 | - address: 192.168.179.1 25 | type: A 26 | - address: fe80::9656:d028:8652:1111 27 | type: AAAA 28 | - name: service2.local 29 | target: 30 | - address: 192.168.178.2 31 | type: A 32 | ` 33 | 34 | func TestMain(m *testing.M) { 35 | GetConfig().ReadConfigData([]byte(config)) 36 | initServer() 37 | code := m.Run() 38 | os.Exit(code) 39 | } 40 | 41 | func checkTestBool(t *testing.T, expected, actual bool) { 42 | if expected != actual { 43 | t.Fatalf("Expected '%t', but got '%t' at:\n%s", expected, actual, debug.Stack()) 44 | } 45 | } 46 | 47 | func checkTestInt(t *testing.T, expected, actual int) { 48 | if expected != actual { 49 | t.Fatalf("Expected '%d', but got '%d' at:\n%s", expected, actual, debug.Stack()) 50 | } 51 | } 52 | 53 | func checkTestString(t *testing.T, expected, actual string) { 54 | if expected != actual { 55 | t.Fatalf("Expected '%s', but got '%s' at:\n%s", expected, actual, debug.Stack()) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net" 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | func parseQuery(source net.Addr, m *dns.Msg) { 12 | for _, q := range m.Question { 13 | name := strings.ToLower(q.Name) 14 | res, errCode := processDnsQuery(name, q.Qtype, source) 15 | m.Answer = append(m.Answer, res...) 16 | m.Rcode = errCode 17 | } 18 | } 19 | 20 | func processDnsQuery(name string, qtype uint16, source net.Addr) ([]dns.RR, int) { 21 | arr, err := queryLocal(name, qtype) 22 | if err == nil { 23 | logQueryResult(source, name, qtype, "resolved as local address") 24 | return arr, dns.RcodeSuccess 25 | } 26 | arr, err = queryBlacklist(name, qtype) 27 | if err == nil { 28 | logQueryResult(source, name, qtype, "resolved as blacklisted name") 29 | return arr, dns.RcodeNameError 30 | } 31 | arr, err = queryUpstream(name, qtype) 32 | if err == nil { 33 | logQueryResult(source, name, qtype, "resolved via upstream") 34 | return arr, dns.RcodeSuccess 35 | } 36 | logQueryResult(source, name, qtype, "did not resolve") 37 | return []dns.RR{}, dns.RcodeNameError 38 | } 39 | 40 | func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { 41 | m := new(dns.Msg) 42 | m.SetReply(r) 43 | m.Compress = false 44 | 45 | switch r.Opcode { 46 | case dns.OpcodeQuery: 47 | parseQuery(w.RemoteAddr(), m) 48 | } 49 | 50 | w.WriteMsg(m) 51 | w.Close() 52 | } 53 | 54 | func listenAndServe() { 55 | dns.HandleFunc(".", handleDnsRequest) 56 | 57 | server := &dns.Server{ 58 | Addr: GetConfig().ListenAddr, 59 | Net: "udp", 60 | } 61 | log.Printf("Starting at %s\n", GetConfig().ListenAddr) 62 | err := server.ListenAndServe() 63 | defer server.Shutdown() 64 | if err != nil { 65 | log.Fatalf("Failed to start server: %s\n ", err.Error()) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/server_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | func TestProcessDnsQueryLocalA(t *testing.T) { 11 | res, errCode := processDnsQuery("service1.local.", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 12 | checkTestInt(t, dns.RcodeSuccess, errCode) 13 | checkTestInt(t, 2, len(res)) 14 | aRecord1 := res[0].(*dns.A) 15 | aRecord2 := res[1].(*dns.A) 16 | checkTestString(t, "192.168.178.1", aRecord1.A.String()) 17 | checkTestString(t, "192.168.179.1", aRecord2.A.String()) 18 | } 19 | 20 | func TestProcessDnsQueryLocalAAAA(t *testing.T) { 21 | res, errCode := processDnsQuery("service1.local.", dns.TypeAAAA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 22 | checkTestInt(t, dns.RcodeSuccess, errCode) 23 | checkTestInt(t, 1, len(res)) 24 | aRecord1 := res[0].(*dns.AAAA) 25 | checkTestString(t, "fe80::9656:d028:8652:1111", aRecord1.AAAA.String()) 26 | } 27 | 28 | func TestProcessDnsQueryBlacklist(t *testing.T) { 29 | res, errCode := processDnsQuery("googleads.g.doubleclick.net.", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 30 | checkTestInt(t, dns.RcodeNameError, errCode) 31 | checkTestInt(t, 0, len(res)) 32 | } 33 | 34 | func TestProcessDnsQueryBlacklistWhitelisted(t *testing.T) { 35 | res, errCode := processDnsQuery("iadsdk.apple.com.", dns.TypeCNAME, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 36 | checkTestInt(t, dns.RcodeSuccess, errCode) 37 | checkTestInt(t, 1, len(res)) 38 | cnameRecord1 := res[0].(*dns.CNAME) 39 | checkTestString(t, "iadsdk.apple.com.akadns.net.", cnameRecord1.Target) 40 | } 41 | 42 | func TestProcessDnsQueryUpstreamSuccess(t *testing.T) { 43 | res, errCode := processDnsQuery("dns.google.", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 44 | checkTestInt(t, dns.RcodeSuccess, errCode) 45 | checkTestInt(t, 2, len(res)) 46 | aRecord1 := res[0].(*dns.A) 47 | aRecord2 := res[1].(*dns.A) 48 | checkTestBool(t, true, aRecord1.A.String() == "8.8.8.8" || aRecord1.A.String() == "8.8.4.4") 49 | checkTestBool(t, true, aRecord2.A.String() == "8.8.8.8" || aRecord2.A.String() == "8.8.4.4") 50 | checkTestBool(t, true, aRecord1.A.String() != aRecord2.A.String()) 51 | } 52 | 53 | func TestProcessDnsQueryUpstreamNonExistent(t *testing.T) { 54 | res, errCode := processDnsQuery("nonexistentrecord.virtualzone.de.", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 55 | checkTestInt(t, dns.RcodeNameError, errCode) 56 | checkTestInt(t, 0, len(res)) 57 | } 58 | 59 | func TestProcessDnsQueryEmptyName(t *testing.T) { 60 | res, errCode := processDnsQuery(".", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 61 | checkTestInt(t, dns.RcodeNameError, errCode) 62 | checkTestInt(t, 0, len(res)) 63 | } 64 | 65 | func TestProcessDnsQueryWildcard(t *testing.T) { 66 | res, errCode := processDnsQuery("*.", dns.TypeA, &net.IPAddr{IP: []byte{127, 0, 0, 1}}) 67 | checkTestInt(t, dns.RcodeNameError, errCode) 68 | checkTestInt(t, 0, len(res)) 69 | } 70 | -------------------------------------------------------------------------------- /src/upstream.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | func queryUpstream(name string, qtype uint16) ([]dns.RR, error) { 11 | // Check cache first 12 | res, err := GetUpstreamCache().Get(name, qtype) 13 | if err == nil { 14 | // Record found in cache 15 | log.Printf("query for %s %s resolved via cache\n", getQueryTypeText(qtype), name) 16 | if len(res) == 0 { 17 | return nil, errors.New("record not found via upstream DNS server") 18 | } 19 | return res, nil 20 | } 21 | 22 | // If not cached, perform actual upstream query 23 | m1 := new(dns.Msg) 24 | m1.Id = dns.Id() 25 | m1.RecursionDesired = true 26 | m1.Question = make([]dns.Question, 1) 27 | m1.Question[0] = dns.Question{ 28 | Name: name, 29 | Qtype: qtype, 30 | Qclass: dns.ClassINET, 31 | } 32 | 33 | for _, server := range GetConfig().UpstreamDNS { 34 | in, err := doUpstreamQuery(m1, server) 35 | if err == nil { 36 | GetUpstreamCache().Set(name, qtype, in.Answer) 37 | if len(in.Answer) == 0 { 38 | return nil, errors.New("record not found via upstream DNS server") 39 | } 40 | return in.Answer, nil 41 | } 42 | } 43 | 44 | return nil, errors.New("could not resolve query via any upstream DNS server") 45 | } 46 | 47 | func doUpstreamQuery(m *dns.Msg, address string) (*dns.Msg, error) { 48 | c := new(dns.Client) 49 | in, _, err := c.Exchange(m, address) 50 | return in, err 51 | } 52 | -------------------------------------------------------------------------------- /src/upstream_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | func TestUpstreamSuccess(t *testing.T) { 10 | res, err := queryUpstream("dns.google.", dns.TypeA) 11 | checkTestBool(t, true, err == nil) 12 | checkTestBool(t, false, res == nil) 13 | checkTestInt(t, 2, len(res)) 14 | aRecord1 := res[0].(*dns.A) 15 | aRecord2 := res[1].(*dns.A) 16 | checkTestBool(t, true, aRecord1.A.String() == "8.8.8.8" || aRecord1.A.String() == "8.8.4.4") 17 | checkTestBool(t, true, aRecord2.A.String() == "8.8.8.8" || aRecord2.A.String() == "8.8.4.4") 18 | checkTestBool(t, true, aRecord1.A.String() != aRecord2.A.String()) 19 | } 20 | 21 | func TestUpstreamNonExistent(t *testing.T) { 22 | res, err := queryUpstream("nonexistentrecord.virtualzone.de.", dns.TypeA) 23 | checkTestBool(t, false, err == nil) 24 | checkTestBool(t, true, res == nil) 25 | } 26 | 27 | func TestUpstreamCname(t *testing.T) { 28 | res, err := queryUpstream("iadsdk.apple.com.", dns.TypeCNAME) 29 | checkTestBool(t, true, err == nil) 30 | checkTestBool(t, false, res == nil) 31 | checkTestInt(t, 1, len(res)) 32 | cnameRecord1 := res[0].(*dns.CNAME) 33 | checkTestString(t, "iadsdk.apple.com.akadns.net.", cnameRecord1.Target) 34 | } 35 | 36 | func TestUpstreamCacheExistent(t *testing.T) { 37 | // Clear cache 38 | GetUpstreamCache().Clear() 39 | 40 | // Verify record is not cached 41 | res, err := GetUpstreamCache().Get("dns.google.", dns.TypeA) 42 | checkTestBool(t, false, err == nil) 43 | checkTestBool(t, true, res == nil) 44 | 45 | // Perform upstream query 46 | res, err = queryUpstream("dns.google.", dns.TypeA) 47 | checkTestBool(t, true, err == nil) 48 | checkTestBool(t, false, res == nil) 49 | checkTestInt(t, 2, len(res)) 50 | aRecord1 := res[0].(*dns.A) 51 | aRecord2 := res[1].(*dns.A) 52 | checkTestBool(t, true, aRecord1.A.String() == "8.8.8.8" || aRecord1.A.String() == "8.8.4.4") 53 | checkTestBool(t, true, aRecord2.A.String() == "8.8.8.8" || aRecord2.A.String() == "8.8.4.4") 54 | checkTestBool(t, true, aRecord1.A.String() != aRecord2.A.String()) 55 | 56 | // Verify record is cached now 57 | res, err = GetUpstreamCache().Get("dns.google.", dns.TypeA) 58 | checkTestBool(t, true, err == nil) 59 | checkTestBool(t, false, res == nil) 60 | checkTestInt(t, 2, len(res)) 61 | 62 | // Perform upstream query again (should resolve via cache) 63 | res, err = queryUpstream("dns.google.", dns.TypeA) 64 | checkTestBool(t, true, err == nil) 65 | checkTestBool(t, false, res == nil) 66 | checkTestInt(t, 2, len(res)) 67 | aRecord1 = res[0].(*dns.A) 68 | aRecord2 = res[1].(*dns.A) 69 | checkTestBool(t, true, aRecord1.A.String() == "8.8.8.8" || aRecord1.A.String() == "8.8.4.4") 70 | checkTestBool(t, true, aRecord2.A.String() == "8.8.8.8" || aRecord2.A.String() == "8.8.4.4") 71 | checkTestBool(t, true, aRecord1.A.String() != aRecord2.A.String()) 72 | } 73 | 74 | func TestUpstreamCacheNonExistent(t *testing.T) { 75 | // Clear cache 76 | GetUpstreamCache().Clear() 77 | 78 | // Verify record is not cached 79 | res, err := GetUpstreamCache().Get("nonexistentrecord.virtualzone.de.", dns.TypeA) 80 | checkTestBool(t, false, err == nil) 81 | checkTestBool(t, true, res == nil) 82 | 83 | // Perform upstream query 84 | res, err = queryUpstream("nonexistentrecord.virtualzone.de.", dns.TypeA) 85 | checkTestBool(t, false, err == nil) 86 | checkTestBool(t, true, res == nil) 87 | 88 | // Verify record is cached now 89 | res, err = GetUpstreamCache().Get("nonexistentrecord.virtualzone.de.", dns.TypeA) 90 | checkTestBool(t, true, err == nil) 91 | checkTestBool(t, true, res == nil) 92 | } 93 | -------------------------------------------------------------------------------- /src/upstreamcache.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/jellydator/ttlcache/v3" 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | type UpstreamCache struct { 13 | Cache *ttlcache.Cache[string, []dns.RR] 14 | } 15 | 16 | var UpstreamCacheInstance *UpstreamCache = &UpstreamCache{} 17 | 18 | func GetUpstreamCache() *UpstreamCache { 19 | return UpstreamCacheInstance 20 | } 21 | 22 | func (c *UpstreamCache) Init() { 23 | c.Cache = ttlcache.New( 24 | ttlcache.WithDisableTouchOnHit[string, []dns.RR](), 25 | ) 26 | go c.Cache.Start() 27 | } 28 | 29 | func (c *UpstreamCache) Set(name string, qtype uint16, rr []dns.RR) { 30 | ttl := time.Duration(c.getMinTtl(rr)) * time.Second 31 | c.Cache.Set(c.getKey(name, qtype), rr, ttl) 32 | } 33 | 34 | func (c *UpstreamCache) Get(name string, qtype uint16) ([]dns.RR, error) { 35 | res := c.Cache.Get(c.getKey(name, qtype)) 36 | if res == nil || res.IsExpired() { 37 | return nil, errors.New("record not found in cache") 38 | } 39 | return res.Value(), nil 40 | } 41 | 42 | func (c *UpstreamCache) Clear() { 43 | c.Cache.DeleteAll() 44 | } 45 | 46 | func (c *UpstreamCache) getKey(name string, qtype uint16) string { 47 | return name + "_" + strconv.Itoa(int(qtype)) 48 | } 49 | 50 | func (c *UpstreamCache) getMinTtl(rr []dns.RR) uint32 { 51 | var res uint32 = 1800 // Default: 30 minutes 52 | for _, record := range rr { 53 | if record.Header().Ttl < res { 54 | res = record.Header().Ttl 55 | } 56 | } 57 | return res 58 | } 59 | -------------------------------------------------------------------------------- /src/version.go: -------------------------------------------------------------------------------- 1 | // File autogenerated from VERSION, do not change manually! 2 | package main 3 | 4 | const AppVersion = "v0.5.0" 5 | --------------------------------------------------------------------------------