├── .codecov.yml ├── .dockerignore ├── .gitattributes ├── .github └── workflows │ ├── build.yaml │ ├── docker.yml │ └── lint.yaml ├── .gitignore ├── .markdownlint.json ├── LICENSE ├── Makefile ├── README.md ├── bamboo-specs └── bamboo.yaml ├── config.yaml.dist ├── docker ├── Dockerfile └── README.md ├── fastip ├── cache.go ├── cache_internal_test.go ├── fastest.go ├── fastest_internal_test.go ├── ping.go └── ping_internal_test.go ├── go.mod ├── go.sum ├── internal ├── bootstrap │ ├── bootstrap.go │ ├── bootstrap_test.go │ ├── error.go │ ├── resolver.go │ └── resolver_test.go ├── cmd │ ├── args.go │ ├── cmd.go │ ├── config.go │ ├── flag.go │ ├── proxy.go │ └── tls.go ├── dnsmsg │ └── constructor.go ├── dnsproxytest │ ├── dnsproxytest.go │ ├── interface.go │ └── interface_test.go ├── handler │ ├── constructor.go │ ├── default.go │ ├── default_internal_test.go │ ├── handler.go │ ├── hosts.go │ ├── ipv6halt.go │ └── testdata │ │ └── TestDefault_resolveFromHosts │ │ └── hosts ├── netutil │ ├── listenconfig.go │ ├── listenconfig_unix.go │ ├── listenconfig_windows.go │ ├── netutil.go │ ├── paths.go │ ├── paths_unix.go │ ├── paths_windows.go │ ├── testdata │ │ └── TestHosts │ │ │ ├── bad_file │ │ │ └── hosts │ │ │ └── good_file │ │ │ └── hosts │ ├── udp.go │ ├── udp_unix.go │ ├── udp_windows.go │ ├── udpoob_darwin.go │ └── udpoob_others.go └── version │ └── version.go ├── main.go ├── proxy ├── beforerequest.go ├── beforerequest_internal_test.go ├── bogusnxdomain.go ├── bogusnxdomain_internal_test.go ├── cache.go ├── cache_internal_test.go ├── config.go ├── constructor.go ├── dns64.go ├── dns64_internal_test.go ├── dnscontext.go ├── errors.go ├── errors_internal_test.go ├── errors_plan9.go ├── exchange.go ├── exchange_internal_test.go ├── handler_internal_test.go ├── helpers.go ├── lookup.go ├── lookup_internal_test.go ├── optimisticresolver.go ├── optimisticresolver_internal_test.go ├── pending.go ├── pending_test.go ├── proxy.go ├── proxy_internal_test.go ├── proxy_test.go ├── proxycache.go ├── ratelimit.go ├── ratelimit_internal_test.go ├── recursiondetector.go ├── recursiondetector_internal_test.go ├── retry.go ├── retry_internal_test.go ├── server.go ├── serverdnscrypt.go ├── serverdnscrypt_internal_test.go ├── serverhttps.go ├── serverhttps_internal_test.go ├── serverquic.go ├── serverquic_internal_test.go ├── servertcp.go ├── servertcp_internal_test.go ├── serverudp.go ├── serverudp_internal_test.go ├── stats.go ├── stats_test.go ├── upstreammode.go ├── upstreammode_test.go ├── upstreams.go └── upstreams_internal_test.go ├── proxyutil └── dns.go ├── scripts ├── hooks │ └── pre-commit └── make │ ├── build-docker.sh │ ├── build-release.sh │ ├── go-build.sh │ ├── go-deps.sh │ ├── go-lint.sh │ ├── go-test.sh │ ├── go-tools.sh │ ├── go-upd-tools.sh │ ├── helper.sh │ ├── md-lint.sh │ ├── sh-lint.sh │ └── txt-lint.sh ├── staticcheck.conf └── upstream ├── dnscrypt.go ├── dnscrypt_internal_test.go ├── doh.go ├── doh_internal_test.go ├── doq.go ├── doq_internal_test.go ├── dot.go ├── dot_internal_test.go ├── dot_unix.go ├── dot_windows.go ├── hostsresolver.go ├── hostsresolver_test.go ├── parallel.go ├── parallel_internal_test.go ├── plain.go ├── plain_internal_test.go ├── resolver.go ├── resolver_internal_test.go ├── resolver_test.go ├── upstream.go └── upstream_internal_test.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 40% 6 | threshold: null 7 | patch: false 8 | changes: false 9 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore everything except for explicitly allowed stuff. 2 | * 3 | !build/docker 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | vendor/** binary 2 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | 'env': 4 | 'GO_VERSION': '1.24.2' 5 | 6 | 'on': 7 | 'push': 8 | 'tags': 9 | - 'v*' 10 | 'branches': 11 | - '*' 12 | 'pull_request': 13 | 14 | jobs: 15 | tests: 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | matrix: 19 | os: 20 | - windows-latest 21 | - macos-latest 22 | - ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@master 25 | - uses: actions/setup-go@v2 26 | with: 27 | go-version: '${{ env.GO_VERSION }}' 28 | - name: Run tests 29 | env: 30 | CI: "1" 31 | run: |- 32 | make test 33 | - name: Upload coverage 34 | uses: codecov/codecov-action@v1 35 | if: "success() && matrix.os == 'ubuntu-latest'" 36 | with: 37 | token: ${{ secrets.CODECOV_TOKEN }} 38 | file: ./coverage.txt 39 | 40 | build: 41 | needs: 42 | - tests 43 | runs-on: ubuntu-latest 44 | steps: 45 | - uses: actions/checkout@master 46 | - uses: actions/setup-go@v2 47 | with: 48 | go-version: '${{ env.GO_VERSION }}' 49 | - name: Build release 50 | run: |- 51 | set -e -u -x 52 | 53 | RELEASE_VERSION="${GITHUB_REF##*/}" 54 | if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi 55 | echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV 56 | 57 | make VERBOSE=1 VERSION="${RELEASE_VERSION}" release 58 | 59 | ls -l build/dnsproxy-* 60 | - name: Create release 61 | if: startsWith(github.ref, 'refs/tags/v') 62 | id: create_release 63 | uses: actions/create-release@v1 64 | env: 65 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 66 | with: 67 | tag_name: ${{ github.ref }} 68 | release_name: Release ${{ github.ref }} 69 | draft: false 70 | prerelease: false 71 | - name: Upload 72 | if: startsWith(github.ref, 'refs/tags/v') 73 | uses: xresloader/upload-to-github-release@v1.3.12 74 | env: 75 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 76 | with: 77 | file: "build/dnsproxy-*.tar.gz;build/dnsproxy-*.zip" 78 | tags: true 79 | draft: false 80 | 81 | notify: 82 | needs: 83 | - build 84 | if: 85 | ${{ always() && 86 | ( 87 | github.event_name == 'push' || 88 | github.event.pull_request.head.repo.full_name == github.repository 89 | ) 90 | }} 91 | runs-on: ubuntu-latest 92 | steps: 93 | - name: Conclusion 94 | uses: technote-space/workflow-conclusion-action@v1 95 | - name: Send Slack notif 96 | uses: 8398a7/action-slack@v3 97 | with: 98 | status: ${{ env.WORKFLOW_CONCLUSION }} 99 | fields: workflow, repo, message, commit, author, eventName,ref 100 | env: 101 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 102 | SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} 103 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | 'name': Docker 2 | 3 | 'env': 4 | 'GO_VERSION': '1.24.2' 5 | 6 | 'on': 7 | 'push': 8 | 'tags': 9 | - 'v*' 10 | # Builds from the master branch will be pushed with the `dev` tag. 11 | 'branches': 12 | - 'master' 13 | 14 | 'jobs': 15 | 'docker': 16 | 'runs-on': 'ubuntu-latest' 17 | 'steps': 18 | - 'name': 'Checkout' 19 | 'uses': 'actions/checkout@v3' 20 | 'with': 21 | 'fetch-depth': 0 22 | - 'name': 'Set up Go' 23 | 'uses': 'actions/setup-go@v3' 24 | 'with': 25 | 'go-version': '${{ env.GO_VERSION }}' 26 | - 'name': 'Set up Go modules cache' 27 | 'uses': 'actions/cache@v4' 28 | 'with': 29 | 'path': '~/go/pkg/mod' 30 | 'key': "${{ runner.os }}-go-${{ hashFiles('go.sum') }}" 31 | 'restore-keys': '${{ runner.os }}-go-' 32 | - 'name': 'Set up QEMU' 33 | 'uses': 'docker/setup-qemu-action@v1' 34 | - 'name': 'Set up Docker Buildx' 35 | 'uses': 'docker/setup-buildx-action@v1' 36 | - 'name': 'Publish to Docker Hub' 37 | 'env': 38 | 'DOCKER_USER': ${{ secrets.DOCKER_USER }} 39 | 'DOCKER_PASSWORD': ${{ secrets.DOCKER_PASSWORD }} 40 | 'run': |- 41 | set -e -u -x 42 | 43 | RELEASE_VERSION="${GITHUB_REF##*/}" 44 | if [[ "${RELEASE_VERSION}" != v* ]]; then RELEASE_VERSION='dev'; fi 45 | echo "RELEASE_VERSION=\"${RELEASE_VERSION}\"" >> $GITHUB_ENV 46 | 47 | docker login \ 48 | -u="${DOCKER_USER}" \ 49 | -p="${DOCKER_PASSWORD}" 50 | 51 | make \ 52 | VERSION="${RELEASE_VERSION}" \ 53 | DOCKER_IMAGE_NAME="adguard/dnsproxy" \ 54 | DOCKER_OUTPUT="type=image,name=adguard/dnsproxy,push=true" \ 55 | VERBOSE="1" \ 56 | docker 57 | 58 | 'notify': 59 | 'needs': 60 | - 'docker' 61 | 'if': 62 | ${{ always() && 63 | ( 64 | github.event_name == 'push' || 65 | github.event.pull_request.head.repo.full_name == github.repository 66 | ) 67 | }} 68 | 'runs-on': ubuntu-latest 69 | 'steps': 70 | - 'name': Conclusion 71 | 'uses': technote-space/workflow-conclusion-action@v1 72 | - 'name': Send Slack notif 73 | 'uses': 8398a7/action-slack@v3 74 | 'with': 75 | 'status': ${{ env.WORKFLOW_CONCLUSION }} 76 | 'fields': workflow, repo, message, commit, author, eventName,ref 77 | 'env': 78 | 'GITHUB_TOKEN': ${{ secrets.GITHUB_TOKEN }} 79 | 'SLACK_WEBHOOK_URL': ${{ secrets.SLACK_WEBHOOK_URL }} 80 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | 'name': 'lint' 2 | 3 | 'env': 4 | 'GO_VERSION': '1.24.2' 5 | 6 | 'on': 7 | 'push': 8 | 'tags': 9 | - 'v*' 10 | 'branches': 11 | - '*' 12 | 'pull_request': 13 | 14 | 'jobs': 15 | 'go-lint': 16 | 'runs-on': 'ubuntu-latest' 17 | 'steps': 18 | - 'uses': 'actions/checkout@v2' 19 | - 'name': 'Set up Go' 20 | 'uses': 'actions/setup-go@v3' 21 | 'with': 22 | 'go-version': '${{ env.GO_VERSION }}' 23 | - 'name': 'run-lint' 24 | 'run': > 25 | make go-deps go-tools go-lint 26 | 27 | 'notify': 28 | 'needs': 29 | - 'go-lint' 30 | # Secrets are not passed to workflows that are triggered by a pull request 31 | # from a fork. 32 | # 33 | # Use always() to signal to the runner that this job must run even if the 34 | # previous ones failed. 35 | 'if': 36 | ${{ 37 | always() && 38 | github.repository_owner == 'AdguardTeam' && 39 | ( 40 | github.event_name == 'push' || 41 | github.event.pull_request.head.repo.full_name == github.repository 42 | ) 43 | }} 44 | 'runs-on': 'ubuntu-latest' 45 | 'steps': 46 | - 'name': 'Conclusion' 47 | 'uses': 'technote-space/workflow-conclusion-action@v1' 48 | - 'name': 'Send Slack notif' 49 | 'uses': '8398a7/action-slack@v3' 50 | 'with': 51 | 'status': '${{ env.WORKFLOW_CONCLUSION }}' 52 | 'fields': 'workflow, repo, message, commit, author, eventName, ref' 53 | 'env': 54 | 'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}' 55 | 'SLACK_WEBHOOK_URL': '${{ secrets.SLACK_WEBHOOK_URL }}' 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Please, DO NOT put your text editors' temporary files here. The more are 2 | # added, the harder it gets to maintain and manage projects' gitignores. Put 3 | # them into your global gitignore file instead. 4 | # 5 | # See https://stackoverflow.com/a/7335487/1892060. 6 | # 7 | # Only build, run, and test outputs here. Sorted. With negations at the 8 | # bottom to make sure they take effect. 9 | *.out 10 | *.test 11 | /bin/ 12 | build 13 | dnsproxy 14 | dnsproxy.exe 15 | example.crt 16 | example.key 17 | coverage.txt 18 | config.yaml 19 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "ul-indent": { 3 | "indent": 4 4 | }, 5 | "ul-style": { 6 | "style": "dash" 7 | }, 8 | "emphasis-style": { 9 | "style": "asterisk" 10 | }, 11 | "no-duplicate-heading": { 12 | "siblings_only": true 13 | }, 14 | "no-inline-html": { 15 | "allowed_elements": [ 16 | "a" 17 | ] 18 | }, 19 | "no-trailing-spaces": { 20 | "br_spaces": 0 21 | }, 22 | "line-length": false, 23 | "no-bare-urls": false, 24 | "link-fragments": false 25 | } 26 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Keep the Makefile POSIX-compliant. We currently allow hyphens in target 2 | # names, but that may change in the future. 3 | # 4 | # See https://pubs.opengroup.org/onlinepubs/9799919799/utilities/make.html. 5 | .POSIX: 6 | 7 | # This comment is used to simplify checking local copies of the Makefile. Bump 8 | # this number every time a significant change is made to this Makefile. 9 | # 10 | # AdGuard-Project-Version: 9 11 | 12 | # Don't name these macros "GO" etc., because GNU Make apparently makes them 13 | # exported environment variables with the literal value of "${GO:-go}" and so 14 | # on, which is not what we need. Use a dot in the name to make sure that users 15 | # don't have an environment variable with the same name. 16 | # 17 | # See https://unix.stackexchange.com/q/646255/105635. 18 | GO.MACRO = $${GO:-go} 19 | VERBOSE.MACRO = $${VERBOSE:-0} 20 | 21 | BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)} 22 | DIST_DIR = build 23 | GOAMD64 = v1 24 | GOPROXY = https://proxy.golang.org|direct 25 | GOTELEMETRY = off 26 | GOTOOLCHAIN = go1.24.3 27 | OUT = dnsproxy 28 | RACE = 0 29 | REVISION = $${REVISION:-$$(git rev-parse --short HEAD)} 30 | VERSION = 0 31 | 32 | ENV = env\ 33 | BRANCH="$(BRANCH)"\ 34 | DIST_DIR='$(DIST_DIR)'\ 35 | GO="$(GO.MACRO)"\ 36 | GOAMD64='$(GOAMD64)'\ 37 | GOPROXY='$(GOPROXY)'\ 38 | GOTELEMETRY='$(GOTELEMETRY)'\ 39 | GOTOOLCHAIN='$(GOTOOLCHAIN)'\ 40 | OUT='$(OUT)'\ 41 | PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\ 42 | RACE='$(RACE)'\ 43 | REVISION="$(REVISION)"\ 44 | VERBOSE="$(VERBOSE.MACRO)"\ 45 | VERSION="$(VERSION)"\ 46 | 47 | # Keep the line above blank. 48 | 49 | ENV_MISC = env\ 50 | PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\ 51 | VERBOSE="$(VERBOSE.MACRO)"\ 52 | 53 | # Keep the line above blank. 54 | 55 | # Keep this target first, so that a naked make invocation triggers a full build. 56 | build: go-deps go-build 57 | 58 | init: ; git config core.hooksPath ./scripts/hooks 59 | 60 | test: go-test 61 | 62 | go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh 63 | go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh 64 | go-env: ; $(ENV) "$(GO.MACRO)" env 65 | go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh 66 | go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh 67 | go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh 68 | go-upd-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-upd-tools.sh 69 | 70 | go-check: go-tools go-lint go-test 71 | 72 | # A quick check to make sure that all operating systems relevant to the 73 | # development of the project can be typechecked and built successfully. 74 | go-os-check: 75 | $(ENV) GOOS='darwin' "$(GO.MACRO)" vet ./... 76 | $(ENV) GOOS='freebsd' "$(GO.MACRO)" vet ./... 77 | $(ENV) GOOS='openbsd' "$(GO.MACRO)" vet ./... 78 | $(ENV) GOOS='linux' "$(GO.MACRO)" vet ./... 79 | $(ENV) GOOS='windows' "$(GO.MACRO)" vet ./... 80 | 81 | txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh 82 | 83 | md-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/md-lint.sh 84 | sh-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/sh-lint.sh 85 | 86 | clean: ; $(ENV) $(GO.MACRO) clean && rm -f -r '$(DIST_DIR)' 87 | 88 | release: clean 89 | $(ENV) "$(SHELL)" ./scripts/make/build-release.sh 90 | 91 | docker: release 92 | $(ENV) "$(SHELL)" ./scripts/make/build-docker.sh 93 | -------------------------------------------------------------------------------- /bamboo-specs/bamboo.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 'version': 2 3 | 'plan': 4 | 'project-key': 'GO' 5 | 'key': 'DNSPROXY' 6 | 'name': 'dnsproxy - Build and run tests' 7 | 'variables': 8 | 'dockerFpm': 'alanfranz/fpm-within-docker:ubuntu-bionic' 9 | # When there is a patch release of Go available, set this property to an 10 | # exact patch version as opposed to a minor one to make sure that this exact 11 | # version is actually used and not whatever the docker daemon on the CI has 12 | # cached a few months ago. 13 | 'dockerGo': 'golang:1.24.2' 14 | 'maintainer': 'Adguard Go Team' 15 | 'name': 'dnsproxy' 16 | 17 | 'stages': 18 | # TODO(e.burkov): Add separate lint stage for texts. 19 | - 'Lint': 20 | 'manual': false 21 | 'final': false 22 | 'jobs': 23 | - 'Lint' 24 | - 'Test': 25 | 'manual': false 26 | 'final': false 27 | 'jobs': 28 | - 'Test' 29 | 30 | 'Lint': 31 | 'docker': 32 | 'image': '${bamboo.dockerGo}' 33 | 'volumes': 34 | '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}' 35 | '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}' 36 | 'key': 'LINT' 37 | 'other': 38 | 'clean-working-dir': true 39 | 'requirements': 40 | - 'adg-docker': true 41 | 'tasks': 42 | - 'checkout': 43 | 'force-clean-build': true 44 | - 'script': 45 | 'interpreter': 'SHELL' 46 | 'scripts': 47 | - | 48 | set -e -f -u -x 49 | 50 | make VERBOSE=1 GOMAXPROCS=1 go-tools go-lint 51 | 52 | 'Test': 53 | 'docker': 54 | 'image': '${bamboo.dockerGo}' 55 | 'volumes': 56 | '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}' 57 | '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}' 58 | 'key': 'TEST' 59 | 'other': 60 | 'clean-working-dir': true 61 | 'requirements': 62 | - 'adg-docker': true 63 | 'tasks': 64 | - 'checkout': 65 | 'force-clean-build': true 66 | - 'script': 67 | 'interpreter': 'SHELL' 68 | # Projects that have go-bench and/or go-fuzz targets should add them 69 | # here as well. 70 | 'scripts': 71 | - | 72 | set -e -f -u -x 73 | 74 | make VERBOSE=1 go-deps go-test 75 | 76 | 'branches': 77 | 'create': 'for-pull-request' 78 | 'delete': 79 | 'after-deleted-days': 1 80 | 'after-inactive-days': 5 81 | 'link-to-jira': true 82 | 83 | 'notifications': 84 | - 'events': 85 | - 'plan-status-changed' 86 | 'recipients': 87 | - 'webhook': 88 | 'name': 'Build webhook' 89 | 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo' 90 | 91 | 'labels': [] 92 | 93 | 'other': 94 | 'concurrent-build-plugin': 'system-default' 95 | -------------------------------------------------------------------------------- /config.yaml.dist: -------------------------------------------------------------------------------- 1 | # This is the yaml configuration file for dnsproxy with minimal working 2 | # configuration, all the options available can be seen with ./dnsproxy --help. 3 | # To use it within dnsproxy specify the --config-path=/ 4 | # option. Any other command-line options specified will override the values 5 | # from the config file. 6 | --- 7 | bootstrap: 8 | - "8.8.8.8:53" 9 | listen-addrs: 10 | - "0.0.0.0" 11 | listen-ports: 12 | - 53 13 | max-go-routines: 0 14 | ratelimit: 0 15 | ratelimit-subnet-len-ipv4: 24 16 | ratelimit-subnet-len-ipv6: 64 17 | udp-buf-size: 0 18 | upstream: 19 | - "1.1.1.1:53" 20 | timeout: '10s' 21 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # A docker file for scripts/make/build-docker.sh. 2 | 3 | FROM alpine:3.18 4 | 5 | ARG BUILD_DATE 6 | ARG VERSION 7 | ARG VCS_REF 8 | 9 | LABEL\ 10 | maintainer="AdGuard Team " \ 11 | org.opencontainers.image.authors="AdGuard Team " \ 12 | org.opencontainers.image.created=$BUILD_DATE \ 13 | org.opencontainers.image.description="Simple DNS proxy with DoH, DoT, DoQ and DNSCrypt support" \ 14 | org.opencontainers.image.documentation="https://github.com/AdguardTeam/dnsproxy" \ 15 | org.opencontainers.image.licenses="Apache-2.0" \ 16 | org.opencontainers.image.revision=$VCS_REF \ 17 | org.opencontainers.image.source="https://github.com/AdguardTeam/dnsproxy" \ 18 | org.opencontainers.image.title="dnsproxy" \ 19 | org.opencontainers.image.url="https://github.com/AdguardTeam/dnsproxy" \ 20 | org.opencontainers.image.vendor="AdGuard" \ 21 | org.opencontainers.image.version=$VERSION 22 | 23 | # Update certificates. 24 | RUN apk --no-cache add ca-certificates libcap tzdata && \ 25 | mkdir -p /opt/dnsproxy && chown -R nobody: /opt/dnsproxy 26 | 27 | ARG DIST_DIR 28 | ARG TARGETARCH 29 | ARG TARGETOS 30 | ARG TARGETVARIANT 31 | 32 | COPY --chown=nobody:nogroup\ 33 | ./${DIST_DIR}/docker/dnsproxy_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\ 34 | /opt/dnsproxy/dnsproxy 35 | COPY --chown=nobody:nogroup\ 36 | ./${DIST_DIR}/docker/config.yaml\ 37 | /opt/dnsproxy/config.yaml 38 | 39 | RUN setcap 'cap_net_bind_service=+eip' /opt/dnsproxy/dnsproxy 40 | 41 | # 53 : TCP, UDP : DNS 42 | # 80 : TCP : HTTP 43 | # 443 : TCP, UDP : HTTPS, DNS-over-HTTPS (incl. HTTP/3), DNSCrypt (main) 44 | # 853 : TCP, UDP : DNS-over-TLS, DNS-over-QUIC 45 | # 5443 : TCP, UDP : DNSCrypt (alt) 46 | # 6060 : TCP : HTTP (pprof) 47 | EXPOSE 53/tcp 53/udp \ 48 | 80/tcp \ 49 | 443/tcp 443/udp \ 50 | 853/tcp 853/udp \ 51 | 5443/tcp 5443/udp \ 52 | 6060/tcp 53 | 54 | WORKDIR /opt/dnsproxy 55 | 56 | ENTRYPOINT ["/opt/dnsproxy/dnsproxy"] 57 | CMD ["--config-path=/opt/dnsproxy/config.yaml"] 58 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # DNS Proxy 2 | 3 | A simple DNS proxy server that supports all existing DNS protocols including 4 | `DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover, 5 | it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server. 6 | 7 | Learn more about dnsproxy and its full capabilities in 8 | its [Github repo][dnsproxy]. 9 | 10 | [dnsproxy]: https://github.com/AdguardTeam/dnsproxy 11 | 12 | ## Quick start 13 | 14 | ### Pull the Docker image 15 | 16 | This command will pull the latest stable version: 17 | 18 | ```shell 19 | docker pull adguard/dnsproxy 20 | ``` 21 | 22 | ### Run the container 23 | 24 | Run the container with the default configuration (see `config.yaml.dist` in the 25 | repository) and expose DNS ports. 26 | 27 | ```shell 28 | docker run --name dnsproxy \ 29 | -p 53:53/tcp -p 53:53/udp \ 30 | adguard/dnsproxy 31 | ``` 32 | 33 | Run the container with command-line args configuration and expose DNS ports. 34 | 35 | ```shell 36 | docker run --name dnsproxy_google_dns \ 37 | -p 53:53/tcp -p 53:53/udp \ 38 | adguard/dnsproxy \ 39 | -u 8.8.8.8:53 40 | ``` 41 | 42 | Run the container with a configuration file and expose DNS ports. 43 | 44 | ```shell 45 | docker run --name dnsproxy_google_dns \ 46 | -p 53:53/tcp -p 53:53/udp \ 47 | -v $PWD/config.yaml:/opt/dnsproxy/config.yaml \ 48 | adguard/dnsproxy 49 | ``` 50 | -------------------------------------------------------------------------------- /fastip/cache.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "encoding/binary" 5 | "net/netip" 6 | "time" 7 | ) 8 | 9 | const ( 10 | // fastestAddrCacheTTLSec is the cache TTL for IP addresses. 11 | fastestAddrCacheTTLSec = 10 * 60 12 | ) 13 | 14 | // cacheEntry represents an item that will be stored in the cache. 15 | // 16 | // TODO(e.burkov): Rewrite the cache using zero-values instead of storing 17 | // useless boolean as an integer. 18 | type cacheEntry struct { 19 | // status is 1 if the item is timed out. 20 | status int 21 | latencyMsec uint 22 | } 23 | 24 | // packCacheEntry packs the cache entry and the TTL to bytes in the following 25 | // order: 26 | // 27 | // - expire [4]byte (Unix time, seconds), 28 | // - status byte (0 for ok, 1 for timed out), 29 | // - latency [2]byte (milliseconds). 30 | func packCacheEntry(ent *cacheEntry, ttl uint32) (d []byte) { 31 | expire := uint32(time.Now().Unix()) + ttl 32 | 33 | d = make([]byte, 4+1+2) 34 | binary.BigEndian.PutUint32(d, expire) 35 | i := 4 36 | 37 | d[i] = byte(ent.status) 38 | i++ 39 | 40 | binary.BigEndian.PutUint16(d[i:], uint16(ent.latencyMsec)) 41 | // i += 2 42 | 43 | return d 44 | } 45 | 46 | // unpackCacheEntry unpacks bytes to cache entry and checks TTL, if the record 47 | // is expired returns nil. 48 | func unpackCacheEntry(data []byte) (ent *cacheEntry) { 49 | now := time.Now().Unix() 50 | expire := binary.BigEndian.Uint32(data[:4]) 51 | if int64(expire) <= now { 52 | return nil 53 | } 54 | 55 | ent = &cacheEntry{} 56 | i := 4 57 | 58 | ent.status = int(data[i]) 59 | i++ 60 | 61 | ent.latencyMsec = uint(binary.BigEndian.Uint16(data[i:])) 62 | // i += 2 63 | 64 | return ent 65 | } 66 | 67 | // cacheFind finds entry in the cache for the given IP address. Returns nil if 68 | // nothing is found or if the record is expired. 69 | func (f *FastestAddr) cacheFind(ip netip.Addr) (ent *cacheEntry) { 70 | val := f.ipCache.Get(ip.AsSlice()) 71 | if val == nil { 72 | return nil 73 | } 74 | 75 | return unpackCacheEntry(val) 76 | } 77 | 78 | // cacheAddFailure stores unsuccessful attempt in cache. 79 | func (f *FastestAddr) cacheAddFailure(ip netip.Addr) { 80 | ent := cacheEntry{ 81 | status: 1, 82 | } 83 | 84 | f.ipCacheLock.Lock() 85 | defer f.ipCacheLock.Unlock() 86 | 87 | if f.cacheFind(ip) == nil { 88 | f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec) 89 | } 90 | } 91 | 92 | // cacheAddSuccessful stores a successful ping result in the cache. Replaces 93 | // previous result if our latency is lower. 94 | func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) { 95 | ent := cacheEntry{ 96 | latencyMsec: latency, 97 | } 98 | 99 | f.ipCacheLock.Lock() 100 | defer f.ipCacheLock.Unlock() 101 | 102 | entCached := f.cacheFind(ip) 103 | if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency { 104 | f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec) 105 | } 106 | } 107 | 108 | // cacheAdd adds a new entry to the cache. 109 | func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) { 110 | val := packCacheEntry(ent, ttl) 111 | f.ipCache.Set(ip.AsSlice(), val) 112 | } 113 | -------------------------------------------------------------------------------- /fastip/cache_internal_test.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "testing" 7 | "time" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestCacheAdd(t *testing.T) { 14 | f := New(&Config{Logger: slogutil.NewDiscardLogger()}) 15 | ent := cacheEntry{ 16 | status: 0, 17 | latencyMsec: 111, 18 | } 19 | 20 | ip := netip.MustParseAddr("1.1.1.1") 21 | f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec) 22 | 23 | // check that it's there 24 | assert.NotNil(t, f.cacheFind(ip)) 25 | } 26 | 27 | func TestCacheTtl(t *testing.T) { 28 | f := New(&Config{Logger: slogutil.NewDiscardLogger()}) 29 | ent := cacheEntry{ 30 | status: 0, 31 | latencyMsec: 111, 32 | } 33 | 34 | ip := netip.MustParseAddr("1.1.1.1") 35 | f.cacheAdd(&ent, ip, 1) 36 | 37 | // check that it's there 38 | assert.NotNil(t, f.cacheFind(ip)) 39 | 40 | // wait for more than one second 41 | time.Sleep(time.Millisecond * 1001) 42 | 43 | // check that now it returns nil 44 | assert.Nil(t, f.cacheFind(ip)) 45 | } 46 | 47 | func TestCacheAddSuccessfulOverwrite(t *testing.T) { 48 | f := New(&Config{Logger: slogutil.NewDiscardLogger()}) 49 | 50 | ip := netip.MustParseAddr("1.1.1.1") 51 | f.cacheAddFailure(ip) 52 | 53 | // check that it's there 54 | ent := f.cacheFind(ip) 55 | assert.NotNil(t, ent) 56 | assert.Equal(t, 1, ent.status) 57 | 58 | // check that it will overwrite existing rec 59 | f.cacheAddSuccessful(ip, 11) 60 | 61 | // check that it's there now 62 | ent = f.cacheFind(ip) 63 | assert.NotNil(t, ent) 64 | assert.Equal(t, 0, ent.status) 65 | assert.Equal(t, uint(11), ent.latencyMsec) 66 | } 67 | 68 | func TestCacheAddFailureNoOverwrite(t *testing.T) { 69 | f := New(&Config{Logger: slogutil.NewDiscardLogger()}) 70 | 71 | ip := netip.MustParseAddr("1.1.1.1") 72 | f.cacheAddSuccessful(ip, 11) 73 | 74 | // check that it's there 75 | ent := f.cacheFind(ip) 76 | assert.NotNil(t, ent) 77 | assert.Equal(t, 0, ent.status) 78 | 79 | // check that it will overwrite existing rec 80 | f.cacheAddFailure(ip) 81 | 82 | // check that the old record is still there 83 | ent = f.cacheFind(ip) 84 | assert.NotNil(t, ent) 85 | assert.Equal(t, 0, ent.status) 86 | assert.Equal(t, uint(11), ent.latencyMsec) 87 | } 88 | 89 | // TODO(ameshkov): Actually test something. 90 | func TestCache(_ *testing.T) { 91 | f := New(&Config{Logger: slogutil.NewDiscardLogger()}) 92 | ent := cacheEntry{ 93 | status: 0, 94 | latencyMsec: 111, 95 | } 96 | 97 | val := packCacheEntry(&ent, 1) 98 | f.ipCache.Set(net.ParseIP("1.1.1.1").To4(), val) 99 | ent = cacheEntry{ 100 | status: 0, 101 | latencyMsec: 222, 102 | } 103 | 104 | f.cacheAdd(&ent, netip.MustParseAddr("2.2.2.2"), fastestAddrCacheTTLSec) 105 | } 106 | -------------------------------------------------------------------------------- /fastip/fastest_internal_test.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "net/netip" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/dnsproxy/upstream" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/miekg/dns" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestFastestAddr_ExchangeFastest(t *testing.T) { 16 | l := slogutil.NewDiscardLogger() 17 | 18 | t.Run("error", func(t *testing.T) { 19 | const errDesired errors.Error = "this is expected" 20 | 21 | u := &errUpstream{ 22 | err: errDesired, 23 | } 24 | f := New(&Config{ 25 | Logger: l, 26 | PingWaitTimeout: DefaultPingWaitTimeout, 27 | }) 28 | 29 | resp, up, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{u}) 30 | require.Error(t, err) 31 | 32 | assert.ErrorIs(t, err, errDesired) 33 | assert.Nil(t, resp) 34 | assert.Nil(t, up) 35 | }) 36 | 37 | t.Run("one_dead", func(t *testing.T) { 38 | port := listen(t, netip.IPv4Unspecified()) 39 | 40 | f := New(&Config{ 41 | Logger: l, 42 | PingWaitTimeout: DefaultPingWaitTimeout, 43 | }) 44 | f.pingPorts = []uint{port} 45 | 46 | // The alive IP is the just created local listener's address. The dead 47 | // one is known as TEST-NET-1 which shouldn't be routed at all. See 48 | // RFC-5737 (https://datatracker.ietf.org/doc/html/rfc5737). 49 | aliveAddr := netip.MustParseAddr("127.0.0.1") 50 | 51 | alive := &testAUpstream{ 52 | recs: []*dns.A{newTestRec(t, aliveAddr)}, 53 | } 54 | dead := &testAUpstream{ 55 | recs: []*dns.A{newTestRec(t, netip.MustParseAddr("192.0.2.1"))}, 56 | } 57 | 58 | rep, ups, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{dead, alive}) 59 | require.NoError(t, err) 60 | 61 | assert.Equal(t, ups, alive) 62 | 63 | require.NotNil(t, rep) 64 | require.NotEmpty(t, rep.Answer) 65 | require.IsType(t, new(dns.A), rep.Answer[0]) 66 | 67 | ip := rep.Answer[0].(*dns.A).A 68 | assert.Equal(t, aliveAddr.AsSlice(), []byte(ip)) 69 | }) 70 | 71 | t.Run("all_dead", func(t *testing.T) { 72 | f := New(&Config{ 73 | Logger: l, 74 | PingWaitTimeout: DefaultPingWaitTimeout, 75 | }) 76 | f.pingPorts = []uint{getFreePort(t)} 77 | 78 | firstIP := netip.MustParseAddr("127.0.0.1") 79 | ups := &testAUpstream{ 80 | recs: []*dns.A{ 81 | newTestRec(t, firstIP), 82 | newTestRec(t, netip.MustParseAddr("127.0.0.2")), 83 | newTestRec(t, netip.MustParseAddr("127.0.0.3")), 84 | }, 85 | } 86 | 87 | resp, _, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{ups}) 88 | require.NoError(t, err) 89 | 90 | require.NotNil(t, resp) 91 | require.NotEmpty(t, resp.Answer) 92 | require.IsType(t, new(dns.A), resp.Answer[0]) 93 | 94 | ip := resp.Answer[0].(*dns.A).A 95 | assert.Equal(t, firstIP.AsSlice(), []byte(ip)) 96 | }) 97 | } 98 | 99 | // testAUpstream is a mock err upstream structure for tests. 100 | type errUpstream struct { 101 | err error 102 | closeErr error 103 | } 104 | 105 | // Address implements the [upstream.Upstream] interface for *errUpstream. 106 | func (u *errUpstream) Address() string { 107 | return "bad_upstream" 108 | } 109 | 110 | // Exchange implements the [upstream.Upstream] interface for *errUpstream. 111 | func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) { 112 | return nil, u.err 113 | } 114 | 115 | // Close implements the [upstream.Upstream] interface for *errUpstream. 116 | func (u *errUpstream) Close() error { 117 | return u.closeErr 118 | } 119 | 120 | // testAUpstream is a mock A upstream structure for tests. 121 | type testAUpstream struct { 122 | recs []*dns.A 123 | } 124 | 125 | // type check 126 | var _ upstream.Upstream = (*testAUpstream)(nil) 127 | 128 | // Exchange implements the [upstream.Upstream] interface for *testAUpstream. 129 | func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { 130 | resp = &dns.Msg{} 131 | resp.SetReply(m) 132 | 133 | for _, a := range u.recs { 134 | resp.Answer = append(resp.Answer, a) 135 | } 136 | 137 | return resp, nil 138 | } 139 | 140 | // Address implements the [upstream.Upstream] interface for *testAUpstream. 141 | func (u *testAUpstream) Address() (addr string) { 142 | return "" 143 | } 144 | 145 | // Close implements the [upstream.Upstream] interface for *testAUpstream. 146 | func (u *testAUpstream) Close() (err error) { 147 | return nil 148 | } 149 | 150 | // newTestRec returns a new test A record. 151 | func newTestRec(t *testing.T, addr netip.Addr) (rr *dns.A) { 152 | return &dns.A{ 153 | Hdr: dns.RR_Header{ 154 | Rrtype: dns.TypeA, 155 | Name: dns.Fqdn(t.Name()), 156 | Ttl: 60, 157 | }, 158 | A: addr.AsSlice(), 159 | } 160 | } 161 | 162 | // newTestReq returns a new test A request. 163 | func newTestReq(t *testing.T) (req *dns.Msg) { 164 | return &dns.Msg{ 165 | MsgHdr: dns.MsgHdr{ 166 | Id: dns.Id(), 167 | RecursionDesired: true, 168 | }, 169 | Question: []dns.Question{{ 170 | Name: dns.Fqdn(t.Name()), 171 | Qtype: dns.TypeA, 172 | Qclass: dns.ClassINET, 173 | }}, 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /fastip/ping.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "net/netip" 5 | "time" 6 | 7 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 8 | "github.com/AdguardTeam/golibs/logutil/slogutil" 9 | ) 10 | 11 | // pingTCPTimeout is a TCP connection timeout. It's higher than pingWaitTimeout 12 | // since the slower connections will be cached anyway. 13 | const pingTCPTimeout = 4 * time.Second 14 | 15 | // pingResult is the result of dialing the address. 16 | type pingResult struct { 17 | // addrPort is the address-port pair the result is related to. 18 | addrPort netip.AddrPort 19 | 20 | // latency is the duration of dialing process in milliseconds. 21 | latency uint 22 | 23 | // success is true when the dialing succeeded. 24 | success bool 25 | } 26 | 27 | // schedulePings returns the result with the fastest IP address from the cache, 28 | // if it's found, and starts pinging other IPs which are not cached or outdated. 29 | // Returns scheduled flag which indicates that some goroutines have been 30 | // scheduled. 31 | func (f *FastestAddr) schedulePings( 32 | resCh chan *pingResult, 33 | ips []netip.Addr, 34 | host string, 35 | ) (pr *pingResult, scheduled bool) { 36 | for _, ip := range ips { 37 | cached := f.cacheFind(ip) 38 | if cached == nil { 39 | scheduled = true 40 | for _, port := range f.pingPorts { 41 | go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh) 42 | } 43 | 44 | continue 45 | } 46 | 47 | if cached.status == 0 && (pr == nil || cached.latencyMsec < pr.latency) { 48 | pr = &pingResult{ 49 | addrPort: netip.AddrPortFrom(ip, 0), 50 | latency: cached.latencyMsec, 51 | success: true, 52 | } 53 | } 54 | } 55 | 56 | return pr, scheduled 57 | } 58 | 59 | // pingAll pings all ips concurrently and returns as soon as the fastest one is 60 | // found or the timeout is exceeded. 61 | func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) { 62 | ipN := len(ips) 63 | switch ipN { 64 | case 0: 65 | return nil 66 | case 1: 67 | return &pingResult{ 68 | addrPort: netip.AddrPortFrom(ips[0], 0), 69 | success: true, 70 | } 71 | } 72 | 73 | resCh := make(chan *pingResult, ipN*len(f.pingPorts)) 74 | pr, scheduled := f.schedulePings(resCh, ips, host) 75 | if !scheduled { 76 | if pr != nil { 77 | f.logger.Debug( 78 | "pinging all returns cached response", 79 | "host", host, 80 | "addr", pr.addrPort, 81 | ) 82 | } else { 83 | f.logger.Debug("pinging all returns nothing", "host", host) 84 | } 85 | 86 | return pr 87 | } 88 | 89 | res := f.firstSuccessRes(resCh, host) 90 | if res == nil { 91 | // In case of timeout return cached or nil. 92 | return pr 93 | } 94 | 95 | if pr == nil || res.latency <= pr.latency { 96 | // Cache wasn't found or is worse than res. 97 | return res 98 | } 99 | 100 | // Return cached result. 101 | return pr 102 | } 103 | 104 | // firstSuccessRes waits and returns the first successful ping result or nil in 105 | // case of timeout. 106 | func (f *FastestAddr) firstSuccessRes(resCh chan *pingResult, host string) (res *pingResult) { 107 | after := time.After(f.pingWaitTimeout) 108 | for { 109 | select { 110 | case res = <-resCh: 111 | f.logger.Debug( 112 | "pinging all got result", 113 | "host", host, 114 | "addr", res.addrPort, 115 | "status", res.success, 116 | ) 117 | 118 | if !res.success { 119 | continue 120 | } 121 | 122 | return res 123 | case <-after: 124 | f.logger.Debug("pinging all timed out", "host", host) 125 | 126 | return nil 127 | } 128 | } 129 | } 130 | 131 | // pingDoTCP sends the result of dialing the specified address into resCh. 132 | func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) { 133 | l := f.logger.With("host", host, "addr", addrPort) 134 | l.Debug("open tcp connection") 135 | 136 | start := time.Now() 137 | conn, err := f.pinger.Dial(bootstrap.NetworkTCP, addrPort.String()) 138 | elapsed := time.Since(start) 139 | 140 | success := err == nil 141 | if success { 142 | if cErr := conn.Close(); cErr != nil { 143 | l.Debug("closing tcp connection", slogutil.KeyError, cErr) 144 | } 145 | } 146 | 147 | latency := uint(elapsed.Milliseconds()) 148 | 149 | resCh <- &pingResult{ 150 | addrPort: addrPort, 151 | latency: latency, 152 | success: success, 153 | } 154 | 155 | addr := addrPort.Addr().Unmap() 156 | if success { 157 | l.Debug("tcp ping success", "elapsed", elapsed) 158 | f.cacheAddSuccessful(addr, latency) 159 | } else { 160 | l.Debug("tcp ping failed to connect", "elapsed", elapsed, slogutil.KeyError, err) 161 | f.cacheAddFailure(addr) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/AdguardTeam/dnsproxy 2 | 3 | go 1.24.3 4 | 5 | require ( 6 | github.com/AdguardTeam/golibs v0.32.11 7 | github.com/ameshkov/dnscrypt/v2 v2.4.0 8 | github.com/ameshkov/dnsstamps v1.0.3 9 | github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 10 | github.com/bluele/gcache v0.0.2 11 | github.com/miekg/dns v1.1.66 12 | github.com/patrickmn/go-cache v2.1.0+incompatible 13 | // TODO(s.chzhen): Update after investigation of the 0-RTT bug/behavior 14 | // when TestUpstreamDoH_serverRestart/http3/second_try keeps failing. 15 | github.com/quic-go/quic-go v0.52.0 16 | github.com/stretchr/testify v1.10.0 17 | golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b // indirect 18 | golang.org/x/net v0.40.0 19 | golang.org/x/sys v0.33.0 20 | gonum.org/v1/gonum v0.16.0 21 | gopkg.in/yaml.v3 v3.0.1 22 | ) 23 | 24 | require ( 25 | cloud.google.com/go v0.121.2 // indirect 26 | cloud.google.com/go/ai v0.12.1 // indirect 27 | cloud.google.com/go/auth v0.16.2 // indirect 28 | cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect 29 | cloud.google.com/go/compute/metadata v0.7.0 // indirect 30 | cloud.google.com/go/longrunning v0.6.7 // indirect 31 | github.com/BurntSushi/toml v1.5.0 // indirect 32 | github.com/ccojocar/zxcvbn-go v1.0.4 // indirect 33 | github.com/davecgh/go-spew v1.1.1 // indirect 34 | github.com/felixge/httpsnoop v1.0.4 // indirect 35 | github.com/fzipp/gocyclo v0.6.0 // indirect 36 | github.com/go-logr/logr v1.4.3 // indirect 37 | github.com/go-logr/stdr v1.2.2 // indirect 38 | github.com/go-task/slim-sprig/v3 v3.0.0 // indirect 39 | github.com/golangci/misspell v0.7.0 // indirect 40 | github.com/google/generative-ai-go v0.20.1 // indirect 41 | github.com/google/go-cmp v0.7.0 // indirect 42 | github.com/google/pprof v0.0.0-20250602020802-c6617b811d0e // indirect 43 | github.com/google/renameio/v2 v2.0.0 // indirect 44 | github.com/google/s2a-go v0.1.9 // indirect 45 | github.com/google/uuid v1.6.0 // indirect 46 | github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect 47 | github.com/googleapis/gax-go/v2 v2.14.2 // indirect 48 | github.com/gookit/color v1.5.4 // indirect 49 | github.com/gordonklaus/ineffassign v0.1.0 // indirect 50 | github.com/jstemmer/go-junit-report/v2 v2.1.0 // indirect 51 | github.com/kisielk/errcheck v1.9.0 // indirect 52 | github.com/onsi/ginkgo/v2 v2.23.4 // indirect 53 | github.com/pmezard/go-difflib v1.0.0 // indirect 54 | github.com/quic-go/qpack v0.5.1 // indirect 55 | github.com/robfig/cron/v3 v3.0.1 // indirect 56 | github.com/rogpeppe/go-internal v1.14.1 // indirect 57 | github.com/securego/gosec/v2 v2.22.4 // indirect 58 | github.com/uudashr/gocognit v1.2.0 // indirect 59 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 60 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 61 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect 62 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect 63 | go.opentelemetry.io/otel v1.36.0 // indirect 64 | go.opentelemetry.io/otel/metric v1.36.0 // indirect 65 | go.opentelemetry.io/otel/trace v1.36.0 // indirect 66 | go.uber.org/automaxprocs v1.6.0 // indirect 67 | go.uber.org/mock v0.5.2 // indirect 68 | golang.org/x/crypto v0.38.0 // indirect 69 | golang.org/x/exp/typeparams v0.0.0-20250531010427-b6e5de432a8b // indirect 70 | golang.org/x/mod v0.25.0 // indirect 71 | golang.org/x/oauth2 v0.30.0 // indirect 72 | golang.org/x/sync v0.15.0 // indirect 73 | golang.org/x/telemetry v0.0.0-20250605140807-cd7dbf5ade20 // indirect 74 | golang.org/x/term v0.32.0 // indirect 75 | golang.org/x/text v0.26.0 // indirect 76 | golang.org/x/time v0.12.0 // indirect 77 | golang.org/x/tools v0.33.0 // indirect 78 | golang.org/x/vuln v1.1.4 // indirect 79 | google.golang.org/api v0.236.0 // indirect 80 | google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect 81 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect 82 | google.golang.org/grpc v1.73.0 // indirect 83 | google.golang.org/protobuf v1.36.6 // indirect 84 | honnef.co/go/tools v0.6.1 // indirect 85 | mvdan.cc/editorconfig v0.3.0 // indirect 86 | mvdan.cc/gofumpt v0.8.0 // indirect 87 | mvdan.cc/sh/v3 v3.11.0 // indirect 88 | mvdan.cc/unparam v0.0.0-20250301125049-0df0534333a4 // indirect 89 | ) 90 | 91 | tool ( 92 | github.com/fzipp/gocyclo/cmd/gocyclo 93 | github.com/golangci/misspell/cmd/misspell 94 | github.com/gordonklaus/ineffassign 95 | github.com/jstemmer/go-junit-report/v2 96 | github.com/kisielk/errcheck 97 | github.com/securego/gosec/v2/cmd/gosec 98 | github.com/uudashr/gocognit/cmd/gocognit 99 | golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment 100 | golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness 101 | golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow 102 | golang.org/x/vuln/cmd/govulncheck 103 | honnef.co/go/tools/cmd/staticcheck 104 | mvdan.cc/gofumpt 105 | mvdan.cc/sh/v3/cmd/shfmt 106 | mvdan.cc/unparam 107 | ) 108 | -------------------------------------------------------------------------------- /internal/bootstrap/bootstrap.go: -------------------------------------------------------------------------------- 1 | // Package bootstrap provides types and functions to resolve upstream hostnames 2 | // and to dial retrieved addresses. 3 | package bootstrap 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log/slog" 9 | "net" 10 | "net/netip" 11 | "net/url" 12 | "slices" 13 | "time" 14 | 15 | "github.com/AdguardTeam/golibs/errors" 16 | "github.com/AdguardTeam/golibs/logutil/slogutil" 17 | "github.com/AdguardTeam/golibs/netutil" 18 | ) 19 | 20 | // Network is a network type for use in [Resolver]'s methods. 21 | type Network = string 22 | 23 | const ( 24 | // NetworkIP is a network type for both address families. 25 | NetworkIP Network = "ip" 26 | 27 | // NetworkIP4 is a network type for IPv4 address family. 28 | NetworkIP4 Network = "ip4" 29 | 30 | // NetworkIP6 is a network type for IPv6 address family. 31 | NetworkIP6 Network = "ip6" 32 | 33 | // NetworkTCP is a network type for TCP connections. 34 | NetworkTCP Network = "tcp" 35 | 36 | // NetworkUDP is a network type for UDP connections. 37 | NetworkUDP Network = "udp" 38 | ) 39 | 40 | // DialHandler is a dial function for creating unencrypted network connections 41 | // to the upstream server. It establishes the connection to the server 42 | // specified at initialization and ignores the addr. network must be one of 43 | // [NetworkTCP] or [NetworkUDP]. 44 | type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error) 45 | 46 | // ResolveDialContext returns a DialHandler that uses addresses resolved from u 47 | // using resolver. l and u must not be nil. 48 | func ResolveDialContext( 49 | u *url.URL, 50 | timeout time.Duration, 51 | r Resolver, 52 | preferV6 bool, 53 | l *slog.Logger, 54 | ) (h DialHandler, err error) { 55 | defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }() 56 | 57 | host, port, err := netutil.SplitHostPort(u.Host) 58 | if err != nil { 59 | // Don't wrap the error since it's informative enough as is and there is 60 | // already deferred annotation here. 61 | return nil, err 62 | } 63 | 64 | if r == nil { 65 | return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers) 66 | } 67 | 68 | ctx := context.Background() 69 | if timeout > 0 { 70 | var cancel func() 71 | ctx, cancel = context.WithTimeout(ctx, timeout) 72 | defer cancel() 73 | } 74 | 75 | // TODO(e.burkov): Use network properly, perhaps, pass it through options. 76 | ips, err := r.LookupNetIP(ctx, NetworkIP, host) 77 | if err != nil { 78 | return nil, fmt.Errorf("resolving hostname: %w", err) 79 | } 80 | 81 | if preferV6 { 82 | slices.SortStableFunc(ips, netutil.PreferIPv6) 83 | } else { 84 | slices.SortStableFunc(ips, netutil.PreferIPv4) 85 | } 86 | 87 | addrs := make([]string, 0, len(ips)) 88 | for _, ip := range ips { 89 | addrs = append(addrs, netip.AddrPortFrom(ip, port).String()) 90 | } 91 | 92 | return NewDialContext(timeout, l, addrs...), nil 93 | } 94 | 95 | // NewDialContext returns a DialHandler that dials addrs and returns the first 96 | // successful connection. At least a single addr should be specified. l must 97 | // not be nil. 98 | func NewDialContext(timeout time.Duration, l *slog.Logger, addrs ...string) (h DialHandler) { 99 | addrLen := len(addrs) 100 | if addrLen == 0 { 101 | l.Debug("no addresses to dial") 102 | 103 | return func(_ context.Context, _, _ string) (conn net.Conn, err error) { 104 | return nil, errors.Error("no addresses") 105 | } 106 | } 107 | 108 | dialer := &net.Dialer{ 109 | Timeout: timeout, 110 | } 111 | 112 | return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) { 113 | var errs []error 114 | 115 | // Return first succeeded connection. Note that we're using addrs 116 | // instead of what's passed to the function. 117 | for i, addr := range addrs { 118 | a := l.With("addr", addr) 119 | a.DebugContext(ctx, "dialing", "idx", i+1, "total", addrLen) 120 | 121 | start := time.Now() 122 | conn, err = dialer.DialContext(ctx, network, addr) 123 | elapsed := time.Since(start) 124 | if err != nil { 125 | a.DebugContext(ctx, "connection failed", "elapsed", elapsed, slogutil.KeyError, err) 126 | errs = append(errs, err) 127 | 128 | continue 129 | } 130 | 131 | a.DebugContext(ctx, "connection succeeded", "elapsed", elapsed) 132 | 133 | return conn, nil 134 | } 135 | 136 | return nil, errors.Join(errs...) 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /internal/bootstrap/error.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import "github.com/AdguardTeam/golibs/errors" 4 | 5 | // ErrNoResolvers is returned when zero resolvers specified. 6 | const ErrNoResolvers errors.Error = "no resolvers specified" 7 | -------------------------------------------------------------------------------- /internal/bootstrap/resolver.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "net" 8 | "net/netip" 9 | "slices" 10 | 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/AdguardTeam/golibs/logutil/slogutil" 13 | ) 14 | 15 | // Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] 16 | // from standard library also implements this interface. 17 | type Resolver interface { 18 | // LookupNetIP looks up the IP addresses for the given host. network should 19 | // be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be 20 | // empty even if err is nil. All the addrs must be valid. 21 | LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error) 22 | } 23 | 24 | // type check 25 | var _ Resolver = &net.Resolver{} 26 | 27 | // ParallelResolver is a slice of resolvers that are queried concurrently. The 28 | // first successful response is returned. 29 | type ParallelResolver []Resolver 30 | 31 | // type check 32 | var _ Resolver = ParallelResolver(nil) 33 | 34 | // LookupNetIP implements the [Resolver] interface for ParallelResolver. 35 | func (r ParallelResolver) LookupNetIP( 36 | ctx context.Context, 37 | network Network, 38 | host string, 39 | ) (addrs []netip.Addr, err error) { 40 | resolversNum := len(r) 41 | switch resolversNum { 42 | case 0: 43 | return nil, ErrNoResolvers 44 | case 1: 45 | return r[0].LookupNetIP(ctx, network, host) 46 | default: 47 | // Go on. 48 | } 49 | 50 | // Size of channel must accommodate results of lookups from all resolvers, 51 | // sending into channel will block otherwise. 52 | ch := make(chan any, resolversNum) 53 | for _, rslv := range r { 54 | go lookupAsync(ctx, rslv, network, host, ch) 55 | } 56 | 57 | var errs []error 58 | for range r { 59 | switch result := <-ch; result := result.(type) { 60 | case error: 61 | errs = append(errs, result) 62 | case []netip.Addr: 63 | return result, nil 64 | } 65 | } 66 | 67 | return nil, errors.Join(errs...) 68 | } 69 | 70 | // recoverAndLog is a deferred helper that recovers from a panic and logs the 71 | // panic value with the logger from context or with a default logger. Sends the 72 | // recovered value into resCh. 73 | // 74 | // TODO(a.garipov): Move this helper to golibs. 75 | func recoverAndLog(ctx context.Context, resCh chan<- any) { 76 | v := recover() 77 | if v == nil { 78 | return 79 | } 80 | 81 | err, ok := v.(error) 82 | if !ok { 83 | err = fmt.Errorf("error value: %v", v) 84 | } 85 | 86 | l, ok := slogutil.LoggerFromContext(ctx) 87 | if !ok { 88 | l = slog.Default() 89 | } 90 | 91 | l.ErrorContext(ctx, "recovered panic", slogutil.KeyError, err) 92 | slogutil.PrintStack(ctx, l, slog.LevelError) 93 | 94 | resCh <- err 95 | } 96 | 97 | // lookupAsync performs a lookup for ip of host with r and sends the result into 98 | // resCh. It is intended to be used as a goroutine. 99 | func lookupAsync(ctx context.Context, r Resolver, network, host string, resCh chan<- any) { 100 | // TODO(d.kolyshev): Propose better solution to recover without requiring 101 | // logger in the context. 102 | defer recoverAndLog(ctx, resCh) 103 | 104 | addrs, err := r.LookupNetIP(ctx, network, host) 105 | if err != nil { 106 | resCh <- err 107 | } else { 108 | resCh <- addrs 109 | } 110 | } 111 | 112 | // ConsequentResolver is a slice of resolvers that are queried in order until 113 | // the first successful non-empty response, as opposed to just successful 114 | // response requirement in [ParallelResolver]. 115 | type ConsequentResolver []Resolver 116 | 117 | // type check 118 | var _ Resolver = ConsequentResolver(nil) 119 | 120 | // LookupNetIP implements the [Resolver] interface for ConsequentResolver. 121 | func (resolvers ConsequentResolver) LookupNetIP( 122 | ctx context.Context, 123 | network Network, 124 | host string, 125 | ) (addrs []netip.Addr, err error) { 126 | if len(resolvers) == 0 { 127 | return nil, ErrNoResolvers 128 | } 129 | 130 | var errs []error 131 | for _, r := range resolvers { 132 | addrs, err = r.LookupNetIP(ctx, network, host) 133 | if err == nil && len(addrs) > 0 { 134 | return addrs, nil 135 | } 136 | 137 | errs = append(errs, err) 138 | } 139 | 140 | return nil, errors.Join(errs...) 141 | } 142 | 143 | // StaticResolver is a resolver which always responds with an underlying slice 144 | // of IP addresses regardless of host and network. 145 | type StaticResolver []netip.Addr 146 | 147 | // type check 148 | var _ Resolver = StaticResolver(nil) 149 | 150 | // LookupNetIP implements the [Resolver] interface for StaticResolver. 151 | func (r StaticResolver) LookupNetIP( 152 | _ context.Context, 153 | _ Network, 154 | _ string, 155 | ) (addrs []netip.Addr, err error) { 156 | return slices.Clone(r), nil 157 | } 158 | -------------------------------------------------------------------------------- /internal/bootstrap/resolver_test.go: -------------------------------------------------------------------------------- 1 | package bootstrap_test 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 10 | "github.com/AdguardTeam/golibs/netutil" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // testResolver is the [Resolver] interface implementation for testing purposes. 17 | type testResolver struct { 18 | onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error) 19 | } 20 | 21 | // LookupNetIP implements the [Resolver] interface for *testResolver. 22 | func (r *testResolver) LookupNetIP( 23 | ctx context.Context, 24 | network string, 25 | host string, 26 | ) (addrs []netip.Addr, err error) { 27 | return r.onLookupNetIP(ctx, network, host) 28 | } 29 | 30 | func TestLookupParallel(t *testing.T) { 31 | const hostname = "host.name" 32 | 33 | t.Run("no_resolvers", func(t *testing.T) { 34 | addrs, err := bootstrap.ParallelResolver(nil).LookupNetIP(context.Background(), "ip", "") 35 | assert.ErrorIs(t, err, bootstrap.ErrNoResolvers) 36 | assert.Nil(t, addrs) 37 | }) 38 | 39 | pt := testutil.PanicT{} 40 | hostAddrs := []netip.Addr{netutil.IPv4Localhost()} 41 | 42 | immediate := &testResolver{ 43 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 44 | require.Equal(pt, hostname, host) 45 | require.Equal(pt, "ip", network) 46 | 47 | return hostAddrs, nil 48 | }, 49 | } 50 | 51 | t.Run("one_resolver", func(t *testing.T) { 52 | addrs, err := bootstrap.ParallelResolver{immediate}.LookupNetIP( 53 | context.Background(), 54 | "ip", 55 | hostname, 56 | ) 57 | require.NoError(t, err) 58 | 59 | assert.Equal(t, hostAddrs, addrs) 60 | }) 61 | 62 | t.Run("two_resolvers", func(t *testing.T) { 63 | delayCh := make(chan struct{}, 1) 64 | delayed := &testResolver{ 65 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 66 | require.Equal(pt, hostname, host) 67 | require.Equal(pt, "ip", network) 68 | 69 | testutil.RequireReceive(pt, delayCh, testTimeout) 70 | 71 | return []netip.Addr{netutil.IPv6Localhost()}, nil 72 | }, 73 | } 74 | 75 | addrs, err := bootstrap.ParallelResolver{immediate, delayed}.LookupNetIP( 76 | context.Background(), 77 | "ip", 78 | hostname, 79 | ) 80 | require.NoError(t, err) 81 | testutil.RequireSend(t, delayCh, struct{}{}, testTimeout) 82 | 83 | assert.Equal(t, hostAddrs, addrs) 84 | }) 85 | 86 | t.Run("all_errors", func(t *testing.T) { 87 | err := assert.AnError 88 | errStr := err.Error() 89 | wantErrMsg := strings.Join([]string{errStr, errStr, errStr}, "\n") 90 | 91 | r := &testResolver{ 92 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 93 | return nil, assert.AnError 94 | }, 95 | } 96 | 97 | addrs, err := bootstrap.ParallelResolver{r, r, r}.LookupNetIP( 98 | context.Background(), 99 | "ip", 100 | hostname, 101 | ) 102 | testutil.AssertErrorMsg(t, wantErrMsg, err) 103 | assert.Nil(t, addrs) 104 | }) 105 | } 106 | -------------------------------------------------------------------------------- /internal/cmd/cmd.go: -------------------------------------------------------------------------------- 1 | // Package cmd is the dnsproxy CLI entry point. 2 | package cmd 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "log/slog" 8 | "net/http" 9 | "net/http/pprof" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | "time" 14 | 15 | "github.com/AdguardTeam/dnsproxy/internal/version" 16 | "github.com/AdguardTeam/dnsproxy/proxy" 17 | "github.com/AdguardTeam/golibs/errors" 18 | "github.com/AdguardTeam/golibs/logutil/slogutil" 19 | "github.com/AdguardTeam/golibs/osutil" 20 | ) 21 | 22 | // Main is the entrypoint of dnsproxy CLI. Main may accept arguments, such as 23 | // embedded assets and command-line arguments. 24 | func Main() { 25 | conf, exitCode, err := parseConfig() 26 | if err != nil { 27 | _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("parsing options: %w", err)) 28 | } 29 | 30 | if conf == nil { 31 | os.Exit(exitCode) 32 | } 33 | 34 | logOutput := os.Stdout 35 | if conf.LogOutput != "" { 36 | // #nosec G302 -- Trust the file path that is given in the 37 | // configuration. 38 | logOutput, err = os.OpenFile(conf.LogOutput, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644) 39 | if err != nil { 40 | _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("cannot create a log file: %s", err)) 41 | 42 | os.Exit(osutil.ExitCodeArgumentError) 43 | } 44 | 45 | defer func() { _ = logOutput.Close() }() 46 | } 47 | 48 | lvl := slog.LevelInfo 49 | if conf.Verbose { 50 | lvl = slog.LevelDebug 51 | } 52 | 53 | l := slogutil.New(&slogutil.Config{ 54 | Output: logOutput, 55 | Format: slogutil.FormatDefault, 56 | Level: lvl, 57 | // TODO(d.kolyshev): Consider making configurable. 58 | AddTimestamp: true, 59 | }) 60 | 61 | ctx := context.Background() 62 | 63 | if conf.Pprof { 64 | runPprof(l) 65 | } 66 | 67 | err = runProxy(ctx, l, conf) 68 | if err != nil { 69 | l.ErrorContext(ctx, "running dnsproxy", slogutil.KeyError, err) 70 | 71 | // As defers are skipped in case of os.Exit, close logOutput manually. 72 | // 73 | // TODO(a.garipov): Consider making logger.Close method. 74 | if logOutput != os.Stdout { 75 | _ = logOutput.Close() 76 | } 77 | 78 | os.Exit(osutil.ExitCodeFailure) 79 | } 80 | } 81 | 82 | // runProxy starts and runs the proxy. l must not be nil. 83 | // 84 | // TODO(e.burkov): Move into separate dnssvc package. 85 | func runProxy(ctx context.Context, l *slog.Logger, conf *configuration) (err error) { 86 | var ( 87 | buildVersion = version.Version() 88 | revision = version.Revision() 89 | branch = version.Branch() 90 | commitTime = version.CommitTime() 91 | ) 92 | 93 | l.InfoContext( 94 | ctx, 95 | "dnsproxy starting", 96 | "version", buildVersion, 97 | "revision", revision, 98 | "branch", branch, 99 | "commit_time", commitTime, 100 | ) 101 | 102 | // Prepare the proxy server and its configuration. 103 | proxyConf, err := createProxyConfig(ctx, l, conf) 104 | if err != nil { 105 | return fmt.Errorf("configuring proxy: %w", err) 106 | } 107 | 108 | dnsProxy, err := proxy.New(proxyConf) 109 | if err != nil { 110 | return fmt.Errorf("creating proxy: %w", err) 111 | } 112 | 113 | // Start the proxy server. 114 | err = dnsProxy.Start(ctx) 115 | if err != nil { 116 | return fmt.Errorf("starting dnsproxy: %w", err) 117 | } 118 | 119 | // TODO(e.burkov): Use [service.SignalHandler]. 120 | signalChannel := make(chan os.Signal, 1) 121 | signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) 122 | <-signalChannel 123 | 124 | // Stopping the proxy. 125 | err = dnsProxy.Shutdown(ctx) 126 | if err != nil { 127 | return fmt.Errorf("stopping dnsproxy: %w", err) 128 | } 129 | 130 | return nil 131 | } 132 | 133 | // runPprof runs pprof server on localhost:6060. 134 | // 135 | // TODO(e.burkov): Use [httputil.RoutePprof]. 136 | func runPprof(l *slog.Logger) { 137 | mux := http.NewServeMux() 138 | mux.HandleFunc("/debug/pprof/", pprof.Index) 139 | mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 140 | mux.HandleFunc("/debug/pprof/profile", pprof.Profile) 141 | mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 142 | mux.HandleFunc("/debug/pprof/trace", pprof.Trace) 143 | mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) 144 | mux.Handle("/debug/pprof/block", pprof.Handler("block")) 145 | mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) 146 | mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) 147 | mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) 148 | mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) 149 | 150 | go func() { 151 | // TODO(d.kolyshev): Consider making configurable. 152 | pprofAddr := "localhost:6060" 153 | l.Info("starting pprof", "addr", pprofAddr) 154 | 155 | srv := &http.Server{ 156 | Addr: pprofAddr, 157 | ReadTimeout: 60 * time.Second, 158 | Handler: mux, 159 | } 160 | 161 | err := srv.ListenAndServe() 162 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 163 | l.Error("pprof failed to listen %v", "addr", pprofAddr, slogutil.KeyError, err) 164 | } 165 | }() 166 | } 167 | -------------------------------------------------------------------------------- /internal/cmd/flag.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/AdguardTeam/golibs/stringutil" 10 | ) 11 | 12 | // uint32Value is an uint32 that can be defined as a flag for [flag.FlagSet]. 13 | type uint32Value uint32 14 | 15 | // type check 16 | var _ flag.Value = (*uint32Value)(nil) 17 | 18 | // Set implements the [flag.Value] interface for *uint32Value. 19 | func (i *uint32Value) Set(s string) (err error) { 20 | v, err := strconv.ParseUint(s, 0, 32) 21 | *i = uint32Value(v) 22 | 23 | return err 24 | } 25 | 26 | // String implements the [flag.Value] interface for *uint32Value. 27 | func (i *uint32Value) String() (out string) { 28 | return strconv.FormatUint(uint64(*i), 10) 29 | } 30 | 31 | // float32Value is an float32 that can be defined as a flag for [flag.FlagSet]. 32 | type float32Value float32 33 | 34 | // type check 35 | var _ flag.Value = (*float32Value)(nil) 36 | 37 | // Set implements the [flag.Value] interface for *float32Value. 38 | func (i *float32Value) Set(s string) (err error) { 39 | v, err := strconv.ParseFloat(s, 32) 40 | *i = float32Value(v) 41 | 42 | return err 43 | } 44 | 45 | // String implements the [flag.Value] interface for *float32Value. 46 | func (i *float32Value) String() (out string) { 47 | return strconv.FormatFloat(float64(*i), 'f', 3, 32) 48 | } 49 | 50 | // intSliceValue represent a struct with a slice of integers that can be defined 51 | // as a flag for [flag.FlagSet]. 52 | type intSliceValue struct { 53 | // values is the pointer to a slice of integers to store parsed values. 54 | values *[]int 55 | 56 | // isSet is false until the corresponding flag is met for the first time. 57 | // When the flag is found, the default value is overwritten with zero value. 58 | isSet bool 59 | } 60 | 61 | // newIntSliceValue returns a pointer to intSliceValue with the given value. 62 | func newIntSliceValue(p *[]int) (out *intSliceValue) { 63 | return &intSliceValue{ 64 | values: p, 65 | isSet: false, 66 | } 67 | } 68 | 69 | // type check 70 | var _ flag.Value = (*intSliceValue)(nil) 71 | 72 | // Set implements the [flag.Value] interface for *intSliceValue. 73 | func (i *intSliceValue) Set(s string) (err error) { 74 | v, err := strconv.Atoi(s) 75 | if err != nil { 76 | return fmt.Errorf("parsing integer slice arg %q: %w", s, err) 77 | } 78 | 79 | if !i.isSet { 80 | i.isSet = true 81 | *i.values = []int{} 82 | } 83 | 84 | *i.values = append(*i.values, v) 85 | 86 | return nil 87 | } 88 | 89 | // String implements the [flag.Value] interface for *intSliceValue. 90 | func (i *intSliceValue) String() (out string) { 91 | if i == nil || i.values == nil { 92 | return "" 93 | } 94 | 95 | sb := &strings.Builder{} 96 | for idx, v := range *i.values { 97 | if idx > 0 { 98 | stringutil.WriteToBuilder(sb, ",") 99 | } 100 | 101 | stringutil.WriteToBuilder(sb, strconv.Itoa(v)) 102 | } 103 | 104 | return sb.String() 105 | } 106 | 107 | // stringSliceValue represent a struct with a slice of strings that can be 108 | // defined as a flag for [flag.FlagSet]. 109 | type stringSliceValue struct { 110 | // values is the pointer to a slice of string to store parsed values. 111 | values *[]string 112 | 113 | // isSet is false until the corresponding flag is met for the first time. 114 | // When the flag is found, the default value is overwritten with zero value. 115 | isSet bool 116 | } 117 | 118 | // newStringSliceValue returns a pointer to stringSliceValue with the given 119 | // value. 120 | func newStringSliceValue(p *[]string) (out *stringSliceValue) { 121 | return &stringSliceValue{ 122 | values: p, 123 | isSet: false, 124 | } 125 | } 126 | 127 | // type check 128 | var _ flag.Value = (*stringSliceValue)(nil) 129 | 130 | // Set implements the [flag.Value] interface for *stringSliceValue. 131 | func (i *stringSliceValue) Set(s string) (err error) { 132 | if !i.isSet { 133 | i.isSet = true 134 | *i.values = []string{} 135 | } 136 | 137 | *i.values = append(*i.values, s) 138 | 139 | return nil 140 | } 141 | 142 | // String implements the [flag.Value] interface for *stringSliceValue. 143 | func (i *stringSliceValue) String() (out string) { 144 | if i == nil || i.values == nil { 145 | return "" 146 | } 147 | 148 | sb := &strings.Builder{} 149 | for idx, v := range *i.values { 150 | if idx > 0 { 151 | stringutil.WriteToBuilder(sb, ",") 152 | } 153 | 154 | stringutil.WriteToBuilder(sb, v) 155 | } 156 | 157 | return sb.String() 158 | } 159 | -------------------------------------------------------------------------------- /internal/cmd/tls.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | // NewTLSConfig returns the TLS config that includes a certificate. Use it for 10 | // server TLS configuration or for a client certificate. If caPath is empty, 11 | // system CAs will be used. 12 | func newTLSConfig(conf *configuration) (c *tls.Config, err error) { 13 | // Set default TLS min/max versions 14 | tlsMinVersion := tls.VersionTLS10 15 | tlsMaxVersion := tls.VersionTLS13 16 | 17 | switch conf.TLSMinVersion { 18 | case 1.1: 19 | tlsMinVersion = tls.VersionTLS11 20 | case 1.2: 21 | tlsMinVersion = tls.VersionTLS12 22 | case 1.3: 23 | tlsMinVersion = tls.VersionTLS13 24 | } 25 | 26 | switch conf.TLSMaxVersion { 27 | case 1.0: 28 | tlsMaxVersion = tls.VersionTLS10 29 | case 1.1: 30 | tlsMaxVersion = tls.VersionTLS11 31 | case 1.2: 32 | tlsMaxVersion = tls.VersionTLS12 33 | } 34 | 35 | cert, err := loadX509KeyPair(conf.TLSCertPath, conf.TLSKeyPath) 36 | if err != nil { 37 | return nil, fmt.Errorf("loading TLS cert: %s", err) 38 | } 39 | 40 | // #nosec G402 -- TLS MinVersion is configured by user. 41 | return &tls.Config{ 42 | Certificates: []tls.Certificate{cert}, 43 | MinVersion: uint16(tlsMinVersion), 44 | MaxVersion: uint16(tlsMaxVersion), 45 | }, nil 46 | } 47 | 48 | // loadX509KeyPair reads and parses a public/private key pair from a pair of 49 | // files. The files must contain PEM encoded data. The certificate file may 50 | // contain intermediate certificates following the leaf certificate to form a 51 | // certificate chain. On successful return, Certificate.Leaf will be nil 52 | // because the parsed form of the certificate is not retained. 53 | func loadX509KeyPair(certFile, keyFile string) (crt tls.Certificate, err error) { 54 | // #nosec G304 -- Trust the file path that is given in the configuration. 55 | certPEMBlock, err := os.ReadFile(certFile) 56 | if err != nil { 57 | return tls.Certificate{}, err 58 | } 59 | 60 | // #nosec G304 -- Trust the file path that is given in the configuration. 61 | keyPEMBlock, err := os.ReadFile(keyFile) 62 | if err != nil { 63 | return tls.Certificate{}, err 64 | } 65 | 66 | return tls.X509KeyPair(certPEMBlock, keyPEMBlock) 67 | } 68 | -------------------------------------------------------------------------------- /internal/dnsmsg/constructor.go: -------------------------------------------------------------------------------- 1 | // Package dnsmsg contains common constants, functions, and types for inspecting 2 | // and constructing DNS messages. 3 | package dnsmsg 4 | 5 | import ( 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // MessageConstructor creates DNS messages. 12 | type MessageConstructor interface { 13 | // NewMsgNXDOMAIN creates a new response message replying to req with the 14 | // NXDOMAIN code. 15 | NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) 16 | 17 | // NewMsgSERVFAIL creates a new response message replying to req with the 18 | // SERVFAIL code. 19 | NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) 20 | 21 | // NewMsgNOTIMPLEMENTED creates a new response message replying to req with 22 | // the NOTIMPLEMENTED code. 23 | NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) 24 | 25 | // NewMsgNODATA creates a new empty response message replying to req with 26 | // the NOERROR code. 27 | // 28 | // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. 29 | NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) 30 | } 31 | 32 | // DefaultMessageConstructor is a default implementation of 33 | // [MessageConstructor]. 34 | type DefaultMessageConstructor struct{} 35 | 36 | // type check 37 | var _ MessageConstructor = DefaultMessageConstructor{} 38 | 39 | // NewMsgNXDOMAIN implements the [MessageConstructor] interface for 40 | // DefaultMessageConstructor. 41 | func (DefaultMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) { 42 | return reply(req, dns.RcodeNameError) 43 | } 44 | 45 | // NewMsgSERVFAIL implements the [MessageConstructor] interface for 46 | // DefaultMessageConstructor. 47 | func (DefaultMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) { 48 | return reply(req, dns.RcodeServerFailure) 49 | } 50 | 51 | // NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for 52 | // DefaultMessageConstructor. 53 | func (DefaultMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) { 54 | resp = reply(req, dns.RcodeNotImplemented) 55 | 56 | // Most of the Internet and especially the inner core has an MTU of at least 57 | // 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet 58 | // is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)). 59 | // 60 | // See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17. 61 | const maxUDPPayload = 1452 62 | 63 | // NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so 64 | // explicitly set it. 65 | resp.SetEdns0(maxUDPPayload, false) 66 | 67 | return resp 68 | } 69 | 70 | // NewMsgNODATA implements the [MessageConstructor] interface for 71 | // DefaultMessageConstructor. 72 | func (DefaultMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) { 73 | resp = reply(req, dns.RcodeSuccess) 74 | 75 | zone := req.Question[0].Name 76 | soa := &dns.SOA{ 77 | // Values copied from verisign's nonexistent .com domain. 78 | // 79 | // Their exact values are not important in our use case because they are 80 | // used for domain transfers between primary/secondary DNS servers. 81 | Refresh: 1800, 82 | Retry: 60, 83 | Expire: 604800, 84 | Minttl: 86400, 85 | // copied from AdGuard DNS 86 | Ns: "fake-for-negative-caching.adguard.com.", 87 | Serial: 100500, 88 | Mbox: "hostmaster.", 89 | // rest is request-specific 90 | Hdr: dns.RR_Header{ 91 | Name: zone, 92 | Rrtype: dns.TypeSOA, 93 | Ttl: 10, 94 | Class: dns.ClassINET, 95 | }, 96 | } 97 | 98 | if !strings.HasPrefix(zone, ".") { 99 | soa.Mbox += zone 100 | } 101 | 102 | resp.Ns = append(resp.Ns, soa) 103 | 104 | return resp 105 | } 106 | 107 | // reply creates a new response message replying to req with the given code. 108 | func reply(req *dns.Msg, code int) (resp *dns.Msg) { 109 | resp = (&dns.Msg{}).SetRcode(req, code) 110 | resp.RecursionAvailable = true 111 | 112 | return resp 113 | } 114 | -------------------------------------------------------------------------------- /internal/dnsproxytest/dnsproxytest.go: -------------------------------------------------------------------------------- 1 | // Package dnsproxytest provides a set of test utilities for the dnsproxy 2 | // module. 3 | package dnsproxytest 4 | -------------------------------------------------------------------------------- /internal/dnsproxytest/interface.go: -------------------------------------------------------------------------------- 1 | package dnsproxytest 2 | 3 | import ( 4 | "github.com/miekg/dns" 5 | ) 6 | 7 | // FakeUpstream is a fake [proxy.Upstream] implementation for tests. 8 | // 9 | // TODO(e.burkov): Move this to the golibs some time later. 10 | type FakeUpstream struct { 11 | OnAddress func() (addr string) 12 | OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) 13 | OnClose func() (err error) 14 | } 15 | 16 | // Address implements the [proxy.Upstream] interface for *FakeUpstream. 17 | func (u *FakeUpstream) Address() (addr string) { 18 | return u.OnAddress() 19 | } 20 | 21 | // Exchange implements the [proxy.Upstream] interface for *FakeUpstream. 22 | func (u *FakeUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { 23 | return u.OnExchange(req) 24 | } 25 | 26 | // Close implements the [proxy.Upstream] interface for *FakeUpstream. 27 | func (u *FakeUpstream) Close() (err error) { 28 | return u.OnClose() 29 | } 30 | 31 | // TestMessageConstructor is a fake [proxy.MessageConstructor] implementation 32 | // for tests. 33 | type TestMessageConstructor struct { 34 | OnNewMsgNXDOMAIN func(req *dns.Msg) (resp *dns.Msg) 35 | OnNewMsgSERVFAIL func(req *dns.Msg) (resp *dns.Msg) 36 | OnNewMsgNOTIMPLEMENTED func(req *dns.Msg) (resp *dns.Msg) 37 | OnNewMsgNODATA func(req *dns.Msg) (resp *dns.Msg) 38 | } 39 | 40 | // NewTestMessageConstructor creates a new *TestMessageConstructor with all it's 41 | // methods set to panic. 42 | func NewTestMessageConstructor() (c *TestMessageConstructor) { 43 | return &TestMessageConstructor{ 44 | OnNewMsgNXDOMAIN: func(_ *dns.Msg) (_ *dns.Msg) { 45 | panic("unexpected call of TestMessageConstructor.NewMsgNXDOMAIN") 46 | }, 47 | OnNewMsgSERVFAIL: func(_ *dns.Msg) (_ *dns.Msg) { 48 | panic("unexpected call of TestMessageConstructor.NewMsgSERVFAIL") 49 | }, 50 | OnNewMsgNOTIMPLEMENTED: func(_ *dns.Msg) (_ *dns.Msg) { 51 | panic("unexpected call of TestMessageConstructor.NewMsgNOTIMPLEMENTED") 52 | }, 53 | OnNewMsgNODATA: func(_ *dns.Msg) (_ *dns.Msg) { 54 | panic("unexpected call of TestMessageConstructor.NewMsgNODATA") 55 | }, 56 | } 57 | } 58 | 59 | // NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for 60 | // *TestMessageConstructor. 61 | func (c *TestMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) { 62 | return c.OnNewMsgNXDOMAIN(req) 63 | } 64 | 65 | // NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for 66 | // *TestMessageConstructor. 67 | func (c *TestMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) { 68 | return c.OnNewMsgSERVFAIL(req) 69 | } 70 | 71 | // NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for 72 | // *TestMessageConstructor. 73 | func (c *TestMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) { 74 | return c.OnNewMsgNOTIMPLEMENTED(req) 75 | } 76 | 77 | // NewMsgNODATA implements the [MessageConstructor] interface for 78 | // *TestMessageConstructor. 79 | func (c *TestMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) { 80 | return c.OnNewMsgNODATA(req) 81 | } 82 | -------------------------------------------------------------------------------- /internal/dnsproxytest/interface_test.go: -------------------------------------------------------------------------------- 1 | package dnsproxytest_test 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/dnsmsg" 5 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 6 | "github.com/AdguardTeam/dnsproxy/upstream" 7 | ) 8 | 9 | // type checks 10 | var ( 11 | _ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil) 12 | _ dnsmsg.MessageConstructor = (*dnsproxytest.TestMessageConstructor)(nil) 13 | ) 14 | -------------------------------------------------------------------------------- /internal/handler/constructor.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/netip" 5 | 6 | "github.com/AdguardTeam/dnsproxy/proxy" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // messageConstructor is an extension of the [proxy.MessageConstructor] 11 | // interface that also provides methods for creating DNS responses. 12 | type messageConstructor interface { 13 | proxy.MessageConstructor 14 | 15 | // NewCompressedResponse creates a new compressed response message for req 16 | // with the given response code. 17 | NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) 18 | 19 | // NewPTRAnswer creates a new resource record for PTR response with the 20 | // given FQDN and PTR domain. Arguments must be fully qualified domain 21 | // names. 22 | NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) 23 | 24 | // NewIPResponse creates a new A/AAAA response message for req with the 25 | // given IP addresses. All IP addresses must be of the same family. 26 | NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) 27 | } 28 | 29 | // defaultConstructor is a wrapper for [proxy.MessageConstructor] that also 30 | // implements the [messageConstructor] interface. 31 | // 32 | // TODO(e.burkov): This implementation reflects the one from AdGuard Home, 33 | // consider moving it to [golibs]. 34 | type defaultConstructor struct { 35 | proxy.MessageConstructor 36 | } 37 | 38 | // type check 39 | var _ messageConstructor = defaultConstructor{} 40 | 41 | // NewCompressedResponse implements the [messageConstructor] interface for 42 | // defaultConstructor. 43 | func (defaultConstructor) NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) { 44 | resp = reply(req, code) 45 | resp.Compress = true 46 | 47 | return resp 48 | } 49 | 50 | // NewPTRAnswer implements the [messageConstructor] interface for 51 | // [defaultConstructor]. 52 | func (defaultConstructor) NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) { 53 | return &dns.PTR{ 54 | Hdr: hdr(fqdn, dns.TypePTR), 55 | Ptr: dns.Fqdn(ptrFQDN), 56 | } 57 | } 58 | 59 | // NewIPResponse implements the [messageConstructor] interface for 60 | // [defaultConstructor] 61 | func (c defaultConstructor) NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) { 62 | var ans []dns.RR 63 | switch req.Question[0].Qtype { 64 | case dns.TypeA: 65 | ans = genAnswersWithIPv4s(req, ips) 66 | case dns.TypeAAAA: 67 | for _, ip := range ips { 68 | if ip.Is6() { 69 | ans = append(ans, newAnswerAAAA(req, ip)) 70 | } 71 | } 72 | default: 73 | // Go on and return an empty response. 74 | } 75 | 76 | resp = c.NewCompressedResponse(req, dns.RcodeSuccess) 77 | resp.Answer = ans 78 | 79 | return resp 80 | } 81 | 82 | // defaultResponseTTL is the default TTL for the DNS responses in seconds. 83 | const defaultResponseTTL = 10 84 | 85 | // hdr creates a new DNS header with the given name and RR type. 86 | func hdr(name string, rrType uint16) (h dns.RR_Header) { 87 | return dns.RR_Header{ 88 | Name: name, 89 | Rrtype: rrType, 90 | Ttl: defaultResponseTTL, 91 | Class: dns.ClassINET, 92 | } 93 | } 94 | 95 | // reply creates a DNS response for req. 96 | func reply(req *dns.Msg, code int) (resp *dns.Msg) { 97 | resp = (&dns.Msg{}).SetRcode(req, code) 98 | resp.RecursionAvailable = true 99 | 100 | return resp 101 | } 102 | 103 | // newAnswerA creates a DNS A answer for req with the given IP address. 104 | func newAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) { 105 | return &dns.A{ 106 | Hdr: hdr(req.Question[0].Name, dns.TypeA), 107 | A: ip.AsSlice(), 108 | } 109 | } 110 | 111 | // newAnswerAAAA creates a DNS AAAA answer for req with the given IP address. 112 | func newAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) { 113 | return &dns.AAAA{ 114 | Hdr: hdr(req.Question[0].Name, dns.TypeAAAA), 115 | AAAA: ip.AsSlice(), 116 | } 117 | } 118 | 119 | // genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any 120 | // of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and 121 | // returns nil, 122 | func genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) { 123 | for _, ip := range ips { 124 | if !ip.Is4() { 125 | return nil 126 | } 127 | 128 | ans = append(ans, newAnswerA(req, ip)) 129 | } 130 | 131 | return ans 132 | } 133 | -------------------------------------------------------------------------------- /internal/handler/default.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | 7 | "github.com/AdguardTeam/dnsproxy/proxy" 8 | "github.com/AdguardTeam/golibs/hostsfile" 9 | ) 10 | 11 | // DefaultConfig is the configuration for [Default]. 12 | type DefaultConfig struct { 13 | // MessageConstructor constructs DNS messages. It must not be nil. 14 | MessageConstructor proxy.MessageConstructor 15 | 16 | // Logger is the logger. It must not be nil. 17 | Logger *slog.Logger 18 | 19 | // HostsFiles is the index containing the records of the hosts files. 20 | HostsFiles hostsfile.Storage 21 | 22 | // HaltIPv6 halts the processing of AAAA requests and makes the handler 23 | // reply with NODATA to them. 24 | HaltIPv6 bool 25 | } 26 | 27 | // Default implements the default configurable [proxy.RequestHandler]. 28 | type Default struct { 29 | messages messageConstructor 30 | hosts hostsfile.Storage 31 | logger *slog.Logger 32 | isIPv6Halted bool 33 | } 34 | 35 | // NewDefault creates a new [Default] handler. 36 | func NewDefault(conf *DefaultConfig) (d *Default) { 37 | mc, ok := conf.MessageConstructor.(messageConstructor) 38 | if !ok { 39 | mc = defaultConstructor{ 40 | MessageConstructor: conf.MessageConstructor, 41 | } 42 | } 43 | 44 | return &Default{ 45 | logger: conf.Logger, 46 | isIPv6Halted: conf.HaltIPv6, 47 | messages: mc, 48 | hosts: conf.HostsFiles, 49 | } 50 | } 51 | 52 | // HandleRequest resolves the DNS request within proxyCtx. It only calls 53 | // [proxy.Proxy.Resolve] if the request isn't handled by any of the internal 54 | // handlers. 55 | func (h *Default) HandleRequest(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) { 56 | // TODO(e.burkov): Use the [*context.Context] instead of 57 | // [*proxy.DNSContext] when the interface-based handler is implemented. 58 | ctx := context.TODO() 59 | 60 | h.logger.DebugContext(ctx, "handling request", "req", &proxyCtx.Req.Question[0]) 61 | 62 | if proxyCtx.Res = h.haltAAAA(ctx, proxyCtx.Req); proxyCtx.Res != nil { 63 | return nil 64 | } 65 | 66 | if proxyCtx.Res = h.resolveFromHosts(ctx, proxyCtx.Req); proxyCtx.Res != nil { 67 | return nil 68 | } 69 | 70 | return p.Resolve(proxyCtx) 71 | } 72 | -------------------------------------------------------------------------------- /internal/handler/handler.go: -------------------------------------------------------------------------------- 1 | // Package handler provides some customizable DNS request handling logic used in 2 | // the proxy. 3 | package handler 4 | -------------------------------------------------------------------------------- /internal/handler/hosts.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/netip" 7 | "os" 8 | "slices" 9 | "strings" 10 | 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/AdguardTeam/golibs/hostsfile" 13 | "github.com/AdguardTeam/golibs/logutil/slogutil" 14 | "github.com/AdguardTeam/golibs/netutil" 15 | "github.com/miekg/dns" 16 | ) 17 | 18 | // emptyStorage is a [hostsfile.Storage] that contains no records. 19 | // 20 | // TODO(e.burkov): Move to [hostsfile]. 21 | type emptyStorage [0]hostsfile.Record 22 | 23 | // type check 24 | var _ hostsfile.Storage = emptyStorage{} 25 | 26 | // ByAddr implements the [hostsfile.Storage] interface for [emptyStorage]. 27 | func (emptyStorage) ByAddr(_ netip.Addr) (names []string) { 28 | return nil 29 | } 30 | 31 | // ByName implements the [hostsfile.Storage] interface for [emptyStorage]. 32 | func (emptyStorage) ByName(_ string) (addrs []netip.Addr) { 33 | return nil 34 | } 35 | 36 | // ReadHosts reads the hosts files from the file system and returns a storage 37 | // with parsed records. strg is always usable even if an error occurred. 38 | func ReadHosts(paths []string) (strg hostsfile.Storage, err error) { 39 | // Don't check the error since it may only appear when any readers used. 40 | defaultStrg, _ := hostsfile.NewDefaultStorage() 41 | 42 | var errs []error 43 | for _, path := range paths { 44 | err = readHostsFile(defaultStrg, path) 45 | if err != nil { 46 | // Don't wrap the error since it's informative enough as is. 47 | errs = append(errs, err) 48 | } 49 | } 50 | 51 | // TODO(e.burkov): Add method for length. 52 | isEmpty := true 53 | defaultStrg.RangeAddrs(func(_ string, _ []netip.Addr) (cont bool) { 54 | isEmpty = false 55 | 56 | return false 57 | }) 58 | 59 | if isEmpty { 60 | return emptyStorage{}, errors.Join(errs...) 61 | } 62 | 63 | return defaultStrg, errors.Join(errs...) 64 | } 65 | 66 | // readHostsFile reads the hosts file at path and parses it into strg. 67 | func readHostsFile(strg *hostsfile.DefaultStorage, path string) (err error) { 68 | // #nosec G304 -- Trust the file path from the configuration file. 69 | f, err := os.Open(path) 70 | if err != nil { 71 | // Don't wrap the error since it's informative enough as is. 72 | return err 73 | } 74 | 75 | defer func() { err = errors.WithDeferred(err, f.Close()) }() 76 | 77 | err = hostsfile.Parse(strg, f, nil) 78 | if err != nil { 79 | return fmt.Errorf("parsing hosts file %q: %w", path, err) 80 | } 81 | 82 | return nil 83 | } 84 | 85 | // resolveFromHosts resolves the DNS query from the hosts file. It fills the 86 | // response with the A, AAAA, and PTR records from the hosts file. 87 | func (h *Default) resolveFromHosts(ctx context.Context, req *dns.Msg) (resp *dns.Msg) { 88 | var addrs []netip.Addr 89 | var ptrs []string 90 | 91 | q := req.Question[0] 92 | name := strings.TrimSuffix(q.Name, ".") 93 | switch q.Qtype { 94 | case dns.TypeA: 95 | addrs = slices.Clone(h.hosts.ByName(name)) 96 | addrs = slices.DeleteFunc(addrs, netip.Addr.Is6) 97 | case dns.TypeAAAA: 98 | addrs = slices.Clone(h.hosts.ByName(name)) 99 | addrs = slices.DeleteFunc(addrs, netip.Addr.Is4) 100 | case dns.TypePTR: 101 | addr, err := netutil.IPFromReversedAddr(name) 102 | if err != nil { 103 | h.logger.DebugContext(ctx, "failed parsing ptr", slogutil.KeyError, err) 104 | 105 | return nil 106 | } 107 | 108 | ptrs = h.hosts.ByAddr(addr) 109 | default: 110 | return nil 111 | } 112 | 113 | switch { 114 | case len(addrs) > 0: 115 | resp = h.messages.NewIPResponse(req, addrs) 116 | case len(ptrs) > 0: 117 | resp = h.messages.NewCompressedResponse(req, dns.RcodeSuccess) 118 | name = req.Question[0].Name 119 | for _, ptr := range ptrs { 120 | resp.Answer = append(resp.Answer, h.messages.NewPTRAnswer(name, dns.Fqdn(ptr))) 121 | } 122 | default: 123 | h.logger.DebugContext(ctx, "no hosts records found", "name", name, "qtype", q.Qtype) 124 | } 125 | 126 | return resp 127 | } 128 | -------------------------------------------------------------------------------- /internal/handler/ipv6halt.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // haltAAAA halts the processing of AAAA requests if IPv6 is disabled. req must 10 | // not be nil. 11 | func (h *Default) haltAAAA(ctx context.Context, req *dns.Msg) (resp *dns.Msg) { 12 | if h.isIPv6Halted && req.Question[0].Qtype == dns.TypeAAAA { 13 | h.logger.DebugContext( 14 | ctx, 15 | "ipv6 is disabled; replying with empty response", 16 | "req", req.Question[0].Name, 17 | ) 18 | 19 | return h.messages.NewMsgNODATA(req) 20 | } 21 | 22 | return nil 23 | } 24 | -------------------------------------------------------------------------------- /internal/handler/testdata/TestDefault_resolveFromHosts/hosts: -------------------------------------------------------------------------------- 1 | 1.2.3.4 ipv4.domain.example 2 | 2001:db8::1 ipv6.domain.example 3 | # comment 4 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | import ( 4 | "log/slog" 5 | "net" 6 | ) 7 | 8 | // ListenConfig returns the default [net.ListenConfig] used by the plain-DNS 9 | // servers in this module. l must not be nil. 10 | // 11 | // TODO(a.garipov): Add tests. 12 | // 13 | // TODO(a.garipov): Add an option to not set SO_REUSEPORT on Unix to prevent 14 | // issues with OpenWrt. 15 | // 16 | // See https://github.com/AdguardTeam/AdGuardHome/issues/5872. 17 | // 18 | // TODO(a.garipov): DRY with AdGuard DNS when we can. 19 | func ListenConfig(l *slog.Logger) (lc *net.ListenConfig) { 20 | return &net.ListenConfig{ 21 | Control: listenControl{logger: l}.defaultListenControl, 22 | } 23 | } 24 | 25 | // listenControl is a wrapper struct with logger. 26 | type listenControl struct { 27 | logger *slog.Logger 28 | } 29 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package netutil 4 | 5 | import ( 6 | "fmt" 7 | "syscall" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "golang.org/x/sys/unix" 12 | ) 13 | 14 | // defaultListenControl is used as a [net.ListenConfig.Control] function to set 15 | // the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the 16 | // DNS servers in this module. 17 | func (lc listenControl) defaultListenControl(_, _ string, c syscall.RawConn) (err error) { 18 | var opErr error 19 | err = c.Control(func(fd uintptr) { 20 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) 21 | if opErr != nil { 22 | opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr) 23 | 24 | return 25 | } 26 | 27 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) 28 | if opErr != nil { 29 | if errors.Is(opErr, unix.ENOPROTOOPT) { 30 | // Some Linux OSs do not seem to support SO_REUSEPORT, including 31 | // some varieties of OpenWrt. Issue a warning. 32 | lc.logger.Warn("SO_REUSEPORT not supported", slogutil.KeyError, opErr) 33 | opErr = nil 34 | } else { 35 | opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr) 36 | } 37 | } 38 | }) 39 | 40 | return errors.WithDeferred(opErr, err) 41 | } 42 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import "syscall" 6 | 7 | // defaultListenControl is nil on Windows, because it doesn't support 8 | // SO_REUSEPORT. 9 | func (listenControl) defaultListenControl(_, _ string, _ syscall.RawConn) (err error) { 10 | return nil 11 | } 12 | -------------------------------------------------------------------------------- /internal/netutil/netutil.go: -------------------------------------------------------------------------------- 1 | // Package netutil contains network-related utilities common among dnsproxy 2 | // packages. 3 | // 4 | // TODO(a.garipov): Move improved versions of these into netutil in module 5 | // golibs. 6 | package netutil 7 | 8 | import ( 9 | "net/netip" 10 | "strings" 11 | ) 12 | 13 | // ParseSubnet parses s either as a CIDR prefix itself, or as an IP address, 14 | // returning the corresponding single-IP CIDR prefix. 15 | // 16 | // TODO(e.burkov): Replace usages with [netutil.Prefix]. 17 | func ParseSubnet(s string) (p netip.Prefix, err error) { 18 | if strings.Contains(s, "/") { 19 | p, err = netip.ParsePrefix(s) 20 | if err != nil { 21 | return netip.Prefix{}, err 22 | } 23 | } else { 24 | var ip netip.Addr 25 | ip, err = netip.ParseAddr(s) 26 | if err != nil { 27 | return netip.Prefix{}, err 28 | } 29 | 30 | p = netip.PrefixFrom(ip, ip.BitLen()) 31 | } 32 | 33 | return p, nil 34 | } 35 | -------------------------------------------------------------------------------- /internal/netutil/paths.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | // DefaultHostsPaths returns the slice of default paths to system hosts files. 4 | // 5 | // TODO(s.chzhen): Since [fs.FS] is no longer needed, update the 6 | // [hostsfile.DefaultHostsPaths] from golibs. 7 | func DefaultHostsPaths() (paths []string, err error) { 8 | return defaultHostsPaths() 9 | } 10 | -------------------------------------------------------------------------------- /internal/netutil/paths_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package netutil 4 | 5 | import "github.com/AdguardTeam/golibs/hostsfile" 6 | 7 | // defaultHostsPaths returns default paths to hosts files for UNIX. 8 | func defaultHostsPaths() (paths []string, err error) { 9 | paths, err = hostsfile.DefaultHostsPaths() 10 | if err != nil { 11 | // Should not happen because error is always nil. 12 | panic(err) 13 | } 14 | 15 | res := make([]string, 0, len(paths)) 16 | for _, p := range paths { 17 | res = append(res, "/"+p) 18 | } 19 | 20 | return res, nil 21 | } 22 | -------------------------------------------------------------------------------- /internal/netutil/paths_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import ( 6 | "fmt" 7 | "path" 8 | 9 | "golang.org/x/sys/windows" 10 | ) 11 | 12 | // defaultHostsPaths returns default paths to hosts files for Windows. 13 | func defaultHostsPaths() (paths []string, err error) { 14 | sysDir, err := windows.GetSystemDirectory() 15 | if err != nil { 16 | return []string{}, fmt.Errorf("getting system directory: %w", err) 17 | } 18 | 19 | p := path.Join(sysDir, "drivers", "etc", "hosts") 20 | 21 | return []string{p}, nil 22 | } 23 | -------------------------------------------------------------------------------- /internal/netutil/testdata/TestHosts/bad_file/hosts: -------------------------------------------------------------------------------- 1 | # comment about the following empty line 2 | 3 | # comment about the above empty line 4 | 5 | 1.2.3.256 a.b # invalid address 6 | 1.2.3.4 a.123 # invalid top-level domain 7 | 1.2.3.4 .a.b # empty domain 8 | -------------------------------------------------------------------------------- /internal/netutil/testdata/TestHosts/good_file/hosts: -------------------------------------------------------------------------------- 1 | # IPv4 2 | 3 | # 1st host. 4 | 0.0.0.1 Host.One 5 | 6 | # 2nd host. 7 | 0.0.0.2 Host.Two 8 | 9 | # 1st host full duplicate. 10 | 0.0.0.1 host.one 11 | 12 | # 2nd host duplicate with new name. 13 | 0.0.0.2 host.two Host.New 14 | 15 | # 1st host with foreign name. 16 | 0.0.0.1 host.new 17 | 18 | # 2nd host new name. 19 | 0.0.0.2 Again.Host.Two 20 | 21 | # Mapped 22 | 23 | # 1st host. 24 | ::ffff:0.0.0.1 Host.One 25 | 26 | # 2nd host. 27 | ::ffff:0.0.0.2 Host.Two 28 | 29 | # 1st host full duplicate. 30 | ::ffff:0.0.0.1 host.one 31 | 32 | # 2nd host duplicate with new name. 33 | ::ffff:0.0.0.2 host.two Host.New 34 | 35 | # 1st host with foreign name. 36 | ::ffff:0.0.0.1 host.new 37 | 38 | # 2nd host new name. 39 | ::ffff:0.0.0.2 Again.Host.Two 40 | 41 | # IPv6 42 | 43 | # 1st host. 44 | ::1 Host.One 45 | 46 | # 2nd host. 47 | ::2 Host.Two 48 | 49 | # 1st host full duplicate. 50 | ::1 host.one 51 | 52 | # 2nd host duplicate with new name. 53 | ::2 host.two Host.New 54 | 55 | # 1st host with foreign name. 56 | ::1 host.new 57 | 58 | # 2nd host new name. 59 | ::2 Again.Host.Two 60 | -------------------------------------------------------------------------------- /internal/netutil/udp.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | ) 7 | 8 | // UDPGetOOBSize returns maximum size of the received OOB data. 9 | func UDPGetOOBSize() (oobSize int) { 10 | return udpGetOOBSize() 11 | } 12 | 13 | // UDPSetOptions sets flag options on a UDP socket to be able to receive the 14 | // necessary OOB data. 15 | func UDPSetOptions(c *net.UDPConn) (err error) { 16 | return udpSetOptions(c) 17 | } 18 | 19 | // UDPRead reads the message from conn using buf and receives a control-message 20 | // payload of size udpOOBSize from it. It returns the number of bytes copied 21 | // into buf and the source address of the message. 22 | // 23 | // TODO(s.chzhen): Consider using netip.Addr. 24 | func UDPRead( 25 | conn *net.UDPConn, 26 | buf []byte, 27 | udpOOBSize int, 28 | ) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) { 29 | return udpRead(conn, buf, udpOOBSize) 30 | } 31 | 32 | // UDPWrite writes the data to the remoteAddr using conn. 33 | // 34 | // TODO(s.chzhen): Consider using netip.Addr. 35 | func UDPWrite( 36 | data []byte, 37 | conn *net.UDPConn, 38 | remoteAddr *net.UDPAddr, 39 | localIP netip.Addr, 40 | ) (n int, err error) { 41 | return udpWrite(data, conn, remoteAddr, localIP) 42 | } 43 | -------------------------------------------------------------------------------- /internal/netutil/udp_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package netutil 4 | 5 | import ( 6 | "fmt" 7 | "net" 8 | "net/netip" 9 | 10 | "github.com/AdguardTeam/golibs/netutil" 11 | "golang.org/x/net/ipv4" 12 | "golang.org/x/net/ipv6" 13 | ) 14 | 15 | // These are the set of socket option flags for configuring an IPv[46] UDP 16 | // connection to receive an appropriate OOB data. For both versions the flags 17 | // are: 18 | // 19 | // - FlagDst 20 | // - FlagInterface 21 | const ( 22 | ipv4Flags ipv4.ControlFlags = ipv4.FlagDst | ipv4.FlagInterface 23 | ipv6Flags ipv6.ControlFlags = ipv6.FlagDst | ipv6.FlagInterface 24 | ) 25 | 26 | // udpGetOOBSize obtains the destination IP from OOB data. 27 | func udpGetOOBSize() (oobSize int) { 28 | return max(len(ipv4.NewControlMessage(ipv4Flags)), len(ipv6.NewControlMessage(ipv6Flags))) 29 | } 30 | 31 | func udpSetOptions(c *net.UDPConn) (err error) { 32 | err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6Flags, true) 33 | err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4Flags, true) 34 | if err6 != nil && err4 != nil { 35 | return fmt.Errorf("failed to call SetControlMessage: ipv4: %v; ipv6: %v", err4, err6) 36 | } 37 | 38 | return nil 39 | } 40 | 41 | func udpGetDstFromOOB(oob []byte) (dst netip.Addr, err error) { 42 | cm6 := &ipv6.ControlMessage{} 43 | if cm6.Parse(oob) == nil && cm6.Dst != nil { 44 | // Linux maps IPv4 addresses to IPv6 ones by default, so we can get an 45 | // IPv4 dst from an IPv6 control-message. 46 | return netutil.IPToAddrNoMapped(cm6.Dst) 47 | } 48 | 49 | cm4 := &ipv4.ControlMessage{} 50 | if cm4.Parse(oob) == nil && cm4.Dst != nil { 51 | return netutil.IPToAddr(cm4.Dst, netutil.AddrFamilyIPv4) 52 | } 53 | 54 | return netip.Addr{}, nil 55 | } 56 | 57 | func udpRead( 58 | c *net.UDPConn, 59 | buf []byte, 60 | udpOOBSize int, 61 | ) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) { 62 | var oobn int 63 | oob := make([]byte, udpOOBSize) 64 | n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob) 65 | if err != nil { 66 | return -1, netip.Addr{}, nil, err 67 | } 68 | 69 | localIP, err = udpGetDstFromOOB(oob[:oobn]) 70 | if err != nil { 71 | return -1, netip.Addr{}, nil, err 72 | } 73 | 74 | return n, localIP, remoteAddr, nil 75 | } 76 | 77 | func udpWrite( 78 | data []byte, 79 | conn *net.UDPConn, 80 | remoteAddr *net.UDPAddr, 81 | localIP netip.Addr, 82 | ) (n int, err error) { 83 | n, _, err = conn.WriteMsgUDP(data, udpMakeOOBWithSrc(localIP), remoteAddr) 84 | 85 | return n, err 86 | } 87 | -------------------------------------------------------------------------------- /internal/netutil/udp_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import ( 6 | "net" 7 | "net/netip" 8 | ) 9 | 10 | func udpGetOOBSize() int { 11 | return 0 12 | } 13 | 14 | func udpSetOptions(c *net.UDPConn) error { 15 | return nil 16 | } 17 | 18 | func udpRead(c *net.UDPConn, buf []byte, _ int) (int, netip.Addr, *net.UDPAddr, error) { 19 | n, addr, err := c.ReadFrom(buf) 20 | var udpAddr *net.UDPAddr 21 | if addr != nil { 22 | udpAddr = addr.(*net.UDPAddr) 23 | } 24 | 25 | return n, netip.Addr{}, udpAddr, err 26 | } 27 | 28 | func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ netip.Addr) (int, error) { 29 | return conn.WriteTo(bytes, remoteAddr) 30 | } 31 | -------------------------------------------------------------------------------- /internal/netutil/udpoob_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | 3 | package netutil 4 | 5 | import ( 6 | "net/netip" 7 | 8 | "golang.org/x/net/ipv6" 9 | ) 10 | 11 | // udpMakeOOBWithSrc makes the OOB data with the specified source IP. 12 | func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) { 13 | if ip.Is4() { 14 | // Do not set the IPv4 source address via OOB, because it can cause the 15 | // address to become unspecified on darwin. 16 | // 17 | // See https://github.com/AdguardTeam/AdGuardHome/issues/2807. 18 | // 19 | // TODO(e.burkov): Develop a workaround to make it write OOB only when 20 | // listening on an unspecified address. 21 | return []byte{} 22 | } 23 | 24 | return (&ipv6.ControlMessage{ 25 | Src: ip.AsSlice(), 26 | }).Marshal() 27 | } 28 | -------------------------------------------------------------------------------- /internal/netutil/udpoob_others.go: -------------------------------------------------------------------------------- 1 | //go:build !darwin 2 | 3 | package netutil 4 | 5 | import ( 6 | "net/netip" 7 | 8 | "golang.org/x/net/ipv4" 9 | "golang.org/x/net/ipv6" 10 | ) 11 | 12 | // udpMakeOOBWithSrc makes the OOB data with the specified source IP. 13 | func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) { 14 | if ip.Is4() { 15 | return (&ipv4.ControlMessage{ 16 | Src: ip.AsSlice(), 17 | }).Marshal() 18 | } 19 | 20 | return (&ipv6.ControlMessage{ 21 | Src: ip.AsSlice(), 22 | }).Marshal() 23 | } 24 | -------------------------------------------------------------------------------- /internal/version/version.go: -------------------------------------------------------------------------------- 1 | // Package version contains dnsproxy version information. 2 | package version 3 | 4 | // Versions 5 | 6 | // These are set by the linker. Unfortunately, we cannot set constants during 7 | // linking, and Go doesn't have a concept of immutable variables, so to be 8 | // thorough we have to only export them through getters. 9 | var ( 10 | branch string 11 | committime string 12 | revision string 13 | version string 14 | ) 15 | 16 | // Branch returns the compiled-in value of the Git branch. 17 | func Branch() (b string) { 18 | return branch 19 | } 20 | 21 | // CommitTime returns the compiled-in value of the build time as a string. 22 | func CommitTime() (t string) { 23 | return committime 24 | } 25 | 26 | // Revision returns the compiled-in value of the Git revision. 27 | func Revision() (r string) { 28 | return revision 29 | } 30 | 31 | // Version returns the compiled-in value of the build version as a string. 32 | func Version() (v string) { 33 | return version 34 | } 35 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/cmd" 5 | ) 6 | 7 | func main() { 8 | cmd.Main() 9 | } 10 | -------------------------------------------------------------------------------- /proxy/beforerequest.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // BeforeRequestError is an error that signals that the request should be 12 | // responded with the given response message. 13 | type BeforeRequestError struct { 14 | // Err is the error that caused the response. It must not be nil. 15 | Err error 16 | 17 | // Response is the response message to be sent to the client. It must be a 18 | // valid response message. 19 | Response *dns.Msg 20 | } 21 | 22 | // type check 23 | var _ error = (*BeforeRequestError)(nil) 24 | 25 | // Error implements the [error] interface for *BeforeRequestError. 26 | func (e *BeforeRequestError) Error() (msg string) { 27 | return fmt.Sprintf("%s; respond with %s", e.Err, dns.RcodeToString[e.Response.Rcode]) 28 | } 29 | 30 | // type check 31 | var _ errors.Wrapper = (*BeforeRequestError)(nil) 32 | 33 | // Unwrap implements the [errors.Wrapper] interface for *BeforeRequestError. 34 | func (e *BeforeRequestError) Unwrap() (unwrapped error) { 35 | return e.Err 36 | } 37 | 38 | // BeforeRequestHandler is an object that can handle the request before it's 39 | // processed by [Proxy]. 40 | type BeforeRequestHandler interface { 41 | // HandleBefore is called before each DNS request is started processing. 42 | // The passed [DNSContext] contains the Req, Addr, and IsLocalClient fields 43 | // set accordingly. 44 | // 45 | // If returned err is a [BeforeRequestError], the given response message is 46 | // used. If err is nil, the request is processed further. [Proxy] assumes 47 | // a handler itself doesn't set the [DNSContext.Res] field. 48 | HandleBefore(p *Proxy, dctx *DNSContext) (err error) 49 | } 50 | 51 | // noopRequestHandler is a no-op implementation of [BeforeRequestHandler] that 52 | // always returns nil. 53 | type noopRequestHandler struct{} 54 | 55 | // type check 56 | var _ BeforeRequestHandler = noopRequestHandler{} 57 | 58 | // HandleBefore implements the [BeforeRequestHandler] interface for 59 | // noopRequestHandler. 60 | func (noopRequestHandler) HandleBefore(_ *Proxy, _ *DNSContext) (err error) { 61 | return nil 62 | } 63 | 64 | // handleBefore calls the [BeforeRequestHandler] if it's set. If the returned 65 | // error is nil, it returns true and the request is processed further. If the 66 | // returned error has type [BeforeRequestError], the specified response is sent 67 | // to the client. Otherwise, the request just ignored. 68 | func (p *Proxy) handleBefore(d *DNSContext) (cont bool) { 69 | err := p.beforeRequestHandler.HandleBefore(p, d) 70 | if err == nil { 71 | return true 72 | } 73 | 74 | p.logger.Debug("handling before request", slogutil.KeyError, err) 75 | 76 | if befReqErr := (&BeforeRequestError{}); errors.As(err, &befReqErr) { 77 | d.Res = befReqErr.Response 78 | 79 | p.logDNSMessage(d.Res) 80 | p.respond(d) 81 | } 82 | 83 | return false 84 | } 85 | -------------------------------------------------------------------------------- /proxy/beforerequest_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 11 | "github.com/AdguardTeam/dnsproxy/upstream" 12 | "github.com/AdguardTeam/golibs/errors" 13 | "github.com/AdguardTeam/golibs/logutil/slogutil" 14 | "github.com/AdguardTeam/golibs/netutil" 15 | "github.com/AdguardTeam/golibs/testutil" 16 | "github.com/miekg/dns" 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | ) 20 | 21 | // testBeforeRequestHandler is a mock before request handler implementation to 22 | // simplify testing. 23 | type testBeforeRequestHandler struct { 24 | onHandleBefore func(p *Proxy, dctx *DNSContext) (err error) 25 | } 26 | 27 | // type check 28 | var _ BeforeRequestHandler = (*testBeforeRequestHandler)(nil) 29 | 30 | // HandleBefore implements the [BeforeRequestHandler] interface for 31 | // *testBeforeRequestHandler. 32 | func (h *testBeforeRequestHandler) HandleBefore(p *Proxy, dctx *DNSContext) (err error) { 33 | return h.onHandleBefore(p, dctx) 34 | } 35 | 36 | func TestProxy_HandleDNSRequest_beforeRequestHandler(t *testing.T) { 37 | t.Parallel() 38 | 39 | const ( 40 | allowedID = iota 41 | droppedID 42 | errorID 43 | ) 44 | 45 | allowedRequest := (&dns.Msg{}).SetQuestion("allowed.", dns.TypeA) 46 | allowedRequest.Id = allowedID 47 | allowedResponse := (&dns.Msg{}).SetReply(allowedRequest) 48 | 49 | droppedRequest := (&dns.Msg{}).SetQuestion("dropped.", dns.TypeA) 50 | droppedRequest.Id = droppedID 51 | 52 | errorRequest := (&dns.Msg{}).SetQuestion("error.", dns.TypeA) 53 | errorRequest.Id = errorID 54 | errorResponse := (&dns.Msg{}).SetReply(errorRequest) 55 | 56 | p := mustNew(t, &Config{ 57 | Logger: slogutil.NewDiscardLogger(), 58 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 59 | UpstreamConfig: &UpstreamConfig{ 60 | Upstreams: []upstream.Upstream{&dnsproxytest.FakeUpstream{ 61 | OnExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { 62 | return allowedResponse.Copy(), nil 63 | }, 64 | OnAddress: func() (addr string) { return "general" }, 65 | OnClose: func() (err error) { return nil }, 66 | }}, 67 | }, 68 | TrustedProxies: defaultTrustedProxies, 69 | PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed), 70 | BeforeRequestHandler: &testBeforeRequestHandler{ 71 | onHandleBefore: func(p *Proxy, dctx *DNSContext) (err error) { 72 | switch dctx.Req.Id { 73 | case allowedID: 74 | return nil 75 | case droppedID: 76 | return errors.Error("just drop") 77 | case errorID: 78 | return &BeforeRequestError{ 79 | Err: errors.Error("just error"), 80 | Response: errorResponse, 81 | } 82 | default: 83 | panic(fmt.Sprintf("unexpected request id: %d", dctx.Req.Id)) 84 | } 85 | }, 86 | }, 87 | }) 88 | ctx := context.Background() 89 | require.NoError(t, p.Start(ctx)) 90 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return p.Shutdown(ctx) }) 91 | 92 | client := &dns.Client{ 93 | Net: string(ProtoTCP), 94 | Timeout: 200 * time.Millisecond, 95 | } 96 | addr := p.Addr(ProtoTCP).String() 97 | 98 | t.Run("allowed", func(t *testing.T) { 99 | t.Parallel() 100 | 101 | resp, _, err := client.Exchange(allowedRequest, addr) 102 | require.NoError(t, err) 103 | assert.Equal(t, allowedResponse, resp) 104 | }) 105 | 106 | t.Run("dropped", func(t *testing.T) { 107 | t.Parallel() 108 | 109 | resp, _, err := client.Exchange(droppedRequest, addr) 110 | 111 | wantErr := &net.OpError{} 112 | require.ErrorAs(t, err, &wantErr) 113 | assert.True(t, wantErr.Timeout()) 114 | 115 | assert.Nil(t, resp) 116 | }) 117 | 118 | t.Run("error", func(t *testing.T) { 119 | t.Parallel() 120 | 121 | resp, _, err := client.Exchange(errorRequest, addr) 122 | require.NoError(t, err) 123 | assert.Equal(t, errorResponse, resp) 124 | }) 125 | } 126 | -------------------------------------------------------------------------------- /proxy/bogusnxdomain.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/proxyutil" 5 | "github.com/AdguardTeam/golibs/netutil" 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // isBogusNXDomain returns true if m contains at least a single IP address in 10 | // the Answer section contained in BogusNXDomain subnets of p. 11 | func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) { 12 | if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 { 13 | return false 14 | } else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA { 15 | return false 16 | } 17 | 18 | set := netutil.SliceSubnetSet(p.BogusNXDomain) 19 | for _, rr := range m.Answer { 20 | ip := proxyutil.IPFromRR(rr) 21 | if set.Contains(ip) { 22 | return true 23 | } 24 | } 25 | 26 | return false 27 | } 28 | -------------------------------------------------------------------------------- /proxy/bogusnxdomain_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/netip" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/dnsproxy/upstream" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/miekg/dns" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestProxy_IsBogusNXDomain(t *testing.T) { 18 | prx := mustNew(t, &Config{ 19 | Logger: slogutil.NewDiscardLogger(), 20 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 21 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 22 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 23 | TrustedProxies: defaultTrustedProxies, 24 | RatelimitSubnetLenIPv4: 24, 25 | RatelimitSubnetLenIPv6: 64, 26 | CacheEnabled: true, 27 | BogusNXDomain: []netip.Prefix{ 28 | netip.MustParsePrefix("4.3.2.1/24"), 29 | netip.MustParsePrefix("1.2.3.4/8"), 30 | netip.MustParsePrefix("10.11.12.13/32"), 31 | netip.MustParsePrefix("102:304:506:708:90a:b0c:d0e:f10/120"), 32 | }, 33 | }) 34 | 35 | testCases := []struct { 36 | name string 37 | ans []dns.RR 38 | wantRcode int 39 | }{{ 40 | name: "bogus_subnet", 41 | ans: []dns.RR{&dns.A{ 42 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 43 | A: net.ParseIP("4.3.2.1"), 44 | }}, 45 | wantRcode: dns.RcodeNameError, 46 | }, { 47 | name: "bogus_big_subnet", 48 | ans: []dns.RR{&dns.A{ 49 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 50 | A: net.ParseIP("1.254.254.254"), 51 | }}, 52 | wantRcode: dns.RcodeNameError, 53 | }, { 54 | name: "bogus_single_ip", 55 | ans: []dns.RR{&dns.A{ 56 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 57 | A: net.ParseIP("10.11.12.13"), 58 | }}, 59 | wantRcode: dns.RcodeNameError, 60 | }, { 61 | name: "bogus_6", 62 | ans: []dns.RR{&dns.AAAA{ 63 | Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, 64 | AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99}, 65 | }}, 66 | wantRcode: dns.RcodeNameError, 67 | }, { 68 | name: "non-bogus", 69 | ans: []dns.RR{&dns.A{ 70 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 71 | A: net.ParseIP("10.11.12.14"), 72 | }}, 73 | wantRcode: dns.RcodeSuccess, 74 | }, { 75 | name: "non-bogus_6", 76 | ans: []dns.RR{&dns.AAAA{ 77 | Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, 78 | AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15}, 79 | }}, 80 | wantRcode: dns.RcodeSuccess, 81 | }} 82 | 83 | u := testUpstream{} 84 | prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u} 85 | 86 | ctx := context.Background() 87 | err := prx.Start(ctx) 88 | require.NoError(t, err) 89 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return prx.Shutdown(ctx) }) 90 | 91 | d := &DNSContext{ 92 | Req: newHostTestMessage("host"), 93 | } 94 | 95 | for _, tc := range testCases { 96 | u.ans = tc.ans 97 | 98 | t.Run(tc.name, func(t *testing.T) { 99 | err = prx.Resolve(d) 100 | require.NoError(t, err) 101 | require.NotNil(t, d.Res) 102 | 103 | assert.Equal(t, tc.wantRcode, d.Res.Rcode) 104 | }) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /proxy/constructor.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/dnsmsg" 5 | ) 6 | 7 | // MessageConstructor creates DNS messages. 8 | type MessageConstructor = dnsmsg.MessageConstructor 9 | -------------------------------------------------------------------------------- /proxy/errors.go: -------------------------------------------------------------------------------- 1 | //go:build !plan9 2 | // +build !plan9 3 | 4 | package proxy 5 | 6 | import ( 7 | "syscall" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | ) 11 | 12 | // isEPIPE checks if the underlying error is EPIPE. syscall.EPIPE exists on all 13 | // OSes except for Plan 9. Validate with: 14 | // 15 | // $ for os in $(go tool dist list | cut -d / -f 1 | sort -u) 16 | // do 17 | // echo -n "$os" 18 | // env GOOS="$os" go doc syscall.EPIPE | grep -F -e EPIPE 19 | // done 20 | // 21 | // For the Plan 9 version see ./errors_plan9.go. 22 | func isEPIPE(err error) (ok bool) { 23 | return errors.Is(err, syscall.EPIPE) 24 | } 25 | -------------------------------------------------------------------------------- /proxy/errors_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build !plan9 2 | // +build !plan9 3 | 4 | package proxy 5 | 6 | import ( 7 | "fmt" 8 | "syscall" 9 | "testing" 10 | 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestIsEPIPE(t *testing.T) { 16 | type testCase struct { 17 | err error 18 | name string 19 | want bool 20 | } 21 | 22 | testCases := []testCase{{ 23 | name: "nil", 24 | err: nil, 25 | want: false, 26 | }, { 27 | name: "epipe", 28 | err: syscall.EPIPE, 29 | want: true, 30 | }, { 31 | name: "not_epipe", 32 | err: errors.Error("test error"), 33 | want: false, 34 | }, { 35 | name: "wrapped_epipe", 36 | err: fmt.Errorf("test error: %w", syscall.EPIPE), 37 | want: true, 38 | }} 39 | 40 | for _, tc := range testCases { 41 | t.Run(tc.name, func(t *testing.T) { 42 | got := isEPIPE(tc.err) 43 | assert.Equal(t, tc.want, got) 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /proxy/errors_plan9.go: -------------------------------------------------------------------------------- 1 | //go:build plan9 2 | // +build plan9 3 | 4 | package proxy 5 | 6 | import "strings" 7 | 8 | // isEPIPE checks if the underlying error is EPIPE. Plan 9 relies on error 9 | // strings instead of error codes. I couldn't find the exact constant with the 10 | // text returned by a write on a closed socket, but it seems to be "sys: write 11 | // on closed pipe". See Plan 9's "man 2 notify". 12 | // 13 | // We don't currently support Plan 9, so it's not critical, but when we do, this 14 | // needs to be rechecked. 15 | func isEPIPE(err error) (ok bool) { 16 | return strings.Contains(err.Error(), "write on closed pipe") 17 | } 18 | -------------------------------------------------------------------------------- /proxy/exchange.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/AdguardTeam/dnsproxy/upstream" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/miekg/dns" 11 | "gonum.org/v1/gonum/stat/sampleuv" 12 | ) 13 | 14 | // exchangeUpstreams resolves req using the given upstreams. It returns the DNS 15 | // response, the upstream that successfully resolved the request, and the error 16 | // if any. 17 | func (p *Proxy) exchangeUpstreams( 18 | req *dns.Msg, 19 | ups []upstream.Upstream, 20 | ) (resp *dns.Msg, u upstream.Upstream, err error) { 21 | switch p.UpstreamMode { 22 | case UpstreamModeParallel: 23 | return upstream.ExchangeParallel(ups, req) 24 | case UpstreamModeFastestAddr: 25 | switch req.Question[0].Qtype { 26 | case dns.TypeA, dns.TypeAAAA: 27 | return p.fastestAddr.ExchangeFastest(req, ups) 28 | default: 29 | // Go on to the load-balancing mode. 30 | } 31 | default: 32 | // Go on to the load-balancing mode. 33 | } 34 | 35 | if len(ups) == 1 { 36 | u = ups[0] 37 | resp, _, err = p.exchange(u, req) 38 | if err != nil { 39 | return nil, nil, err 40 | } 41 | 42 | // TODO(e.burkov): Consider updating the RTT of a single upstream. 43 | 44 | return resp, u, err 45 | } 46 | 47 | w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc) 48 | var errs []error 49 | for i, ok := w.Take(); ok; i, ok = w.Take() { 50 | u = ups[i] 51 | 52 | var elapsed time.Duration 53 | resp, elapsed, err = p.exchange(u, req) 54 | if err == nil { 55 | p.updateRTT(u.Address(), elapsed) 56 | 57 | return resp, u, nil 58 | } 59 | 60 | errs = append(errs, err) 61 | 62 | // TODO(e.burkov): Use the actual configured timeout or, perhaps, the 63 | // actual measured elapsed time. 64 | p.updateRTT(u.Address(), defaultTimeout) 65 | } 66 | 67 | err = fmt.Errorf("all upstreams failed to exchange request: %w", errors.Join(errs...)) 68 | 69 | return nil, nil, err 70 | } 71 | 72 | // exchange returns the result of the DNS request exchange with the given 73 | // upstream and the elapsed time in milliseconds. It uses the given clock to 74 | // measure the request duration. 75 | func (p *Proxy) exchange( 76 | u upstream.Upstream, 77 | req *dns.Msg, 78 | ) (resp *dns.Msg, dur time.Duration, err error) { 79 | startTime := p.time.Now() 80 | resp, err = u.Exchange(req) 81 | 82 | // Don't use [time.Since] because it uses [time.Now]. 83 | dur = p.time.Now().Sub(startTime) 84 | 85 | addr := u.Address() 86 | q := &req.Question[0] 87 | if err != nil { 88 | p.logger.Error( 89 | "exchange failed", 90 | "upstream", addr, 91 | "question", q, 92 | "duration", dur, 93 | slogutil.KeyError, err, 94 | ) 95 | } else { 96 | p.logger.Debug( 97 | "exchange successfully finished", 98 | "upstream", addr, 99 | "question", q, 100 | "duration", dur, 101 | ) 102 | } 103 | 104 | return resp, dur, err 105 | } 106 | 107 | // upstreamRTTStats is the statistics for a single upstream's round-trip time. 108 | type upstreamRTTStats struct { 109 | // rttSum is the sum of all the round-trip times in microseconds. The 110 | // float64 type is used since it's capable of representing about 285 years 111 | // in microseconds. 112 | rttSum float64 113 | 114 | // reqNum is the number of requests to the upstream. The float64 type is 115 | // used since to avoid unnecessary type conversions. 116 | reqNum float64 117 | } 118 | 119 | // update returns updated stats after adding given RTT. 120 | func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) { 121 | return upstreamRTTStats{ 122 | rttSum: stats.rttSum + float64(rtt.Microseconds()), 123 | reqNum: stats.reqNum + 1, 124 | } 125 | } 126 | 127 | // calcWeights returns the slice of weights, each corresponding to the upstream 128 | // with the same index in the given slice. 129 | func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) { 130 | weights = make([]float64, 0, len(ups)) 131 | 132 | p.rttLock.Lock() 133 | defer p.rttLock.Unlock() 134 | 135 | for _, u := range ups { 136 | stat := p.upstreamRTTStats[u.Address()] 137 | if stat.rttSum == 0 || stat.reqNum == 0 { 138 | // Use 1 as the default weight. 139 | weights = append(weights, 1) 140 | } else { 141 | weights = append(weights, 1/(stat.rttSum/stat.reqNum)) 142 | } 143 | } 144 | 145 | return weights 146 | } 147 | 148 | // updateRTT updates the round-trip time in [upstreamRTTStats] for given 149 | // address. 150 | func (p *Proxy) updateRTT(address string, rtt time.Duration) { 151 | p.rttLock.Lock() 152 | defer p.rttLock.Unlock() 153 | 154 | if p.upstreamRTTStats == nil { 155 | p.upstreamRTTStats = map[string]upstreamRTTStats{} 156 | } 157 | 158 | p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt) 159 | } 160 | -------------------------------------------------------------------------------- /proxy/handler_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/AdguardTeam/golibs/testutil" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestFilteringHandler(t *testing.T) { 17 | // Initializing the test middleware 18 | m := &sync.RWMutex{} 19 | blockResponse := false 20 | 21 | // Prepare the proxy server 22 | dnsProxy := mustNew(t, &Config{ 23 | Logger: slogutil.NewDiscardLogger(), 24 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 25 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 26 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 27 | TrustedProxies: defaultTrustedProxies, 28 | RatelimitSubnetLenIPv4: 24, 29 | RatelimitSubnetLenIPv6: 64, 30 | RequestHandler: func(p *Proxy, d *DNSContext) error { 31 | m.Lock() 32 | defer m.Unlock() 33 | 34 | if !blockResponse { 35 | // Use the default Resolve method if response is not blocked 36 | return p.Resolve(d) 37 | } 38 | 39 | resp := dns.Msg{} 40 | resp.SetRcode(d.Req, dns.RcodeNotImplemented) 41 | resp.RecursionAvailable = true 42 | 43 | // Set the response right away 44 | d.Res = &resp 45 | return nil 46 | }, 47 | }) 48 | 49 | // Start listening 50 | ctx := context.Background() 51 | err := dnsProxy.Start(ctx) 52 | require.NoError(t, err) 53 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) }) 54 | 55 | // Create a DNS-over-UDP client connection 56 | addr := dnsProxy.Addr(ProtoUDP) 57 | client := &dns.Client{ 58 | Net: string(ProtoUDP), 59 | Timeout: testTimeout, 60 | } 61 | 62 | // Send the first message (not blocked) 63 | req := newTestMessage() 64 | 65 | r, _, err := client.Exchange(req, addr.String()) 66 | require.NoError(t, err) 67 | requireResponse(t, req, r) 68 | 69 | // Now send the second and make sure it is blocked 70 | m.Lock() 71 | blockResponse = true 72 | m.Unlock() 73 | 74 | r, _, err = client.Exchange(req, addr.String()) 75 | require.NoError(t, err) 76 | assert.Equal(t, dns.RcodeNotImplemented, r.Rcode) 77 | } 78 | -------------------------------------------------------------------------------- /proxy/helpers.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/AdguardTeam/golibs/netutil" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // ecsFromMsg returns the subnet from EDNS Client Subnet option of m if any. 11 | func ecsFromMsg(m *dns.Msg) (subnet *net.IPNet, scope int) { 12 | opt := m.IsEdns0() 13 | if opt == nil { 14 | return nil, 0 15 | } 16 | 17 | var ip net.IP 18 | var mask net.IPMask 19 | for _, e := range opt.Option { 20 | sn, ok := e.(*dns.EDNS0_SUBNET) 21 | if !ok { 22 | continue 23 | } 24 | 25 | switch sn.Family { 26 | case 1: 27 | ip = sn.Address.To4() 28 | mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv4BitLen) 29 | case 2: 30 | ip = sn.Address 31 | mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv6BitLen) 32 | default: 33 | continue 34 | } 35 | 36 | return &net.IPNet{IP: ip, Mask: mask}, int(sn.SourceScope) 37 | } 38 | 39 | return nil, 0 40 | } 41 | 42 | // setECS sets the EDNS client subnet option based on ip and scope into m. It 43 | // returns masked IP and mask length. 44 | func setECS(m *dns.Msg, ip net.IP, scope uint8) (subnet *net.IPNet) { 45 | const ( 46 | // defaultECSv4 is the default length of network mask for IPv4 address 47 | // in ECS option. 48 | defaultECSv4 = 24 49 | 50 | // defaultECSv6 is the default length of network mask for IPv6 address 51 | // in ECS. The size of 7 octets is chosen as a reasonable minimum since 52 | // at least Google's public DNS refuses requests containing the options 53 | // with longer network masks. 54 | defaultECSv6 = 56 55 | ) 56 | 57 | e := &dns.EDNS0_SUBNET{ 58 | Code: dns.EDNS0SUBNET, 59 | SourceScope: scope, 60 | } 61 | 62 | subnet = &net.IPNet{} 63 | if ip4 := ip.To4(); ip4 != nil { 64 | e.Family = 1 65 | e.SourceNetmask = defaultECSv4 66 | subnet.Mask = net.CIDRMask(defaultECSv4, netutil.IPv4BitLen) 67 | ip = ip4 68 | } else { 69 | // Assume the IP address has already been validated. 70 | e.Family = 2 71 | e.SourceNetmask = defaultECSv6 72 | subnet.Mask = net.CIDRMask(defaultECSv6, netutil.IPv6BitLen) 73 | } 74 | subnet.IP = ip.Mask(subnet.Mask) 75 | e.Address = subnet.IP 76 | 77 | // If OPT record already exists so just add EDNS option inside it. Note 78 | // that servers may return FORMERR if they meet several OPT RRs. 79 | if opt := m.IsEdns0(); opt != nil { 80 | opt.Option = append(opt.Option, e) 81 | 82 | return subnet 83 | } 84 | 85 | // Create an OPT record and add EDNS option inside it. 86 | o := &dns.OPT{ 87 | Hdr: dns.RR_Header{ 88 | Name: ".", 89 | Rrtype: dns.TypeOPT, 90 | }, 91 | Option: []dns.EDNS0{e}, 92 | } 93 | o.SetUDPSize(4096) 94 | m.Extra = append(m.Extra, o) 95 | 96 | return subnet 97 | } 98 | -------------------------------------------------------------------------------- /proxy/lookup.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "slices" 7 | 8 | "github.com/AdguardTeam/dnsproxy/proxyutil" 9 | "github.com/AdguardTeam/dnsproxy/upstream" 10 | "github.com/AdguardTeam/golibs/errors" 11 | "github.com/AdguardTeam/golibs/logutil/slogutil" 12 | "github.com/AdguardTeam/golibs/netutil" 13 | "github.com/miekg/dns" 14 | ) 15 | 16 | // helper struct to pass results of lookupIPAddr function 17 | type lookupResult struct { 18 | resp *dns.Msg 19 | err error 20 | } 21 | 22 | // lookupIPAddr resolves the specified host IP addresses. It is intended to be 23 | // used as a goroutine. 24 | func (p *Proxy) lookupIPAddr( 25 | ctx context.Context, 26 | host string, 27 | qtype uint16, 28 | ch chan *lookupResult, 29 | ) { 30 | defer slogutil.RecoverAndLog(ctx, p.logger) 31 | 32 | req := (&dns.Msg{}).SetQuestion(host, qtype) 33 | 34 | // TODO(d.kolyshev): Investigate why the client address is not defined. 35 | d := p.newDNSContext(ProtoUDP, req, netip.AddrPort{}) 36 | err := p.Resolve(d) 37 | ch <- &lookupResult{ 38 | resp: d.Res, 39 | err: err, 40 | } 41 | } 42 | 43 | // ErrEmptyHost is returned by LookupIPAddr when the host is empty and can't be 44 | // resolved. 45 | const ErrEmptyHost = errors.Error("host is empty") 46 | 47 | // type check 48 | var _ upstream.Resolver = (*Proxy)(nil) 49 | 50 | // LookupNetIP implements the [upstream.Resolver] interface for *Proxy. It 51 | // resolves the specified host IP addresses by sending two DNS queries (A and 52 | // AAAA) in parallel. It returns both results for those two queries. 53 | func (p *Proxy) LookupNetIP( 54 | ctx context.Context, 55 | _ string, 56 | host string, 57 | ) (addrs []netip.Addr, err error) { 58 | if host == "" { 59 | return nil, ErrEmptyHost 60 | } 61 | 62 | host = dns.Fqdn(host) 63 | 64 | ch := make(chan *lookupResult) 65 | go p.lookupIPAddr(ctx, host, dns.TypeA, ch) 66 | go p.lookupIPAddr(ctx, host, dns.TypeAAAA, ch) 67 | 68 | var errs []error 69 | for range 2 { 70 | result := <-ch 71 | if result.err != nil { 72 | errs = append(errs, result.err) 73 | 74 | continue 75 | } 76 | 77 | addrs = appendAnswerAddrs(addrs, result.resp.Answer) 78 | } 79 | 80 | if len(addrs) == 0 && len(errs) != 0 { 81 | return addrs, errors.Join(errs...) 82 | } 83 | 84 | if p.Config.PreferIPv6 { 85 | slices.SortStableFunc(addrs, netutil.PreferIPv6) 86 | } else { 87 | slices.SortStableFunc(addrs, netutil.PreferIPv4) 88 | } 89 | 90 | return addrs, nil 91 | } 92 | 93 | // appendAnswerAddrs returns addrs with addresses appended from the given ans. 94 | func appendAnswerAddrs(addrs []netip.Addr, ans []dns.RR) (res []netip.Addr) { 95 | for _, ansRR := range ans { 96 | a := proxyutil.IPFromRR(ansRR) 97 | if a != (netip.Addr{}) { 98 | addrs = append(addrs, a) 99 | } 100 | } 101 | 102 | return addrs 103 | } 104 | -------------------------------------------------------------------------------- /proxy/lookup_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/dnsproxy/upstream" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestLookupNetIP(t *testing.T) { 15 | // Use AdGuard DNS here. 16 | dnsUpstream, err := upstream.AddressToUpstream( 17 | "94.140.14.14", 18 | &upstream.Options{ 19 | Logger: slogutil.NewDiscardLogger(), 20 | Timeout: defaultTimeout, 21 | }, 22 | ) 23 | require.NoError(t, err) 24 | 25 | conf := &Config{ 26 | Logger: slogutil.NewDiscardLogger(), 27 | UpstreamConfig: &UpstreamConfig{ 28 | Upstreams: []upstream.Upstream{dnsUpstream}, 29 | }, 30 | } 31 | 32 | p, err := New(conf) 33 | require.NoError(t, err) 34 | 35 | // Now let's try doing some lookups. 36 | addrs, err := p.LookupNetIP(context.Background(), "", "dns.google") 37 | require.NoError(t, err) 38 | require.NotEmpty(t, addrs) 39 | 40 | assert.Contains(t, addrs, netip.MustParseAddr("8.8.8.8")) 41 | assert.Contains(t, addrs, netip.MustParseAddr("8.8.4.4")) 42 | if len(addrs) > 2 { 43 | assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8888")) 44 | assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8844")) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /proxy/optimisticresolver.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "log/slog" 7 | "sync" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | ) 11 | 12 | // cachingResolver is the DNS resolver that is also able to cache responses. 13 | type cachingResolver interface { 14 | // replyFromUpstream returns true if the request from dctx is successfully 15 | // resolved and the response may be cached. 16 | // 17 | // TODO(e.burkov): Find out when ok can be false with nil err. 18 | replyFromUpstream(dctx *DNSContext) (ok bool, err error) 19 | 20 | // cacheResp caches the response from dctx. 21 | cacheResp(dctx *DNSContext) 22 | } 23 | 24 | // type check 25 | var _ cachingResolver = (*Proxy)(nil) 26 | 27 | // optimisticResolver is used to eventually resolve expired cached requests. 28 | type optimisticResolver struct { 29 | reqs *sync.Map 30 | cr cachingResolver 31 | } 32 | 33 | // newOptimisticResolver returns the new resolver for expired cached requests. 34 | // cr must not be nil. 35 | func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) { 36 | return &optimisticResolver{ 37 | reqs: &sync.Map{}, 38 | cr: cr, 39 | } 40 | } 41 | 42 | // unit is a convenient alias for struct{}. 43 | type unit = struct{} 44 | 45 | // resolveOnce tries to resolve the request from dctx but only a single request 46 | // with the same key at the same period of time. It runs in a separate 47 | // goroutine. Do not pass the *DNSContext which is used elsewhere since it 48 | // isn't intended to be used concurrently. 49 | func (s *optimisticResolver) resolveOnce(dctx *DNSContext, key []byte, l *slog.Logger) { 50 | defer slogutil.RecoverAndLog(context.TODO(), l) 51 | 52 | keyHexed := hex.EncodeToString(key) 53 | if _, ok := s.reqs.LoadOrStore(keyHexed, unit{}); ok { 54 | return 55 | } 56 | defer s.reqs.Delete(keyHexed) 57 | 58 | ok, err := s.cr.replyFromUpstream(dctx) 59 | if err != nil { 60 | l.Debug("resolving request for optimistic cache", slogutil.KeyError, err) 61 | } 62 | 63 | if ok { 64 | s.cr.cacheResp(dctx) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /proxy/optimisticresolver_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "log/slog" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | // testCachingResolver is a stub implementation of the cachingResolver interface 15 | // to simplify testing. 16 | type testCachingResolver struct { 17 | onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error) 18 | onCacheResp func(dctx *DNSContext) 19 | } 20 | 21 | // replyFromUpstream implements the cachingResolver interface for 22 | // *testCachingResolver. 23 | func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) { 24 | return tcr.onReplyFromUpstream(dctx) 25 | } 26 | 27 | // cacheResp implements the cachingResolver interface for *testCachingResolver. 28 | func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) { 29 | tcr.onCacheResp(dctx) 30 | } 31 | 32 | func TestOptimisticResolver_ResolveOnce(t *testing.T) { 33 | in, out := make(chan unit), make(chan unit) 34 | var timesResolved, timesSet int 35 | 36 | tcr := &testCachingResolver{ 37 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { 38 | timesResolved++ 39 | 40 | return true, nil 41 | }, 42 | onCacheResp: func(_ *DNSContext) { 43 | timesSet++ 44 | 45 | // Pass the signal to begin running secondary goroutines. 46 | out <- unit{} 47 | // Block until all the secondary goroutines finish. 48 | <-in 49 | }, 50 | } 51 | 52 | s := newOptimisticResolver(tcr) 53 | sameKey := []byte{1, 2, 3} 54 | 55 | // Start the primary goroutine. 56 | go s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger()) 57 | // Block until the primary goroutine reaches the resolve function. 58 | <-out 59 | 60 | wg := &sync.WaitGroup{} 61 | 62 | const secondaryNum = 10 63 | wg.Add(secondaryNum) 64 | for range secondaryNum { 65 | go func() { 66 | defer wg.Done() 67 | 68 | s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger()) 69 | }() 70 | } 71 | 72 | // Wait until all the secondary goroutines are finished. 73 | wg.Wait() 74 | // Pass the signal to terminate the primary goroutine. 75 | in <- unit{} 76 | 77 | assert.Equal(t, 1, timesResolved) 78 | assert.Equal(t, 1, timesSet) 79 | } 80 | 81 | func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) { 82 | key := []byte{1, 2, 3} 83 | 84 | t.Run("error", func(t *testing.T) { 85 | // TODO(d.kolyshev): Consider adding mock handler to golibs. 86 | logOutput := &bytes.Buffer{} 87 | l := slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{ 88 | AddSource: false, 89 | Level: slog.LevelDebug, 90 | ReplaceAttr: nil, 91 | })) 92 | 93 | const rErr errors.Error = "sample resolving error" 94 | 95 | cached := false 96 | s := newOptimisticResolver(&testCachingResolver{ 97 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rErr }, 98 | onCacheResp: func(_ *DNSContext) { cached = true }, 99 | }) 100 | s.resolveOnce(nil, key, l) 101 | 102 | assert.True(t, cached) 103 | assert.Contains(t, logOutput.String(), rErr.Error()) 104 | }) 105 | 106 | t.Run("not_ok", func(t *testing.T) { 107 | cached := false 108 | s := newOptimisticResolver(&testCachingResolver{ 109 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil }, 110 | onCacheResp: func(_ *DNSContext) { cached = true }, 111 | }) 112 | s.resolveOnce(nil, key, slogutil.NewDiscardLogger()) 113 | 114 | assert.False(t, cached) 115 | }) 116 | } 117 | -------------------------------------------------------------------------------- /proxy/pending.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/AdguardTeam/golibs/errors" 9 | ) 10 | 11 | // pendingRequests handles identical requests that are in progress. It is used 12 | // to avoid sending the same request multiple times to the upstream server. The 13 | // implementations are: 14 | // - [defaultPendingRequests]. 15 | // - [emptyPendingRequests]. 16 | type pendingRequests interface { 17 | // queue is called for each request. It returns false if there are no 18 | // identical requests in progress. Otherwise it blocks until the first 19 | // request is completed and returns the error that occurred during its 20 | // resolution. 21 | queue(ctx context.Context, dctx *DNSContext) (loaded bool, err error) 22 | 23 | // done must be called after the request is completed, if queue returned 24 | // false for it. 25 | done(ctx context.Context, dctx *DNSContext, err error) 26 | } 27 | 28 | // defaultPendingRequests is a default implementation of the [pendingRequests] 29 | // interface. It must be created with [newDefaultPendingRequests]. 30 | type defaultPendingRequests struct { 31 | storage *sync.Map 32 | } 33 | 34 | // pendingRequest is a structure that stores the query state and result. 35 | type pendingRequest struct { 36 | // finish is a channel that is closed when the request is completed. It is 37 | // used to block request processing for any but the first one. 38 | finish chan struct{} 39 | 40 | // resolveErr is the error that occurred during the request processing. It 41 | // may be nil. It must only be accessed for reading after the finish 42 | // channel is closed. 43 | resolveErr error 44 | 45 | // cloneDNSCtx is a clone of the DNSContext that was used to create the 46 | // pendingRequest and store its result. It must only be accessed for 47 | // reading after the finish channel is closed. 48 | cloneDNSCtx *DNSContext 49 | } 50 | 51 | // newDefaultPendingRequests creates a new instance of DefaultPendingRequests. 52 | func newDefaultPendingRequests() (pr *defaultPendingRequests) { 53 | return &defaultPendingRequests{ 54 | storage: &sync.Map{}, 55 | } 56 | } 57 | 58 | // type check 59 | var _ pendingRequests = (*defaultPendingRequests)(nil) 60 | 61 | // queue implements the [pendingRequests] interface for 62 | // [defaultPendingRequests]. 63 | func (pr *defaultPendingRequests) queue( 64 | ctx context.Context, 65 | dctx *DNSContext, 66 | ) (loaded bool, err error) { 67 | var key []byte 68 | if dctx.ReqECS != nil { 69 | ones, _ := dctx.ReqECS.Mask.Size() 70 | key = msgToKeyWithSubnet(dctx.Req, dctx.ReqECS.IP, ones) 71 | } else { 72 | key = msgToKey(dctx.Req) 73 | } 74 | 75 | req := &pendingRequest{ 76 | finish: make(chan struct{}), 77 | } 78 | 79 | pendingVal, loaded := pr.storage.LoadOrStore(string(key), req) 80 | if !loaded { 81 | return false, nil 82 | } 83 | 84 | pending := pendingVal.(*pendingRequest) 85 | <-pending.finish 86 | 87 | origDNSCtx := pending.cloneDNSCtx 88 | 89 | // TODO(a.garipov): Perhaps, statistics should be calculated separately for 90 | // each request. 91 | dctx.queryStatistics = origDNSCtx.queryStatistics 92 | dctx.Upstream = origDNSCtx.Upstream 93 | if origDNSCtx.Res != nil { 94 | // TODO(e.burkov): Add cloner for DNS messages. 95 | dctx.Res = origDNSCtx.Res.Copy().SetReply(dctx.Req) 96 | } 97 | 98 | return loaded, pending.resolveErr 99 | } 100 | 101 | // done implements the [pendingRequests] interface for [defaultPendingRequests]. 102 | func (pr *defaultPendingRequests) done(ctx context.Context, dctx *DNSContext, err error) { 103 | var key []byte 104 | if dctx.ReqECS != nil { 105 | ones, _ := dctx.ReqECS.Mask.Size() 106 | key = msgToKeyWithSubnet(dctx.Req, dctx.ReqECS.IP, ones) 107 | } else { 108 | key = msgToKey(dctx.Req) 109 | } 110 | 111 | pendingVal, ok := pr.storage.Load(string(key)) 112 | if !ok { 113 | panic(fmt.Errorf("loading pending request: key %x: %w", key, errors.ErrNoValue)) 114 | } 115 | 116 | pending := pendingVal.(*pendingRequest) 117 | pending.resolveErr = err 118 | 119 | cloneCtx := &DNSContext{ 120 | Upstream: dctx.Upstream, 121 | queryStatistics: dctx.queryStatistics, 122 | } 123 | 124 | if dctx.Res != nil { 125 | cloneCtx.Res = dctx.Res.Copy() 126 | } 127 | 128 | pending.cloneDNSCtx = cloneCtx 129 | 130 | pr.storage.Delete(string(key)) 131 | close(pending.finish) 132 | } 133 | 134 | // emptyPendingRequests is a no-op implementation of PendingRequests. It is 135 | // used when pending requests are not needed. 136 | type emptyPendingRequests struct{} 137 | 138 | // type check 139 | var _ pendingRequests = emptyPendingRequests{} 140 | 141 | // queue implements the [pendingRequests] interface for [emptyPendingRequests]. 142 | // It always returns false and does not block. 143 | func (emptyPendingRequests) queue(_ context.Context, _ *DNSContext) (loaded bool, err error) { 144 | return false, nil 145 | } 146 | 147 | // done implements the [pendingRequests] interface for [emptyPendingRequests]. 148 | func (emptyPendingRequests) done(_ context.Context, _ *DNSContext, _ error) {} 149 | -------------------------------------------------------------------------------- /proxy/pending_test.go: -------------------------------------------------------------------------------- 1 | package proxy_test 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 11 | "github.com/AdguardTeam/dnsproxy/proxy" 12 | "github.com/AdguardTeam/dnsproxy/upstream" 13 | "github.com/AdguardTeam/golibs/logutil/slogutil" 14 | "github.com/AdguardTeam/golibs/netutil" 15 | "github.com/AdguardTeam/golibs/testutil" 16 | "github.com/miekg/dns" 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | ) 20 | 21 | // TODO(e.burkov): Merge those with the ones in internal tests and move to 22 | // dnsproxytest. 23 | 24 | const ( 25 | // testTimeout is the common timeout for tests and contexts. 26 | testTimeout = 1 * time.Second 27 | 28 | // testCacheSize is the default size of the cache in bytes. 29 | testCacheSize = 64 * 1024 30 | ) 31 | 32 | var ( 33 | // localhostAnyPort is a localhost address with an arbitrary port. 34 | localhostAnyPort = netip.AddrPortFrom(netutil.IPv4Localhost(), 0) 35 | 36 | // testTrustedProxies is a set of trusted proxies that includes all 37 | // addresses used in tests. 38 | testTrustedProxies = netutil.SliceSubnetSet{ 39 | netip.MustParsePrefix("0.0.0.0/0"), 40 | netip.MustParsePrefix("::0/0"), 41 | } 42 | ) 43 | 44 | // assertEqualResponses is a helper function that checks if two DNS messages are 45 | // equal, excluding their ID. 46 | // 47 | // TODO(e.burkov): Cosider using go-cmp. 48 | func assertEqualResponses(tb testing.TB, expected, actual *dns.Msg) { 49 | tb.Helper() 50 | 51 | if expected == nil { 52 | require.Nil(tb, actual) 53 | 54 | return 55 | } 56 | 57 | require.NotNil(tb, actual) 58 | 59 | expectedHdr, actualHdr := expected.MsgHdr, actual.MsgHdr 60 | expectedHdr.Id, actualHdr.Id = 0, 0 61 | assert.Equal(tb, expectedHdr, actualHdr) 62 | 63 | assert.Equal(tb, expected.Question, actual.Question) 64 | assert.Equal(tb, expected.Answer, actual.Answer) 65 | assert.Equal(tb, expected.Ns, actual.Ns) 66 | assert.Equal(tb, expected.Extra, actual.Extra) 67 | } 68 | 69 | func TestPendingRequests(t *testing.T) { 70 | t.Parallel() 71 | 72 | const reqsNum = 100 73 | 74 | // workloadWG is used to hold the upstream response until as many requests 75 | // as possible reach the [proxy.Resolve] method. This is a best-effort 76 | // approach, so it's not strictly guaranteed to hold all requests, but it 77 | // works for the test. 78 | workloadWG := &sync.WaitGroup{} 79 | workloadWG.Add(reqsNum) 80 | 81 | once := &sync.Once{} 82 | u := &dnsproxytest.FakeUpstream{ 83 | OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { 84 | once.Do(func() { 85 | resp = (&dns.Msg{}).SetReply(req) 86 | }) 87 | 88 | // Only allow a single request to be processed. 89 | require.NotNil(testutil.PanicT{}, resp) 90 | 91 | workloadWG.Wait() 92 | 93 | return resp, nil 94 | }, 95 | OnAddress: func() (addr string) { return "" }, 96 | OnClose: func() (err error) { return nil }, 97 | } 98 | 99 | p, err := proxy.New(&proxy.Config{ 100 | Logger: slogutil.NewDiscardLogger(), 101 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 102 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 103 | UpstreamConfig: &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{u}}, 104 | TrustedProxies: testTrustedProxies, 105 | RatelimitSubnetLenIPv4: 24, 106 | RatelimitSubnetLenIPv6: 64, 107 | Ratelimit: 0, 108 | CacheEnabled: true, 109 | CacheSizeBytes: testCacheSize, 110 | EnableEDNSClientSubnet: true, 111 | PendingRequests: &proxy.PendingRequestsConfig{ 112 | Enabled: true, 113 | }, 114 | RequestHandler: func(prx *proxy.Proxy, dctx *proxy.DNSContext) (err error) { 115 | workloadWG.Done() 116 | 117 | return prx.Resolve(dctx) 118 | }, 119 | }) 120 | require.NoError(t, err) 121 | 122 | ctx := testutil.ContextWithTimeout(t, testTimeout) 123 | err = p.Start(ctx) 124 | require.NoError(t, err) 125 | testutil.CleanupAndRequireSuccess(t, func() (err error) { 126 | ctx = testutil.ContextWithTimeout(t, testTimeout) 127 | 128 | return p.Shutdown(ctx) 129 | }) 130 | 131 | addr := p.Addr(proxy.ProtoTCP).String() 132 | client := &dns.Client{ 133 | Net: string(proxy.ProtoTCP), 134 | Timeout: testTimeout, 135 | } 136 | 137 | resolveWG := &sync.WaitGroup{} 138 | responses := make([]*dns.Msg, reqsNum) 139 | errs := make([]error, reqsNum) 140 | 141 | for i := range reqsNum { 142 | resolveWG.Add(1) 143 | 144 | req := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA) 145 | 146 | go func() { 147 | defer resolveWG.Done() 148 | 149 | reqCtx := testutil.ContextWithTimeout(t, testTimeout) 150 | responses[i], _, errs[i] = client.ExchangeContext(reqCtx, req, addr) 151 | }() 152 | } 153 | 154 | resolveWG.Wait() 155 | 156 | require.NoError(t, errs[0]) 157 | 158 | for i, resp := range responses[:len(responses)-1] { 159 | assert.Equal(t, errs[i], errs[i+1]) 160 | assertEqualResponses(t, resp, responses[i+1]) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /proxy/proxycache.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | "slices" 6 | ) 7 | 8 | // cacheForContext returns cache object for the given context. 9 | func (p *Proxy) cacheForContext(d *DNSContext) (c *cache) { 10 | if d.CustomUpstreamConfig != nil && d.CustomUpstreamConfig.cache != nil { 11 | return d.CustomUpstreamConfig.cache 12 | } 13 | 14 | return p.cache 15 | } 16 | 17 | // replyFromCache tries to get the response from general or subnet cache. In 18 | // case the cache is present in d, it's used first. Returns true on success. 19 | func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { 20 | dctxCache := p.cacheForContext(d) 21 | 22 | var ci *cacheItem 23 | var cacheSource string 24 | var expired bool 25 | var key []byte 26 | 27 | // TODO(d.kolyshev): Use EnableEDNSClientSubnet from dctxCache. 28 | if p.Config.EnableEDNSClientSubnet && d.ReqECS != nil { 29 | ci, expired, key = dctxCache.getWithSubnet(d.Req, d.ReqECS) 30 | cacheSource = "subnet cache" 31 | } else { 32 | ci, expired, key = dctxCache.get(d.Req) 33 | cacheSource = "general cache" 34 | } 35 | 36 | if hit = ci != nil; !hit { 37 | return hit 38 | } 39 | 40 | d.Res = ci.m 41 | d.queryStatistics = cachedQueryStatistics(ci.u) 42 | 43 | p.logger.Debug( 44 | "replying from cache", 45 | "source", cacheSource, 46 | "ecs_enabled", p.Config.EnableEDNSClientSubnet, 47 | ) 48 | 49 | if dctxCache.optimistic && expired { 50 | // Build a reduced clone of the current context to avoid data race. 51 | minCtxClone := &DNSContext{ 52 | // It is only read inside the optimistic resolver. 53 | CustomUpstreamConfig: d.CustomUpstreamConfig, 54 | ReqECS: cloneIPNet(d.ReqECS), 55 | IsPrivateClient: d.IsPrivateClient, 56 | } 57 | if d.Req != nil { 58 | minCtxClone.Req = d.Req.Copy() 59 | addDO(minCtxClone.Req) 60 | } 61 | 62 | go p.shortFlighter.resolveOnce(minCtxClone, key, p.logger) 63 | } 64 | 65 | return hit 66 | } 67 | 68 | // cloneIPNet returns a deep clone of n. 69 | func cloneIPNet(n *net.IPNet) (clone *net.IPNet) { 70 | if n == nil { 71 | return nil 72 | } 73 | 74 | return &net.IPNet{ 75 | IP: slices.Clone(n.IP), 76 | Mask: slices.Clone(n.Mask), 77 | } 78 | } 79 | 80 | // cacheResp stores the response from d in general or subnet cache. In case the 81 | // cache is present in d, it's used first. 82 | func (p *Proxy) cacheResp(d *DNSContext) { 83 | dctxCache := p.cacheForContext(d) 84 | 85 | if !p.EnableEDNSClientSubnet { 86 | dctxCache.set(d.Res, d.Upstream, p.logger) 87 | 88 | return 89 | } 90 | 91 | switch ecs, scope := ecsFromMsg(d.Res); { 92 | case ecs != nil && d.ReqECS != nil: 93 | ones, bits := ecs.Mask.Size() 94 | reqOnes, _ := d.ReqECS.Mask.Size() 95 | 96 | // If FAMILY, SOURCE PREFIX-LENGTH, and SOURCE PREFIX-LENGTH bits of 97 | // ADDRESS in the response don't match the non-zero fields in the 98 | // corresponding query, the full response MUST be dropped. 99 | // 100 | // See RFC 7871 Section 7.3. 101 | // 102 | // TODO(a.meshkov): The whole response MUST be dropped if ECS in it 103 | // doesn't correspond. 104 | if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes { 105 | p.logger.Debug( 106 | "not caching response; subnet mismatch", 107 | "ecs", ecs, 108 | "req_ecs", d.ReqECS, 109 | ) 110 | 111 | return 112 | } 113 | 114 | // If SCOPE PREFIX-LENGTH is not longer than SOURCE PREFIX-LENGTH, store 115 | // SCOPE PREFIX-LENGTH bits of ADDRESS, and then mark the response as 116 | // valid for all addresses that fall within that range. 117 | // 118 | // See RFC 7871 Section 7.3.1. 119 | if scope < reqOnes { 120 | ecs.Mask = net.CIDRMask(scope, bits) 121 | ecs.IP = ecs.IP.Mask(ecs.Mask) 122 | } 123 | 124 | p.logger.Debug("caching response", "ecs", ecs) 125 | 126 | dctxCache.setWithSubnet(d.Res, d.Upstream, ecs, p.logger) 127 | case d.ReqECS != nil: 128 | // Cache the response for all subnets since the server doesn't support 129 | // EDNS Client Subnet option. 130 | dctxCache.setWithSubnet(d.Res, d.Upstream, &net.IPNet{IP: nil, Mask: nil}, p.logger) 131 | default: 132 | dctxCache.set(d.Res, d.Upstream, p.logger) 133 | } 134 | } 135 | 136 | // ClearCache clears the DNS cache of p. 137 | func (p *Proxy) ClearCache() { 138 | if p.cache == nil { 139 | return 140 | } 141 | 142 | p.cache.clearItems() 143 | p.cache.clearItemsWithSubnet() 144 | p.logger.Debug("cache cleared") 145 | } 146 | -------------------------------------------------------------------------------- /proxy/ratelimit.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | "slices" 7 | "time" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | rate "github.com/beefsack/go-rate" 11 | gocache "github.com/patrickmn/go-cache" 12 | ) 13 | 14 | func (p *Proxy) limiterForIP(ip string) interface{} { 15 | p.ratelimitLock.Lock() 16 | defer p.ratelimitLock.Unlock() 17 | if p.ratelimitBuckets == nil { 18 | p.ratelimitBuckets = gocache.New(time.Hour, time.Hour) 19 | } 20 | 21 | // check if ratelimiter for that IP already exists, if not, create 22 | value, found := p.ratelimitBuckets.Get(ip) 23 | if !found { 24 | value = rate.New(p.Ratelimit, time.Second) 25 | p.ratelimitBuckets.Set(ip, value, time.Hour) 26 | } 27 | 28 | return value 29 | } 30 | 31 | func (p *Proxy) isRatelimited(addr netip.Addr) (ok bool) { 32 | if p.Ratelimit <= 0 { 33 | // The ratelimit is disabled. 34 | return false 35 | } 36 | 37 | addr = addr.Unmap() 38 | // Already sorted by [Proxy.Init]. 39 | _, ok = slices.BinarySearchFunc(p.RatelimitWhitelist, addr, netip.Addr.Compare) 40 | if ok { 41 | return false 42 | } 43 | 44 | var pref netip.Prefix 45 | if addr.Is4() { 46 | pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv4) 47 | } else { 48 | pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv6) 49 | } 50 | pref = pref.Masked() 51 | 52 | // TODO(s.chzhen): Improve caching. Decrease allocations. 53 | ipStr := pref.Addr().String() 54 | value := p.limiterForIP(ipStr) 55 | rl, ok := value.(*rate.RateLimiter) 56 | if !ok { 57 | p.logger.Error( 58 | "invalid value found in ratelimit cache", 59 | slogutil.KeyError, 60 | fmt.Errorf("bad type %T", value), 61 | ) 62 | 63 | return false 64 | } 65 | 66 | allow, _ := rl.Try() 67 | 68 | return !allow 69 | } 70 | -------------------------------------------------------------------------------- /proxy/ratelimit_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/netip" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/AdguardTeam/golibs/testutil" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestRatelimitingProxy(t *testing.T) { 16 | dnsProxy := mustNew(t, &Config{ 17 | Logger: slogutil.NewDiscardLogger(), 18 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 19 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 20 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 21 | TrustedProxies: defaultTrustedProxies, 22 | RatelimitSubnetLenIPv4: 24, 23 | RatelimitSubnetLenIPv6: 64, 24 | Ratelimit: 1, 25 | }) 26 | 27 | // Start listening 28 | ctx := context.Background() 29 | err := dnsProxy.Start(ctx) 30 | require.NoError(t, err) 31 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) }) 32 | 33 | // Create a DNS-over-UDP client connection 34 | addr := dnsProxy.Addr(ProtoUDP) 35 | client := &dns.Client{ 36 | Net: string(ProtoUDP), 37 | Timeout: testTimeout, 38 | } 39 | 40 | // Send the first message (not blocked) 41 | req := newTestMessage() 42 | 43 | r, _, err := client.Exchange(req, addr.String()) 44 | if err != nil { 45 | t.Fatalf("error in the first request: %s", err) 46 | } 47 | requireResponse(t, req, r) 48 | 49 | // Send the second message (blocked) 50 | req = newTestMessage() 51 | 52 | _, _, err = client.Exchange(req, addr.String()) 53 | if err == nil { 54 | t.Fatalf("second request was not blocked") 55 | } 56 | } 57 | 58 | func TestRatelimiting(t *testing.T) { 59 | // rate limit is 1 per sec 60 | p := Proxy{} 61 | p.Ratelimit = 1 62 | 63 | addr := netip.MustParseAddr("127.0.0.1") 64 | 65 | limited := p.isRatelimited(addr) 66 | 67 | if limited { 68 | t.Fatal("First request must have been allowed") 69 | } 70 | 71 | limited = p.isRatelimited(addr) 72 | 73 | if !limited { 74 | t.Fatal("Second request must have been ratelimited") 75 | } 76 | } 77 | 78 | func TestWhitelist(t *testing.T) { 79 | // rate limit is 1 per sec with whitelist 80 | p := Proxy{} 81 | p.Ratelimit = 1 82 | p.RatelimitWhitelist = []netip.Addr{ 83 | netip.MustParseAddr("127.0.0.1"), 84 | netip.MustParseAddr("127.0.0.2"), 85 | netip.MustParseAddr("127.0.0.125"), 86 | } 87 | 88 | addr := netip.MustParseAddr("127.0.0.1") 89 | 90 | limited := p.isRatelimited(addr) 91 | 92 | if limited { 93 | t.Fatal("First request must have been allowed") 94 | } 95 | 96 | limited = p.isRatelimited(addr) 97 | 98 | if limited { 99 | t.Fatal("Second request must have been allowed due to whitelist") 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /proxy/recursiondetector.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding/binary" 5 | "time" 6 | 7 | glcache "github.com/AdguardTeam/golibs/cache" 8 | "github.com/AdguardTeam/golibs/netutil" 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // uint* sizes in bytes to improve readability. 13 | // 14 | // TODO(e.burkov): Remove when there will be a more regardful way to define 15 | // those. See https://github.com/golang/go/issues/29982. 16 | const ( 17 | uint16sz = 2 18 | uint64sz = 8 19 | ) 20 | 21 | // TODO(e.burkov): Consider making configurable. 22 | const ( 23 | // recursionTTL is the time recursive request is cached for. 24 | recursionTTL = 1 * time.Second 25 | 26 | // cachedRecurrentReqNum is the maximum number of cached recurrent requests. 27 | cachedRecurrentReqNum = 1000 28 | ) 29 | 30 | // recursionDetector detects recursion in DNS forwarding. 31 | type recursionDetector struct { 32 | recentRequests glcache.Cache 33 | ttl time.Duration 34 | } 35 | 36 | // check checks if the passed req was already sent by the server. 37 | func (rd *recursionDetector) check(msg *dns.Msg) (ok bool) { 38 | if len(msg.Question) == 0 { 39 | return false 40 | } 41 | 42 | key := msgToSignature(msg) 43 | expireData := rd.recentRequests.Get(key) 44 | if expireData == nil { 45 | return false 46 | } 47 | 48 | expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData))) 49 | 50 | return time.Now().Before(expire) 51 | } 52 | 53 | // add caches the msg if it has anything in the questions section. 54 | func (rd *recursionDetector) add(msg *dns.Msg) { 55 | now := time.Now() 56 | 57 | if len(msg.Question) == 0 { 58 | return 59 | } 60 | 61 | key := msgToSignature(msg) 62 | expire64 := uint64(now.Add(rd.ttl).UnixNano()) 63 | expire := make([]byte, uint64sz) 64 | binary.BigEndian.PutUint64(expire, expire64) 65 | 66 | rd.recentRequests.Set(key, expire) 67 | } 68 | 69 | // clear clears the recent requests cache. 70 | func (rd *recursionDetector) clear() { 71 | rd.recentRequests.Clear() 72 | } 73 | 74 | // newRecursionDetector returns the initialized *recursionDetector. 75 | func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) { 76 | return &recursionDetector{ 77 | recentRequests: glcache.New(glcache.Config{ 78 | EnableLRU: true, 79 | MaxCount: suspectsNum, 80 | }), 81 | ttl: ttl, 82 | } 83 | } 84 | 85 | // msgToSignature converts msg into it's signature represented in bytes. 86 | func msgToSignature(msg *dns.Msg) (sig []byte) { 87 | sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen) 88 | // The binary.BigEndian byte order is used everywhere except when the real 89 | // machine's endianness is needed. 90 | byteOrder := binary.BigEndian 91 | byteOrder.PutUint16(sig[0:], msg.Id) 92 | q := msg.Question[0] 93 | byteOrder.PutUint16(sig[uint16sz:], q.Qtype) 94 | copy(sig[2*uint16sz:], []byte(q.Name)) 95 | 96 | return sig 97 | } 98 | -------------------------------------------------------------------------------- /proxy/recursiondetector_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "log/slog" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/netutil" 12 | "github.com/miekg/dns" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestRecursionDetector_Check(t *testing.T) { 17 | rd := newRecursionDetector(0, 2) 18 | 19 | const ( 20 | recID = 1234 21 | recTTL = time.Hour * 1 22 | ) 23 | 24 | const nonRecID = recID * 2 25 | 26 | sampleQuestion := dns.Question{ 27 | Name: "some.domain", 28 | Qtype: dns.TypeAAAA, 29 | } 30 | sampleMsg := &dns.Msg{ 31 | MsgHdr: dns.MsgHdr{ 32 | Id: recID, 33 | }, 34 | Question: []dns.Question{sampleQuestion}, 35 | } 36 | 37 | // Manually add the message with big ttl. 38 | key := msgToSignature(sampleMsg) 39 | expire := make([]byte, uint64sz) 40 | binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano())) 41 | rd.recentRequests.Set(key, expire) 42 | 43 | // Add an expired message. 44 | sampleMsg.Id = nonRecID 45 | rd.add(sampleMsg) 46 | 47 | testCases := []struct { 48 | name string 49 | questions []dns.Question 50 | id uint16 51 | want bool 52 | }{{ 53 | name: "recurrent", 54 | questions: []dns.Question{sampleQuestion}, 55 | id: recID, 56 | want: true, 57 | }, { 58 | name: "not_suspected", 59 | questions: []dns.Question{sampleQuestion}, 60 | id: recID + 1, 61 | want: false, 62 | }, { 63 | name: "expired", 64 | questions: []dns.Question{sampleQuestion}, 65 | id: nonRecID, 66 | want: false, 67 | }, { 68 | name: "empty", 69 | questions: []dns.Question{}, 70 | id: nonRecID, 71 | want: false, 72 | }} 73 | 74 | for _, tc := range testCases { 75 | sampleMsg.Id = tc.id 76 | sampleMsg.Question = tc.questions 77 | t.Run(tc.name, func(t *testing.T) { 78 | detected := rd.check(sampleMsg) 79 | assert.Equal(t, tc.want, detected) 80 | }) 81 | } 82 | } 83 | 84 | func TestRecursionDetector_Suspect(t *testing.T) { 85 | rd := newRecursionDetector(0, 1) 86 | 87 | testCases := []struct { 88 | msg *dns.Msg 89 | name string 90 | want int 91 | }{{ 92 | msg: &dns.Msg{ 93 | MsgHdr: dns.MsgHdr{ 94 | Id: 1234, 95 | }, 96 | Question: []dns.Question{{ 97 | Name: "some.domain", 98 | Qtype: dns.TypeA, 99 | }}, 100 | }, 101 | name: "simple", 102 | want: 1, 103 | }, { 104 | msg: &dns.Msg{}, 105 | name: "unencumbered", 106 | want: 0, 107 | }} 108 | 109 | for _, tc := range testCases { 110 | t.Run(tc.name, func(t *testing.T) { 111 | t.Cleanup(rd.clear) 112 | rd.add(tc.msg) 113 | assert.Equal(t, tc.want, rd.recentRequests.Stats().Count) 114 | }) 115 | } 116 | } 117 | 118 | func BenchmarkMsgToSignature(b *testing.B) { 119 | const name = "some.not.very.long.host.name" 120 | 121 | msg := &dns.Msg{ 122 | MsgHdr: dns.MsgHdr{ 123 | Id: 1234, 124 | }, 125 | Question: []dns.Question{{ 126 | Name: name, 127 | Qtype: dns.TypeAAAA, 128 | }}, 129 | } 130 | 131 | var sigData []byte 132 | 133 | b.Run("efficient", func(b *testing.B) { 134 | b.ReportAllocs() 135 | 136 | for b.Loop() { 137 | sigData = msgToSignature(msg) 138 | } 139 | 140 | assert.NotEmpty(b, sigData) 141 | }) 142 | 143 | b.Run("inefficient", func(b *testing.B) { 144 | b.ReportAllocs() 145 | 146 | for b.Loop() { 147 | sigData = msgToSignatureSlow(msg) 148 | } 149 | 150 | assert.NotEmpty(b, sigData) 151 | }) 152 | 153 | // Most recent results: 154 | // 155 | // goos: darwin 156 | // goarch: amd64 157 | // pkg: github.com/AdguardTeam/dnsproxy/proxy 158 | // cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz 159 | // BenchmarkMsgToSignature/efficient-12 18789852 61.07 ns/op 288 B/op 1 allocs/op 160 | // BenchmarkMsgToSignature/inefficient-12 582990 2016 ns/op 624 B/op 3 allocs/op 161 | } 162 | 163 | // msgToSignatureSlow converts msg into it's signature represented in bytes in 164 | // the less efficient way. 165 | // 166 | // See [BenchmarkMsgToSignature]. 167 | func msgToSignatureSlow(msg *dns.Msg) (sig []byte) { 168 | type msgSignature struct { 169 | name [netutil.MaxDomainNameLen]byte 170 | id uint16 171 | qtype uint16 172 | } 173 | 174 | b := bytes.NewBuffer(sig) 175 | q := msg.Question[0] 176 | signature := msgSignature{ 177 | id: msg.Id, 178 | qtype: q.Qtype, 179 | } 180 | copy(signature.name[:], q.Name) 181 | if err := binary.Write(b, binary.BigEndian, signature); err != nil { 182 | slog.Default().Debug("writing message signature", slogutil.KeyError, err) 183 | } 184 | 185 | return b.Bytes() 186 | } 187 | -------------------------------------------------------------------------------- /proxy/retry.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | ) 9 | 10 | // BindRetryConfig contains configuration for the listeners binding retry 11 | // mechanism. 12 | type BindRetryConfig struct { 13 | // Interval is the minimum time to wait after the latest failure. It must 14 | // not be negative if Enabled is true. 15 | Interval time.Duration 16 | 17 | // Count is the maximum number of retries after the first attempt. 18 | Count uint 19 | 20 | // Enabled indicates whether the binding should be retried. 21 | Enabled bool 22 | } 23 | 24 | // bindWithRetry calls f until it returns no error or the retries limit is 25 | // reached, sleeping for configured interval between attempts. bindFunc must 26 | // not be nil and should carry the result of the binding operation itself. 27 | func (p *Proxy) bindWithRetry(ctx context.Context, bindFunc func() (err error)) (err error) { 28 | err = bindFunc() 29 | if err == nil { 30 | return nil 31 | } 32 | 33 | p.logger.WarnContext(ctx, "binding", "attempt", 1, slogutil.KeyError, err) 34 | 35 | for attempt := uint(1); attempt <= p.bindRetryCount; attempt++ { 36 | time.Sleep(p.bindRetryIvl) 37 | 38 | retryErr := bindFunc() 39 | if retryErr == nil { 40 | return nil 41 | } 42 | 43 | p.logger.WarnContext(ctx, "binding", "attempt", attempt+1, slogutil.KeyError, retryErr) 44 | } 45 | 46 | return err 47 | } 48 | -------------------------------------------------------------------------------- /proxy/retry_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | "github.com/AdguardTeam/golibs/testutil" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestWithRetry(t *testing.T) { 13 | t.Parallel() 14 | 15 | const ( 16 | errA errors.Error = "error about a" 17 | errB errors.Error = "error about b" 18 | ) 19 | 20 | var ( 21 | good = func() (err error) { 22 | return nil 23 | } 24 | 25 | badOne = func() (err error) { 26 | return errA 27 | } 28 | 29 | // Don't protect against concurrent access since the closure is expected 30 | // to be used in a single case. 31 | returnedA = false 32 | badBoth = func() (err error) { 33 | if !returnedA { 34 | returnedA = true 35 | 36 | return errA 37 | } 38 | 39 | return errB 40 | } 41 | 42 | // Don't protect against concurrent access since the closure is expected 43 | // to be used in a single case. 44 | returnedErr = false 45 | badThenOk = func() (err error) { 46 | if !returnedErr { 47 | returnedErr = true 48 | 49 | return assert.AnError 50 | } 51 | 52 | return nil 53 | } 54 | ) 55 | 56 | testCases := []struct { 57 | f func() (err error) 58 | wantErr error 59 | name string 60 | }{{ 61 | f: good, 62 | wantErr: nil, 63 | name: "no_error", 64 | }, { 65 | f: badOne, 66 | wantErr: errA, 67 | name: "one_error", 68 | }, { 69 | f: badBoth, 70 | wantErr: errA, 71 | name: "two_errors", 72 | }, { 73 | f: badThenOk, 74 | wantErr: nil, 75 | name: "error_then_ok", 76 | }} 77 | 78 | p := &Proxy{ 79 | logger: slogutil.NewDiscardLogger(), 80 | bindRetryCount: 1, 81 | bindRetryIvl: 0, 82 | } 83 | 84 | for _, tc := range testCases { 85 | t.Run(tc.name, func(t *testing.T) { 86 | t.Parallel() 87 | 88 | ctx := testutil.ContextWithTimeout(t, testTimeout) 89 | 90 | err := p.bindWithRetry(ctx, tc.f) 91 | assert.ErrorIs(t, err, tc.wantErr) 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /proxy/serverdnscrypt.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | 8 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/netutil" 11 | "github.com/AdguardTeam/golibs/syncutil" 12 | "github.com/ameshkov/dnscrypt/v2" 13 | "github.com/miekg/dns" 14 | ) 15 | 16 | func (p *Proxy) initDNSCryptListeners(ctx context.Context) (err error) { 17 | if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 { 18 | // Do nothing if DNSCrypt listen addresses are not specified. 19 | return nil 20 | } 21 | 22 | if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" { 23 | return errors.Error("invalid dnscrypt configuration: no certificate or provider name") 24 | } 25 | 26 | p.logger.InfoContext(ctx, "initializing dnscrypt", "provider", p.DNSCryptProviderName) 27 | p.dnsCryptServer = &dnscrypt.Server{ 28 | ProviderName: p.DNSCryptProviderName, 29 | ResolverCert: p.DNSCryptResolverCert, 30 | Handler: &dnsCryptHandler{ 31 | proxy: p, 32 | reqSema: p.requestsSema, 33 | }, 34 | Logger: p.logger, 35 | } 36 | 37 | for _, addr := range p.DNSCryptUDPListenAddr { 38 | udp, lErr := p.listenDNSCryptUDP(ctx, addr) 39 | if lErr != nil { 40 | return fmt.Errorf("listening to dnscrypt udp on addr %s: %w", addr, lErr) 41 | } 42 | 43 | p.dnsCryptUDPListen = append(p.dnsCryptUDPListen, udp) 44 | } 45 | 46 | for _, addr := range p.DNSCryptTCPListenAddr { 47 | tcp, lErr := p.listenDNSCryptTCP(ctx, addr) 48 | if lErr != nil { 49 | return fmt.Errorf("listening to dnscrypt tcp on addr %s: %w", addr, lErr) 50 | } 51 | 52 | p.dnsCryptTCPListen = append(p.dnsCryptTCPListen, tcp) 53 | } 54 | 55 | return nil 56 | } 57 | 58 | // listenDNSCryptUDP returns a new UDP connection for DNSCrypt listening on 59 | // addr. 60 | func (p *Proxy) listenDNSCryptUDP( 61 | ctx context.Context, 62 | addr *net.UDPAddr, 63 | ) (conn *net.UDPConn, err error) { 64 | addrStr := addr.String() 65 | p.logger.InfoContext(ctx, "creating dnscrypt udp server socket", "addr", addrStr) 66 | 67 | err = p.bindWithRetry(ctx, func() (listenErr error) { 68 | conn, listenErr = net.ListenUDP(bootstrap.NetworkUDP, addr) 69 | 70 | return listenErr 71 | }) 72 | if err != nil { 73 | return nil, fmt.Errorf("listening to udp socket: %w", err) 74 | } 75 | 76 | p.logger.InfoContext(ctx, "listening for dnscrypt messages on udp", "addr", conn.LocalAddr()) 77 | 78 | return conn, nil 79 | } 80 | 81 | // listenDNSCryptTCP returns a new TCP listener for DNSCrypt listening on addr. 82 | func (p *Proxy) listenDNSCryptTCP( 83 | ctx context.Context, 84 | addr *net.TCPAddr, 85 | ) (conn *net.TCPListener, err error) { 86 | addrStr := addr.String() 87 | p.logger.InfoContext(ctx, "creating dnscrypt tcp server socket", "addr", addrStr) 88 | 89 | err = p.bindWithRetry(ctx, func() (listenErr error) { 90 | conn, listenErr = net.ListenTCP(bootstrap.NetworkTCP, addr) 91 | 92 | return listenErr 93 | }) 94 | if err != nil { 95 | return nil, fmt.Errorf("listening to tcp socket: %w", err) 96 | } 97 | 98 | p.logger.InfoContext(ctx, "listening for dnscrypt messages on tcp", "addr", conn.Addr()) 99 | 100 | return conn, nil 101 | } 102 | 103 | // dnsCryptHandler - dnscrypt.Handler implementation 104 | type dnsCryptHandler struct { 105 | proxy *Proxy 106 | 107 | reqSema syncutil.Semaphore 108 | } 109 | 110 | // compile-time type check 111 | var _ dnscrypt.Handler = &dnsCryptHandler{} 112 | 113 | // ServeDNS - processes the DNS query 114 | func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (err error) { 115 | d := h.proxy.newDNSContext(ProtoDNSCrypt, req, netutil.NetAddrToAddrPort(rw.RemoteAddr())) 116 | d.DNSCryptResponseWriter = rw 117 | 118 | // TODO(d.kolyshev): Pass and use context from above. 119 | err = h.reqSema.Acquire(context.Background()) 120 | if err != nil { 121 | return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err) 122 | } 123 | defer h.reqSema.Release() 124 | 125 | return h.proxy.handleDNSRequest(d) 126 | } 127 | 128 | // Writes a response to the UDP client 129 | func (p *Proxy) respondDNSCrypt(d *DNSContext) error { 130 | if d.Res == nil { 131 | // If no response has been written, do nothing and let it drop 132 | return nil 133 | } 134 | 135 | return d.DNSCryptResponseWriter.WriteMsg(d.Res) 136 | } 137 | -------------------------------------------------------------------------------- /proxy/serverdnscrypt_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/ameshkov/dnscrypt/v2" 13 | "github.com/ameshkov/dnsstamps" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | // TODO(d.kolyshev): Remove this after quic-go has migrated to slog. 19 | func TestMain(m *testing.M) { 20 | testutil.DiscardLogOutput(m) 21 | } 22 | 23 | func getFreePort() uint { 24 | l, _ := net.Listen("tcp", "127.0.0.1:0") 25 | port := uint(l.Addr().(*net.TCPAddr).Port) 26 | 27 | // stop listening immediately 28 | _ = l.Close() 29 | 30 | // sleep for 100ms (may be necessary on Windows) 31 | time.Sleep(100 * time.Millisecond) 32 | return port 33 | } 34 | 35 | func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) { 36 | rc, err := dnscrypt.GenerateResolverConfig("example.org", nil) 37 | assert.NoError(t, err) 38 | 39 | cert, err := rc.CreateCert() 40 | assert.NoError(t, err) 41 | 42 | port := getFreePort() 43 | p := mustNew(t, &Config{ 44 | Logger: slogutil.NewDiscardLogger(), 45 | DNSCryptUDPListenAddr: []*net.UDPAddr{{ 46 | Port: int(port), IP: net.ParseIP(listenIP), 47 | }}, 48 | DNSCryptTCPListenAddr: []*net.TCPAddr{{ 49 | Port: int(port), IP: net.ParseIP(listenIP), 50 | }}, 51 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 52 | TrustedProxies: defaultTrustedProxies, 53 | RatelimitSubnetLenIPv4: 24, 54 | RatelimitSubnetLenIPv6: 64, 55 | EnableEDNSClientSubnet: true, 56 | CacheEnabled: true, 57 | CacheMinTTL: 20, 58 | CacheMaxTTL: 40, 59 | DNSCryptProviderName: rc.ProviderName, 60 | DNSCryptResolverCert: cert, 61 | }) 62 | 63 | return p, rc 64 | } 65 | 66 | func TestDNSCryptProxy(t *testing.T) { 67 | // Prepare the proxy server 68 | dnsProxy, rc := createTestDNSCryptProxy(t) 69 | 70 | // Start listening 71 | ctx := context.Background() 72 | err := dnsProxy.Start(ctx) 73 | require.NoError(t, err) 74 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) }) 75 | 76 | // Generate a DNS stamp 77 | addr := fmt.Sprintf("%s:%d", listenIP, dnsProxy.Addr(ProtoDNSCrypt).(*net.UDPAddr).Port) 78 | stamp, err := rc.CreateStamp(addr) 79 | assert.Nil(t, err) 80 | 81 | // Test DNSCrypt proxy on both UDP and TCP 82 | checkDNSCryptProxy(t, "udp", stamp) 83 | checkDNSCryptProxy(t, "tcp", stamp) 84 | } 85 | 86 | func checkDNSCryptProxy(t *testing.T, proto string, stamp dnsstamps.ServerStamp) { 87 | // Create a DNSCrypt client 88 | c := &dnscrypt.Client{ 89 | Timeout: defaultTimeout, 90 | Net: proto, 91 | } 92 | 93 | // Fetch the server certificate 94 | ri, err := c.DialStamp(stamp) 95 | assert.Nil(t, err) 96 | 97 | // Send the test message 98 | msg := newTestMessage() 99 | reply, err := c.Exchange(msg, ri) 100 | assert.Nil(t, err) 101 | requireResponse(t, msg, reply) 102 | } 103 | -------------------------------------------------------------------------------- /proxy/servertcp_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "crypto/x509" 7 | "net" 8 | "testing" 9 | 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/miekg/dns" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestTcpProxy(t *testing.T) { 17 | dnsProxy := mustStartDefaultProxy(t) 18 | 19 | // Create a DNS-over-TCP client connection 20 | addr := dnsProxy.Addr(ProtoTCP) 21 | conn, err := dns.Dial("tcp", addr.String()) 22 | require.NoError(t, err) 23 | 24 | sendTestMessages(t, conn) 25 | } 26 | 27 | func TestTlsProxy(t *testing.T) { 28 | serverConfig, caPem := newTLSConfig(t) 29 | dnsProxy := mustNew(t, &Config{ 30 | Logger: slogutil.NewDiscardLogger(), 31 | TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 32 | HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 33 | QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 34 | TLSConfig: serverConfig, 35 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 36 | TrustedProxies: defaultTrustedProxies, 37 | RatelimitSubnetLenIPv4: 24, 38 | RatelimitSubnetLenIPv6: 64, 39 | }) 40 | 41 | // Start listening 42 | ctx := context.Background() 43 | err := dnsProxy.Start(ctx) 44 | require.NoError(t, err) 45 | testutil.CleanupAndRequireSuccess(t, func() (err error) { return dnsProxy.Shutdown(ctx) }) 46 | 47 | roots := x509.NewCertPool() 48 | roots.AppendCertsFromPEM(caPem) 49 | tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots} 50 | 51 | // Create a DNS-over-TLS client connection 52 | addr := dnsProxy.Addr(ProtoTLS) 53 | conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) 54 | require.NoError(t, err) 55 | 56 | sendTestMessages(t, conn) 57 | } 58 | -------------------------------------------------------------------------------- /proxy/serverudp_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestUdpProxy(t *testing.T) { 11 | dnsProxy := mustStartDefaultProxy(t) 12 | 13 | // Create a DNS-over-UDP client connection 14 | addr := dnsProxy.Addr(ProtoUDP) 15 | conn, err := dns.Dial("udp", addr.String()) 16 | require.NoError(t, err) 17 | 18 | sendTestMessages(t, conn) 19 | } 20 | -------------------------------------------------------------------------------- /proxy/upstreammode.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding" 5 | "fmt" 6 | ) 7 | 8 | // UpstreamMode is an enumeration of upstream mode representations. 9 | // 10 | // TODO(d.kolyshev): Set uint8 as underlying type. 11 | type UpstreamMode string 12 | 13 | const ( 14 | // UpstreamModeLoadBalance is the default upstream mode. It balances the 15 | // upstreams load. 16 | UpstreamModeLoadBalance UpstreamMode = "load_balance" 17 | 18 | // UpstreamModeParallel makes server to query all configured upstream 19 | // servers in parallel. 20 | UpstreamModeParallel UpstreamMode = "parallel" 21 | 22 | // UpstreamModeFastestAddr controls whether the server should respond to A 23 | // or AAAA requests only with the fastest IP address detected by ICMP 24 | // response time or TCP connection time. 25 | UpstreamModeFastestAddr UpstreamMode = "fastest_addr" 26 | ) 27 | 28 | // type check 29 | var _ encoding.TextUnmarshaler = (*UpstreamMode)(nil) 30 | 31 | // UnmarshalText implements [encoding.TextUnmarshaler] interface for 32 | // *UpstreamMode. 33 | func (m *UpstreamMode) UnmarshalText(b []byte) (err error) { 34 | switch um := UpstreamMode(b); um { 35 | case 36 | UpstreamModeLoadBalance, 37 | UpstreamModeParallel, 38 | UpstreamModeFastestAddr: 39 | *m = um 40 | default: 41 | return fmt.Errorf( 42 | "invalid upstream mode %q, supported: %q, %q, %q", 43 | b, 44 | UpstreamModeLoadBalance, 45 | UpstreamModeParallel, 46 | UpstreamModeFastestAddr, 47 | ) 48 | } 49 | 50 | return nil 51 | } 52 | 53 | // type check 54 | var _ encoding.TextMarshaler = UpstreamMode("") 55 | 56 | // MarshalText implements [encoding.TextMarshaler] interface for UpstreamMode. 57 | func (m UpstreamMode) MarshalText() (text []byte, err error) { 58 | return []byte(m), nil 59 | } 60 | -------------------------------------------------------------------------------- /proxy/upstreammode_test.go: -------------------------------------------------------------------------------- 1 | package proxy_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/dnsproxy/proxy" 7 | "github.com/AdguardTeam/golibs/testutil" 8 | ) 9 | 10 | func TestUpstreamMode_encoding(t *testing.T) { 11 | t.Parallel() 12 | 13 | v := proxy.UpstreamModeLoadBalance 14 | 15 | testutil.AssertMarshalText(t, "load_balance", &v) 16 | testutil.AssertUnmarshalText(t, "load_balance", &v) 17 | } 18 | -------------------------------------------------------------------------------- /proxyutil/dns.go: -------------------------------------------------------------------------------- 1 | // Package proxyutil contains helper functions that are used in all other 2 | // dnsproxy packages. 3 | package proxyutil 4 | 5 | import ( 6 | "encoding/binary" 7 | "net/netip" 8 | 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // AddPrefix adds a 2-byte prefix with the DNS message length. 13 | func AddPrefix(b []byte) (m []byte) { 14 | m = make([]byte, 2+len(b)) 15 | binary.BigEndian.PutUint16(m, uint16(len(b))) 16 | copy(m[2:], b) 17 | 18 | return m 19 | } 20 | 21 | // IPFromRR returns the IP address from rr if any. 22 | func IPFromRR(rr dns.RR) (ip netip.Addr) { 23 | var data []byte 24 | switch rr := rr.(type) { 25 | case *dns.A: 26 | data = rr.A.To4() 27 | case *dns.AAAA: 28 | data = rr.AAAA 29 | default: 30 | return netip.Addr{} 31 | } 32 | 33 | ip, _ = netip.AddrFromSlice(data) 34 | 35 | return ip 36 | } 37 | -------------------------------------------------------------------------------- /scripts/hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e -f -u 4 | 5 | # This comment is used to simplify checking local copies of the script. 6 | # Bump this number every time a significant change is made to this 7 | # script. 8 | # 9 | # AdGuard-Project-Version: 4 10 | 11 | # TODO(a.garipov): Add pre-merge-commit. 12 | 13 | # Only show interactive prompts if there a terminal is attached to 14 | # stdout. While this technically doesn't guarantee that reading from 15 | # /dev/tty works, this should work reasonably well on all of our 16 | # supported development systems and in most terminal emulators. 17 | is_tty='0' 18 | if [ -t '1' ]; then 19 | is_tty='1' 20 | fi 21 | readonly is_tty 22 | 23 | # prompt is a helper that prompts the user for interactive input if that 24 | # can be done. If there is no terminal attached, it sleeps for two 25 | # seconds, giving the programmer some time to react, and returns with 26 | # a zero exit code. 27 | prompt() { 28 | if [ "$is_tty" -eq '0' ]; then 29 | sleep 2 30 | 31 | return 0 32 | fi 33 | 34 | while true; do 35 | printf 'commit anyway? y/[n]: ' 36 | read -r ans &2 19 | fi 20 | } 21 | 22 | log 'starting to build dnsproxy release' 23 | 24 | version="${VERSION:-}" 25 | readonly version 26 | 27 | log "version '$version'" 28 | 29 | dist="${DIST_DIR:-build}" 30 | readonly dist 31 | 32 | out="${OUT:-dnsproxy}" 33 | 34 | log "checking tools" 35 | 36 | for tool in tar zip; do 37 | if ! command -v "$tool" >/dev/null; then 38 | log "tool '$tool' not found" 39 | 40 | exit 1 41 | fi 42 | done 43 | 44 | # Data section. Arrange data into space-separated tables for read -r to read. 45 | # Use 0 for missing values. 46 | 47 | # os arch arm mips 48 | platforms="\ 49 | darwin amd64 0 0 50 | darwin arm64 0 0 51 | freebsd 386 0 0 52 | freebsd amd64 0 0 53 | freebsd arm 5 0 54 | freebsd arm 6 0 55 | freebsd arm 7 0 56 | freebsd arm64 0 0 57 | linux 386 0 0 58 | linux amd64 0 0 59 | linux arm 5 0 60 | linux arm 6 0 61 | linux arm 7 0 62 | linux arm64 0 0 63 | linux mips 0 softfloat 64 | linux mips64 0 softfloat 65 | linux mips64le 0 softfloat 66 | linux mipsle 0 softfloat 67 | linux ppc64le 0 0 68 | openbsd amd64 0 0 69 | openbsd arm64 0 0 70 | windows 386 0 0 71 | windows amd64 0 0 72 | windows arm64 0 0" 73 | readonly platforms 74 | 75 | build() { 76 | # Get the arguments. Here and below, use the "build_" prefix for all 77 | # variables local to function build. 78 | build_dir="${dist}/${1}" \ 79 | build_name="$1" \ 80 | build_os="$2" \ 81 | build_arch="$3" \ 82 | build_arm="$4" \ 83 | build_mips="$5" \ 84 | ; 85 | 86 | # Use the ".exe" filename extension if we build a Windows release. 87 | if [ "$build_os" = 'windows' ]; then 88 | build_output="./${build_dir}/${out}.exe" 89 | else 90 | build_output="./${build_dir}/${out}" 91 | fi 92 | 93 | mkdir -p "./${build_dir}" 94 | 95 | # Build the binary. 96 | # 97 | # Set GOARM and GOMIPS to an empty string if $build_arm and $build_mips 98 | # are zero by removing the zero as if it's a prefix. 99 | # 100 | # Don't use quotes with $build_par because we want an empty space if 101 | # parallelism wasn't set. 102 | env GOARCH="$build_arch" \ 103 | GOARM="${build_arm#0}" \ 104 | GOMIPS="${build_mips#0}" \ 105 | GOOS="$os" \ 106 | VERBOSE="$((verbose - 1))" \ 107 | VERSION="$version" \ 108 | OUT="$build_output" \ 109 | sh ./scripts/make/go-build.sh 110 | 111 | log "$build_output" 112 | 113 | # Prepare the build directory for archiving. 114 | cp ./LICENSE ./README.md "$build_dir" 115 | 116 | # Make archives. Windows prefers ZIP archives; the rest, gzipped tarballs. 117 | case "$build_os" in 118 | 'windows') 119 | build_archive="./${dist}/${out}-${build_name}-${version}.zip" 120 | # TODO(a.garipov): Find an option similar to the -C option of tar for 121 | # zip. 122 | (cd "${dist}" && zip -9 -q -r "../${build_archive}" "./${build_name}") 123 | ;; 124 | *) 125 | build_archive="./${dist}/${out}-${build_name}-${version}.tar.gz" 126 | tar -C "./${dist}" -c -f - "./${build_name}" | gzip -9 - >"$build_archive" 127 | ;; 128 | esac 129 | 130 | log "$build_archive" 131 | } 132 | 133 | log "starting builds" 134 | 135 | # Go over all platforms defined in the space-separated table above, tweak the 136 | # values where necessary, and feed to build. 137 | echo "$platforms" | while read -r os arch arm mips; do 138 | case "$arch" in 139 | arm) 140 | name="${os}-${arch}${arm}" 141 | ;; 142 | *) 143 | name="${os}-${arch}" 144 | ;; 145 | esac 146 | 147 | build "$name" "$os" "$arch" "$arm" "$mips" 148 | done 149 | 150 | log "finished" 151 | -------------------------------------------------------------------------------- /scripts/make/go-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # dnsproxy build script 4 | # 5 | # The commentary in this file is written with the assumption that the reader 6 | # only has superficial knowledge of the POSIX shell language and alike. 7 | # Experienced readers may find it overly verbose. 8 | 9 | # This comment is used to simplify checking local copies of the script. Bump 10 | # this number every time a significant change is made to this script. 11 | # 12 | # AdGuard-Project-Version: 2 13 | 14 | # The default verbosity level is 0. Show every command that is run and every 15 | # package that is processed if the caller requested verbosity level greater than 16 | # 0. Also show subcommands if the requested verbosity level is greater than 1. 17 | # Otherwise, do nothing. 18 | verbose="${VERBOSE:-0}" 19 | readonly verbose 20 | 21 | if [ "$verbose" -gt '1' ]; then 22 | env 23 | set -x 24 | v_flags='-v=1' 25 | x_flags='-x=1' 26 | elif [ "$verbose" -gt '0' ]; then 27 | set -x 28 | v_flags='-v=1' 29 | x_flags='-x=0' 30 | else 31 | set +x 32 | v_flags='-v=0' 33 | x_flags='-x=0' 34 | fi 35 | readonly x_flags v_flags 36 | 37 | # Exit the script if a pipeline fails (-e), prevent accidental filename 38 | # expansion (-f), and consider undefined variables as errors (-u). 39 | set -e -f -u 40 | 41 | # Allow users to override the go command from environment. For example, to 42 | # build two releases with two different Go versions and test the difference. 43 | go="${GO:-go}" 44 | readonly go 45 | 46 | # Set the build parameters unless already set. 47 | branch="${BRANCH:-$(git rev-parse --abbrev-ref HEAD)}" 48 | revision="${REVISION:-$(git rev-parse --short HEAD)}" 49 | version="${VERSION:-0}" 50 | readonly branch revision version 51 | 52 | # Set date and time of the latest commit unless already set. 53 | committime="${SOURCE_DATE_EPOCH:-$(git log -1 --pretty=%ct)}" 54 | readonly committime 55 | 56 | # Compile them in. 57 | version_pkg='github.com/AdguardTeam/dnsproxy/internal/version' 58 | ldflags="-s -w" 59 | ldflags="${ldflags} -X ${version_pkg}.branch=${branch}" 60 | ldflags="${ldflags} -X ${version_pkg}.committime=${committime}" 61 | ldflags="${ldflags} -X ${version_pkg}.revision=${revision}" 62 | ldflags="${ldflags} -X ${version_pkg}.version=${version}" 63 | readonly ldflags version_pkg 64 | 65 | # Allow users to limit the build's parallelism. 66 | parallelism="${PARALLELISM:-}" 67 | readonly parallelism 68 | 69 | # Use GOFLAGS for -p, because -p=0 simply disables the build instead of leaving 70 | # the default value. 71 | if [ "${parallelism}" != '' ]; then 72 | GOFLAGS="${GOFLAGS:-} -p=${parallelism}" 73 | fi 74 | readonly GOFLAGS 75 | export GOFLAGS 76 | 77 | # Allow users to specify a different output name. 78 | out="${OUT:-dnsproxy}" 79 | readonly out 80 | 81 | o_flags="-o=${out}" 82 | readonly o_flags 83 | 84 | # Allow users to enable the race detector. Unfortunately, that means that cgo 85 | # must be enabled. 86 | if [ "${RACE:-0}" -eq '0' ]; then 87 | CGO_ENABLED='0' 88 | race_flags='--race=0' 89 | else 90 | CGO_ENABLED='1' 91 | race_flags='--race=1' 92 | fi 93 | readonly CGO_ENABLED race_flags 94 | export CGO_ENABLED 95 | 96 | if [ "$verbose" -gt '0' ]; then 97 | "$go" env 98 | fi 99 | 100 | "$go" build \ 101 | --ldflags="$ldflags" \ 102 | "$race_flags" \ 103 | --trimpath \ 104 | "$o_flags" \ 105 | "$v_flags" \ 106 | "$x_flags" 107 | -------------------------------------------------------------------------------- /scripts/make/go-deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 2 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" mod download "$x_flags" 30 | -------------------------------------------------------------------------------- /scripts/make/go-test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 6 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Verbosity levels: 12 | # 0 = Don't print anything except for errors. 13 | # 1 = Print commands, but not nested commands. 14 | # 2 = Print everything. 15 | if [ "$verbose" -gt '1' ]; then 16 | set -x 17 | v_flags='-v=1' 18 | x_flags='-x=1' 19 | elif [ "$verbose" -gt '0' ]; then 20 | set -x 21 | v_flags='-v=1' 22 | x_flags='-x=0' 23 | else 24 | set +x 25 | v_flags='-v=0' 26 | x_flags='-x=0' 27 | fi 28 | readonly v_flags x_flags 29 | 30 | set -e -f -u 31 | 32 | if [ "${RACE:-1}" -eq '0' ]; then 33 | race_flags='--race=0' 34 | else 35 | race_flags='--race=1' 36 | fi 37 | readonly race_flags 38 | 39 | count_flags='--count=2' 40 | cover_flags='--coverprofile=./cover.out' 41 | go="${GO:-go}" 42 | shuffle_flags='--shuffle=on' 43 | timeout_flags="${TIMEOUT_FLAGS:---timeout=2m}" 44 | readonly count_flags cover_flags go shuffle_flags timeout_flags 45 | 46 | go_test() { 47 | "$go" test \ 48 | "$count_flags" \ 49 | "$cover_flags" \ 50 | "$race_flags" \ 51 | "$shuffle_flags" \ 52 | "$timeout_flags" \ 53 | "$v_flags" \ 54 | "$x_flags" \ 55 | ./... 56 | } 57 | 58 | test_reports_dir="${TEST_REPORTS_DIR:-}" 59 | readonly test_reports_dir 60 | 61 | if [ "$test_reports_dir" = '' ]; then 62 | go_test 63 | 64 | exit "$?" 65 | fi 66 | 67 | mkdir -p "$test_reports_dir" 68 | 69 | # NOTE: The pipe ignoring the exit code here is intentional, as go-junit-report 70 | # will set the exit code to be saved. 71 | go_test 2>&1 \ 72 | | tee "${test_reports_dir}/test-output.txt" 73 | 74 | # Don't fail on errors in exporting, because TEST_REPORTS_DIR is generally only 75 | # not empty in CI, and so the exit code must be preserved to exit with it later. 76 | set +e 77 | go-junit-report \ 78 | --in "${test_reports_dir}/test-output.txt" \ 79 | --set-exit-code \ 80 | >"${test_reports_dir}/test-report.xml" 81 | printf '%s\n' "$?" \ 82 | >"${test_reports_dir}/test-exit-code.txt" 83 | -------------------------------------------------------------------------------- /scripts/make/go-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 7 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | set -x 13 | v_flags='-v=1' 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | v_flags='-v=1' 18 | x_flags='-x=0' 19 | else 20 | set +x 21 | v_flags='-v=0' 22 | x_flags='-x=0' 23 | fi 24 | readonly v_flags x_flags 25 | 26 | set -e -f -u 27 | 28 | # Reset GOARCH and GOOS to make sure we install the tools for the native 29 | # architecture even when we're cross-compiling the main binary, and also to 30 | # prevent the "cannot install cross-compiled binaries when GOBIN is set" error. 31 | env \ 32 | GOARCH="" \ 33 | GOBIN="${PWD}/bin" \ 34 | GOOS="" \ 35 | GOWORK='off' \ 36 | "${GO:-go}" install "$v_flags" "$x_flags" tool \ 37 | ; 38 | -------------------------------------------------------------------------------- /scripts/make/go-upd-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 4 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" get -u "$x_flags" tool 30 | "$go" mod tidy "$x_flags" 31 | -------------------------------------------------------------------------------- /scripts/make/helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Common script helpers 4 | # 5 | # This file contains common script helpers. It should be sourced in scripts 6 | # right after the initial environment processing. 7 | 8 | # This comment is used to simplify checking local copies of the script. Bump 9 | # this number every time a remarkable change is made to this script. 10 | # 11 | # AdGuard-Project-Version: 4 12 | 13 | # Deferred helpers 14 | 15 | not_found_msg=' 16 | looks like a binary not found error. 17 | make sure you have installed the linter binaries using: 18 | 19 | $ make go-tools 20 | ' 21 | readonly not_found_msg 22 | 23 | not_found() { 24 | if [ "$?" -eq '127' ]; then 25 | # Code 127 is the exit status a shell uses when a command or a file is 26 | # not found, according to the Bash Hackers wiki. 27 | # 28 | # See https://wiki.bash-hackers.org/dict/terms/exit_status. 29 | echo "$not_found_msg" 1>&2 30 | fi 31 | } 32 | trap not_found EXIT 33 | 34 | # Helpers 35 | 36 | # run_linter runs the given linter with two additions: 37 | # 38 | # 1. If the first argument is "-e", run_linter exits with a nonzero exit code 39 | # if there is anything in the command's combined output. 40 | # 41 | # 2. In any case, run_linter adds the program's name to its combined output. 42 | run_linter() ( 43 | set +e 44 | 45 | if [ "${VERBOSE:-0}" -lt '2' ]; then 46 | set +x 47 | fi 48 | 49 | cmd="${1:?run_linter: provide a command}" 50 | shift 51 | 52 | exit_on_output='0' 53 | if [ "$cmd" = '-e' ]; then 54 | exit_on_output='1' 55 | cmd="${1:?run_linter: provide a command}" 56 | shift 57 | fi 58 | 59 | readonly cmd 60 | 61 | output="$("$cmd" "$@")" 62 | exitcode="$?" 63 | 64 | readonly output 65 | 66 | if [ "$output" != '' ]; then 67 | echo "$output" | sed -e "s/^/${cmd}: /" 68 | 69 | if [ "$exitcode" -eq '0' ] && [ "$exit_on_output" -eq '1' ]; then 70 | exitcode='1' 71 | fi 72 | fi 73 | 74 | return "$exitcode" 75 | ) 76 | -------------------------------------------------------------------------------- /scripts/make/md-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a remarkable change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | set -e -f -u 12 | 13 | if [ "$verbose" -gt '0' ]; then 14 | set -x 15 | fi 16 | 17 | markdownlint \ 18 | ./README.md \ 19 | ; 20 | -------------------------------------------------------------------------------- /scripts/make/sh-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a remarkable change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Don't use -f, because we use globs in this script. 12 | set -e -u 13 | 14 | if [ "$verbose" -gt '0' ]; then 15 | set -x 16 | fi 17 | 18 | # Source the common helpers, including not_found and run_linter. 19 | . ./scripts/make/helper.sh 20 | 21 | run_linter -e shfmt --binary-next-line -d -p -s \ 22 | ./scripts/hooks/* \ 23 | ./scripts/make/*.sh \ 24 | ; 25 | 26 | shellcheck -e 'SC2250' -f 'gcc' -o 'all' -x -- \ 27 | ./scripts/hooks/* \ 28 | ./scripts/make/*.sh \ 29 | ; 30 | -------------------------------------------------------------------------------- /scripts/make/txt-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a remarkable change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 8 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '0' ]; then 12 | set -x 13 | fi 14 | 15 | # Set $EXIT_ON_ERROR to zero to see all errors. 16 | if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]; then 17 | set +e 18 | else 19 | set -e 20 | fi 21 | 22 | # We don't need glob expansions and we want to see errors about unset variables. 23 | set -f -u 24 | 25 | # Source the common helpers, including not_found. 26 | . ./scripts/make/helper.sh 27 | 28 | # Simple analyzers 29 | 30 | # trailing_newlines is a simple check that makes sure that all plain-text files 31 | # have a trailing newlines to make sure that all tools work correctly with them. 32 | trailing_newlines() ( 33 | nl="$(printf '\n')" 34 | readonly nl 35 | 36 | find . \ 37 | -type 'f' \ 38 | '!' '(' \ 39 | -name '*.out' \ 40 | -o -name '*.test' \ 41 | -o -name 'dnsproxy' \ 42 | -o -path './.git/*' \ 43 | -o -path './bin/*' \ 44 | ')' \ 45 | | while read -r f; do 46 | final_byte="$(tail -c -1 "$f")" 47 | if [ "$final_byte" != "$nl" ]; then 48 | printf '%s: must have a trailing newline\n' "$f" 49 | fi 50 | done 51 | ) 52 | 53 | # trailing_whitespace is a simple check that makes sure that there are no 54 | # trailing whitespace in plain-text files. 55 | trailing_whitespace() { 56 | find . \ 57 | -type 'f' \ 58 | '!' '(' \ 59 | -name '*.out' \ 60 | -o -name '*.test' \ 61 | -o -name 'dnsproxy' \ 62 | -o -path './.git/*' \ 63 | -o -path './bin/*' \ 64 | ')' \ 65 | | while read -r f; do 66 | grep -e '[[:space:]]$' -n -- "$f" \ 67 | | sed -e "s:^:${f}\::" -e 's/ \+$/>>>&<< timeout { 49 | t.Fatalf("exchange took more time than the configured timeout: %v", elapsed) 50 | } 51 | } 52 | 53 | func TestExchangeParallelEmpty(t *testing.T) { 54 | ups := []Upstream{ 55 | &testUpstream{empty: true}, 56 | &testUpstream{empty: true}, 57 | } 58 | 59 | req := createTestMessage() 60 | resp, up, err := ExchangeParallel(ups, req) 61 | require.Error(t, err) 62 | 63 | assert.Nil(t, resp) 64 | assert.Nil(t, up) 65 | } 66 | 67 | // testUpstream represents a mock upstream structure. 68 | type testUpstream struct { 69 | // addr is a mock A record IP address to be returned. 70 | addr netip.Addr 71 | 72 | // err is a mock error to be returned. 73 | err bool 74 | 75 | // empty indicates if a nil response is returned. 76 | empty bool 77 | 78 | // sleep is a delay before response. 79 | sleep time.Duration 80 | } 81 | 82 | // type check 83 | var _ Upstream = (*testUpstream)(nil) 84 | 85 | // Exchange implements the [Upstream] interface for *testUpstream. 86 | func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { 87 | if u.sleep != 0 { 88 | time.Sleep(u.sleep) 89 | } 90 | 91 | if u.empty { 92 | return nil, nil 93 | } 94 | 95 | if u.err { 96 | return nil, fmt.Errorf("upstream error") 97 | } 98 | 99 | resp = &dns.Msg{} 100 | resp.SetReply(req) 101 | 102 | if u.addr != (netip.Addr{}) { 103 | a := dns.A{ 104 | A: u.addr.AsSlice(), 105 | } 106 | 107 | resp.Answer = append(resp.Answer, &a) 108 | } 109 | 110 | return resp, nil 111 | } 112 | 113 | // Address implements the [Upstream] interface for *testUpstream. 114 | func (u *testUpstream) Address() (addr string) { 115 | return "" 116 | } 117 | 118 | // Close implements the [Upstream] interface for *testUpstream. 119 | func (u *testUpstream) Close() (err error) { 120 | return nil 121 | } 122 | 123 | func TestExchangeAll(t *testing.T) { 124 | delayedAnsAddr := netip.MustParseAddr("1.1.1.1") 125 | ansAddr := netip.MustParseAddr("3.3.3.3") 126 | 127 | ups := []Upstream{&testUpstream{ 128 | addr: delayedAnsAddr, 129 | sleep: 100 * time.Millisecond, 130 | }, &testUpstream{ 131 | err: true, 132 | }, &testUpstream{ 133 | addr: ansAddr, 134 | }} 135 | 136 | req := createHostTestMessage("test.org") 137 | res, err := ExchangeAll(ups, req) 138 | require.NoError(t, err) 139 | require.Len(t, res, 2) 140 | 141 | resp := res[0].Resp 142 | require.NotNil(t, resp) 143 | require.NotEmpty(t, resp.Answer) 144 | require.IsType(t, new(dns.A), resp.Answer[0]) 145 | 146 | ip := resp.Answer[0].(*dns.A).A 147 | assert.Equal(t, ansAddr.AsSlice(), []byte(ip)) 148 | 149 | resp = res[1].Resp 150 | require.NotNil(t, resp) 151 | require.NotEmpty(t, resp.Answer) 152 | require.IsType(t, new(dns.A), resp.Answer[0]) 153 | 154 | ip = resp.Answer[0].(*dns.A).A 155 | assert.Equal(t, delayedAnsAddr.AsSlice(), []byte(ip)) 156 | } 157 | -------------------------------------------------------------------------------- /upstream/resolver_internal_test.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "testing" 7 | "time" 8 | 9 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 10 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/miekg/dns" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestCachingResolver_staleness(t *testing.T) { 18 | ip4 := netip.MustParseAddr("1.2.3.4") 19 | ip6 := netip.MustParseAddr("2001:db8::1") 20 | 21 | const ( 22 | smallTTL = 10 * time.Second 23 | largeTTL = 1000 * time.Second 24 | 25 | fqdn = "test.fully.qualified.name." 26 | ) 27 | 28 | onExchange := func(req *dns.Msg) (resp *dns.Msg, err error) { 29 | resp = (&dns.Msg{}).SetReply(req) 30 | 31 | hdr := dns.RR_Header{ 32 | Name: req.Question[0].Name, 33 | Rrtype: req.Question[0].Qtype, 34 | Class: dns.ClassINET, 35 | } 36 | var rr dns.RR 37 | switch q := req.Question[0]; q.Qtype { 38 | case dns.TypeA: 39 | hdr.Ttl = uint32(smallTTL.Seconds()) 40 | rr = &dns.A{Hdr: hdr, A: ip4.AsSlice()} 41 | case dns.TypeAAAA: 42 | hdr.Ttl = uint32(largeTTL.Seconds()) 43 | rr = &dns.AAAA{Hdr: hdr, AAAA: ip6.AsSlice()} 44 | default: 45 | require.Contains(testutil.PanicT{}, []uint16{dns.TypeA, dns.TypeAAAA}, q.Qtype) 46 | } 47 | resp.Answer = append(resp.Answer, rr) 48 | 49 | return resp, nil 50 | } 51 | 52 | ups := &dnsproxytest.FakeUpstream{ 53 | OnAddress: func() (_ string) { panic("not implemented") }, 54 | OnClose: func() (_ error) { panic("not implemented") }, 55 | OnExchange: onExchange, 56 | } 57 | 58 | r := NewCachingResolver(&UpstreamResolver{Upstream: ups}) 59 | 60 | require.True(t, t.Run("resolve", func(t *testing.T) { 61 | testCases := []struct { 62 | name string 63 | network bootstrap.Network 64 | want []netip.Addr 65 | }{{ 66 | name: "ip4", 67 | network: bootstrap.NetworkIP4, 68 | want: []netip.Addr{ip4}, 69 | }, { 70 | name: "ip6", 71 | network: bootstrap.NetworkIP6, 72 | want: []netip.Addr{ip6}, 73 | }, { 74 | name: "both", 75 | network: bootstrap.NetworkIP, 76 | want: []netip.Addr{ip4, ip6}, 77 | }} 78 | 79 | for _, tc := range testCases { 80 | t.Run(tc.name, func(t *testing.T) { 81 | if tc.name != "both" { 82 | t.Skip(`TODO(e.burkov): Bootstrap now only uses "ip" network, see TODO there.`) 83 | } 84 | 85 | res, err := r.LookupNetIP(context.Background(), tc.network, fqdn) 86 | require.NoError(t, err) 87 | 88 | assert.ElementsMatch(t, tc.want, res) 89 | }) 90 | } 91 | })) 92 | 93 | t.Run("staleness", func(t *testing.T) { 94 | now := time.Now() 95 | cached := r.findCached(fqdn, now) 96 | require.ElementsMatch(t, []netip.Addr{ip4, ip6}, cached) 97 | 98 | cached = r.findCached(fqdn, now.Add(smallTTL+time.Second)) 99 | require.Empty(t, cached) 100 | }) 101 | } 102 | -------------------------------------------------------------------------------- /upstream/resolver_test.go: -------------------------------------------------------------------------------- 1 | package upstream_test 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "testing" 7 | "time" 8 | 9 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 10 | "github.com/AdguardTeam/dnsproxy/upstream" 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/AdguardTeam/golibs/logutil/slogutil" 13 | "github.com/miekg/dns" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func TestNewUpstreamResolver(t *testing.T) { 19 | ups := &dnsproxytest.FakeUpstream{ 20 | OnAddress: func() (_ string) { panic("not implemented") }, 21 | OnClose: func() (_ error) { panic("not implemented") }, 22 | OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { 23 | resp = (&dns.Msg{}).SetReply(req) 24 | resp.Answer = []dns.RR{&dns.A{ 25 | Hdr: dns.RR_Header{ 26 | Name: req.Question[0].Name, 27 | Rrtype: dns.TypeA, 28 | Class: dns.ClassINET, 29 | Ttl: 60, 30 | }, 31 | A: netip.MustParseAddr("1.2.3.4").AsSlice(), 32 | }} 33 | 34 | return resp, nil 35 | }, 36 | } 37 | 38 | r := &upstream.UpstreamResolver{Upstream: ups} 39 | 40 | ipAddrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com") 41 | require.NoError(t, err) 42 | 43 | assert.NotEmpty(t, ipAddrs) 44 | } 45 | 46 | func TestNewUpstreamResolver_validity(t *testing.T) { 47 | t.Parallel() 48 | 49 | withTimeoutOpt := &upstream.Options{ 50 | Logger: slogutil.NewDiscardLogger(), 51 | Timeout: 3 * time.Second, 52 | } 53 | 54 | testCases := []struct { 55 | name string 56 | addr string 57 | wantErrMsg string 58 | }{{ 59 | name: "udp", 60 | addr: "1.1.1.1:53", 61 | wantErrMsg: "", 62 | }, { 63 | name: "dot", 64 | addr: "tls://1.1.1.1", 65 | wantErrMsg: "", 66 | }, { 67 | name: "doh", 68 | addr: "https://1.1.1.1/dns-query", 69 | wantErrMsg: "", 70 | }, { 71 | name: "sdns", 72 | addr: "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", 73 | wantErrMsg: "", 74 | }, { 75 | name: "tcp", 76 | addr: "tcp://9.9.9.9", 77 | wantErrMsg: "", 78 | }, { 79 | name: "invalid_tls", 80 | addr: "tls://dns.adguard.com", 81 | wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` + 82 | `unexpected character (at "dns.adguard.com")`, 83 | }, { 84 | name: "invalid_https", 85 | addr: "https://dns.adguard.com/dns-query", 86 | wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` + 87 | `unexpected character (at "dns.adguard.com")`, 88 | }, { 89 | name: "invalid_tcp", 90 | addr: "tcp://dns.adguard.com", 91 | wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` + 92 | `unexpected character (at "dns.adguard.com")`, 93 | }, { 94 | name: "invalid_no_scheme", 95 | addr: "dns.adguard.com", 96 | wantErrMsg: `not a bootstrap: ParseAddr("dns.adguard.com"): ` + 97 | `unexpected character (at "dns.adguard.com")`, 98 | }} 99 | 100 | for _, tc := range testCases { 101 | t.Run(tc.name, func(t *testing.T) { 102 | t.Parallel() 103 | 104 | r, err := upstream.NewUpstreamResolver(tc.addr, withTimeoutOpt) 105 | if tc.wantErrMsg != "" { 106 | assert.Equal(t, tc.wantErrMsg, err.Error()) 107 | if nberr := (&upstream.NotBootstrapError{}); errors.As(err, &nberr) { 108 | assert.NotNil(t, r) 109 | } 110 | 111 | return 112 | } 113 | 114 | require.NoError(t, err) 115 | 116 | addrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com") 117 | require.NoError(t, err) 118 | 119 | assert.NotEmpty(t, addrs) 120 | }) 121 | } 122 | } 123 | --------------------------------------------------------------------------------