├── .gitattributes ├── .dockerignore ├── internal ├── handler │ ├── testdata │ │ └── TestDefault_resolveFromHosts │ │ │ └── hosts │ ├── handler.go │ ├── ipv6halt.go │ ├── default.go │ ├── hosts.go │ └── constructor.go ├── dnsproxytest │ ├── dnsproxytest.go │ └── interface.go ├── bootstrap │ ├── error.go │ ├── resolver_test.go │ ├── bootstrap.go │ └── resolver.go ├── netutil │ ├── testdata │ │ └── TestHosts │ │ │ ├── bad_file │ │ │ └── hosts │ │ │ └── good_file │ │ │ └── hosts │ ├── listenconfig_windows.go │ ├── paths.go │ ├── udpoob_others.go │ ├── paths_windows.go │ ├── paths_unix.go │ ├── udp_windows.go │ ├── udpoob_darwin.go │ ├── listenconfig.go │ ├── netutil.go │ ├── udp.go │ ├── listenconfig_unix.go │ └── udp_unix.go ├── version │ └── version.go ├── cmd │ ├── tls.go │ ├── flag.go │ └── cmd.go └── dnsmsg │ └── constructor.go ├── main.go ├── .codecov.yml ├── proxy ├── constructor.go ├── upstreammode_test.go ├── serverudp_internal_test.go ├── errors.go ├── errors_plan9.go ├── bogusnxdomain.go ├── errors_internal_test.go ├── lookup_internal_test.go ├── retry.go ├── servertcp_internal_test.go ├── upstreammode.go ├── ratelimit.go ├── retry_internal_test.go ├── handler_internal_test.go ├── optimisticresolver.go ├── ratelimit_internal_test.go ├── helpers.go ├── recursiondetector.go ├── lookup.go ├── beforerequest.go ├── serverdnscrypt_internal_test.go ├── bogusnxdomain_internal_test.go ├── optimisticresolver_internal_test.go ├── beforerequest_internal_test.go ├── serverdnscrypt.go ├── proxycache.go ├── recursiondetector_internal_test.go ├── exchange.go ├── pending_test.go └── pending.go ├── upstream ├── upstream_test.go ├── resolver_internal_test.go ├── dot_windows.go ├── dot_unix.go ├── hostsresolver_test.go ├── hostsresolver.go ├── parallel_internal_test.go ├── dnscrypt.go └── parallel.go ├── scripts ├── make │ ├── md-lint.sh │ ├── go-deps.sh │ ├── go-upd-tools.sh │ ├── sh-lint.sh │ ├── go-tools.sh │ ├── go-test.sh │ ├── txt-lint.sh │ ├── helper.sh │ ├── go-build.sh │ ├── build-docker.sh │ └── build-release.sh └── hooks │ └── pre-commit ├── .markdownlint.json ├── .gitignore ├── config.yaml.dist ├── staticcheck.conf ├── proxyutil └── dns.go ├── docker ├── README.md └── Dockerfile ├── .github └── workflows │ ├── lint.yaml │ ├── docker.yml │ └── build.yaml ├── fastip ├── cache_internal_test.go ├── cache.go ├── ping.go └── fastest_internal_test.go ├── bamboo-specs └── bamboo.yaml ├── Makefile └── go.mod /.gitattributes: -------------------------------------------------------------------------------- 1 | vendor/** binary 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore everything except for explicitly allowed stuff. 2 | * 3 | !build/docker 4 | -------------------------------------------------------------------------------- /internal/handler/testdata/TestDefault_resolveFromHosts/hosts: -------------------------------------------------------------------------------- 1 | 1.2.3.4 ipv4.domain.example 2 | 2001:db8::1 ipv6.domain.example 3 | # comment 4 | -------------------------------------------------------------------------------- /internal/handler/handler.go: -------------------------------------------------------------------------------- 1 | // Package handler provides some customizable DNS request handling logic used in 2 | // the proxy. 3 | package handler 4 | -------------------------------------------------------------------------------- /internal/dnsproxytest/dnsproxytest.go: -------------------------------------------------------------------------------- 1 | // Package dnsproxytest provides a set of test utilities for the dnsproxy 2 | // module. 3 | package dnsproxytest 4 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/cmd" 5 | ) 6 | 7 | func main() { 8 | cmd.Main() 9 | } 10 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 40% 6 | threshold: null 7 | patch: false 8 | changes: false 9 | -------------------------------------------------------------------------------- /proxy/constructor.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/dnsmsg" 5 | ) 6 | 7 | // MessageConstructor creates DNS messages. 8 | type MessageConstructor = dnsmsg.MessageConstructor 9 | -------------------------------------------------------------------------------- /internal/bootstrap/error.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import "github.com/AdguardTeam/golibs/errors" 4 | 5 | // ErrNoResolvers is returned when zero resolvers specified. 6 | const ErrNoResolvers errors.Error = "no resolvers specified" 7 | -------------------------------------------------------------------------------- /internal/netutil/testdata/TestHosts/bad_file/hosts: -------------------------------------------------------------------------------- 1 | # comment about the following empty line 2 | 3 | # comment about the above empty line 4 | 5 | 1.2.3.256 a.b # invalid address 6 | 1.2.3.4 a.123 # invalid top-level domain 7 | 1.2.3.4 .a.b # empty domain 8 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import "syscall" 6 | 7 | // defaultListenControl is nil on Windows, because it doesn't support 8 | // SO_REUSEPORT. 9 | func (listenControl) defaultListenControl(_, _ string, _ syscall.RawConn) (err error) { 10 | return nil 11 | } 12 | -------------------------------------------------------------------------------- /upstream/upstream_test.go: -------------------------------------------------------------------------------- 1 | package upstream_test 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/AdguardTeam/golibs/logutil/slogutil" 7 | ) 8 | 9 | // testTimeout is common timeout for tests. 10 | const testTimeout = 1 * time.Second 11 | 12 | // testLogger is common logger for tests. 13 | var testLogger = slogutil.NewDiscardLogger() 14 | -------------------------------------------------------------------------------- /internal/netutil/paths.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | // DefaultHostsPaths returns the slice of default paths to system hosts files. 4 | // 5 | // TODO(s.chzhen): Since [fs.FS] is no longer needed, update the 6 | // [hostsfile.DefaultHostsPaths] from golibs. 7 | func DefaultHostsPaths() (paths []string, err error) { 8 | return defaultHostsPaths() 9 | } 10 | -------------------------------------------------------------------------------- /upstream/resolver_internal_test.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "net/netip" 5 | "time" 6 | ) 7 | 8 | // FindCached exports the internal method r.findCached for testing. 9 | // 10 | // TODO(e.burkov): Find a way of testing without it. 11 | func (r *CachingResolver) FindCached(host string, now time.Time) (addrs []netip.Addr) { 12 | return r.findCached(host, now) 13 | } 14 | -------------------------------------------------------------------------------- /upstream/dot_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package upstream 4 | 5 | import ( 6 | "github.com/AdguardTeam/golibs/errors" 7 | "golang.org/x/sys/windows" 8 | ) 9 | 10 | // isConnBroken returns true if err means that a connection is broken. 11 | func isConnBroken(err error) (ok bool) { 12 | return errors.Is(err, windows.WSAECONNABORTED) || errors.Is(err, windows.WSAECONNRESET) 13 | } 14 | -------------------------------------------------------------------------------- /upstream/dot_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd || linux || openbsd || netbsd 2 | 3 | package upstream 4 | 5 | import ( 6 | "github.com/AdguardTeam/golibs/errors" 7 | "golang.org/x/sys/unix" 8 | ) 9 | 10 | // isConnBroken returns true if err means that a connection is broken. 11 | func isConnBroken(err error) (ok bool) { 12 | return errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ETIMEDOUT) 13 | } 14 | -------------------------------------------------------------------------------- /scripts/make/md-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | set -e -f -u 12 | 13 | if [ "$verbose" -gt '0' ]; then 14 | set -x 15 | fi 16 | 17 | markdownlint \ 18 | ./README.md \ 19 | ; 20 | -------------------------------------------------------------------------------- /proxy/upstreammode_test.go: -------------------------------------------------------------------------------- 1 | package proxy_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/dnsproxy/proxy" 7 | "github.com/AdguardTeam/golibs/testutil" 8 | ) 9 | 10 | func TestUpstreamMode_encoding(t *testing.T) { 11 | t.Parallel() 12 | 13 | v := proxy.UpstreamModeLoadBalance 14 | 15 | testutil.AssertMarshalText(t, "load_balance", &v) 16 | testutil.AssertUnmarshalText(t, "load_balance", &v) 17 | } 18 | -------------------------------------------------------------------------------- /proxy/serverudp_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/miekg/dns" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestUdpProxy(t *testing.T) { 11 | dnsProxy := mustStartDefaultProxy(t) 12 | 13 | // Create a DNS-over-UDP client connection 14 | addr := dnsProxy.Addr(ProtoUDP) 15 | conn, err := dns.Dial("udp", addr.String()) 16 | require.NoError(t, err) 17 | 18 | sendTestMessages(t, conn) 19 | } 20 | -------------------------------------------------------------------------------- /internal/netutil/udpoob_others.go: -------------------------------------------------------------------------------- 1 | //go:build !darwin 2 | 3 | package netutil 4 | 5 | import ( 6 | "net/netip" 7 | 8 | "golang.org/x/net/ipv4" 9 | "golang.org/x/net/ipv6" 10 | ) 11 | 12 | // udpMakeOOBWithSrc makes the OOB data with the specified source IP. 13 | func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) { 14 | if ip.Is4() { 15 | return (&ipv4.ControlMessage{ 16 | Src: ip.AsSlice(), 17 | }).Marshal() 18 | } 19 | 20 | return (&ipv6.ControlMessage{ 21 | Src: ip.AsSlice(), 22 | }).Marshal() 23 | } 24 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "ul-indent": { 3 | "indent": 4 4 | }, 5 | "ul-style": { 6 | "style": "dash" 7 | }, 8 | "emphasis-style": { 9 | "style": "asterisk" 10 | }, 11 | "no-duplicate-heading": { 12 | "siblings_only": true 13 | }, 14 | "no-inline-html": { 15 | "allowed_elements": [ 16 | "a" 17 | ] 18 | }, 19 | "no-trailing-spaces": { 20 | "br_spaces": 0 21 | }, 22 | "line-length": false, 23 | "no-bare-urls": false, 24 | "no-emphasis-as-heading": false, 25 | "link-fragments": false 26 | } 27 | -------------------------------------------------------------------------------- /internal/netutil/paths_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import ( 6 | "fmt" 7 | "path" 8 | 9 | "golang.org/x/sys/windows" 10 | ) 11 | 12 | // defaultHostsPaths returns default paths to hosts files for Windows. 13 | func defaultHostsPaths() (paths []string, err error) { 14 | sysDir, err := windows.GetSystemDirectory() 15 | if err != nil { 16 | return []string{}, fmt.Errorf("getting system directory: %w", err) 17 | } 18 | 19 | p := path.Join(sysDir, "drivers", "etc", "hosts") 20 | 21 | return []string{p}, nil 22 | } 23 | -------------------------------------------------------------------------------- /internal/netutil/paths_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package netutil 4 | 5 | import "github.com/AdguardTeam/golibs/hostsfile" 6 | 7 | // defaultHostsPaths returns default paths to hosts files for UNIX. 8 | func defaultHostsPaths() (paths []string, err error) { 9 | paths, err = hostsfile.DefaultHostsPaths() 10 | if err != nil { 11 | // Should not happen because error is always nil. 12 | panic(err) 13 | } 14 | 15 | res := make([]string, 0, len(paths)) 16 | for _, p := range paths { 17 | res = append(res, "/"+p) 18 | } 19 | 20 | return res, nil 21 | } 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Please, DO NOT put your text editors' temporary files here. The more are 2 | # added, the harder it gets to maintain and manage projects' gitignores. Put 3 | # them into your global gitignore file instead. 4 | # 5 | # See https://stackoverflow.com/a/7335487/1892060. 6 | # 7 | # Only build, run, and test outputs here. Sorted. With negations at the 8 | # bottom to make sure they take effect. 9 | *.exe 10 | *.out 11 | *.test 12 | /bin/ 13 | /tmp/ 14 | build 15 | dnsproxy 16 | dnsproxy.exe 17 | example.crt 18 | example.key 19 | coverage.txt 20 | config.yaml 21 | test-reports/ 22 | -------------------------------------------------------------------------------- /internal/handler/ipv6halt.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // haltAAAA halts the processing of AAAA requests if IPv6 is disabled. req must 10 | // not be nil. 11 | func (h *Default) haltAAAA(ctx context.Context, req *dns.Msg) (resp *dns.Msg) { 12 | if h.isIPv6Halted && req.Question[0].Qtype == dns.TypeAAAA { 13 | h.logger.DebugContext( 14 | ctx, 15 | "ipv6 is disabled; replying with empty response", 16 | "req", req.Question[0].Name, 17 | ) 18 | 19 | return h.messages.NewMsgNODATA(req) 20 | } 21 | 22 | return nil 23 | } 24 | -------------------------------------------------------------------------------- /scripts/make/go-deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 2 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" mod download "$x_flags" 30 | -------------------------------------------------------------------------------- /proxy/errors.go: -------------------------------------------------------------------------------- 1 | //go:build !plan9 2 | // +build !plan9 3 | 4 | package proxy 5 | 6 | import ( 7 | "syscall" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | ) 11 | 12 | // isEPIPE checks if the underlying error is EPIPE. syscall.EPIPE exists on all 13 | // OSes except for Plan 9. Validate with: 14 | // 15 | // $ for os in $(go tool dist list | cut -d / -f 1 | sort -u) 16 | // do 17 | // echo -n "$os" 18 | // env GOOS="$os" go doc syscall.EPIPE | grep -F -e EPIPE 19 | // done 20 | // 21 | // For the Plan 9 version see ./errors_plan9.go. 22 | func isEPIPE(err error) (ok bool) { 23 | return errors.Is(err, syscall.EPIPE) 24 | } 25 | -------------------------------------------------------------------------------- /proxy/errors_plan9.go: -------------------------------------------------------------------------------- 1 | //go:build plan9 2 | // +build plan9 3 | 4 | package proxy 5 | 6 | import "strings" 7 | 8 | // isEPIPE checks if the underlying error is EPIPE. Plan 9 relies on error 9 | // strings instead of error codes. I couldn't find the exact constant with the 10 | // text returned by a write on a closed socket, but it seems to be "sys: write 11 | // on closed pipe". See Plan 9's "man 2 notify". 12 | // 13 | // We don't currently support Plan 9, so it's not critical, but when we do, this 14 | // needs to be rechecked. 15 | func isEPIPE(err error) (ok bool) { 16 | return strings.Contains(err.Error(), "write on closed pipe") 17 | } 18 | -------------------------------------------------------------------------------- /config.yaml.dist: -------------------------------------------------------------------------------- 1 | # This is the yaml configuration file for dnsproxy with minimal working 2 | # configuration, all the options available can be seen with ./dnsproxy --help. 3 | # To use it within dnsproxy specify the --config-path=/ 4 | # option. Any other command-line options specified will override the values 5 | # from the config file. 6 | --- 7 | bootstrap: 8 | - "8.8.8.8:53" 9 | listen-addrs: 10 | - "0.0.0.0" 11 | listen-ports: 12 | - 53 13 | max-go-routines: 0 14 | ratelimit: 0 15 | ratelimit-subnet-len-ipv4: 24 16 | ratelimit-subnet-len-ipv6: 64 17 | udp-buf-size: 0 18 | upstream: 19 | - "1.1.1.1:53" 20 | timeout: '10s' 21 | -------------------------------------------------------------------------------- /scripts/make/go-upd-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 4 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" get -u "$x_flags" tool 30 | "$go" mod tidy "$x_flags" 31 | -------------------------------------------------------------------------------- /internal/netutil/udp_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netutil 4 | 5 | import ( 6 | "net" 7 | "net/netip" 8 | ) 9 | 10 | func udpGetOOBSize() int { 11 | return 0 12 | } 13 | 14 | func udpSetOptions(c *net.UDPConn) error { 15 | return nil 16 | } 17 | 18 | func udpRead(c *net.UDPConn, buf []byte, _ int) (int, netip.Addr, *net.UDPAddr, error) { 19 | n, addr, err := c.ReadFrom(buf) 20 | var udpAddr *net.UDPAddr 21 | if addr != nil { 22 | udpAddr = addr.(*net.UDPAddr) 23 | } 24 | 25 | return n, netip.Addr{}, udpAddr, err 26 | } 27 | 28 | func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ netip.Addr) (int, error) { 29 | return conn.WriteTo(bytes, remoteAddr) 30 | } 31 | -------------------------------------------------------------------------------- /staticcheck.conf: -------------------------------------------------------------------------------- 1 | # This comment is used to simplify checking local copies of the staticcheck 2 | # configuration. Bump this number every time a significant change is made to 3 | # this file. 4 | # 5 | # AdGuard-Project-Version: 1 6 | checks = ["all"] 7 | initialisms = [ 8 | # See https://github.com/dominikh/go-tools/blob/master/config/config.go. 9 | # 10 | # Do not add "PTR" since we use "Ptr" as a suffix. 11 | "inherit" 12 | , "ASN" 13 | , "DHCP" 14 | , "DNSSEC" 15 | # E.g. SentryDSN. 16 | , "DSN" 17 | , "ECS" 18 | , "EDNS" 19 | , "MX" 20 | , "QUIC" 21 | , "RA" 22 | , "RRSIG" 23 | , "RTT" 24 | , "SDNS" 25 | , "SLAAC" 26 | , "SOA" 27 | , "SVCB" 28 | , "TLD" 29 | , "WHOIS" 30 | ] 31 | dot_import_whitelist = [] 32 | http_status_code_whitelist = [] 33 | -------------------------------------------------------------------------------- /internal/netutil/udpoob_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | 3 | package netutil 4 | 5 | import ( 6 | "net/netip" 7 | 8 | "golang.org/x/net/ipv6" 9 | ) 10 | 11 | // udpMakeOOBWithSrc makes the OOB data with the specified source IP. 12 | func udpMakeOOBWithSrc(ip netip.Addr) (b []byte) { 13 | if ip.Is4() { 14 | // Do not set the IPv4 source address via OOB, because it can cause the 15 | // address to become unspecified on darwin. 16 | // 17 | // See https://github.com/AdguardTeam/AdGuardHome/issues/2807. 18 | // 19 | // TODO(e.burkov): Develop a workaround to make it write OOB only when 20 | // listening on an unspecified address. 21 | return []byte{} 22 | } 23 | 24 | return (&ipv6.ControlMessage{ 25 | Src: ip.AsSlice(), 26 | }).Marshal() 27 | } 28 | -------------------------------------------------------------------------------- /scripts/make/sh-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Don't use -f, because we use globs in this script. 12 | set -e -u 13 | 14 | if [ "$verbose" -gt '0' ]; then 15 | set -x 16 | fi 17 | 18 | # Source the common helpers, including not_found and run_linter. 19 | . ./scripts/make/helper.sh 20 | 21 | run_linter -e shfmt --binary-next-line -d -p -s \ 22 | ./scripts/hooks/* \ 23 | ./scripts/make/*.sh \ 24 | ; 25 | 26 | shellcheck -e 'SC2250' -f 'gcc' -o 'all' -x -- \ 27 | ./scripts/hooks/* \ 28 | ./scripts/make/*.sh \ 29 | ; 30 | -------------------------------------------------------------------------------- /proxy/bogusnxdomain.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/proxyutil" 5 | "github.com/AdguardTeam/golibs/netutil" 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // isBogusNXDomain returns true if m contains at least a single IP address in 10 | // the Answer section contained in BogusNXDomain subnets of p. 11 | func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) { 12 | if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 { 13 | return false 14 | } else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA { 15 | return false 16 | } 17 | 18 | set := netutil.SliceSubnetSet(p.BogusNXDomain) 19 | for _, rr := range m.Answer { 20 | ip := proxyutil.IPFromRR(rr) 21 | if set.Contains(ip) { 22 | return true 23 | } 24 | } 25 | 26 | return false 27 | } 28 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | import ( 4 | "log/slog" 5 | "net" 6 | ) 7 | 8 | // ListenConfig returns the default [net.ListenConfig] used by the plain-DNS 9 | // servers in this module. l must not be nil. 10 | // 11 | // TODO(a.garipov): Add tests. 12 | // 13 | // TODO(a.garipov): Add an option to not set SO_REUSEPORT on Unix to prevent 14 | // issues with OpenWrt. 15 | // 16 | // See https://github.com/AdguardTeam/AdGuardHome/issues/5872. 17 | // 18 | // TODO(a.garipov): DRY with AdGuard DNS when we can. 19 | func ListenConfig(l *slog.Logger) (lc *net.ListenConfig) { 20 | return &net.ListenConfig{ 21 | Control: listenControl{logger: l}.defaultListenControl, 22 | } 23 | } 24 | 25 | // listenControl is a wrapper struct with logger. 26 | type listenControl struct { 27 | logger *slog.Logger 28 | } 29 | -------------------------------------------------------------------------------- /proxyutil/dns.go: -------------------------------------------------------------------------------- 1 | // Package proxyutil contains helper functions that are used in all other 2 | // dnsproxy packages. 3 | package proxyutil 4 | 5 | import ( 6 | "encoding/binary" 7 | "net/netip" 8 | 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // AddPrefix adds a 2-byte prefix with the DNS message length. 13 | func AddPrefix(b []byte) (m []byte) { 14 | m = make([]byte, 2+len(b)) 15 | binary.BigEndian.PutUint16(m, uint16(len(b))) 16 | copy(m[2:], b) 17 | 18 | return m 19 | } 20 | 21 | // IPFromRR returns the IP address from rr if any. 22 | func IPFromRR(rr dns.RR) (ip netip.Addr) { 23 | var data []byte 24 | switch rr := rr.(type) { 25 | case *dns.A: 26 | data = rr.A.To4() 27 | case *dns.AAAA: 28 | data = rr.AAAA 29 | default: 30 | return netip.Addr{} 31 | } 32 | 33 | ip, _ = netip.AddrFromSlice(data) 34 | 35 | return ip 36 | } 37 | -------------------------------------------------------------------------------- /internal/netutil/netutil.go: -------------------------------------------------------------------------------- 1 | // Package netutil contains network-related utilities common among dnsproxy 2 | // packages. 3 | // 4 | // TODO(a.garipov): Move improved versions of these into netutil in module 5 | // golibs. 6 | package netutil 7 | 8 | import ( 9 | "net/netip" 10 | "strings" 11 | ) 12 | 13 | // ParseSubnet parses s either as a CIDR prefix itself, or as an IP address, 14 | // returning the corresponding single-IP CIDR prefix. 15 | // 16 | // TODO(e.burkov): Replace usages with [netutil.Prefix]. 17 | func ParseSubnet(s string) (p netip.Prefix, err error) { 18 | if strings.Contains(s, "/") { 19 | p, err = netip.ParsePrefix(s) 20 | if err != nil { 21 | return netip.Prefix{}, err 22 | } 23 | } else { 24 | var ip netip.Addr 25 | ip, err = netip.ParseAddr(s) 26 | if err != nil { 27 | return netip.Prefix{}, err 28 | } 29 | 30 | p = netip.PrefixFrom(ip, ip.BitLen()) 31 | } 32 | 33 | return p, nil 34 | } 35 | -------------------------------------------------------------------------------- /proxy/errors_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build !plan9 2 | // +build !plan9 3 | 4 | package proxy 5 | 6 | import ( 7 | "fmt" 8 | "syscall" 9 | "testing" 10 | 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestIsEPIPE(t *testing.T) { 16 | type testCase struct { 17 | err error 18 | name string 19 | want bool 20 | } 21 | 22 | testCases := []testCase{{ 23 | name: "nil", 24 | err: nil, 25 | want: false, 26 | }, { 27 | name: "epipe", 28 | err: syscall.EPIPE, 29 | want: true, 30 | }, { 31 | name: "not_epipe", 32 | err: errors.Error("test error"), 33 | want: false, 34 | }, { 35 | name: "wrapped_epipe", 36 | err: fmt.Errorf("test error: %w", syscall.EPIPE), 37 | want: true, 38 | }} 39 | 40 | for _, tc := range testCases { 41 | t.Run(tc.name, func(t *testing.T) { 42 | got := isEPIPE(tc.err) 43 | assert.Equal(t, tc.want, got) 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /scripts/make/go-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 7 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | set -x 13 | v_flags='-v=1' 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | v_flags='-v=1' 18 | x_flags='-x=0' 19 | else 20 | set +x 21 | v_flags='-v=0' 22 | x_flags='-x=0' 23 | fi 24 | readonly v_flags x_flags 25 | 26 | set -e -f -u 27 | 28 | # Reset GOARCH and GOOS to make sure we install the tools for the native 29 | # architecture even when we're cross-compiling the main binary, and also to 30 | # prevent the "cannot install cross-compiled binaries when GOBIN is set" error. 31 | env \ 32 | GOARCH="" \ 33 | GOBIN="${PWD}/bin" \ 34 | GOOS="" \ 35 | GOWORK='off' \ 36 | "${GO:-go}" install "$v_flags" "$x_flags" tool \ 37 | ; 38 | -------------------------------------------------------------------------------- /internal/version/version.go: -------------------------------------------------------------------------------- 1 | // Package version contains dnsproxy version information. 2 | package version 3 | 4 | // Versions 5 | 6 | // These are set by the linker. Unfortunately, we cannot set constants during 7 | // linking, and Go doesn't have a concept of immutable variables, so to be 8 | // thorough we have to only export them through getters. 9 | var ( 10 | branch string 11 | committime string 12 | revision string 13 | version string 14 | ) 15 | 16 | // Branch returns the compiled-in value of the Git branch. 17 | func Branch() (b string) { 18 | return branch 19 | } 20 | 21 | // CommitTime returns the compiled-in value of the build time as a string. 22 | func CommitTime() (t string) { 23 | return committime 24 | } 25 | 26 | // Revision returns the compiled-in value of the Git revision. 27 | func Revision() (r string) { 28 | return revision 29 | } 30 | 31 | // Version returns the compiled-in value of the build version as a string. 32 | func Version() (v string) { 33 | return version 34 | } 35 | -------------------------------------------------------------------------------- /internal/netutil/testdata/TestHosts/good_file/hosts: -------------------------------------------------------------------------------- 1 | # IPv4 2 | 3 | # 1st host. 4 | 0.0.0.1 Host.One 5 | 6 | # 2nd host. 7 | 0.0.0.2 Host.Two 8 | 9 | # 1st host full duplicate. 10 | 0.0.0.1 host.one 11 | 12 | # 2nd host duplicate with new name. 13 | 0.0.0.2 host.two Host.New 14 | 15 | # 1st host with foreign name. 16 | 0.0.0.1 host.new 17 | 18 | # 2nd host new name. 19 | 0.0.0.2 Again.Host.Two 20 | 21 | # Mapped 22 | 23 | # 1st host. 24 | ::ffff:0.0.0.1 Host.One 25 | 26 | # 2nd host. 27 | ::ffff:0.0.0.2 Host.Two 28 | 29 | # 1st host full duplicate. 30 | ::ffff:0.0.0.1 host.one 31 | 32 | # 2nd host duplicate with new name. 33 | ::ffff:0.0.0.2 host.two Host.New 34 | 35 | # 1st host with foreign name. 36 | ::ffff:0.0.0.1 host.new 37 | 38 | # 2nd host new name. 39 | ::ffff:0.0.0.2 Again.Host.Two 40 | 41 | # IPv6 42 | 43 | # 1st host. 44 | ::1 Host.One 45 | 46 | # 2nd host. 47 | ::2 Host.Two 48 | 49 | # 1st host full duplicate. 50 | ::1 host.one 51 | 52 | # 2nd host duplicate with new name. 53 | ::2 host.two Host.New 54 | 55 | # 1st host with foreign name. 56 | ::1 host.new 57 | 58 | # 2nd host new name. 59 | ::2 Again.Host.Two 60 | -------------------------------------------------------------------------------- /internal/netutil/udp.go: -------------------------------------------------------------------------------- 1 | package netutil 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | ) 7 | 8 | // UDPGetOOBSize returns maximum size of the received OOB data. 9 | func UDPGetOOBSize() (oobSize int) { 10 | return udpGetOOBSize() 11 | } 12 | 13 | // UDPSetOptions sets flag options on a UDP socket to be able to receive the 14 | // necessary OOB data. 15 | func UDPSetOptions(c *net.UDPConn) (err error) { 16 | return udpSetOptions(c) 17 | } 18 | 19 | // UDPRead reads the message from conn using buf and receives a control-message 20 | // payload of size udpOOBSize from it. It returns the number of bytes copied 21 | // into buf and the source address of the message. 22 | // 23 | // TODO(s.chzhen): Consider using netip.Addr. 24 | func UDPRead( 25 | conn *net.UDPConn, 26 | buf []byte, 27 | udpOOBSize int, 28 | ) (n int, localIP netip.Addr, remoteAddr *net.UDPAddr, err error) { 29 | return udpRead(conn, buf, udpOOBSize) 30 | } 31 | 32 | // UDPWrite writes the data to the remoteAddr using conn. 33 | // 34 | // TODO(s.chzhen): Consider using netip.Addr. 35 | func UDPWrite( 36 | data []byte, 37 | conn *net.UDPConn, 38 | remoteAddr *net.UDPAddr, 39 | localIP netip.Addr, 40 | ) (n int, err error) { 41 | return udpWrite(data, conn, remoteAddr, localIP) 42 | } 43 | -------------------------------------------------------------------------------- /proxy/lookup_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/dnsproxy/upstream" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestLookupNetIP(t *testing.T) { 15 | // Use AdGuard DNS here. 16 | dnsUpstream, err := upstream.AddressToUpstream( 17 | "94.140.14.14", 18 | &upstream.Options{ 19 | Logger: slogutil.NewDiscardLogger(), 20 | Timeout: defaultTimeout, 21 | }, 22 | ) 23 | require.NoError(t, err) 24 | 25 | conf := &Config{ 26 | Logger: slogutil.NewDiscardLogger(), 27 | UpstreamConfig: &UpstreamConfig{ 28 | Upstreams: []upstream.Upstream{dnsUpstream}, 29 | }, 30 | } 31 | 32 | p, err := New(conf) 33 | require.NoError(t, err) 34 | 35 | // Now let's try doing some lookups. 36 | addrs, err := p.LookupNetIP(context.Background(), "", "dns.google") 37 | require.NoError(t, err) 38 | require.NotEmpty(t, addrs) 39 | 40 | assert.Contains(t, addrs, netip.MustParseAddr("8.8.8.8")) 41 | assert.Contains(t, addrs, netip.MustParseAddr("8.8.4.4")) 42 | if len(addrs) > 2 { 43 | assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8888")) 44 | assert.Contains(t, addrs, netip.MustParseAddr("2001:4860:4860::8844")) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/netutil/listenconfig_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package netutil 4 | 5 | import ( 6 | "fmt" 7 | "syscall" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "golang.org/x/sys/unix" 12 | ) 13 | 14 | // defaultListenControl is used as a [net.ListenConfig.Control] function to set 15 | // the SO_REUSEADDR and SO_REUSEPORT socket options on all sockets used by the 16 | // DNS servers in this module. 17 | func (lc listenControl) defaultListenControl(_, _ string, c syscall.RawConn) (err error) { 18 | var opErr error 19 | err = c.Control(func(fd uintptr) { 20 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) 21 | if opErr != nil { 22 | opErr = fmt.Errorf("setting SO_REUSEADDR: %w", opErr) 23 | 24 | return 25 | } 26 | 27 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) 28 | if opErr != nil { 29 | if errors.Is(opErr, unix.ENOPROTOOPT) { 30 | // Some Linux OSs do not seem to support SO_REUSEPORT, including 31 | // some varieties of OpenWrt. Issue a warning. 32 | lc.logger.Warn("SO_REUSEPORT not supported", slogutil.KeyError, opErr) 33 | opErr = nil 34 | } else { 35 | opErr = fmt.Errorf("setting SO_REUSEPORT: %w", opErr) 36 | } 37 | } 38 | }) 39 | 40 | return errors.WithDeferred(opErr, err) 41 | } 42 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # DNS Proxy 2 | 3 | A simple DNS proxy server that supports all existing DNS protocols including 4 | `DNS-over-TLS`, `DNS-over-HTTPS`, `DNSCrypt`, and `DNS-over-QUIC`. Moreover, 5 | it can work as a `DNS-over-HTTPS`, `DNS-over-TLS` or `DNS-over-QUIC` server. 6 | 7 | Learn more about dnsproxy and its full capabilities in 8 | its [Github repo][dnsproxy]. 9 | 10 | [dnsproxy]: https://github.com/AdguardTeam/dnsproxy 11 | 12 | ## Quick start 13 | 14 | ### Pull the Docker image 15 | 16 | This command will pull the latest stable version: 17 | 18 | ```shell 19 | docker pull adguard/dnsproxy 20 | ``` 21 | 22 | ### Run the container 23 | 24 | Run the container with the default configuration (see `config.yaml.dist` in the 25 | repository) and expose DNS ports. 26 | 27 | ```shell 28 | docker run --name dnsproxy \ 29 | -p 53:53/tcp -p 53:53/udp \ 30 | adguard/dnsproxy 31 | ``` 32 | 33 | Run the container with command-line args configuration and expose DNS ports. 34 | 35 | ```shell 36 | docker run --name dnsproxy_google_dns \ 37 | -p 53:53/tcp -p 53:53/udp \ 38 | adguard/dnsproxy \ 39 | -u 8.8.8.8:53 40 | ``` 41 | 42 | Run the container with a configuration file and expose DNS ports. 43 | 44 | ```shell 45 | docker run --name dnsproxy_google_dns \ 46 | -p 53:53/tcp -p 53:53/udp \ 47 | -v $PWD/config.yaml:/opt/dnsproxy/config.yaml \ 48 | adguard/dnsproxy 49 | ``` 50 | -------------------------------------------------------------------------------- /proxy/retry.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | ) 9 | 10 | // BindRetryConfig contains configuration for the listeners binding retry 11 | // mechanism. 12 | type BindRetryConfig struct { 13 | // Interval is the minimum time to wait after the latest failure. It must 14 | // not be negative if Enabled is true. 15 | Interval time.Duration 16 | 17 | // Count is the maximum number of retries after the first attempt. 18 | Count uint 19 | 20 | // Enabled indicates whether the binding should be retried. 21 | Enabled bool 22 | } 23 | 24 | // bindWithRetry calls f until it returns no error or the retries limit is 25 | // reached, sleeping for configured interval between attempts. bindFunc must 26 | // not be nil and should carry the result of the binding operation itself. 27 | func (p *Proxy) bindWithRetry(ctx context.Context, bindFunc func() (err error)) (err error) { 28 | err = bindFunc() 29 | if err == nil { 30 | return nil 31 | } 32 | 33 | p.logger.WarnContext(ctx, "binding", "attempt", 1, slogutil.KeyError, err) 34 | 35 | for attempt := uint(1); attempt <= p.bindRetryCount; attempt++ { 36 | time.Sleep(p.bindRetryIvl) 37 | 38 | retryErr := bindFunc() 39 | if retryErr == nil { 40 | return nil 41 | } 42 | 43 | p.logger.WarnContext(ctx, "binding", "attempt", attempt+1, slogutil.KeyError, retryErr) 44 | } 45 | 46 | return err 47 | } 48 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | 'name': 'lint' 2 | 3 | 'env': 4 | 'GO_VERSION': '1.25.5' 5 | 6 | 'on': 7 | 'push': 8 | 'tags': 9 | - 'v*' 10 | 'branches': 11 | - '*' 12 | 'pull_request': 13 | 14 | 'jobs': 15 | 'go-lint': 16 | 'runs-on': 'ubuntu-latest' 17 | 'steps': 18 | - 'uses': 'actions/checkout@v2' 19 | - 'name': 'Set up Go' 20 | 'uses': 'actions/setup-go@v3' 21 | 'with': 22 | 'go-version': '${{ env.GO_VERSION }}' 23 | - 'name': 'run-lint' 24 | 'run': > 25 | make go-deps go-tools go-lint 26 | 27 | 'notify': 28 | 'needs': 29 | - 'go-lint' 30 | # Secrets are not passed to workflows that are triggered by a pull request 31 | # from a fork. 32 | # 33 | # Use always() to signal to the runner that this job must run even if the 34 | # previous ones failed. 35 | 'if': 36 | ${{ 37 | always() && 38 | github.repository_owner == 'AdguardTeam' && 39 | ( 40 | github.event_name == 'push' || 41 | github.event.pull_request.head.repo.full_name == github.repository 42 | ) 43 | }} 44 | 'runs-on': 'ubuntu-latest' 45 | 'steps': 46 | - 'name': 'Conclusion' 47 | 'uses': 'technote-space/workflow-conclusion-action@v1' 48 | - 'name': 'Send Slack notif' 49 | 'uses': '8398a7/action-slack@v3' 50 | 'with': 51 | 'status': '${{ env.WORKFLOW_CONCLUSION }}' 52 | 'fields': 'workflow, repo, message, commit, author, eventName, ref' 53 | 'env': 54 | 'GITHUB_TOKEN': '${{ secrets.GITHUB_TOKEN }}' 55 | 'SLACK_WEBHOOK_URL': '${{ secrets.SLACK_WEBHOOK_URL }}' 56 | -------------------------------------------------------------------------------- /proxy/servertcp_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "net" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/AdguardTeam/golibs/testutil/servicetest" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestTcpProxy(t *testing.T) { 16 | dnsProxy := mustStartDefaultProxy(t) 17 | 18 | // Create a DNS-over-TCP client connection 19 | addr := dnsProxy.Addr(ProtoTCP) 20 | conn, err := dns.Dial("tcp", addr.String()) 21 | require.NoError(t, err) 22 | 23 | sendTestMessages(t, conn) 24 | } 25 | 26 | func TestTlsProxy(t *testing.T) { 27 | serverConfig, caPem := newTLSConfig(t) 28 | dnsProxy := mustNew(t, &Config{ 29 | Logger: slogutil.NewDiscardLogger(), 30 | TLSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 31 | HTTPSListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 32 | QUICListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 33 | TLSConfig: serverConfig, 34 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 35 | TrustedProxies: defaultTrustedProxies, 36 | RatelimitSubnetLenIPv4: 24, 37 | RatelimitSubnetLenIPv6: 64, 38 | }) 39 | 40 | servicetest.RequireRun(t, dnsProxy, testTimeout) 41 | 42 | roots := x509.NewCertPool() 43 | roots.AppendCertsFromPEM(caPem) 44 | tlsConfig := &tls.Config{ServerName: tlsServerName, RootCAs: roots} 45 | 46 | // Create a DNS-over-TLS client connection 47 | addr := dnsProxy.Addr(ProtoTLS) 48 | conn, err := dns.DialWithTLS("tcp-tls", addr.String(), tlsConfig) 49 | require.NoError(t, err) 50 | 51 | sendTestMessages(t, conn) 52 | } 53 | -------------------------------------------------------------------------------- /proxy/upstreammode.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "encoding" 5 | "fmt" 6 | ) 7 | 8 | // UpstreamMode is an enumeration of upstream mode representations. 9 | // 10 | // TODO(d.kolyshev): Set uint8 as underlying type. 11 | type UpstreamMode string 12 | 13 | const ( 14 | // UpstreamModeLoadBalance is the default upstream mode. It balances the 15 | // upstreams load. 16 | UpstreamModeLoadBalance UpstreamMode = "load_balance" 17 | 18 | // UpstreamModeParallel makes server to query all configured upstream 19 | // servers in parallel. 20 | UpstreamModeParallel UpstreamMode = "parallel" 21 | 22 | // UpstreamModeFastestAddr controls whether the server should respond to A 23 | // or AAAA requests only with the fastest IP address detected by ICMP 24 | // response time or TCP connection time. 25 | UpstreamModeFastestAddr UpstreamMode = "fastest_addr" 26 | ) 27 | 28 | // type check 29 | var _ encoding.TextUnmarshaler = (*UpstreamMode)(nil) 30 | 31 | // UnmarshalText implements [encoding.TextUnmarshaler] interface for 32 | // *UpstreamMode. 33 | func (m *UpstreamMode) UnmarshalText(b []byte) (err error) { 34 | switch um := UpstreamMode(b); um { 35 | case 36 | UpstreamModeLoadBalance, 37 | UpstreamModeParallel, 38 | UpstreamModeFastestAddr: 39 | *m = um 40 | default: 41 | return fmt.Errorf( 42 | "invalid upstream mode %q, supported: %q, %q, %q", 43 | b, 44 | UpstreamModeLoadBalance, 45 | UpstreamModeParallel, 46 | UpstreamModeFastestAddr, 47 | ) 48 | } 49 | 50 | return nil 51 | } 52 | 53 | // type check 54 | var _ encoding.TextMarshaler = UpstreamMode("") 55 | 56 | // MarshalText implements [encoding.TextMarshaler] interface for UpstreamMode. 57 | func (m UpstreamMode) MarshalText() (text []byte, err error) { 58 | return []byte(m), nil 59 | } 60 | -------------------------------------------------------------------------------- /proxy/ratelimit.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | "slices" 7 | "time" 8 | 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | rate "github.com/beefsack/go-rate" 11 | gocache "github.com/patrickmn/go-cache" 12 | ) 13 | 14 | func (p *Proxy) limiterForIP(ip string) interface{} { 15 | p.ratelimitLock.Lock() 16 | defer p.ratelimitLock.Unlock() 17 | if p.ratelimitBuckets == nil { 18 | p.ratelimitBuckets = gocache.New(time.Hour, time.Hour) 19 | } 20 | 21 | // check if ratelimiter for that IP already exists, if not, create 22 | value, found := p.ratelimitBuckets.Get(ip) 23 | if !found { 24 | value = rate.New(p.Ratelimit, time.Second) 25 | p.ratelimitBuckets.Set(ip, value, time.Hour) 26 | } 27 | 28 | return value 29 | } 30 | 31 | func (p *Proxy) isRatelimited(addr netip.Addr) (ok bool) { 32 | if p.Ratelimit <= 0 { 33 | // The ratelimit is disabled. 34 | return false 35 | } 36 | 37 | addr = addr.Unmap() 38 | // Already sorted by [Proxy.Init]. 39 | _, ok = slices.BinarySearchFunc(p.RatelimitWhitelist, addr, netip.Addr.Compare) 40 | if ok { 41 | return false 42 | } 43 | 44 | var pref netip.Prefix 45 | if addr.Is4() { 46 | pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv4) 47 | } else { 48 | pref = netip.PrefixFrom(addr, p.RatelimitSubnetLenIPv6) 49 | } 50 | pref = pref.Masked() 51 | 52 | // TODO(s.chzhen): Improve caching. Decrease allocations. 53 | ipStr := pref.Addr().String() 54 | value := p.limiterForIP(ipStr) 55 | rl, ok := value.(*rate.RateLimiter) 56 | if !ok { 57 | p.logger.Error( 58 | "invalid value found in ratelimit cache", 59 | slogutil.KeyError, 60 | fmt.Errorf("bad type %T", value), 61 | ) 62 | 63 | return false 64 | } 65 | 66 | allow, _ := rl.Try() 67 | 68 | return !allow 69 | } 70 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # A docker file for scripts/make/build-docker.sh. 2 | 3 | FROM alpine:3.18 4 | 5 | ARG BUILD_DATE 6 | ARG VERSION 7 | ARG VCS_REF 8 | 9 | LABEL\ 10 | maintainer="AdGuard Team " \ 11 | org.opencontainers.image.authors="AdGuard Team " \ 12 | org.opencontainers.image.created=$BUILD_DATE \ 13 | org.opencontainers.image.description="Simple DNS proxy with DoH, DoT, DoQ and DNSCrypt support" \ 14 | org.opencontainers.image.documentation="https://github.com/AdguardTeam/dnsproxy" \ 15 | org.opencontainers.image.licenses="Apache-2.0" \ 16 | org.opencontainers.image.revision=$VCS_REF \ 17 | org.opencontainers.image.source="https://github.com/AdguardTeam/dnsproxy" \ 18 | org.opencontainers.image.title="dnsproxy" \ 19 | org.opencontainers.image.url="https://github.com/AdguardTeam/dnsproxy" \ 20 | org.opencontainers.image.vendor="AdGuard" \ 21 | org.opencontainers.image.version=$VERSION 22 | 23 | # Update certificates. 24 | RUN apk --no-cache add ca-certificates libcap tzdata && \ 25 | mkdir -p /opt/dnsproxy && chown -R nobody: /opt/dnsproxy 26 | 27 | ARG DIST_DIR 28 | ARG TARGETARCH 29 | ARG TARGETOS 30 | ARG TARGETVARIANT 31 | 32 | COPY --chown=nobody:nogroup\ 33 | ./${DIST_DIR}/docker/dnsproxy_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\ 34 | /opt/dnsproxy/dnsproxy 35 | COPY --chown=nobody:nogroup\ 36 | ./${DIST_DIR}/docker/config.yaml\ 37 | /opt/dnsproxy/config.yaml 38 | 39 | RUN setcap 'cap_net_bind_service=+eip' /opt/dnsproxy/dnsproxy 40 | 41 | # 53 : TCP, UDP : DNS 42 | # 80 : TCP : HTTP 43 | # 443 : TCP, UDP : HTTPS, DNS-over-HTTPS (incl. HTTP/3), DNSCrypt (main) 44 | # 853 : TCP, UDP : DNS-over-TLS, DNS-over-QUIC 45 | # 5443 : TCP, UDP : DNSCrypt (alt) 46 | # 6060 : TCP : HTTP (pprof) 47 | EXPOSE 53/tcp 53/udp \ 48 | 80/tcp \ 49 | 443/tcp 443/udp \ 50 | 853/tcp 853/udp \ 51 | 5443/tcp 5443/udp \ 52 | 6060/tcp 53 | 54 | WORKDIR /opt/dnsproxy 55 | 56 | ENTRYPOINT ["/opt/dnsproxy/dnsproxy"] 57 | CMD ["--config-path=/opt/dnsproxy/config.yaml"] 58 | -------------------------------------------------------------------------------- /proxy/retry_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | "github.com/AdguardTeam/golibs/testutil" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestWithRetry(t *testing.T) { 13 | t.Parallel() 14 | 15 | const ( 16 | errA errors.Error = "error about a" 17 | errB errors.Error = "error about b" 18 | ) 19 | 20 | var ( 21 | good = func() (err error) { 22 | return nil 23 | } 24 | 25 | badOne = func() (err error) { 26 | return errA 27 | } 28 | 29 | // Don't protect against concurrent access since the closure is expected 30 | // to be used in a single case. 31 | returnedA = false 32 | badBoth = func() (err error) { 33 | if !returnedA { 34 | returnedA = true 35 | 36 | return errA 37 | } 38 | 39 | return errB 40 | } 41 | 42 | // Don't protect against concurrent access since the closure is expected 43 | // to be used in a single case. 44 | returnedErr = false 45 | badThenOk = func() (err error) { 46 | if !returnedErr { 47 | returnedErr = true 48 | 49 | return assert.AnError 50 | } 51 | 52 | return nil 53 | } 54 | ) 55 | 56 | testCases := []struct { 57 | f func() (err error) 58 | wantErr error 59 | name string 60 | }{{ 61 | f: good, 62 | wantErr: nil, 63 | name: "no_error", 64 | }, { 65 | f: badOne, 66 | wantErr: errA, 67 | name: "one_error", 68 | }, { 69 | f: badBoth, 70 | wantErr: errA, 71 | name: "two_errors", 72 | }, { 73 | f: badThenOk, 74 | wantErr: nil, 75 | name: "error_then_ok", 76 | }} 77 | 78 | p := &Proxy{ 79 | logger: slogutil.NewDiscardLogger(), 80 | bindRetryCount: 1, 81 | bindRetryIvl: 0, 82 | } 83 | 84 | for _, tc := range testCases { 85 | t.Run(tc.name, func(t *testing.T) { 86 | t.Parallel() 87 | 88 | ctx := testutil.ContextWithTimeout(t, testTimeout) 89 | 90 | err := p.bindWithRetry(ctx, tc.f) 91 | assert.ErrorIs(t, err, tc.wantErr) 92 | }) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /scripts/make/go-test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 6 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Verbosity levels: 12 | # 0 = Don't print anything except for errors. 13 | # 1 = Print commands, but not nested commands. 14 | # 2 = Print everything. 15 | if [ "$verbose" -gt '1' ]; then 16 | set -x 17 | v_flags='-v=1' 18 | x_flags='-x=1' 19 | elif [ "$verbose" -gt '0' ]; then 20 | set -x 21 | v_flags='-v=1' 22 | x_flags='-x=0' 23 | else 24 | set +x 25 | v_flags='-v=0' 26 | x_flags='-x=0' 27 | fi 28 | readonly v_flags x_flags 29 | 30 | set -e -f -u 31 | 32 | if [ "${RACE:-1}" -eq '0' ]; then 33 | race_flags='--race=0' 34 | else 35 | race_flags='--race=1' 36 | fi 37 | readonly race_flags 38 | 39 | count_flags='--count=2' 40 | cover_flags='--coverprofile=./cover.out' 41 | go="${GO:-go}" 42 | shuffle_flags='--shuffle=on' 43 | timeout_flags="${TIMEOUT_FLAGS:---timeout=2m}" 44 | readonly count_flags cover_flags go shuffle_flags timeout_flags 45 | 46 | go_test() { 47 | "$go" test \ 48 | "$count_flags" \ 49 | "$cover_flags" \ 50 | "$race_flags" \ 51 | "$shuffle_flags" \ 52 | "$timeout_flags" \ 53 | "$v_flags" \ 54 | "$x_flags" \ 55 | ./... 56 | } 57 | 58 | test_reports_dir="${TEST_REPORTS_DIR:-}" 59 | readonly test_reports_dir 60 | 61 | if [ "$test_reports_dir" = '' ]; then 62 | go_test 63 | 64 | exit "$?" 65 | fi 66 | 67 | mkdir -p "$test_reports_dir" 68 | 69 | # NOTE: The pipe ignoring the exit code here is intentional, as go-junit-report 70 | # will set the exit code to be saved. 71 | go_test 2>&1 \ 72 | | tee "${test_reports_dir}/test-output.txt" 73 | 74 | # Don't fail on errors in exporting, because TEST_REPORTS_DIR is generally only 75 | # not empty in CI, and so the exit code must be preserved to exit with it later. 76 | set +e 77 | go-junit-report \ 78 | --in "${test_reports_dir}/test-output.txt" \ 79 | --set-exit-code \ 80 | >"${test_reports_dir}/test-report.xml" 81 | printf '%s\n' "$?" \ 82 | >"${test_reports_dir}/test-exit-code.txt" 83 | -------------------------------------------------------------------------------- /proxy/handler_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/golibs/logutil/slogutil" 9 | "github.com/AdguardTeam/golibs/testutil/servicetest" 10 | "github.com/miekg/dns" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestFilteringHandler(t *testing.T) { 16 | // Initializing the test middleware 17 | m := &sync.RWMutex{} 18 | blockResponse := false 19 | 20 | // Prepare the proxy server 21 | dnsProxy := mustNew(t, &Config{ 22 | Logger: slogutil.NewDiscardLogger(), 23 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 24 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 25 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 26 | TrustedProxies: defaultTrustedProxies, 27 | RatelimitSubnetLenIPv4: 24, 28 | RatelimitSubnetLenIPv6: 64, 29 | RequestHandler: func(p *Proxy, d *DNSContext) error { 30 | m.Lock() 31 | defer m.Unlock() 32 | 33 | if !blockResponse { 34 | // Use the default Resolve method if response is not blocked 35 | return p.Resolve(d) 36 | } 37 | 38 | resp := dns.Msg{} 39 | resp.SetRcode(d.Req, dns.RcodeNotImplemented) 40 | resp.RecursionAvailable = true 41 | 42 | // Set the response right away 43 | d.Res = &resp 44 | return nil 45 | }, 46 | }) 47 | 48 | servicetest.RequireRun(t, dnsProxy, testTimeout) 49 | 50 | // Create a DNS-over-UDP client connection 51 | addr := dnsProxy.Addr(ProtoUDP) 52 | client := &dns.Client{ 53 | Net: string(ProtoUDP), 54 | Timeout: testTimeout, 55 | } 56 | 57 | // Send the first message (not blocked) 58 | req := newTestMessage() 59 | 60 | r, _, err := client.Exchange(req, addr.String()) 61 | require.NoError(t, err) 62 | requireResponse(t, req, r) 63 | 64 | // Now send the second and make sure it is blocked 65 | m.Lock() 66 | blockResponse = true 67 | m.Unlock() 68 | 69 | r, _, err = client.Exchange(req, addr.String()) 70 | require.NoError(t, err) 71 | assert.Equal(t, dns.RcodeNotImplemented, r.Rcode) 72 | } 73 | -------------------------------------------------------------------------------- /proxy/optimisticresolver.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "log/slog" 7 | 8 | "github.com/AdguardTeam/golibs/logutil/slogutil" 9 | "github.com/AdguardTeam/golibs/syncutil" 10 | ) 11 | 12 | // cachingResolver is the DNS resolver that is also able to cache responses. 13 | type cachingResolver interface { 14 | // replyFromUpstream returns true if the request from dctx is successfully 15 | // resolved and the response may be cached. 16 | // 17 | // TODO(e.burkov): Find out when ok can be false with nil err. 18 | replyFromUpstream(dctx *DNSContext) (ok bool, err error) 19 | 20 | // cacheResp caches the response from dctx. 21 | cacheResp(dctx *DNSContext) 22 | } 23 | 24 | // type check 25 | var _ cachingResolver = (*Proxy)(nil) 26 | 27 | // unit is a convenient alias for struct{}. 28 | type unit = struct{} 29 | 30 | // optimisticResolver is used to eventually resolve expired cached requests. 31 | type optimisticResolver struct { 32 | reqs *syncutil.Map[string, unit] 33 | cr cachingResolver 34 | } 35 | 36 | // newOptimisticResolver returns the new resolver for expired cached requests. 37 | // cr must not be nil. 38 | func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) { 39 | return &optimisticResolver{ 40 | reqs: syncutil.NewMap[string, unit](), 41 | cr: cr, 42 | } 43 | } 44 | 45 | // resolveOnce tries to resolve the request from dctx but only a single request 46 | // with the same key at the same period of time. It runs in a separate 47 | // goroutine. Do not pass the *DNSContext which is used elsewhere since it 48 | // isn't intended to be used concurrently. 49 | // 50 | // TODO(e.burkov): Pass the context. 51 | func (s *optimisticResolver) resolveOnce(dctx *DNSContext, key []byte, l *slog.Logger) { 52 | defer slogutil.RecoverAndLog(context.TODO(), l) 53 | 54 | keyHexed := hex.EncodeToString(key) 55 | if _, ok := s.reqs.LoadOrStore(keyHexed, unit{}); ok { 56 | return 57 | } 58 | defer s.reqs.Delete(keyHexed) 59 | 60 | ok, err := s.cr.replyFromUpstream(dctx) 61 | if err != nil { 62 | l.Debug("resolving request for optimistic cache", slogutil.KeyError, err) 63 | } 64 | 65 | if ok { 66 | s.cr.cacheResp(dctx) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /internal/cmd/tls.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | // NewTLSConfig returns the TLS config that includes a certificate. Use it for 10 | // server TLS configuration or for a client certificate. If caPath is empty, 11 | // system CAs will be used. 12 | func newTLSConfig(conf *configuration) (c *tls.Config, err error) { 13 | // Set default TLS min/max versions 14 | tlsMinVersion := tls.VersionTLS10 15 | tlsMaxVersion := tls.VersionTLS13 16 | 17 | switch conf.TLSMinVersion { 18 | case 1.1: 19 | tlsMinVersion = tls.VersionTLS11 20 | case 1.2: 21 | tlsMinVersion = tls.VersionTLS12 22 | case 1.3: 23 | tlsMinVersion = tls.VersionTLS13 24 | } 25 | 26 | switch conf.TLSMaxVersion { 27 | case 1.0: 28 | tlsMaxVersion = tls.VersionTLS10 29 | case 1.1: 30 | tlsMaxVersion = tls.VersionTLS11 31 | case 1.2: 32 | tlsMaxVersion = tls.VersionTLS12 33 | } 34 | 35 | cert, err := loadX509KeyPair(conf.TLSCertPath, conf.TLSKeyPath) 36 | if err != nil { 37 | return nil, fmt.Errorf("loading TLS cert: %s", err) 38 | } 39 | 40 | // #nosec G402 -- TLS MinVersion is configured by user. 41 | return &tls.Config{ 42 | Certificates: []tls.Certificate{cert}, 43 | MinVersion: uint16(tlsMinVersion), 44 | MaxVersion: uint16(tlsMaxVersion), 45 | }, nil 46 | } 47 | 48 | // loadX509KeyPair reads and parses a public/private key pair from a pair of 49 | // files. The files must contain PEM encoded data. The certificate file may 50 | // contain intermediate certificates following the leaf certificate to form a 51 | // certificate chain. On successful return, Certificate.Leaf will be nil 52 | // because the parsed form of the certificate is not retained. 53 | func loadX509KeyPair(certFile, keyFile string) (crt tls.Certificate, err error) { 54 | // #nosec G304 -- Trust the file path that is given in the configuration. 55 | certPEMBlock, err := os.ReadFile(certFile) 56 | if err != nil { 57 | return tls.Certificate{}, err 58 | } 59 | 60 | // #nosec G304 -- Trust the file path that is given in the configuration. 61 | keyPEMBlock, err := os.ReadFile(keyFile) 62 | if err != nil { 63 | return tls.Certificate{}, err 64 | } 65 | 66 | return tls.X509KeyPair(certPEMBlock, keyPEMBlock) 67 | } 68 | -------------------------------------------------------------------------------- /scripts/make/txt-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 10 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '0' ]; then 12 | set -x 13 | fi 14 | 15 | # Set $EXIT_ON_ERROR to zero to see all errors. 16 | if [ "${EXIT_ON_ERROR:-1}" -eq '0' ]; then 17 | set +e 18 | else 19 | set -e 20 | fi 21 | 22 | # We don't need glob expansions and we want to see errors about unset variables. 23 | set -f -u 24 | 25 | # Source the common helpers, including not_found. 26 | . ./scripts/make/helper.sh 27 | 28 | # Simple analyzers 29 | 30 | # trailing_newlines is a simple check that makes sure that all plain-text files 31 | # have a trailing newlines to make sure that all tools work correctly with them. 32 | trailing_newlines() ( 33 | nl="$(printf '\n')" 34 | readonly nl 35 | 36 | find_with_ignore \ 37 | -type 'f' \ 38 | '!' '(' \ 39 | -name '*.exe' \ 40 | -o -name '*.out' \ 41 | -o -name '*.test' \ 42 | -o -name 'dnsproxy' \ 43 | ')' \ 44 | -print \ 45 | | while read -r f; do 46 | final_byte="$(tail -c -1 "$f")" 47 | if [ "$final_byte" != "$nl" ]; then 48 | printf '%s: must have a trailing newline\n' "$f" 49 | fi 50 | done 51 | ) 52 | 53 | # trailing_whitespace is a simple check that makes sure that there are no 54 | # trailing whitespace in plain-text files. 55 | trailing_whitespace() { 56 | find_with_ignore \ 57 | -type 'f' \ 58 | '!' '(' \ 59 | -name '*.exe' \ 60 | -o -name '*.out' \ 61 | -o -name '*.test' \ 62 | -o -name 'dnsproxy' \ 63 | ')' \ 64 | -print \ 65 | | while read -r f; do 66 | grep -e '[[:space:]]$' -n -- "$f" \ 67 | | sed -e "s:^:${f}\::" -e 's/ \+$/>>>&<<> $GITHUB_ENV 46 | 47 | docker login \ 48 | -u="${DOCKER_USER}" \ 49 | -p="${DOCKER_PASSWORD}" 50 | 51 | make \ 52 | VERSION="${RELEASE_VERSION}" \ 53 | DOCKER_IMAGE_NAME="adguard/dnsproxy" \ 54 | DOCKER_OUTPUT="type=image,name=adguard/dnsproxy,push=true" \ 55 | VERBOSE="1" \ 56 | docker 57 | 58 | 'notify': 59 | 'needs': 60 | - 'docker' 61 | 'if': 62 | ${{ always() && 63 | ( 64 | github.event_name == 'push' || 65 | github.event.pull_request.head.repo.full_name == github.repository 66 | ) 67 | }} 68 | 'runs-on': ubuntu-latest 69 | 'steps': 70 | - 'name': Conclusion 71 | 'uses': technote-space/workflow-conclusion-action@v1 72 | - 'name': Send Slack notif 73 | 'uses': 8398a7/action-slack@v3 74 | 'with': 75 | 'status': ${{ env.WORKFLOW_CONCLUSION }} 76 | 'fields': workflow, repo, message, commit, author, eventName,ref 77 | 'env': 78 | 'GITHUB_TOKEN': ${{ secrets.GITHUB_TOKEN }} 79 | 'SLACK_WEBHOOK_URL': ${{ secrets.SLACK_WEBHOOK_URL }} 80 | -------------------------------------------------------------------------------- /proxy/helpers.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/AdguardTeam/golibs/netutil" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // ecsFromMsg returns the subnet from EDNS Client Subnet option of m if any. 11 | func ecsFromMsg(m *dns.Msg) (subnet *net.IPNet, scope int) { 12 | opt := m.IsEdns0() 13 | if opt == nil { 14 | return nil, 0 15 | } 16 | 17 | var ip net.IP 18 | var mask net.IPMask 19 | for _, e := range opt.Option { 20 | sn, ok := e.(*dns.EDNS0_SUBNET) 21 | if !ok { 22 | continue 23 | } 24 | 25 | switch sn.Family { 26 | case 1: 27 | ip = sn.Address.To4() 28 | mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv4BitLen) 29 | case 2: 30 | ip = sn.Address 31 | mask = net.CIDRMask(int(sn.SourceNetmask), netutil.IPv6BitLen) 32 | default: 33 | continue 34 | } 35 | 36 | return &net.IPNet{IP: ip, Mask: mask}, int(sn.SourceScope) 37 | } 38 | 39 | return nil, 0 40 | } 41 | 42 | // setECS sets the EDNS client subnet option based on ip and scope into m. It 43 | // returns masked IP and mask length. 44 | func setECS(m *dns.Msg, ip net.IP, scope uint8) (subnet *net.IPNet) { 45 | const ( 46 | // defaultECSv4 is the default length of network mask for IPv4 address 47 | // in ECS option. 48 | defaultECSv4 = 24 49 | 50 | // defaultECSv6 is the default length of network mask for IPv6 address 51 | // in ECS. The size of 7 octets is chosen as a reasonable minimum since 52 | // at least Google's public DNS refuses requests containing the options 53 | // with longer network masks. 54 | defaultECSv6 = 56 55 | ) 56 | 57 | e := &dns.EDNS0_SUBNET{ 58 | Code: dns.EDNS0SUBNET, 59 | SourceScope: scope, 60 | } 61 | 62 | subnet = &net.IPNet{} 63 | if ip4 := ip.To4(); ip4 != nil { 64 | e.Family = 1 65 | e.SourceNetmask = defaultECSv4 66 | subnet.Mask = net.CIDRMask(defaultECSv4, netutil.IPv4BitLen) 67 | ip = ip4 68 | } else { 69 | // Assume the IP address has already been validated. 70 | e.Family = 2 71 | e.SourceNetmask = defaultECSv6 72 | subnet.Mask = net.CIDRMask(defaultECSv6, netutil.IPv6BitLen) 73 | } 74 | subnet.IP = ip.Mask(subnet.Mask) 75 | e.Address = subnet.IP 76 | 77 | // If OPT record already exists so just add EDNS option inside it. Note 78 | // that servers may return FORMERR if they meet several OPT RRs. 79 | if opt := m.IsEdns0(); opt != nil { 80 | opt.Option = append(opt.Option, e) 81 | 82 | return subnet 83 | } 84 | 85 | // Create an OPT record and add EDNS option inside it. 86 | o := &dns.OPT{ 87 | Hdr: dns.RR_Header{ 88 | Name: ".", 89 | Rrtype: dns.TypeOPT, 90 | }, 91 | Option: []dns.EDNS0{e}, 92 | } 93 | o.SetUDPSize(4096) 94 | m.Extra = append(m.Extra, o) 95 | 96 | return subnet 97 | } 98 | -------------------------------------------------------------------------------- /scripts/make/helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Common script helpers 4 | # 5 | # This file contains common script helpers. It should be sourced in scripts 6 | # right after the initial environment processing. 7 | 8 | # This comment is used to simplify checking local copies of the script. Bump 9 | # this number every time a significant change is made to this script. 10 | # 11 | # AdGuard-Project-Version: 5 12 | 13 | # Deferred helpers 14 | 15 | not_found_msg=' 16 | looks like a binary not found error. 17 | make sure you have installed the linter binaries, including using: 18 | 19 | $ make go-tools 20 | ' 21 | readonly not_found_msg 22 | 23 | not_found() { 24 | if [ "$?" -eq '127' ]; then 25 | # Code 127 is the exit status a shell uses when a command or a file is 26 | # not found, according to the Bash Hackers wiki. 27 | # 28 | # See https://wiki.bash-hackers.org/dict/terms/exit_status. 29 | echo "$not_found_msg" 1>&2 30 | fi 31 | } 32 | trap not_found EXIT 33 | 34 | # Helpers 35 | 36 | # run_linter runs the given linter with two additions: 37 | # 38 | # 1. If the first argument is "-e", run_linter exits with a nonzero exit code 39 | # if there is anything in the command's combined output. 40 | # 41 | # 2. In any case, run_linter adds the program's name to its combined output. 42 | run_linter() ( 43 | set +e 44 | 45 | if [ "${VERBOSE:-0}" -lt '2' ]; then 46 | set +x 47 | fi 48 | 49 | cmd="${1:?run_linter: provide a command}" 50 | shift 51 | 52 | exit_on_output='0' 53 | if [ "$cmd" = '-e' ]; then 54 | exit_on_output='1' 55 | cmd="${1:?run_linter: provide a command}" 56 | shift 57 | fi 58 | 59 | readonly cmd 60 | 61 | output="$("$cmd" "$@")" 62 | exitcode="$?" 63 | 64 | readonly output 65 | 66 | if [ "$output" != '' ]; then 67 | echo "$output" | sed -e "s/^/${cmd}: /" 68 | 69 | if [ "$exitcode" -eq '0' ] && [ "$exit_on_output" -eq '1' ]; then 70 | exitcode='1' 71 | fi 72 | fi 73 | 74 | return "$exitcode" 75 | ) 76 | 77 | # find_with_ignore is a wrapper around find that does not descend into ignored 78 | # directories, such as ./tmp/. 79 | # 80 | # NOTE: The arguments must contain one of -exec, -ok, or -print; see 81 | # https://pubs.opengroup.org/onlinepubs/9799919799/utilities/find.html. 82 | # 83 | # TODO(a.garipov): Find a way to integrate the entire gitignore, including the 84 | # global one, without using git, as .git is not copied into the build container. 85 | # 86 | # Keep in sync with .gitignore. 87 | find_with_ignore() { 88 | find . \ 89 | '(' \ 90 | -type 'd' \ 91 | '(' \ 92 | -name '.git' \ 93 | -o -name 'bin' \ 94 | -o -name 'tmp' \ 95 | -o -name 'test-reports' \ 96 | ')' \ 97 | -prune \ 98 | ')' \ 99 | -o \ 100 | "$@" \ 101 | ; 102 | } 103 | -------------------------------------------------------------------------------- /scripts/hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e -f -u 4 | 5 | # This comment is used to simplify checking local copies of the script. Bump 6 | # this number every time a significant change is made to this script. 7 | # 8 | # AdGuard-Project-Version: 5 9 | 10 | # TODO(a.garipov): Add pre-merge-commit. 11 | 12 | # Only show interactive prompts if there a terminal is attached to stdout. 13 | # While this technically doesn't guarantee that reading from /dev/tty works, 14 | # this should work reasonably well on all of our supported development systems 15 | # and in most terminal emulators. 16 | is_tty='0' 17 | if [ -t '1' ]; then 18 | is_tty='1' 19 | fi 20 | readonly is_tty 21 | 22 | # prompt is a helper that prompts the user for interactive input if that can be 23 | # done. If there is no terminal attached, it sleeps for two seconds, giving the 24 | # programmer some time to react, and returns with a zero exit code. 25 | prompt() { 26 | if [ "$is_tty" -eq '0' ]; then 27 | sleep 2 28 | 29 | return 0 30 | fi 31 | 32 | while true; do 33 | printf 'commit anyway? y/[n]: ' 34 | read -r ans > $GITHUB_ENV 56 | 57 | make VERBOSE=1 VERSION="${RELEASE_VERSION}" release 58 | 59 | ls -l build/dnsproxy-* 60 | - name: Create release 61 | if: startsWith(github.ref, 'refs/tags/v') 62 | id: create_release 63 | uses: actions/create-release@v1 64 | env: 65 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 66 | with: 67 | tag_name: ${{ github.ref }} 68 | release_name: Release ${{ github.ref }} 69 | draft: false 70 | prerelease: false 71 | - name: Upload 72 | if: startsWith(github.ref, 'refs/tags/v') 73 | uses: xresloader/upload-to-github-release@v1.3.12 74 | env: 75 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 76 | with: 77 | file: "build/dnsproxy-*.tar.gz;build/dnsproxy-*.zip" 78 | tags: true 79 | draft: false 80 | 81 | notify: 82 | needs: 83 | - build 84 | if: 85 | ${{ always() && 86 | ( 87 | github.event_name == 'push' || 88 | github.event.pull_request.head.repo.full_name == github.repository 89 | ) 90 | }} 91 | runs-on: ubuntu-latest 92 | steps: 93 | - name: Conclusion 94 | uses: technote-space/workflow-conclusion-action@v1 95 | - name: Send Slack notif 96 | uses: 8398a7/action-slack@v3 97 | with: 98 | status: ${{ env.WORKFLOW_CONCLUSION }} 99 | fields: workflow, repo, message, commit, author, eventName,ref 100 | env: 101 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 102 | SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} 103 | -------------------------------------------------------------------------------- /internal/dnsproxytest/interface.go: -------------------------------------------------------------------------------- 1 | package dnsproxytest 2 | 3 | import ( 4 | "github.com/AdguardTeam/dnsproxy/internal/dnsmsg" 5 | "github.com/AdguardTeam/dnsproxy/upstream" 6 | "github.com/AdguardTeam/golibs/testutil" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // Upstream is a mock [upstream.Upstream] implementation for tests. 11 | // 12 | // TODO(e.burkov): Move to golibs. 13 | type Upstream struct { 14 | OnAddress func() (addr string) 15 | OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) 16 | OnClose func() (err error) 17 | } 18 | 19 | // type check 20 | var _ upstream.Upstream = (*Upstream)(nil) 21 | 22 | // Address implements the [upstream.Upstream] interface for *Upstream. 23 | func (u *Upstream) Address() (addr string) { 24 | return u.OnAddress() 25 | } 26 | 27 | // Exchange implements the [upstream.Upstream] interface for *Upstream. 28 | func (u *Upstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { 29 | return u.OnExchange(req) 30 | } 31 | 32 | // Close implements the [upstream.Upstream] interface for *Upstream. 33 | func (u *Upstream) Close() (err error) { 34 | return u.OnClose() 35 | } 36 | 37 | // MessageConstructor is a mock [dnsmsg.MessageConstructor] implementation for 38 | // tests. 39 | type MessageConstructor struct { 40 | OnNewMsgNXDOMAIN func(req *dns.Msg) (resp *dns.Msg) 41 | OnNewMsgSERVFAIL func(req *dns.Msg) (resp *dns.Msg) 42 | OnNewMsgNOTIMPLEMENTED func(req *dns.Msg) (resp *dns.Msg) 43 | OnNewMsgNODATA func(req *dns.Msg) (resp *dns.Msg) 44 | } 45 | 46 | // NewMessageConstructor creates a new *TestMessageConstructor with all it's 47 | // methods set to panic. 48 | func NewMessageConstructor() (c *MessageConstructor) { 49 | return &MessageConstructor{ 50 | OnNewMsgNXDOMAIN: func(req *dns.Msg) (_ *dns.Msg) { 51 | panic(testutil.UnexpectedCall(req)) 52 | }, 53 | OnNewMsgSERVFAIL: func(req *dns.Msg) (_ *dns.Msg) { 54 | panic(testutil.UnexpectedCall(req)) 55 | }, 56 | OnNewMsgNOTIMPLEMENTED: func(req *dns.Msg) (_ *dns.Msg) { 57 | panic(testutil.UnexpectedCall(req)) 58 | }, 59 | OnNewMsgNODATA: func(req *dns.Msg) (_ *dns.Msg) { 60 | panic(testutil.UnexpectedCall(req)) 61 | }, 62 | } 63 | } 64 | 65 | // type check 66 | var _ dnsmsg.MessageConstructor = (*MessageConstructor)(nil) 67 | 68 | // NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for 69 | // *TestMessageConstructor. 70 | func (c *MessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) { 71 | return c.OnNewMsgNXDOMAIN(req) 72 | } 73 | 74 | // NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for 75 | // *TestMessageConstructor. 76 | func (c *MessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) { 77 | return c.OnNewMsgSERVFAIL(req) 78 | } 79 | 80 | // NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for 81 | // *TestMessageConstructor. 82 | func (c *MessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) { 83 | return c.OnNewMsgNOTIMPLEMENTED(req) 84 | } 85 | 86 | // NewMsgNODATA implements the [MessageConstructor] interface for 87 | // *TestMessageConstructor. 88 | func (c *MessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) { 89 | return c.OnNewMsgNODATA(req) 90 | } 91 | -------------------------------------------------------------------------------- /fastip/cache.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "encoding/binary" 5 | "net/netip" 6 | "time" 7 | ) 8 | 9 | const ( 10 | // fastestAddrCacheTTLSec is the cache TTL for IP addresses. 11 | fastestAddrCacheTTLSec = 10 * 60 12 | ) 13 | 14 | // cacheEntry represents an item that will be stored in the cache. 15 | // 16 | // TODO(e.burkov): Rewrite the cache using zero-values instead of storing 17 | // useless boolean as an integer. 18 | type cacheEntry struct { 19 | // status is 1 if the item is timed out. 20 | status int 21 | latencyMsec uint 22 | } 23 | 24 | // packCacheEntry packs the cache entry and the TTL to bytes in the following 25 | // order: 26 | // 27 | // - expire [4]byte (Unix time, seconds), 28 | // - status byte (0 for ok, 1 for timed out), 29 | // - latency [2]byte (milliseconds). 30 | func packCacheEntry(ent *cacheEntry, ttl uint32) (d []byte) { 31 | expire := uint32(time.Now().Unix()) + ttl 32 | 33 | d = make([]byte, 4+1+2) 34 | binary.BigEndian.PutUint32(d, expire) 35 | i := 4 36 | 37 | d[i] = byte(ent.status) 38 | i++ 39 | 40 | binary.BigEndian.PutUint16(d[i:], uint16(ent.latencyMsec)) 41 | // i += 2 42 | 43 | return d 44 | } 45 | 46 | // unpackCacheEntry unpacks bytes to cache entry and checks TTL, if the record 47 | // is expired returns nil. 48 | func unpackCacheEntry(data []byte) (ent *cacheEntry) { 49 | now := time.Now().Unix() 50 | expire := binary.BigEndian.Uint32(data[:4]) 51 | if int64(expire) <= now { 52 | return nil 53 | } 54 | 55 | ent = &cacheEntry{} 56 | i := 4 57 | 58 | ent.status = int(data[i]) 59 | i++ 60 | 61 | ent.latencyMsec = uint(binary.BigEndian.Uint16(data[i:])) 62 | // i += 2 63 | 64 | return ent 65 | } 66 | 67 | // cacheFind finds entry in the cache for the given IP address. Returns nil if 68 | // nothing is found or if the record is expired. 69 | func (f *FastestAddr) cacheFind(ip netip.Addr) (ent *cacheEntry) { 70 | val := f.ipCache.Get(ip.AsSlice()) 71 | if val == nil { 72 | return nil 73 | } 74 | 75 | return unpackCacheEntry(val) 76 | } 77 | 78 | // cacheAddFailure stores unsuccessful attempt in cache. 79 | func (f *FastestAddr) cacheAddFailure(ip netip.Addr) { 80 | ent := cacheEntry{ 81 | status: 1, 82 | } 83 | 84 | f.ipCacheLock.Lock() 85 | defer f.ipCacheLock.Unlock() 86 | 87 | if f.cacheFind(ip) == nil { 88 | f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec) 89 | } 90 | } 91 | 92 | // cacheAddSuccessful stores a successful ping result in the cache. Replaces 93 | // previous result if our latency is lower. 94 | func (f *FastestAddr) cacheAddSuccessful(ip netip.Addr, latency uint) { 95 | ent := cacheEntry{ 96 | latencyMsec: latency, 97 | } 98 | 99 | f.ipCacheLock.Lock() 100 | defer f.ipCacheLock.Unlock() 101 | 102 | entCached := f.cacheFind(ip) 103 | if entCached == nil || entCached.status != 0 || entCached.latencyMsec > latency { 104 | f.cacheAdd(&ent, ip, fastestAddrCacheTTLSec) 105 | } 106 | } 107 | 108 | // cacheAdd adds a new entry to the cache. 109 | func (f *FastestAddr) cacheAdd(ent *cacheEntry, ip netip.Addr, ttl uint32) { 110 | val := packCacheEntry(ent, ttl) 111 | f.ipCache.Set(ip.AsSlice(), val) 112 | } 113 | -------------------------------------------------------------------------------- /internal/bootstrap/resolver_test.go: -------------------------------------------------------------------------------- 1 | package bootstrap_test 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 10 | "github.com/AdguardTeam/golibs/netutil" 11 | "github.com/AdguardTeam/golibs/testutil" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // testResolver is the [Resolver] interface implementation for testing purposes. 17 | // 18 | // TODO(e.burkov): Move to [dnsproxytest]. 19 | type testResolver struct { 20 | onLookupNetIP func(ctx context.Context, network, host string) (addrs []netip.Addr, err error) 21 | } 22 | 23 | // LookupNetIP implements the [Resolver] interface for *testResolver. 24 | func (r *testResolver) LookupNetIP( 25 | ctx context.Context, 26 | network string, 27 | host string, 28 | ) (addrs []netip.Addr, err error) { 29 | return r.onLookupNetIP(ctx, network, host) 30 | } 31 | 32 | func TestLookupParallel(t *testing.T) { 33 | const hostname = "host.name" 34 | 35 | t.Run("no_resolvers", func(t *testing.T) { 36 | addrs, err := bootstrap.ParallelResolver(nil).LookupNetIP(context.Background(), "ip", "") 37 | assert.ErrorIs(t, err, bootstrap.ErrNoResolvers) 38 | assert.Nil(t, addrs) 39 | }) 40 | 41 | pt := testutil.PanicT{} 42 | hostAddrs := []netip.Addr{netutil.IPv4Localhost()} 43 | 44 | immediate := &testResolver{ 45 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 46 | require.Equal(pt, hostname, host) 47 | require.Equal(pt, "ip", network) 48 | 49 | return hostAddrs, nil 50 | }, 51 | } 52 | 53 | t.Run("one_resolver", func(t *testing.T) { 54 | addrs, err := bootstrap.ParallelResolver{immediate}.LookupNetIP( 55 | context.Background(), 56 | "ip", 57 | hostname, 58 | ) 59 | require.NoError(t, err) 60 | 61 | assert.Equal(t, hostAddrs, addrs) 62 | }) 63 | 64 | t.Run("two_resolvers", func(t *testing.T) { 65 | delayCh := make(chan struct{}, 1) 66 | delayed := &testResolver{ 67 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 68 | require.Equal(pt, hostname, host) 69 | require.Equal(pt, "ip", network) 70 | 71 | testutil.RequireReceive(pt, delayCh, testTimeout) 72 | 73 | return []netip.Addr{netutil.IPv6Localhost()}, nil 74 | }, 75 | } 76 | 77 | addrs, err := bootstrap.ParallelResolver{immediate, delayed}.LookupNetIP( 78 | context.Background(), 79 | "ip", 80 | hostname, 81 | ) 82 | require.NoError(t, err) 83 | testutil.RequireSend(t, delayCh, struct{}{}, testTimeout) 84 | 85 | assert.Equal(t, hostAddrs, addrs) 86 | }) 87 | 88 | t.Run("all_errors", func(t *testing.T) { 89 | err := assert.AnError 90 | errStr := err.Error() 91 | wantErrMsg := strings.Join([]string{errStr, errStr, errStr}, "\n") 92 | 93 | r := &testResolver{ 94 | onLookupNetIP: func(_ context.Context, network, host string) ([]netip.Addr, error) { 95 | return nil, assert.AnError 96 | }, 97 | } 98 | 99 | addrs, err := bootstrap.ParallelResolver{r, r, r}.LookupNetIP( 100 | context.Background(), 101 | "ip", 102 | hostname, 103 | ) 104 | testutil.AssertErrorMsg(t, wantErrMsg, err) 105 | assert.Nil(t, addrs) 106 | }) 107 | } 108 | -------------------------------------------------------------------------------- /proxy/bogusnxdomain_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/dnsproxy/upstream" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/AdguardTeam/golibs/testutil/servicetest" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestProxy_IsBogusNXDomain(t *testing.T) { 17 | prx := mustNew(t, &Config{ 18 | Logger: slogutil.NewDiscardLogger(), 19 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 20 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 21 | UpstreamConfig: newTestUpstreamConfig(t, defaultTimeout, testDefaultUpstreamAddr), 22 | TrustedProxies: defaultTrustedProxies, 23 | RatelimitSubnetLenIPv4: 24, 24 | RatelimitSubnetLenIPv6: 64, 25 | CacheEnabled: true, 26 | BogusNXDomain: []netip.Prefix{ 27 | netip.MustParsePrefix("4.3.2.1/24"), 28 | netip.MustParsePrefix("1.2.3.4/8"), 29 | netip.MustParsePrefix("10.11.12.13/32"), 30 | netip.MustParsePrefix("102:304:506:708:90a:b0c:d0e:f10/120"), 31 | }, 32 | }) 33 | 34 | testCases := []struct { 35 | name string 36 | ans []dns.RR 37 | wantRcode int 38 | }{{ 39 | name: "bogus_subnet", 40 | ans: []dns.RR{&dns.A{ 41 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 42 | A: net.ParseIP("4.3.2.1"), 43 | }}, 44 | wantRcode: dns.RcodeNameError, 45 | }, { 46 | name: "bogus_big_subnet", 47 | ans: []dns.RR{&dns.A{ 48 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 49 | A: net.ParseIP("1.254.254.254"), 50 | }}, 51 | wantRcode: dns.RcodeNameError, 52 | }, { 53 | name: "bogus_single_ip", 54 | ans: []dns.RR{&dns.A{ 55 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 56 | A: net.ParseIP("10.11.12.13"), 57 | }}, 58 | wantRcode: dns.RcodeNameError, 59 | }, { 60 | name: "bogus_6", 61 | ans: []dns.RR{&dns.AAAA{ 62 | Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, 63 | AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99}, 64 | }}, 65 | wantRcode: dns.RcodeNameError, 66 | }, { 67 | name: "non-bogus", 68 | ans: []dns.RR{&dns.A{ 69 | Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, 70 | A: net.ParseIP("10.11.12.14"), 71 | }}, 72 | wantRcode: dns.RcodeSuccess, 73 | }, { 74 | name: "non-bogus_6", 75 | ans: []dns.RR{&dns.AAAA{ 76 | Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, 77 | AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15}, 78 | }}, 79 | wantRcode: dns.RcodeSuccess, 80 | }} 81 | 82 | u := testUpstream{} 83 | prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u} 84 | 85 | servicetest.RequireRun(t, prx, testTimeout) 86 | 87 | d := &DNSContext{ 88 | Req: newHostTestMessage("host"), 89 | } 90 | 91 | for _, tc := range testCases { 92 | u.ans = tc.ans 93 | 94 | t.Run(tc.name, func(t *testing.T) { 95 | err := prx.Resolve(d) 96 | require.NoError(t, err) 97 | require.NotNil(t, d.Res) 98 | 99 | assert.Equal(t, tc.wantRcode, d.Res.Rcode) 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /upstream/hostsresolver.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/fs" 7 | "log/slog" 8 | "net/netip" 9 | "slices" 10 | 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/AdguardTeam/golibs/hostsfile" 13 | ) 14 | 15 | // HostsResolver is a [Resolver] that looks into system hosts files, see 16 | // [hostsfile]. 17 | type HostsResolver struct { 18 | // strg contains all the hosts file data needed for lookups. 19 | strg hostsfile.Storage 20 | } 21 | 22 | // NewHostsResolver is the resolver based on system hosts files. 23 | func NewHostsResolver(hosts hostsfile.Storage) (hr *HostsResolver) { 24 | return &HostsResolver{ 25 | strg: hosts, 26 | } 27 | } 28 | 29 | // NewDefaultHostsResolver returns a resolver based on system hosts files 30 | // provided by the [hostsfile.DefaultHostsPaths] and read from rootFSys. In 31 | // case the file by any default path doesn't exist it adds a log debug record. 32 | // If l is nil, [slog.Default] is used. 33 | func NewDefaultHostsResolver( 34 | ctx context.Context, 35 | rootFSys fs.FS, 36 | l *slog.Logger, 37 | ) (hr *HostsResolver, err error) { 38 | if l == nil { 39 | l = slog.Default() 40 | } 41 | 42 | paths, err := hostsfile.DefaultHostsPaths() 43 | if err != nil { 44 | return nil, fmt.Errorf("getting default hosts paths: %w", err) 45 | } 46 | 47 | // The error is always nil here since no readers passed. 48 | strg, _ := hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{ 49 | Logger: l, 50 | }) 51 | 52 | for _, filename := range paths { 53 | err = parseHostsFile(ctx, rootFSys, strg, filename, l) 54 | if err != nil { 55 | // Don't wrap the error since it's already informative enough as is. 56 | return nil, err 57 | } 58 | } 59 | 60 | return NewHostsResolver(strg), nil 61 | } 62 | 63 | // parseHostsFile reads a single hosts file from fsys and parses it into hosts. 64 | func parseHostsFile( 65 | ctx context.Context, 66 | fsys fs.FS, 67 | hosts hostsfile.Set, 68 | filename string, 69 | l *slog.Logger, 70 | ) (err error) { 71 | f, err := fsys.Open(filename) 72 | if err != nil { 73 | if errors.Is(err, fs.ErrNotExist) { 74 | l.DebugContext(ctx, "hosts file does not exist", "filename", filename) 75 | 76 | return nil 77 | } 78 | 79 | // Don't wrap the error since it's already informative enough as is. 80 | return err 81 | } 82 | 83 | defer func() { err = errors.WithDeferred(err, f.Close()) }() 84 | 85 | return hostsfile.Parse(ctx, hosts, f, nil) 86 | } 87 | 88 | // type check 89 | var _ Resolver = (*HostsResolver)(nil) 90 | 91 | // LookupNetIP implements the [Resolver] interface for *hostsResolver. 92 | func (hr *HostsResolver) LookupNetIP( 93 | context context.Context, 94 | network string, 95 | host string, 96 | ) (addrs []netip.Addr, err error) { 97 | var ipMatches func(netip.Addr) (ok bool) 98 | switch network { 99 | case "ip4": 100 | ipMatches = netip.Addr.Is4 101 | case "ip6": 102 | ipMatches = netip.Addr.Is6 103 | case "ip": 104 | return slices.Clone(hr.strg.ByName(host)), nil 105 | default: 106 | return nil, fmt.Errorf("unsupported network %q", network) 107 | } 108 | 109 | for _, addr := range hr.strg.ByName(host) { 110 | if ipMatches(addr) { 111 | addrs = append(addrs, addr) 112 | } 113 | } 114 | 115 | return addrs, nil 116 | } 117 | -------------------------------------------------------------------------------- /scripts/make/go-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # dnsproxy build script 4 | # 5 | # The commentary in this file is written with the assumption that the reader 6 | # only has superficial knowledge of the POSIX shell language and alike. 7 | # Experienced readers may find it overly verbose. 8 | 9 | # This comment is used to simplify checking local copies of the script. Bump 10 | # this number every time a significant change is made to this script. 11 | # 12 | # AdGuard-Project-Version: 2 13 | 14 | # The default verbosity level is 0. Show every command that is run and every 15 | # package that is processed if the caller requested verbosity level greater than 16 | # 0. Also show subcommands if the requested verbosity level is greater than 1. 17 | # Otherwise, do nothing. 18 | verbose="${VERBOSE:-0}" 19 | readonly verbose 20 | 21 | if [ "$verbose" -gt '1' ]; then 22 | env 23 | set -x 24 | v_flags='-v=1' 25 | x_flags='-x=1' 26 | elif [ "$verbose" -gt '0' ]; then 27 | set -x 28 | v_flags='-v=1' 29 | x_flags='-x=0' 30 | else 31 | set +x 32 | v_flags='-v=0' 33 | x_flags='-x=0' 34 | fi 35 | readonly x_flags v_flags 36 | 37 | # Exit the script if a pipeline fails (-e), prevent accidental filename 38 | # expansion (-f), and consider undefined variables as errors (-u). 39 | set -e -f -u 40 | 41 | # Allow users to override the go command from environment. For example, to 42 | # build two releases with two different Go versions and test the difference. 43 | go="${GO:-go}" 44 | readonly go 45 | 46 | # Set the build parameters unless already set. 47 | branch="${BRANCH:-$(git rev-parse --abbrev-ref HEAD)}" 48 | revision="${REVISION:-$(git rev-parse --short HEAD)}" 49 | version="${VERSION:-0}" 50 | readonly branch revision version 51 | 52 | # Set date and time of the latest commit unless already set. 53 | committime="${SOURCE_DATE_EPOCH:-$(git log -1 --pretty=%ct)}" 54 | readonly committime 55 | 56 | # Compile them in. 57 | version_pkg='github.com/AdguardTeam/dnsproxy/internal/version' 58 | ldflags="-s -w" 59 | ldflags="${ldflags} -X ${version_pkg}.branch=${branch}" 60 | ldflags="${ldflags} -X ${version_pkg}.committime=${committime}" 61 | ldflags="${ldflags} -X ${version_pkg}.revision=${revision}" 62 | ldflags="${ldflags} -X ${version_pkg}.version=${version}" 63 | readonly ldflags version_pkg 64 | 65 | # Allow users to limit the build's parallelism. 66 | parallelism="${PARALLELISM:-}" 67 | readonly parallelism 68 | 69 | # Use GOFLAGS for -p, because -p=0 simply disables the build instead of leaving 70 | # the default value. 71 | if [ "${parallelism}" != '' ]; then 72 | GOFLAGS="${GOFLAGS:-} -p=${parallelism}" 73 | fi 74 | readonly GOFLAGS 75 | export GOFLAGS 76 | 77 | # Allow users to specify a different output name. 78 | out="${OUT:-dnsproxy}" 79 | readonly out 80 | 81 | o_flags="-o=${out}" 82 | readonly o_flags 83 | 84 | # Allow users to enable the race detector. Unfortunately, that means that cgo 85 | # must be enabled. 86 | if [ "${RACE:-0}" -eq '0' ]; then 87 | CGO_ENABLED='0' 88 | race_flags='--race=0' 89 | else 90 | CGO_ENABLED='1' 91 | race_flags='--race=1' 92 | fi 93 | readonly CGO_ENABLED race_flags 94 | export CGO_ENABLED 95 | 96 | if [ "$verbose" -gt '0' ]; then 97 | "$go" env 98 | fi 99 | 100 | "$go" build \ 101 | --ldflags="$ldflags" \ 102 | "$race_flags" \ 103 | --trimpath \ 104 | "$o_flags" \ 105 | "$v_flags" \ 106 | "$x_flags" \ 107 | ; 108 | -------------------------------------------------------------------------------- /scripts/make/build-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | verbose="${VERBOSE:-0}" 4 | 5 | if [ "$verbose" -gt '0' ]; then 6 | set -x 7 | debug_flags='--debug=1' 8 | else 9 | set +x 10 | debug_flags='--debug=0' 11 | fi 12 | readonly debug_flags 13 | 14 | set -e -f -u 15 | 16 | # Require these to be set. 17 | commit="${REVISION:?please set REVISION}" 18 | dist_dir="${DIST_DIR:?please set DIST_DIR}" 19 | version="${VERSION:?please set VERSION}" 20 | readonly commit dist_dir version 21 | 22 | # Allow users to use sudo. 23 | sudo_cmd="${SUDO:-}" 24 | readonly sudo_cmd 25 | 26 | docker_platforms="\ 27 | linux/386,\ 28 | linux/amd64,\ 29 | linux/arm/v6,\ 30 | linux/arm/v7,\ 31 | linux/arm64,\ 32 | linux/ppc64le" 33 | readonly docker_platforms 34 | 35 | build_date="$(date -u +'%Y-%m-%dT%H:%M:%SZ')" 36 | readonly build_date 37 | 38 | # Set DOCKER_IMAGE_NAME to 'adguard/dnsproxy' if you want (and are allowed) 39 | # to push to DockerHub. 40 | docker_image_name="${DOCKER_IMAGE_NAME:-dnsproxy-dev}" 41 | readonly docker_image_name 42 | 43 | # Set DOCKER_OUTPUT to 'type=image,name=adguard/dnsproxy,push=true' if you 44 | # want (and are allowed) to push to DockerHub. 45 | # 46 | # If you want to inspect the resulting image using commands like "docker image 47 | # ls", change type to docker and also set docker_platforms to a single platform. 48 | # 49 | # See https://github.com/docker/buildx/issues/166. 50 | docker_output="${DOCKER_OUTPUT:-type=image,name=${docker_image_name},push=false}" 51 | readonly docker_output 52 | 53 | docker_version_tag="--tag=${docker_image_name}:${version}" 54 | docker_channel_tag="--tag=${docker_image_name}:latest" 55 | 56 | # If version is set to 'dev' or empty, only set the version tag and avoid 57 | # polluting the "latest" tag. 58 | if [ "${version:-}" = 'dev' ] || [ "${version:-}" = '' ]; then 59 | docker_channel_tag="" 60 | fi 61 | 62 | readonly docker_version_tag docker_channel_tag 63 | 64 | # Copy the binaries into a new directory under new names, so that it's easier to 65 | # COPY them later. DO NOT remove the trailing underscores. See file 66 | # docker/Dockerfile. 67 | dist_docker="${dist_dir}/docker" 68 | readonly dist_docker 69 | 70 | mkdir -p "$dist_docker" 71 | cp "${dist_dir}/linux-386/dnsproxy" \ 72 | "${dist_docker}/dnsproxy_linux_386_" 73 | cp "${dist_dir}/linux-amd64/dnsproxy" \ 74 | "${dist_docker}/dnsproxy_linux_amd64_" 75 | cp "${dist_dir}/linux-arm64/dnsproxy" \ 76 | "${dist_docker}/dnsproxy_linux_arm64_" 77 | cp "${dist_dir}/linux-arm6/dnsproxy" \ 78 | "${dist_docker}/dnsproxy_linux_arm_v6" 79 | cp "${dist_dir}/linux-arm7/dnsproxy" \ 80 | "${dist_docker}/dnsproxy_linux_arm_v7" 81 | cp "${dist_dir}/linux-ppc64le/dnsproxy" \ 82 | "${dist_docker}/dnsproxy_linux_ppc64le_" 83 | 84 | # Prepare the default configuration for the Docker image. 85 | cp ./config.yaml.dist "${dist_docker}/config.yaml" 86 | 87 | # Don't use quotes with $docker_version_tag and $docker_channel_tag, because we 88 | # want word splitting and or an empty space if tags are empty. 89 | # 90 | # TODO(a.garipov): Once flag --tag of docker buildx build supports commas, use 91 | # them instead. 92 | # 93 | # shellcheck disable=SC2086 94 | $sudo_cmd docker \ 95 | "$debug_flags" \ 96 | buildx build \ 97 | --build-arg BUILD_DATE="$build_date" \ 98 | --build-arg DIST_DIR="$dist_dir" \ 99 | --build-arg VCS_REF="$commit" \ 100 | --build-arg VERSION="$version" \ 101 | --output "$docker_output" \ 102 | --platform "$docker_platforms" \ 103 | $docker_version_tag \ 104 | $docker_channel_tag \ 105 | -f ./docker/Dockerfile \ 106 | . \ 107 | ; 108 | -------------------------------------------------------------------------------- /bamboo-specs/bamboo.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 'version': 2 3 | 'plan': 4 | 'project-key': 'GO' 5 | 'key': 'DNSPROXY' 6 | 'name': 'dnsproxy - Build and run tests' 7 | 'variables': 8 | 'dockerFpm': 'alanfranz/fpm-within-docker:ubuntu-bionic' 9 | # When there is a patch release of Go available, set this property to an 10 | # exact patch version as opposed to a minor one to make sure that this exact 11 | # version is actually used and not whatever the docker daemon on the CI has 12 | # cached a few months ago. 13 | 'dockerGo': 'adguard/go-builder:1.25.5--1' 14 | 'maintainer': 'Adguard Go Team' 15 | 'name': 'dnsproxy' 16 | 17 | 'stages': 18 | # TODO(e.burkov): Add separate lint stage for texts. 19 | - 'Lint': 20 | 'manual': false 21 | 'final': false 22 | 'jobs': 23 | - 'Lint' 24 | - 'Test': 25 | 'manual': false 26 | 'final': false 27 | 'jobs': 28 | - 'Test' 29 | 30 | 'Lint': 31 | 'docker': 32 | 'image': '${bamboo.dockerGo}' 33 | 'volumes': 34 | '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}' 35 | '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}' 36 | 'key': 'LINT' 37 | 'other': 38 | 'clean-working-dir': true 39 | 'requirements': 40 | - 'adg-docker': true 41 | 'tasks': 42 | - 'checkout': 43 | 'force-clean-build': true 44 | - 'script': 45 | 'interpreter': 'SHELL' 46 | 'scripts': 47 | - | 48 | #!/bin/sh 49 | 50 | set -e -f -u -x 51 | 52 | make VERBOSE=1 GOMAXPROCS=1 go-tools go-lint 53 | 54 | 'Test': 55 | 'docker': 56 | 'image': '${bamboo.dockerGo}' 57 | 'volumes': 58 | '${system.GO_CACHE_DIR}': '${bamboo.cacheGo}' 59 | '${system.GO_PKG_CACHE_DIR}': '${bamboo.cacheGoPkg}' 60 | 'final-tasks': 61 | - 'test-parser': 62 | # The default pattern, '**/test-reports/*.xml', works, so don't set 63 | # the test-results property. 64 | 'type': 'junit' 65 | 'ignore-time': true 66 | - 'clean' 67 | 'key': 'TEST' 68 | 'other': 69 | 'clean-working-dir': true 70 | 'requirements': 71 | - 'adg-docker': true 72 | 'tasks': 73 | - 'checkout': 74 | 'force-clean-build': true 75 | - 'script': 76 | 'interpreter': 'SHELL' 77 | # Projects that have go-bench and/or go-fuzz targets should add them 78 | # here as well. 79 | 'scripts': 80 | - | 81 | #!/bin/sh 82 | 83 | set -e -f -u -x 84 | 85 | make \ 86 | GOMAXPROCS=1 \ 87 | VERBOSE=1 \ 88 | go-deps go-tools 89 | 90 | make \ 91 | TEST_REPORTS_DIR="./test-reports/" \ 92 | VERBOSE=1 \ 93 | go-test \ 94 | ; 95 | 96 | exit "$(cat ./test-reports/test-exit-code.txt)" 97 | 98 | 'branches': 99 | 'create': 'for-pull-request' 100 | 'delete': 101 | 'after-deleted-days': 1 102 | 'after-inactive-days': 5 103 | 'link-to-jira': true 104 | 105 | 'notifications': 106 | - 'events': 107 | - 'plan-status-changed' 108 | 'recipients': 109 | - 'webhook': 110 | 'name': 'Build webhook' 111 | 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo' 112 | 113 | 'labels': [] 114 | 115 | 'other': 116 | 'concurrent-build-plugin': 'system-default' 117 | -------------------------------------------------------------------------------- /proxy/optimisticresolver_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "log/slog" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | // testCachingResolver is a stub implementation of the cachingResolver interface 15 | // to simplify testing. 16 | type testCachingResolver struct { 17 | onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error) 18 | onCacheResp func(dctx *DNSContext) 19 | } 20 | 21 | // replyFromUpstream implements the cachingResolver interface for 22 | // *testCachingResolver. 23 | func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) { 24 | return tcr.onReplyFromUpstream(dctx) 25 | } 26 | 27 | // cacheResp implements the cachingResolver interface for *testCachingResolver. 28 | func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) { 29 | tcr.onCacheResp(dctx) 30 | } 31 | 32 | func TestOptimisticResolver_ResolveOnce(t *testing.T) { 33 | in, out := make(chan unit), make(chan unit) 34 | var timesResolved, timesSet int 35 | 36 | tcr := &testCachingResolver{ 37 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { 38 | timesResolved++ 39 | 40 | return true, nil 41 | }, 42 | onCacheResp: func(_ *DNSContext) { 43 | timesSet++ 44 | 45 | // Pass the signal to begin running secondary goroutines. 46 | out <- unit{} 47 | // Block until all the secondary goroutines finish. 48 | <-in 49 | }, 50 | } 51 | 52 | s := newOptimisticResolver(tcr) 53 | sameKey := []byte{1, 2, 3} 54 | 55 | // Start the primary goroutine. 56 | go s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger()) 57 | // Block until the primary goroutine reaches the resolve function. 58 | <-out 59 | 60 | wg := &sync.WaitGroup{} 61 | 62 | const secondaryNum = 10 63 | wg.Add(secondaryNum) 64 | for range secondaryNum { 65 | go func() { 66 | defer wg.Done() 67 | 68 | s.resolveOnce(nil, sameKey, slogutil.NewDiscardLogger()) 69 | }() 70 | } 71 | 72 | // Wait until all the secondary goroutines are finished. 73 | wg.Wait() 74 | // Pass the signal to terminate the primary goroutine. 75 | in <- unit{} 76 | 77 | assert.Equal(t, 1, timesResolved) 78 | assert.Equal(t, 1, timesSet) 79 | } 80 | 81 | func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) { 82 | key := []byte{1, 2, 3} 83 | 84 | t.Run("error", func(t *testing.T) { 85 | // TODO(d.kolyshev): Consider adding mock handler to golibs. 86 | logOutput := &bytes.Buffer{} 87 | l := slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{ 88 | AddSource: false, 89 | Level: slog.LevelDebug, 90 | ReplaceAttr: nil, 91 | })) 92 | 93 | const rErr errors.Error = "sample resolving error" 94 | 95 | cached := false 96 | s := newOptimisticResolver(&testCachingResolver{ 97 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rErr }, 98 | onCacheResp: func(_ *DNSContext) { cached = true }, 99 | }) 100 | s.resolveOnce(nil, key, l) 101 | 102 | assert.True(t, cached) 103 | assert.Contains(t, logOutput.String(), rErr.Error()) 104 | }) 105 | 106 | t.Run("not_ok", func(t *testing.T) { 107 | cached := false 108 | s := newOptimisticResolver(&testCachingResolver{ 109 | onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil }, 110 | onCacheResp: func(_ *DNSContext) { cached = true }, 111 | }) 112 | s.resolveOnce(nil, key, slogutil.NewDiscardLogger()) 113 | 114 | assert.False(t, cached) 115 | }) 116 | } 117 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Keep the Makefile POSIX-compliant. We currently allow hyphens in target 2 | # names, but that may change in the future. 3 | # 4 | # See https://pubs.opengroup.org/onlinepubs/9799919799/utilities/make.html. 5 | .POSIX: 6 | 7 | # This comment is used to simplify checking local copies of the Makefile. Bump 8 | # this number every time a significant change is made to this Makefile. 9 | # 10 | # AdGuard-Project-Version: 11 11 | 12 | # Don't name these macros "GO" etc., because GNU Make apparently makes them 13 | # exported environment variables with the literal value of "${GO:-go}" and so 14 | # on, which is not what we need. Use a dot in the name to make sure that users 15 | # don't have an environment variable with the same name. 16 | # 17 | # See https://unix.stackexchange.com/q/646255/105635. 18 | GO.MACRO = $${GO:-go} 19 | VERBOSE.MACRO = $${VERBOSE:-0} 20 | 21 | BRANCH = $${BRANCH:-$$(git rev-parse --abbrev-ref HEAD)} 22 | DIST_DIR = build 23 | GOAMD64 = v1 24 | GOPROXY = https://proxy.golang.org|direct 25 | GOTELEMETRY = off 26 | OUT = dnsproxy 27 | GOTOOLCHAIN = go1.25.5 28 | RACE = 0 29 | REVISION = $${REVISION:-$$(git rev-parse --short HEAD)} 30 | VERSION = 0 31 | 32 | ENV = env\ 33 | BRANCH="$(BRANCH)"\ 34 | DIST_DIR='$(DIST_DIR)'\ 35 | GO="$(GO.MACRO)"\ 36 | GOAMD64='$(GOAMD64)'\ 37 | GOPROXY='$(GOPROXY)'\ 38 | GOTELEMETRY='$(GOTELEMETRY)'\ 39 | OUT='$(OUT)'\ 40 | GOTOOLCHAIN='$(GOTOOLCHAIN)'\ 41 | PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\ 42 | RACE='$(RACE)'\ 43 | REVISION="$(REVISION)"\ 44 | VERBOSE="$(VERBOSE.MACRO)"\ 45 | VERSION="$(VERSION)"\ 46 | 47 | # Keep the line above blank. 48 | 49 | ENV_MISC = env\ 50 | PATH="$${PWD}/bin:$$("$(GO.MACRO)" env GOPATH)/bin:$${PATH}"\ 51 | VERBOSE="$(VERBOSE.MACRO)"\ 52 | 53 | # Keep the line above blank. 54 | 55 | # Keep this target first, so that a naked make invocation triggers a full build. 56 | .PHONY: build 57 | build: go-deps go-build 58 | 59 | .PHONY: init 60 | init: ; git config core.hooksPath ./scripts/hooks 61 | 62 | .PHONY: test 63 | test: go-test 64 | 65 | .PHONY: go-build go-deps go-env go-lint go-test go-tools go-upd-tools 66 | go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh 67 | go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh 68 | go-env: ; $(ENV) "$(GO.MACRO)" env 69 | go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh 70 | go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh 71 | go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh 72 | go-upd-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-upd-tools.sh 73 | 74 | .PHONY: go-check 75 | go-check: go-tools go-lint go-test 76 | 77 | # A quick check to make sure that all operating systems relevant to the 78 | # development of the project can be typechecked and built successfully. 79 | .PHONY: go-os-check 80 | go-os-check: 81 | $(ENV) GOOS='darwin' "$(GO.MACRO)" vet ./... 82 | $(ENV) GOOS='freebsd' "$(GO.MACRO)" vet ./... 83 | $(ENV) GOOS='openbsd' "$(GO.MACRO)" vet ./... 84 | $(ENV) GOOS='linux' "$(GO.MACRO)" vet ./... 85 | $(ENV) GOOS='windows' "$(GO.MACRO)" vet ./... 86 | 87 | .PHONY: txt-lint 88 | txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh 89 | 90 | .PHONY: md-lint sh-lint 91 | md-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/md-lint.sh 92 | sh-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/sh-lint.sh 93 | 94 | .PHONY: clean 95 | clean: ; $(ENV) $(GO.MACRO) clean && rm -f -r '$(DIST_DIR)' 96 | 97 | .PHONY: release 98 | release: clean 99 | $(ENV) "$(SHELL)" ./scripts/make/build-release.sh 100 | 101 | .PHONY: docker 102 | docker: release 103 | $(ENV) "$(SHELL)" ./scripts/make/build-docker.sh 104 | -------------------------------------------------------------------------------- /proxy/beforerequest_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "testing" 7 | "time" 8 | 9 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 10 | "github.com/AdguardTeam/dnsproxy/upstream" 11 | "github.com/AdguardTeam/golibs/errors" 12 | "github.com/AdguardTeam/golibs/logutil/slogutil" 13 | "github.com/AdguardTeam/golibs/netutil" 14 | "github.com/AdguardTeam/golibs/testutil/servicetest" 15 | "github.com/miekg/dns" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | // testBeforeRequestHandler is a mock before request handler implementation to 21 | // simplify testing. 22 | type testBeforeRequestHandler struct { 23 | onHandleBefore func(p *Proxy, dctx *DNSContext) (err error) 24 | } 25 | 26 | // type check 27 | var _ BeforeRequestHandler = (*testBeforeRequestHandler)(nil) 28 | 29 | // HandleBefore implements the [BeforeRequestHandler] interface for 30 | // *testBeforeRequestHandler. 31 | func (h *testBeforeRequestHandler) HandleBefore(p *Proxy, dctx *DNSContext) (err error) { 32 | return h.onHandleBefore(p, dctx) 33 | } 34 | 35 | func TestProxy_HandleDNSRequest_beforeRequestHandler(t *testing.T) { 36 | t.Parallel() 37 | 38 | const ( 39 | allowedID = iota 40 | droppedID 41 | errorID 42 | ) 43 | 44 | allowedRequest := (&dns.Msg{}).SetQuestion("allowed.", dns.TypeA) 45 | allowedRequest.Id = allowedID 46 | allowedResponse := (&dns.Msg{}).SetReply(allowedRequest) 47 | 48 | droppedRequest := (&dns.Msg{}).SetQuestion("dropped.", dns.TypeA) 49 | droppedRequest.Id = droppedID 50 | 51 | errorRequest := (&dns.Msg{}).SetQuestion("error.", dns.TypeA) 52 | errorRequest.Id = errorID 53 | errorResponse := (&dns.Msg{}).SetReply(errorRequest) 54 | 55 | p := mustNew(t, &Config{ 56 | Logger: slogutil.NewDiscardLogger(), 57 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 58 | UpstreamConfig: &UpstreamConfig{ 59 | Upstreams: []upstream.Upstream{&dnsproxytest.Upstream{ 60 | OnExchange: func(m *dns.Msg) (resp *dns.Msg, err error) { 61 | return allowedResponse.Copy(), nil 62 | }, 63 | OnAddress: func() (addr string) { return "general" }, 64 | OnClose: func() (err error) { return nil }, 65 | }}, 66 | }, 67 | TrustedProxies: defaultTrustedProxies, 68 | PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed), 69 | BeforeRequestHandler: &testBeforeRequestHandler{ 70 | onHandleBefore: func(p *Proxy, dctx *DNSContext) (err error) { 71 | switch dctx.Req.Id { 72 | case allowedID: 73 | return nil 74 | case droppedID: 75 | return errors.Error("just drop") 76 | case errorID: 77 | return &BeforeRequestError{ 78 | Err: errors.Error("just error"), 79 | Response: errorResponse, 80 | } 81 | default: 82 | panic(fmt.Sprintf("unexpected request id: %d", dctx.Req.Id)) 83 | } 84 | }, 85 | }, 86 | }) 87 | 88 | servicetest.RequireRun(t, p, testTimeout) 89 | 90 | client := &dns.Client{ 91 | Net: string(ProtoTCP), 92 | Timeout: 200 * time.Millisecond, 93 | } 94 | addr := p.Addr(ProtoTCP).String() 95 | 96 | t.Run("allowed", func(t *testing.T) { 97 | t.Parallel() 98 | 99 | resp, _, err := client.Exchange(allowedRequest, addr) 100 | require.NoError(t, err) 101 | assert.Equal(t, allowedResponse, resp) 102 | }) 103 | 104 | t.Run("dropped", func(t *testing.T) { 105 | t.Parallel() 106 | 107 | resp, _, err := client.Exchange(droppedRequest, addr) 108 | 109 | wantErr := &net.OpError{} 110 | require.ErrorAs(t, err, &wantErr) 111 | assert.True(t, wantErr.Timeout()) 112 | 113 | assert.Nil(t, resp) 114 | }) 115 | 116 | t.Run("error", func(t *testing.T) { 117 | t.Parallel() 118 | 119 | resp, _, err := client.Exchange(errorRequest, addr) 120 | require.NoError(t, err) 121 | assert.Equal(t, errorResponse, resp) 122 | }) 123 | } 124 | -------------------------------------------------------------------------------- /internal/dnsmsg/constructor.go: -------------------------------------------------------------------------------- 1 | // Package dnsmsg contains common constants, functions, and types for inspecting 2 | // and constructing DNS messages. 3 | package dnsmsg 4 | 5 | import ( 6 | "strings" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // MessageConstructor creates DNS messages. 12 | type MessageConstructor interface { 13 | // NewMsgNXDOMAIN creates a new response message replying to req with the 14 | // NXDOMAIN code. 15 | NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) 16 | 17 | // NewMsgSERVFAIL creates a new response message replying to req with the 18 | // SERVFAIL code. 19 | NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) 20 | 21 | // NewMsgNOTIMPLEMENTED creates a new response message replying to req with 22 | // the NOTIMPLEMENTED code. 23 | NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) 24 | 25 | // NewMsgNODATA creates a new empty response message replying to req with 26 | // the NOERROR code. 27 | // 28 | // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. 29 | NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) 30 | } 31 | 32 | // DefaultMessageConstructor is a default implementation of 33 | // [MessageConstructor]. 34 | type DefaultMessageConstructor struct{} 35 | 36 | // type check 37 | var _ MessageConstructor = DefaultMessageConstructor{} 38 | 39 | // NewMsgNXDOMAIN implements the [MessageConstructor] interface for 40 | // DefaultMessageConstructor. 41 | func (DefaultMessageConstructor) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) { 42 | return reply(req, dns.RcodeNameError) 43 | } 44 | 45 | // NewMsgSERVFAIL implements the [MessageConstructor] interface for 46 | // DefaultMessageConstructor. 47 | func (DefaultMessageConstructor) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) { 48 | return reply(req, dns.RcodeServerFailure) 49 | } 50 | 51 | // NewMsgNOTIMPLEMENTED implements the [MessageConstructor] interface for 52 | // DefaultMessageConstructor. 53 | func (DefaultMessageConstructor) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) { 54 | resp = reply(req, dns.RcodeNotImplemented) 55 | 56 | // Most of the Internet and especially the inner core has an MTU of at least 57 | // 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet 58 | // is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)). 59 | // 60 | // See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17. 61 | const maxUDPPayload = 1452 62 | 63 | // NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so 64 | // explicitly set it. 65 | resp.SetEdns0(maxUDPPayload, false) 66 | 67 | return resp 68 | } 69 | 70 | // NewMsgNODATA implements the [MessageConstructor] interface for 71 | // DefaultMessageConstructor. 72 | func (DefaultMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) { 73 | resp = reply(req, dns.RcodeSuccess) 74 | 75 | zone := req.Question[0].Name 76 | soa := &dns.SOA{ 77 | // Values copied from verisign's nonexistent .com domain. 78 | // 79 | // Their exact values are not important in our use case because they are 80 | // used for domain transfers between primary/secondary DNS servers. 81 | Refresh: 1800, 82 | Retry: 60, 83 | Expire: 604800, 84 | Minttl: 86400, 85 | // copied from AdGuard DNS 86 | Ns: "fake-for-negative-caching.adguard.com.", 87 | Serial: 100500, 88 | Mbox: "hostmaster.", 89 | // rest is request-specific 90 | Hdr: dns.RR_Header{ 91 | Name: zone, 92 | Rrtype: dns.TypeSOA, 93 | Ttl: 10, 94 | Class: dns.ClassINET, 95 | }, 96 | } 97 | 98 | if !strings.HasPrefix(zone, ".") { 99 | soa.Mbox += zone 100 | } 101 | 102 | resp.Ns = append(resp.Ns, soa) 103 | 104 | return resp 105 | } 106 | 107 | // reply creates a new response message replying to req with the given code. 108 | func reply(req *dns.Msg, code int) (resp *dns.Msg) { 109 | resp = (&dns.Msg{}).SetRcode(req, code) 110 | resp.RecursionAvailable = true 111 | 112 | return resp 113 | } 114 | -------------------------------------------------------------------------------- /scripts/make/build-release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | verbose="${VERBOSE:-0}" 4 | readonly verbose 5 | 6 | if [ "$verbose" -gt '2' ]; then 7 | env 8 | set -x 9 | elif [ "$verbose" -gt '1' ]; then 10 | set -x 11 | fi 12 | 13 | set -e -f -u 14 | 15 | log() { 16 | if [ "$verbose" -gt '0' ]; then 17 | # Don't use quotes to get word splitting. 18 | echo "$1" 1>&2 19 | fi 20 | } 21 | 22 | log 'starting to build dnsproxy release' 23 | 24 | version="${VERSION:-}" 25 | readonly version 26 | 27 | log "version '$version'" 28 | 29 | dist="${DIST_DIR:-build}" 30 | readonly dist 31 | 32 | out="${OUT:-dnsproxy}" 33 | 34 | log "checking tools" 35 | 36 | for tool in tar zip; do 37 | if ! command -v "$tool" >/dev/null; then 38 | log "tool '$tool' not found" 39 | 40 | exit 1 41 | fi 42 | done 43 | 44 | # Data section. Arrange data into space-separated tables for read -r to read. 45 | # Use 0 for missing values. 46 | 47 | # os arch arm mips 48 | platforms="\ 49 | darwin amd64 0 0 50 | darwin arm64 0 0 51 | freebsd 386 0 0 52 | freebsd amd64 0 0 53 | freebsd arm 5 0 54 | freebsd arm 6 0 55 | freebsd arm 7 0 56 | freebsd arm64 0 0 57 | linux 386 0 0 58 | linux amd64 0 0 59 | linux arm 5 0 60 | linux arm 6 0 61 | linux arm 7 0 62 | linux arm64 0 0 63 | linux mips 0 softfloat 64 | linux mips64 0 softfloat 65 | linux mips64le 0 softfloat 66 | linux mipsle 0 softfloat 67 | linux ppc64le 0 0 68 | openbsd amd64 0 0 69 | openbsd arm64 0 0 70 | windows 386 0 0 71 | windows amd64 0 0 72 | windows arm64 0 0" 73 | readonly platforms 74 | 75 | build() { 76 | # Get the arguments. Here and below, use the "build_" prefix for all 77 | # variables local to function build. 78 | build_dir="${dist}/${1}" \ 79 | build_name="$1" \ 80 | build_os="$2" \ 81 | build_arch="$3" \ 82 | build_arm="$4" \ 83 | build_mips="$5" \ 84 | ; 85 | 86 | # Use the ".exe" filename extension if we build a Windows release. 87 | if [ "$build_os" = 'windows' ]; then 88 | build_output="./${build_dir}/${out}.exe" 89 | else 90 | build_output="./${build_dir}/${out}" 91 | fi 92 | 93 | mkdir -p "./${build_dir}" 94 | 95 | # Build the binary. 96 | # 97 | # Set GOARM and GOMIPS to an empty string if $build_arm and $build_mips 98 | # are zero by removing the zero as if it's a prefix. 99 | # 100 | # Don't use quotes with $build_par because we want an empty space if 101 | # parallelism wasn't set. 102 | env GOARCH="$build_arch" \ 103 | GOARM="${build_arm#0}" \ 104 | GOMIPS="${build_mips#0}" \ 105 | GOOS="$os" \ 106 | VERBOSE="$((verbose - 1))" \ 107 | VERSION="$version" \ 108 | OUT="$build_output" \ 109 | sh ./scripts/make/go-build.sh 110 | 111 | log "$build_output" 112 | 113 | # Prepare the build directory for archiving. 114 | cp ./LICENSE ./README.md "$build_dir" 115 | 116 | # Make archives. Windows prefers ZIP archives; the rest, gzipped tarballs. 117 | case "$build_os" in 118 | 'windows') 119 | build_archive="./${dist}/${out}-${build_name}-${version}.zip" 120 | # TODO(a.garipov): Find an option similar to the -C option of tar for 121 | # zip. 122 | (cd "${dist}" && zip -9 -q -r "../${build_archive}" "./${build_name}") 123 | ;; 124 | *) 125 | build_archive="./${dist}/${out}-${build_name}-${version}.tar.gz" 126 | tar -C "./${dist}" -c -f - "./${build_name}" | gzip -9 - >"$build_archive" 127 | ;; 128 | esac 129 | 130 | log "$build_archive" 131 | } 132 | 133 | log "starting builds" 134 | 135 | # Go over all platforms defined in the space-separated table above, tweak the 136 | # values where necessary, and feed to build. 137 | echo "$platforms" | while read -r os arch arm mips; do 138 | case "$arch" in 139 | arm) 140 | name="${os}-${arch}${arm}" 141 | ;; 142 | *) 143 | name="${os}-${arch}" 144 | ;; 145 | esac 146 | 147 | build "$name" "$os" "$arch" "$arm" "$mips" 148 | done 149 | 150 | log "finished" 151 | -------------------------------------------------------------------------------- /upstream/parallel_internal_test.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | "testing" 7 | "time" 8 | 9 | "github.com/AdguardTeam/golibs/testutil" 10 | "github.com/miekg/dns" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | const ( 16 | timeout = 2 * time.Second 17 | ) 18 | 19 | // TestExchangeParallel launches several parallel exchanges 20 | func TestExchangeParallel(t *testing.T) { 21 | upstreams := []Upstream{} 22 | upstreamList := []string{"1.2.3.4:55", "8.8.8.1", "8.8.8.8:53"} 23 | 24 | for _, s := range upstreamList { 25 | u, err := AddressToUpstream(s, &Options{ 26 | Logger: testLogger, 27 | Timeout: timeout, 28 | }) 29 | if err != nil { 30 | t.Fatalf("cannot create upstream: %s", err) 31 | } 32 | upstreams = append(upstreams, u) 33 | } 34 | 35 | req := createTestMessage() 36 | start := time.Now() 37 | resp, u, err := ExchangeParallel(upstreams, req) 38 | if err != nil { 39 | t.Fatalf("no response from test upstreams: %s", err) 40 | } 41 | 42 | if u.Address() != "8.8.8.8:53" { 43 | t.Fatalf("shouldn't happen. This upstream can't resolve DNS request: %s", u.Address()) 44 | } 45 | 46 | requireResponse(t, req, resp) 47 | elapsed := time.Since(start) 48 | if elapsed > timeout { 49 | t.Fatalf("exchange took more time than the configured timeout: %v", elapsed) 50 | } 51 | } 52 | 53 | func TestExchangeParallelEmpty(t *testing.T) { 54 | ups := []Upstream{ 55 | &testUpstream{empty: true}, 56 | &testUpstream{empty: true}, 57 | } 58 | 59 | req := createTestMessage() 60 | resp, up, err := ExchangeParallel(ups, req) 61 | require.Error(t, err) 62 | 63 | assert.Nil(t, resp) 64 | assert.Nil(t, up) 65 | } 66 | 67 | // testUpstream represents a mock upstream structure. 68 | type testUpstream struct { 69 | // addr is a mock A record IP address to be returned. 70 | addr netip.Addr 71 | 72 | // err is a mock error to be returned. 73 | err bool 74 | 75 | // empty indicates if a nil response is returned. 76 | empty bool 77 | 78 | // sleep is a delay before response. 79 | sleep time.Duration 80 | } 81 | 82 | // type check 83 | var _ Upstream = (*testUpstream)(nil) 84 | 85 | // Exchange implements the [Upstream] interface for *testUpstream. 86 | func (u *testUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { 87 | if u.sleep != 0 { 88 | time.Sleep(u.sleep) 89 | } 90 | 91 | if u.empty { 92 | return nil, nil 93 | } 94 | 95 | if u.err { 96 | return nil, fmt.Errorf("upstream error") 97 | } 98 | 99 | resp = &dns.Msg{} 100 | resp.SetReply(req) 101 | 102 | if u.addr != (netip.Addr{}) { 103 | a := dns.A{ 104 | A: u.addr.AsSlice(), 105 | } 106 | 107 | resp.Answer = append(resp.Answer, &a) 108 | } 109 | 110 | return resp, nil 111 | } 112 | 113 | // Address implements the [Upstream] interface for *testUpstream. 114 | func (u *testUpstream) Address() (addr string) { 115 | return "" 116 | } 117 | 118 | // Close implements the [Upstream] interface for *testUpstream. 119 | func (u *testUpstream) Close() (err error) { 120 | return nil 121 | } 122 | 123 | func TestExchangeAll(t *testing.T) { 124 | delayedAnsAddr := netip.MustParseAddr("1.1.1.1") 125 | ansAddr := netip.MustParseAddr("3.3.3.3") 126 | 127 | ups := []Upstream{&testUpstream{ 128 | addr: delayedAnsAddr, 129 | sleep: 100 * time.Millisecond, 130 | }, &testUpstream{ 131 | err: true, 132 | }, &testUpstream{ 133 | addr: ansAddr, 134 | }} 135 | 136 | req := createHostTestMessage("test.org") 137 | res, err := ExchangeAll(ups, req) 138 | require.NoError(t, err) 139 | require.Len(t, res, 2) 140 | 141 | resp := res[0].Resp 142 | require.NotNil(t, resp) 143 | require.NotEmpty(t, resp.Answer) 144 | 145 | ip := testutil.RequireTypeAssert[*dns.A](t, resp.Answer[0]).A 146 | assert.Equal(t, ansAddr.AsSlice(), []byte(ip)) 147 | 148 | resp = res[1].Resp 149 | require.NotNil(t, resp) 150 | require.NotEmpty(t, resp.Answer) 151 | 152 | ip = testutil.RequireTypeAssert[*dns.A](t, resp.Answer[0]).A 153 | assert.Equal(t, delayedAnsAddr.AsSlice(), []byte(ip)) 154 | } 155 | -------------------------------------------------------------------------------- /internal/handler/hosts.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "net/netip" 8 | "os" 9 | "slices" 10 | "strings" 11 | 12 | "github.com/AdguardTeam/golibs/errors" 13 | "github.com/AdguardTeam/golibs/hostsfile" 14 | "github.com/AdguardTeam/golibs/logutil/slogutil" 15 | "github.com/AdguardTeam/golibs/netutil" 16 | "github.com/miekg/dns" 17 | ) 18 | 19 | // emptyStorage is a [hostsfile.Storage] that contains no records. 20 | // 21 | // TODO(e.burkov): Move to [hostsfile]. 22 | type emptyStorage [0]hostsfile.Record 23 | 24 | // type check 25 | var _ hostsfile.Storage = emptyStorage{} 26 | 27 | // ByAddr implements the [hostsfile.Storage] interface for [emptyStorage]. 28 | func (emptyStorage) ByAddr(_ netip.Addr) (names []string) { 29 | return nil 30 | } 31 | 32 | // ByName implements the [hostsfile.Storage] interface for [emptyStorage]. 33 | func (emptyStorage) ByName(_ string) (addrs []netip.Addr) { 34 | return nil 35 | } 36 | 37 | // ReadHosts reads the hosts files from the file system and returns a storage 38 | // with parsed records. strg is always usable even if an error occurred. 39 | func ReadHosts( 40 | ctx context.Context, 41 | l *slog.Logger, 42 | paths []string, 43 | ) (strg hostsfile.Storage, err error) { 44 | // Don't check the error since it may only appear when any readers used. 45 | defaultStrg, _ := hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{ 46 | Logger: l, 47 | }) 48 | 49 | var errs []error 50 | for _, path := range paths { 51 | err = readHostsFile(ctx, defaultStrg, path) 52 | if err != nil { 53 | // Don't wrap the error since it's informative enough as is. 54 | errs = append(errs, err) 55 | } 56 | } 57 | 58 | // TODO(e.burkov): Add method for length. 59 | isEmpty := true 60 | defaultStrg.RangeAddrs(func(_ string, _ []netip.Addr) (cont bool) { 61 | isEmpty = false 62 | 63 | return false 64 | }) 65 | 66 | if isEmpty { 67 | return emptyStorage{}, errors.Join(errs...) 68 | } 69 | 70 | return defaultStrg, errors.Join(errs...) 71 | } 72 | 73 | // readHostsFile reads the hosts file at path and parses it into strg. 74 | func readHostsFile(ctx context.Context, strg *hostsfile.DefaultStorage, path string) (err error) { 75 | // #nosec G304 -- Trust the file path from the configuration file. 76 | f, err := os.Open(path) 77 | if err != nil { 78 | // Don't wrap the error since it's informative enough as is. 79 | return err 80 | } 81 | 82 | defer func() { err = errors.WithDeferred(err, f.Close()) }() 83 | 84 | err = hostsfile.Parse(ctx, strg, f, nil) 85 | if err != nil { 86 | return fmt.Errorf("parsing hosts file %q: %w", path, err) 87 | } 88 | 89 | return nil 90 | } 91 | 92 | // resolveFromHosts resolves the DNS query from the hosts file. It fills the 93 | // response with the A, AAAA, and PTR records from the hosts file. 94 | func (h *Default) resolveFromHosts(ctx context.Context, req *dns.Msg) (resp *dns.Msg) { 95 | var addrs []netip.Addr 96 | var ptrs []string 97 | 98 | q := req.Question[0] 99 | name := strings.TrimSuffix(q.Name, ".") 100 | switch q.Qtype { 101 | case dns.TypeA: 102 | addrs = slices.Clone(h.hosts.ByName(name)) 103 | addrs = slices.DeleteFunc(addrs, netip.Addr.Is6) 104 | case dns.TypeAAAA: 105 | addrs = slices.Clone(h.hosts.ByName(name)) 106 | addrs = slices.DeleteFunc(addrs, netip.Addr.Is4) 107 | case dns.TypePTR: 108 | addr, err := netutil.IPFromReversedAddr(name) 109 | if err != nil { 110 | h.logger.DebugContext(ctx, "failed parsing ptr", slogutil.KeyError, err) 111 | 112 | return nil 113 | } 114 | 115 | ptrs = h.hosts.ByAddr(addr) 116 | default: 117 | return nil 118 | } 119 | 120 | switch { 121 | case len(addrs) > 0: 122 | resp = h.messages.NewIPResponse(req, addrs) 123 | case len(ptrs) > 0: 124 | resp = h.messages.NewCompressedResponse(req, dns.RcodeSuccess) 125 | name = req.Question[0].Name 126 | for _, ptr := range ptrs { 127 | resp.Answer = append(resp.Answer, h.messages.NewPTRAnswer(name, dns.Fqdn(ptr))) 128 | } 129 | default: 130 | h.logger.DebugContext(ctx, "no hosts records found", "name", name, "qtype", q.Qtype) 131 | } 132 | 133 | return resp 134 | } 135 | -------------------------------------------------------------------------------- /internal/handler/constructor.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/netip" 5 | 6 | "github.com/AdguardTeam/dnsproxy/proxy" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // messageConstructor is an extension of the [proxy.MessageConstructor] 11 | // interface that also provides methods for creating DNS responses. 12 | type messageConstructor interface { 13 | proxy.MessageConstructor 14 | 15 | // NewCompressedResponse creates a new compressed response message for req 16 | // with the given response code. 17 | NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) 18 | 19 | // NewPTRAnswer creates a new resource record for PTR response with the 20 | // given FQDN and PTR domain. Arguments must be fully qualified domain 21 | // names. 22 | NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) 23 | 24 | // NewIPResponse creates a new A/AAAA response message for req with the 25 | // given IP addresses. All IP addresses must be of the same family. 26 | NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) 27 | } 28 | 29 | // defaultConstructor is a wrapper for [proxy.MessageConstructor] that also 30 | // implements the [messageConstructor] interface. 31 | // 32 | // TODO(e.burkov): This implementation reflects the one from AdGuard Home, 33 | // consider moving it to [golibs]. 34 | type defaultConstructor struct { 35 | proxy.MessageConstructor 36 | } 37 | 38 | // type check 39 | var _ messageConstructor = defaultConstructor{} 40 | 41 | // NewCompressedResponse implements the [messageConstructor] interface for 42 | // defaultConstructor. 43 | func (defaultConstructor) NewCompressedResponse(req *dns.Msg, code int) (resp *dns.Msg) { 44 | resp = reply(req, code) 45 | resp.Compress = true 46 | 47 | return resp 48 | } 49 | 50 | // NewPTRAnswer implements the [messageConstructor] interface for 51 | // [defaultConstructor]. 52 | func (defaultConstructor) NewPTRAnswer(fqdn, ptrFQDN string) (ans *dns.PTR) { 53 | return &dns.PTR{ 54 | Hdr: hdr(fqdn, dns.TypePTR), 55 | Ptr: dns.Fqdn(ptrFQDN), 56 | } 57 | } 58 | 59 | // NewIPResponse implements the [messageConstructor] interface for 60 | // [defaultConstructor] 61 | func (c defaultConstructor) NewIPResponse(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) { 62 | var ans []dns.RR 63 | switch req.Question[0].Qtype { 64 | case dns.TypeA: 65 | ans = genAnswersWithIPv4s(req, ips) 66 | case dns.TypeAAAA: 67 | for _, ip := range ips { 68 | if ip.Is6() { 69 | ans = append(ans, newAnswerAAAA(req, ip)) 70 | } 71 | } 72 | default: 73 | // Go on and return an empty response. 74 | } 75 | 76 | resp = c.NewCompressedResponse(req, dns.RcodeSuccess) 77 | resp.Answer = ans 78 | 79 | return resp 80 | } 81 | 82 | // defaultResponseTTL is the default TTL for the DNS responses in seconds. 83 | const defaultResponseTTL = 10 84 | 85 | // hdr creates a new DNS header with the given name and RR type. 86 | func hdr(name string, rrType uint16) (h dns.RR_Header) { 87 | return dns.RR_Header{ 88 | Name: name, 89 | Rrtype: rrType, 90 | Ttl: defaultResponseTTL, 91 | Class: dns.ClassINET, 92 | } 93 | } 94 | 95 | // reply creates a DNS response for req. 96 | func reply(req *dns.Msg, code int) (resp *dns.Msg) { 97 | resp = (&dns.Msg{}).SetRcode(req, code) 98 | resp.RecursionAvailable = true 99 | 100 | return resp 101 | } 102 | 103 | // newAnswerA creates a DNS A answer for req with the given IP address. 104 | func newAnswerA(req *dns.Msg, ip netip.Addr) (ans *dns.A) { 105 | return &dns.A{ 106 | Hdr: hdr(req.Question[0].Name, dns.TypeA), 107 | A: ip.AsSlice(), 108 | } 109 | } 110 | 111 | // newAnswerAAAA creates a DNS AAAA answer for req with the given IP address. 112 | func newAnswerAAAA(req *dns.Msg, ip netip.Addr) (ans *dns.AAAA) { 113 | return &dns.AAAA{ 114 | Hdr: hdr(req.Question[0].Name, dns.TypeAAAA), 115 | AAAA: ip.AsSlice(), 116 | } 117 | } 118 | 119 | // genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any 120 | // of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and 121 | // returns nil, 122 | func genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) { 123 | for _, ip := range ips { 124 | if !ip.Is4() { 125 | return nil 126 | } 127 | 128 | ans = append(ans, newAnswerA(req, ip)) 129 | } 130 | 131 | return ans 132 | } 133 | -------------------------------------------------------------------------------- /proxy/serverdnscrypt.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | 8 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/AdguardTeam/golibs/netutil" 11 | "github.com/AdguardTeam/golibs/syncutil" 12 | "github.com/ameshkov/dnscrypt/v2" 13 | "github.com/miekg/dns" 14 | ) 15 | 16 | func (p *Proxy) initDNSCryptListeners(ctx context.Context) (err error) { 17 | if len(p.DNSCryptUDPListenAddr) == 0 && len(p.DNSCryptTCPListenAddr) == 0 { 18 | // Do nothing if DNSCrypt listen addresses are not specified. 19 | return nil 20 | } 21 | 22 | if p.DNSCryptResolverCert == nil || p.DNSCryptProviderName == "" { 23 | return errors.Error("invalid dnscrypt configuration: no certificate or provider name") 24 | } 25 | 26 | p.logger.InfoContext(ctx, "initializing dnscrypt", "provider", p.DNSCryptProviderName) 27 | p.dnsCryptServer = &dnscrypt.Server{ 28 | ProviderName: p.DNSCryptProviderName, 29 | ResolverCert: p.DNSCryptResolverCert, 30 | Handler: &dnsCryptHandler{ 31 | proxy: p, 32 | reqSema: p.requestsSema, 33 | }, 34 | Logger: p.logger, 35 | } 36 | 37 | for _, addr := range p.DNSCryptUDPListenAddr { 38 | udp, lErr := p.listenDNSCryptUDP(ctx, addr) 39 | if lErr != nil { 40 | return fmt.Errorf("listening to dnscrypt udp on addr %s: %w", addr, lErr) 41 | } 42 | 43 | p.dnsCryptUDPListen = append(p.dnsCryptUDPListen, udp) 44 | } 45 | 46 | for _, addr := range p.DNSCryptTCPListenAddr { 47 | tcp, lErr := p.listenDNSCryptTCP(ctx, addr) 48 | if lErr != nil { 49 | return fmt.Errorf("listening to dnscrypt tcp on addr %s: %w", addr, lErr) 50 | } 51 | 52 | p.dnsCryptTCPListen = append(p.dnsCryptTCPListen, tcp) 53 | } 54 | 55 | return nil 56 | } 57 | 58 | // listenDNSCryptUDP returns a new UDP connection for DNSCrypt listening on 59 | // addr. 60 | func (p *Proxy) listenDNSCryptUDP( 61 | ctx context.Context, 62 | addr *net.UDPAddr, 63 | ) (conn *net.UDPConn, err error) { 64 | addrStr := addr.String() 65 | p.logger.InfoContext(ctx, "creating dnscrypt udp server socket", "addr", addrStr) 66 | 67 | err = p.bindWithRetry(ctx, func() (listenErr error) { 68 | conn, listenErr = net.ListenUDP(bootstrap.NetworkUDP, addr) 69 | 70 | return listenErr 71 | }) 72 | if err != nil { 73 | return nil, fmt.Errorf("listening to udp socket: %w", err) 74 | } 75 | 76 | p.logger.InfoContext(ctx, "listening for dnscrypt messages on udp", "addr", conn.LocalAddr()) 77 | 78 | return conn, nil 79 | } 80 | 81 | // listenDNSCryptTCP returns a new TCP listener for DNSCrypt listening on addr. 82 | func (p *Proxy) listenDNSCryptTCP( 83 | ctx context.Context, 84 | addr *net.TCPAddr, 85 | ) (conn *net.TCPListener, err error) { 86 | addrStr := addr.String() 87 | p.logger.InfoContext(ctx, "creating dnscrypt tcp server socket", "addr", addrStr) 88 | 89 | err = p.bindWithRetry(ctx, func() (listenErr error) { 90 | conn, listenErr = net.ListenTCP(bootstrap.NetworkTCP, addr) 91 | 92 | return listenErr 93 | }) 94 | if err != nil { 95 | return nil, fmt.Errorf("listening to tcp socket: %w", err) 96 | } 97 | 98 | p.logger.InfoContext(ctx, "listening for dnscrypt messages on tcp", "addr", conn.Addr()) 99 | 100 | return conn, nil 101 | } 102 | 103 | // dnsCryptHandler - dnscrypt.Handler implementation 104 | type dnsCryptHandler struct { 105 | proxy *Proxy 106 | 107 | reqSema syncutil.Semaphore 108 | } 109 | 110 | // compile-time type check 111 | var _ dnscrypt.Handler = &dnsCryptHandler{} 112 | 113 | // ServeDNS - processes the DNS query 114 | func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (err error) { 115 | d := h.proxy.newDNSContext(ProtoDNSCrypt, req, netutil.NetAddrToAddrPort(rw.RemoteAddr())) 116 | d.DNSCryptResponseWriter = rw 117 | 118 | // TODO(d.kolyshev): Pass and use context from above. 119 | err = h.reqSema.Acquire(context.Background()) 120 | if err != nil { 121 | return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err) 122 | } 123 | defer h.reqSema.Release() 124 | 125 | return h.proxy.handleDNSRequest(d) 126 | } 127 | 128 | // Writes a response to the UDP client 129 | func (p *Proxy) respondDNSCrypt(d *DNSContext) error { 130 | if d.Res == nil { 131 | // If no response has been written, do nothing and let it drop 132 | return nil 133 | } 134 | 135 | return d.DNSCryptResponseWriter.WriteMsg(d.Res) 136 | } 137 | -------------------------------------------------------------------------------- /internal/bootstrap/bootstrap.go: -------------------------------------------------------------------------------- 1 | // Package bootstrap provides types and functions to resolve upstream hostnames 2 | // and to dial retrieved addresses. 3 | package bootstrap 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log/slog" 9 | "net" 10 | "net/netip" 11 | "net/url" 12 | "slices" 13 | "time" 14 | 15 | "github.com/AdguardTeam/golibs/errors" 16 | "github.com/AdguardTeam/golibs/logutil/slogutil" 17 | "github.com/AdguardTeam/golibs/netutil" 18 | ) 19 | 20 | // Network is a network type for use in [Resolver]'s methods. 21 | type Network = string 22 | 23 | const ( 24 | // NetworkIP is a network type for both address families. 25 | NetworkIP Network = "ip" 26 | 27 | // NetworkIP4 is a network type for IPv4 address family. 28 | NetworkIP4 Network = "ip4" 29 | 30 | // NetworkIP6 is a network type for IPv6 address family. 31 | NetworkIP6 Network = "ip6" 32 | 33 | // NetworkTCP is a network type for TCP connections. 34 | NetworkTCP Network = "tcp" 35 | 36 | // NetworkUDP is a network type for UDP connections. 37 | NetworkUDP Network = "udp" 38 | ) 39 | 40 | // DialHandler is a dial function for creating unencrypted network connections 41 | // to the upstream server. It establishes the connection to the server 42 | // specified at initialization and ignores the addr. network must be one of 43 | // [NetworkTCP] or [NetworkUDP]. 44 | type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error) 45 | 46 | // ResolveDialContext returns a DialHandler that uses addresses resolved from u 47 | // using resolver. l and u must not be nil. 48 | func ResolveDialContext( 49 | u *url.URL, 50 | timeout time.Duration, 51 | r Resolver, 52 | preferV6 bool, 53 | l *slog.Logger, 54 | ) (h DialHandler, err error) { 55 | defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }() 56 | 57 | host, port, err := netutil.SplitHostPort(u.Host) 58 | if err != nil { 59 | // Don't wrap the error since it's informative enough as is and there is 60 | // already deferred annotation here. 61 | return nil, err 62 | } 63 | 64 | if r == nil { 65 | return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers) 66 | } 67 | 68 | ctx := context.Background() 69 | if timeout > 0 { 70 | var cancel func() 71 | ctx, cancel = context.WithTimeout(ctx, timeout) 72 | defer cancel() 73 | } 74 | 75 | // TODO(e.burkov): Use network properly, perhaps, pass it through options. 76 | ips, err := r.LookupNetIP(ctx, NetworkIP, host) 77 | if err != nil { 78 | return nil, fmt.Errorf("resolving hostname: %w", err) 79 | } 80 | 81 | if preferV6 { 82 | slices.SortStableFunc(ips, netutil.PreferIPv6) 83 | } else { 84 | slices.SortStableFunc(ips, netutil.PreferIPv4) 85 | } 86 | 87 | addrs := make([]string, 0, len(ips)) 88 | for _, ip := range ips { 89 | addrs = append(addrs, netip.AddrPortFrom(ip, port).String()) 90 | } 91 | 92 | return NewDialContext(timeout, l, addrs...), nil 93 | } 94 | 95 | // NewDialContext returns a DialHandler that dials addrs and returns the first 96 | // successful connection. At least a single addr should be specified. l must 97 | // not be nil. 98 | func NewDialContext(timeout time.Duration, l *slog.Logger, addrs ...string) (h DialHandler) { 99 | addrLen := len(addrs) 100 | if addrLen == 0 { 101 | l.Debug("no addresses to dial") 102 | 103 | return func(_ context.Context, _, _ string) (conn net.Conn, err error) { 104 | return nil, errors.Error("no addresses") 105 | } 106 | } 107 | 108 | dialer := &net.Dialer{ 109 | Timeout: timeout, 110 | } 111 | 112 | return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) { 113 | var errs []error 114 | 115 | // Return first succeeded connection. Note that we're using addrs 116 | // instead of what's passed to the function. 117 | for i, addr := range addrs { 118 | a := l.With("addr", addr) 119 | a.DebugContext(ctx, "dialing", "idx", i+1, "total", addrLen) 120 | 121 | start := time.Now() 122 | conn, err = dialer.DialContext(ctx, network, addr) 123 | elapsed := time.Since(start) 124 | if err != nil { 125 | a.DebugContext(ctx, "connection failed", "elapsed", elapsed, slogutil.KeyError, err) 126 | errs = append(errs, err) 127 | 128 | continue 129 | } 130 | 131 | a.DebugContext(ctx, "connection succeeded", "elapsed", elapsed) 132 | 133 | return conn, nil 134 | } 135 | 136 | return nil, errors.Join(errs...) 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /proxy/proxycache.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net" 5 | "slices" 6 | ) 7 | 8 | // cacheForContext returns cache object for the given context. 9 | func (p *Proxy) cacheForContext(d *DNSContext) (c *cache) { 10 | if d.CustomUpstreamConfig != nil && d.CustomUpstreamConfig.cache != nil { 11 | return d.CustomUpstreamConfig.cache 12 | } 13 | 14 | return p.cache 15 | } 16 | 17 | // replyFromCache tries to get the response from general or subnet cache. In 18 | // case the cache is present in d, it's used first. Returns true on success. 19 | func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { 20 | dctxCache := p.cacheForContext(d) 21 | 22 | var ci *cacheItem 23 | var cacheSource string 24 | var expired bool 25 | var key []byte 26 | 27 | // TODO(d.kolyshev): Use EnableEDNSClientSubnet from dctxCache. 28 | if p.Config.EnableEDNSClientSubnet && d.ReqECS != nil { 29 | ci, expired, key = dctxCache.getWithSubnet(d.Req, d.ReqECS) 30 | cacheSource = "subnet cache" 31 | } else { 32 | ci, expired, key = dctxCache.get(d.Req) 33 | cacheSource = "general cache" 34 | } 35 | 36 | if hit = ci != nil; !hit { 37 | return hit 38 | } 39 | 40 | d.Res = ci.m 41 | d.queryStatistics = cachedQueryStatistics(ci.u) 42 | 43 | p.logger.Debug( 44 | "replying from cache", 45 | "source", cacheSource, 46 | "ecs_enabled", p.Config.EnableEDNSClientSubnet, 47 | ) 48 | 49 | if dctxCache.optimistic && expired { 50 | // Build a reduced clone of the current context to avoid data race. 51 | minCtxClone := &DNSContext{ 52 | // It is only read inside the optimistic resolver. 53 | CustomUpstreamConfig: d.CustomUpstreamConfig, 54 | ReqECS: cloneIPNet(d.ReqECS), 55 | IsPrivateClient: d.IsPrivateClient, 56 | } 57 | if d.Req != nil { 58 | minCtxClone.Req = d.Req.Copy() 59 | addDO(minCtxClone.Req) 60 | } 61 | 62 | go p.shortFlighter.resolveOnce(minCtxClone, key, p.logger) 63 | } 64 | 65 | return hit 66 | } 67 | 68 | // cloneIPNet returns a deep clone of n. 69 | func cloneIPNet(n *net.IPNet) (clone *net.IPNet) { 70 | if n == nil { 71 | return nil 72 | } 73 | 74 | return &net.IPNet{ 75 | IP: slices.Clone(n.IP), 76 | Mask: slices.Clone(n.Mask), 77 | } 78 | } 79 | 80 | // cacheResp stores the response from d in general or subnet cache. In case the 81 | // cache is present in d, it's used first. 82 | func (p *Proxy) cacheResp(d *DNSContext) { 83 | dctxCache := p.cacheForContext(d) 84 | 85 | if !p.EnableEDNSClientSubnet { 86 | dctxCache.set(d.Res, d.Upstream, p.logger) 87 | 88 | return 89 | } 90 | 91 | switch ecs, scope := ecsFromMsg(d.Res); { 92 | case ecs != nil && d.ReqECS != nil: 93 | ones, bits := ecs.Mask.Size() 94 | reqOnes, _ := d.ReqECS.Mask.Size() 95 | 96 | // If FAMILY, SOURCE PREFIX-LENGTH, and SOURCE PREFIX-LENGTH bits of 97 | // ADDRESS in the response don't match the non-zero fields in the 98 | // corresponding query, the full response MUST be dropped. 99 | // 100 | // See RFC 7871 Section 7.3. 101 | // 102 | // TODO(a.meshkov): The whole response MUST be dropped if ECS in it 103 | // doesn't correspond. 104 | if !ecs.IP.Mask(ecs.Mask).Equal(d.ReqECS.IP.Mask(d.ReqECS.Mask)) || ones != reqOnes { 105 | p.logger.Debug( 106 | "not caching response; subnet mismatch", 107 | "ecs", ecs, 108 | "req_ecs", d.ReqECS, 109 | ) 110 | 111 | return 112 | } 113 | 114 | // If SCOPE PREFIX-LENGTH is not longer than SOURCE PREFIX-LENGTH, store 115 | // SCOPE PREFIX-LENGTH bits of ADDRESS, and then mark the response as 116 | // valid for all addresses that fall within that range. 117 | // 118 | // See RFC 7871 Section 7.3.1. 119 | if scope < reqOnes { 120 | ecs.Mask = net.CIDRMask(scope, bits) 121 | ecs.IP = ecs.IP.Mask(ecs.Mask) 122 | } 123 | 124 | p.logger.Debug("caching response", "ecs", ecs) 125 | 126 | dctxCache.setWithSubnet(d.Res, d.Upstream, ecs, p.logger) 127 | case d.ReqECS != nil: 128 | // Cache the response for all subnets since the server doesn't support 129 | // EDNS Client Subnet option. 130 | dctxCache.setWithSubnet(d.Res, d.Upstream, &net.IPNet{IP: nil, Mask: nil}, p.logger) 131 | default: 132 | dctxCache.set(d.Res, d.Upstream, p.logger) 133 | } 134 | } 135 | 136 | // ClearCache clears the DNS cache of p. 137 | func (p *Proxy) ClearCache() { 138 | if p.cache == nil { 139 | return 140 | } 141 | 142 | p.cache.clearItems() 143 | p.cache.clearItemsWithSubnet() 144 | p.logger.Debug("cache cleared") 145 | } 146 | -------------------------------------------------------------------------------- /internal/cmd/flag.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/AdguardTeam/golibs/stringutil" 10 | ) 11 | 12 | // uint32Value is an uint32 that can be defined as a flag for [flag.FlagSet]. 13 | type uint32Value uint32 14 | 15 | // type check 16 | var _ flag.Value = (*uint32Value)(nil) 17 | 18 | // Set implements the [flag.Value] interface for *uint32Value. 19 | func (i *uint32Value) Set(s string) (err error) { 20 | v, err := strconv.ParseUint(s, 0, 32) 21 | *i = uint32Value(v) 22 | 23 | return err 24 | } 25 | 26 | // String implements the [flag.Value] interface for *uint32Value. 27 | func (i *uint32Value) String() (out string) { 28 | return strconv.FormatUint(uint64(*i), 10) 29 | } 30 | 31 | // float32Value is an float32 that can be defined as a flag for [flag.FlagSet]. 32 | type float32Value float32 33 | 34 | // type check 35 | var _ flag.Value = (*float32Value)(nil) 36 | 37 | // Set implements the [flag.Value] interface for *float32Value. 38 | func (i *float32Value) Set(s string) (err error) { 39 | v, err := strconv.ParseFloat(s, 32) 40 | *i = float32Value(v) 41 | 42 | return err 43 | } 44 | 45 | // String implements the [flag.Value] interface for *float32Value. 46 | func (i *float32Value) String() (out string) { 47 | return strconv.FormatFloat(float64(*i), 'f', 3, 32) 48 | } 49 | 50 | // intSliceValue represent a struct with a slice of integers that can be defined 51 | // as a flag for [flag.FlagSet]. 52 | type intSliceValue struct { 53 | // values is the pointer to a slice of integers to store parsed values. 54 | values *[]int 55 | 56 | // isSet is false until the corresponding flag is met for the first time. 57 | // When the flag is found, the default value is overwritten with zero value. 58 | isSet bool 59 | } 60 | 61 | // newIntSliceValue returns a pointer to intSliceValue with the given value. 62 | func newIntSliceValue(p *[]int) (out *intSliceValue) { 63 | return &intSliceValue{ 64 | values: p, 65 | isSet: false, 66 | } 67 | } 68 | 69 | // type check 70 | var _ flag.Value = (*intSliceValue)(nil) 71 | 72 | // Set implements the [flag.Value] interface for *intSliceValue. 73 | func (i *intSliceValue) Set(s string) (err error) { 74 | v, err := strconv.Atoi(s) 75 | if err != nil { 76 | return fmt.Errorf("parsing integer slice arg %q: %w", s, err) 77 | } 78 | 79 | if !i.isSet { 80 | i.isSet = true 81 | *i.values = []int{} 82 | } 83 | 84 | *i.values = append(*i.values, v) 85 | 86 | return nil 87 | } 88 | 89 | // String implements the [flag.Value] interface for *intSliceValue. 90 | func (i *intSliceValue) String() (out string) { 91 | if i == nil || i.values == nil { 92 | return "" 93 | } 94 | 95 | sb := &strings.Builder{} 96 | for idx, v := range *i.values { 97 | if idx > 0 { 98 | stringutil.WriteToBuilder(sb, ",") 99 | } 100 | 101 | stringutil.WriteToBuilder(sb, strconv.Itoa(v)) 102 | } 103 | 104 | return sb.String() 105 | } 106 | 107 | // stringSliceValue represent a struct with a slice of strings that can be 108 | // defined as a flag for [flag.FlagSet]. 109 | type stringSliceValue struct { 110 | // values is the pointer to a slice of string to store parsed values. 111 | values *[]string 112 | 113 | // isSet is false until the corresponding flag is met for the first time. 114 | // When the flag is found, the default value is overwritten with zero value. 115 | isSet bool 116 | } 117 | 118 | // newStringSliceValue returns a pointer to stringSliceValue with the given 119 | // value. 120 | func newStringSliceValue(p *[]string) (out *stringSliceValue) { 121 | return &stringSliceValue{ 122 | values: p, 123 | isSet: false, 124 | } 125 | } 126 | 127 | // type check 128 | var _ flag.Value = (*stringSliceValue)(nil) 129 | 130 | // Set implements the [flag.Value] interface for *stringSliceValue. 131 | func (i *stringSliceValue) Set(s string) (err error) { 132 | if !i.isSet { 133 | i.isSet = true 134 | *i.values = []string{} 135 | } 136 | 137 | *i.values = append(*i.values, s) 138 | 139 | return nil 140 | } 141 | 142 | // String implements the [flag.Value] interface for *stringSliceValue. 143 | func (i *stringSliceValue) String() (out string) { 144 | if i == nil || i.values == nil { 145 | return "" 146 | } 147 | 148 | sb := &strings.Builder{} 149 | for idx, v := range *i.values { 150 | if idx > 0 { 151 | stringutil.WriteToBuilder(sb, ",") 152 | } 153 | 154 | stringutil.WriteToBuilder(sb, v) 155 | } 156 | 157 | return sb.String() 158 | } 159 | -------------------------------------------------------------------------------- /fastip/ping.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "net/netip" 5 | "time" 6 | 7 | "github.com/AdguardTeam/dnsproxy/internal/bootstrap" 8 | "github.com/AdguardTeam/golibs/logutil/slogutil" 9 | ) 10 | 11 | // pingTCPTimeout is a TCP connection timeout. It's higher than pingWaitTimeout 12 | // since the slower connections will be cached anyway. 13 | const pingTCPTimeout = 4 * time.Second 14 | 15 | // pingResult is the result of dialing the address. 16 | type pingResult struct { 17 | // addrPort is the address-port pair the result is related to. 18 | addrPort netip.AddrPort 19 | 20 | // latency is the duration of dialing process in milliseconds. 21 | latency uint 22 | 23 | // success is true when the dialing succeeded. 24 | success bool 25 | } 26 | 27 | // schedulePings returns the result with the fastest IP address from the cache, 28 | // if it's found, and starts pinging other IPs which are not cached or outdated. 29 | // Returns scheduled flag which indicates that some goroutines have been 30 | // scheduled. 31 | func (f *FastestAddr) schedulePings( 32 | resCh chan *pingResult, 33 | ips []netip.Addr, 34 | host string, 35 | ) (pr *pingResult, scheduled bool) { 36 | for _, ip := range ips { 37 | cached := f.cacheFind(ip) 38 | if cached == nil { 39 | scheduled = true 40 | for _, port := range f.pingPorts { 41 | go f.pingDoTCP(host, netip.AddrPortFrom(ip, uint16(port)), resCh) 42 | } 43 | 44 | continue 45 | } 46 | 47 | if cached.status == 0 && (pr == nil || cached.latencyMsec < pr.latency) { 48 | pr = &pingResult{ 49 | addrPort: netip.AddrPortFrom(ip, 0), 50 | latency: cached.latencyMsec, 51 | success: true, 52 | } 53 | } 54 | } 55 | 56 | return pr, scheduled 57 | } 58 | 59 | // pingAll pings all ips concurrently and returns as soon as the fastest one is 60 | // found or the timeout is exceeded. 61 | func (f *FastestAddr) pingAll(host string, ips []netip.Addr) (pr *pingResult) { 62 | ipN := len(ips) 63 | switch ipN { 64 | case 0: 65 | return nil 66 | case 1: 67 | return &pingResult{ 68 | addrPort: netip.AddrPortFrom(ips[0], 0), 69 | success: true, 70 | } 71 | } 72 | 73 | resCh := make(chan *pingResult, ipN*len(f.pingPorts)) 74 | pr, scheduled := f.schedulePings(resCh, ips, host) 75 | if !scheduled { 76 | if pr != nil { 77 | f.logger.Debug( 78 | "pinging all returns cached response", 79 | "host", host, 80 | "addr", pr.addrPort, 81 | ) 82 | } else { 83 | f.logger.Debug("pinging all returns nothing", "host", host) 84 | } 85 | 86 | return pr 87 | } 88 | 89 | res := f.firstSuccessRes(resCh, host) 90 | if res == nil { 91 | // In case of timeout return cached or nil. 92 | return pr 93 | } 94 | 95 | if pr == nil || res.latency <= pr.latency { 96 | // Cache wasn't found or is worse than res. 97 | return res 98 | } 99 | 100 | // Return cached result. 101 | return pr 102 | } 103 | 104 | // firstSuccessRes waits and returns the first successful ping result or nil in 105 | // case of timeout. 106 | func (f *FastestAddr) firstSuccessRes(resCh chan *pingResult, host string) (res *pingResult) { 107 | after := time.After(f.pingWaitTimeout) 108 | for { 109 | select { 110 | case res = <-resCh: 111 | f.logger.Debug( 112 | "pinging all got result", 113 | "host", host, 114 | "addr", res.addrPort, 115 | "status", res.success, 116 | ) 117 | 118 | if !res.success { 119 | continue 120 | } 121 | 122 | return res 123 | case <-after: 124 | f.logger.Debug("pinging all timed out", "host", host) 125 | 126 | return nil 127 | } 128 | } 129 | } 130 | 131 | // pingDoTCP sends the result of dialing the specified address into resCh. 132 | func (f *FastestAddr) pingDoTCP(host string, addrPort netip.AddrPort, resCh chan *pingResult) { 133 | l := f.logger.With("host", host, "addr", addrPort) 134 | l.Debug("open tcp connection") 135 | 136 | start := time.Now() 137 | conn, err := f.pinger.Dial(bootstrap.NetworkTCP, addrPort.String()) 138 | elapsed := time.Since(start) 139 | 140 | success := err == nil 141 | if success { 142 | if cErr := conn.Close(); cErr != nil { 143 | l.Debug("closing tcp connection", slogutil.KeyError, cErr) 144 | } 145 | } 146 | 147 | latency := uint(elapsed.Milliseconds()) 148 | 149 | resCh <- &pingResult{ 150 | addrPort: addrPort, 151 | latency: latency, 152 | success: success, 153 | } 154 | 155 | addr := addrPort.Addr().Unmap() 156 | if success { 157 | l.Debug("tcp ping success", "elapsed", elapsed) 158 | f.cacheAddSuccessful(addr, latency) 159 | } else { 160 | l.Debug("tcp ping failed to connect", "elapsed", elapsed, slogutil.KeyError, err) 161 | f.cacheAddFailure(addr) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/AdguardTeam/dnsproxy 2 | 3 | go 1.25.5 4 | 5 | require ( 6 | github.com/AdguardTeam/golibs v0.35.2 7 | github.com/ameshkov/dnscrypt/v2 v2.4.0 8 | github.com/ameshkov/dnsstamps v1.0.3 9 | github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 10 | github.com/bluele/gcache v0.0.2 11 | github.com/miekg/dns v1.1.68 12 | github.com/patrickmn/go-cache v2.1.0+incompatible 13 | // TODO(s.chzhen): Update after investigation of the 0-RTT bug/behavior 14 | // when TestUpstreamDoH_serverRestart/http3/second_try keeps failing. 15 | github.com/quic-go/quic-go v0.56.0 16 | github.com/stretchr/testify v1.11.1 17 | golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect 18 | golang.org/x/net v0.47.0 19 | golang.org/x/sys v0.38.0 20 | gonum.org/v1/gonum v0.16.0 21 | gopkg.in/yaml.v3 v3.0.1 22 | ) 23 | 24 | require ( 25 | cloud.google.com/go v0.123.0 // indirect 26 | cloud.google.com/go/auth v0.17.0 // indirect 27 | cloud.google.com/go/compute/metadata v0.9.0 // indirect 28 | github.com/BurntSushi/toml v1.5.0 // indirect 29 | github.com/anthropics/anthropic-sdk-go v1.19.0 // indirect 30 | github.com/ccojocar/zxcvbn-go v1.0.4 // indirect 31 | github.com/davecgh/go-spew v1.1.1 // indirect 32 | github.com/felixge/httpsnoop v1.0.4 // indirect 33 | github.com/fzipp/gocyclo v0.6.0 // indirect 34 | github.com/go-logr/logr v1.4.3 // indirect 35 | github.com/go-logr/stdr v1.2.2 // indirect 36 | github.com/golangci/misspell v0.7.0 // indirect 37 | github.com/google/go-cmp v0.7.0 // indirect 38 | github.com/google/renameio/v2 v2.0.1 // indirect 39 | github.com/google/s2a-go v0.1.9 // indirect 40 | github.com/google/uuid v1.6.0 // indirect 41 | github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect 42 | github.com/googleapis/gax-go/v2 v2.15.0 // indirect 43 | github.com/gookit/color v1.6.0 // indirect 44 | github.com/gordonklaus/ineffassign v0.2.0 // indirect 45 | github.com/gorilla/websocket v1.5.3 // indirect 46 | github.com/jstemmer/go-junit-report/v2 v2.1.0 // indirect 47 | github.com/kisielk/errcheck v1.9.0 // indirect 48 | github.com/pmezard/go-difflib v1.0.0 // indirect 49 | github.com/quic-go/qpack v0.5.1 // indirect 50 | github.com/robfig/cron/v3 v3.0.1 // indirect 51 | github.com/rogpeppe/go-internal v1.14.1 // indirect 52 | github.com/securego/gosec/v2 v2.22.10 // indirect 53 | github.com/tidwall/gjson v1.18.0 // indirect 54 | github.com/tidwall/match v1.2.0 // indirect 55 | github.com/tidwall/pretty v1.2.1 // indirect 56 | github.com/tidwall/sjson v1.2.5 // indirect 57 | github.com/uudashr/gocognit v1.2.0 // indirect 58 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 59 | go.opentelemetry.io/auto/sdk v1.2.1 // indirect 60 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect 61 | go.opentelemetry.io/otel v1.38.0 // indirect 62 | go.opentelemetry.io/otel/metric v1.38.0 // indirect 63 | go.opentelemetry.io/otel/trace v1.38.0 // indirect 64 | go.uber.org/mock v0.6.0 // indirect 65 | golang.org/x/crypto v0.45.0 // indirect 66 | golang.org/x/exp/typeparams v0.0.0-20251125195548-87e1e737ad39 // indirect 67 | golang.org/x/mod v0.30.0 // indirect 68 | golang.org/x/sync v0.18.0 // indirect 69 | golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc // indirect 70 | golang.org/x/term v0.37.0 // indirect 71 | golang.org/x/text v0.31.0 // indirect 72 | golang.org/x/tools v0.39.0 // indirect 73 | golang.org/x/vuln v1.1.4 // indirect 74 | google.golang.org/genai v1.37.0 // indirect 75 | google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect 76 | google.golang.org/grpc v1.77.0 // indirect 77 | google.golang.org/protobuf v1.36.10 // indirect 78 | honnef.co/go/tools v0.6.1 // indirect 79 | mvdan.cc/editorconfig v0.3.0 // indirect 80 | mvdan.cc/gofumpt v0.9.2 // indirect 81 | mvdan.cc/sh/v3 v3.12.0 // indirect 82 | mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15 // indirect 83 | ) 84 | 85 | // NOTE: Keep in sync with .gitignore. 86 | ignore ( 87 | ./bin/ 88 | ./test-reports/ 89 | ./tmp/ 90 | ) 91 | 92 | tool ( 93 | github.com/fzipp/gocyclo/cmd/gocyclo 94 | github.com/golangci/misspell/cmd/misspell 95 | github.com/gordonklaus/ineffassign 96 | github.com/jstemmer/go-junit-report/v2 97 | github.com/kisielk/errcheck 98 | github.com/securego/gosec/v2/cmd/gosec 99 | github.com/uudashr/gocognit/cmd/gocognit 100 | golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment 101 | golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness 102 | golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow 103 | golang.org/x/vuln/cmd/govulncheck 104 | honnef.co/go/tools/cmd/staticcheck 105 | mvdan.cc/gofumpt 106 | mvdan.cc/sh/v3/cmd/shfmt 107 | mvdan.cc/unparam 108 | ) 109 | -------------------------------------------------------------------------------- /proxy/recursiondetector_internal_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "log/slog" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/netutil" 12 | "github.com/miekg/dns" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestRecursionDetector_Check(t *testing.T) { 17 | rd := newRecursionDetector(0, 2) 18 | 19 | const ( 20 | recID = 1234 21 | recTTL = time.Hour * 1 22 | ) 23 | 24 | const nonRecID = recID * 2 25 | 26 | sampleQuestion := dns.Question{ 27 | Name: "some.domain", 28 | Qtype: dns.TypeAAAA, 29 | } 30 | sampleMsg := &dns.Msg{ 31 | MsgHdr: dns.MsgHdr{ 32 | Id: recID, 33 | }, 34 | Question: []dns.Question{sampleQuestion}, 35 | } 36 | 37 | // Manually add the message with big ttl. 38 | key := msgToSignature(sampleMsg) 39 | expire := make([]byte, uint64sz) 40 | binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano())) 41 | rd.recentRequests.Set(key, expire) 42 | 43 | // Add an expired message. 44 | sampleMsg.Id = nonRecID 45 | rd.add(sampleMsg) 46 | 47 | testCases := []struct { 48 | name string 49 | questions []dns.Question 50 | id uint16 51 | want bool 52 | }{{ 53 | name: "recurrent", 54 | questions: []dns.Question{sampleQuestion}, 55 | id: recID, 56 | want: true, 57 | }, { 58 | name: "not_suspected", 59 | questions: []dns.Question{sampleQuestion}, 60 | id: recID + 1, 61 | want: false, 62 | }, { 63 | name: "expired", 64 | questions: []dns.Question{sampleQuestion}, 65 | id: nonRecID, 66 | want: false, 67 | }, { 68 | name: "empty", 69 | questions: []dns.Question{}, 70 | id: nonRecID, 71 | want: false, 72 | }} 73 | 74 | for _, tc := range testCases { 75 | sampleMsg.Id = tc.id 76 | sampleMsg.Question = tc.questions 77 | t.Run(tc.name, func(t *testing.T) { 78 | detected := rd.check(sampleMsg) 79 | assert.Equal(t, tc.want, detected) 80 | }) 81 | } 82 | } 83 | 84 | func TestRecursionDetector_Suspect(t *testing.T) { 85 | rd := newRecursionDetector(0, 1) 86 | 87 | testCases := []struct { 88 | msg *dns.Msg 89 | name string 90 | want int 91 | }{{ 92 | msg: &dns.Msg{ 93 | MsgHdr: dns.MsgHdr{ 94 | Id: 1234, 95 | }, 96 | Question: []dns.Question{{ 97 | Name: "some.domain", 98 | Qtype: dns.TypeA, 99 | }}, 100 | }, 101 | name: "simple", 102 | want: 1, 103 | }, { 104 | msg: &dns.Msg{}, 105 | name: "unencumbered", 106 | want: 0, 107 | }} 108 | 109 | for _, tc := range testCases { 110 | t.Run(tc.name, func(t *testing.T) { 111 | t.Cleanup(rd.clear) 112 | rd.add(tc.msg) 113 | assert.Equal(t, tc.want, rd.recentRequests.Stats().Count) 114 | }) 115 | } 116 | } 117 | 118 | func BenchmarkMsgToSignature(b *testing.B) { 119 | const name = "some.not.very.long.host.name" 120 | 121 | msg := &dns.Msg{ 122 | MsgHdr: dns.MsgHdr{ 123 | Id: 1234, 124 | }, 125 | Question: []dns.Question{{ 126 | Name: name, 127 | Qtype: dns.TypeAAAA, 128 | }}, 129 | } 130 | 131 | var sigData []byte 132 | 133 | b.Run("efficient", func(b *testing.B) { 134 | b.ReportAllocs() 135 | 136 | for b.Loop() { 137 | sigData = msgToSignature(msg) 138 | } 139 | 140 | assert.NotEmpty(b, sigData) 141 | }) 142 | 143 | b.Run("inefficient", func(b *testing.B) { 144 | b.ReportAllocs() 145 | 146 | for b.Loop() { 147 | sigData = msgToSignatureSlow(msg) 148 | } 149 | 150 | assert.NotEmpty(b, sigData) 151 | }) 152 | 153 | // Most recent results: 154 | // 155 | // goos: darwin 156 | // goarch: amd64 157 | // pkg: github.com/AdguardTeam/dnsproxy/proxy 158 | // cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz 159 | // BenchmarkMsgToSignature/efficient-12 18789852 61.07 ns/op 288 B/op 1 allocs/op 160 | // BenchmarkMsgToSignature/inefficient-12 582990 2016 ns/op 624 B/op 3 allocs/op 161 | } 162 | 163 | // msgToSignatureSlow converts msg into it's signature represented in bytes in 164 | // the less efficient way. 165 | // 166 | // See [BenchmarkMsgToSignature]. 167 | func msgToSignatureSlow(msg *dns.Msg) (sig []byte) { 168 | type msgSignature struct { 169 | name [netutil.MaxDomainNameLen]byte 170 | id uint16 171 | qtype uint16 172 | } 173 | 174 | b := bytes.NewBuffer(sig) 175 | q := msg.Question[0] 176 | signature := msgSignature{ 177 | id: msg.Id, 178 | qtype: q.Qtype, 179 | } 180 | copy(signature.name[:], q.Name) 181 | if err := binary.Write(b, binary.BigEndian, signature); err != nil { 182 | slog.Default().Debug("writing message signature", slogutil.KeyError, err) 183 | } 184 | 185 | return b.Bytes() 186 | } 187 | -------------------------------------------------------------------------------- /proxy/exchange.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/AdguardTeam/dnsproxy/upstream" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/miekg/dns" 11 | "gonum.org/v1/gonum/stat/sampleuv" 12 | ) 13 | 14 | // exchangeUpstreams resolves req using the given upstreams. It returns the DNS 15 | // response, the upstream that successfully resolved the request, and the error 16 | // if any. 17 | func (p *Proxy) exchangeUpstreams( 18 | req *dns.Msg, 19 | ups []upstream.Upstream, 20 | ) (resp *dns.Msg, u upstream.Upstream, err error) { 21 | switch p.UpstreamMode { 22 | case UpstreamModeParallel: 23 | return upstream.ExchangeParallel(ups, req) 24 | case UpstreamModeFastestAddr: 25 | switch req.Question[0].Qtype { 26 | case dns.TypeA, dns.TypeAAAA: 27 | return p.fastestAddr.ExchangeFastest(req, ups) 28 | default: 29 | // Go on to the load-balancing mode. 30 | } 31 | default: 32 | // Go on to the load-balancing mode. 33 | } 34 | 35 | if len(ups) == 1 { 36 | u = ups[0] 37 | resp, _, err = p.exchange(u, req) 38 | if err != nil { 39 | return nil, nil, err 40 | } 41 | 42 | // TODO(e.burkov): Consider updating the RTT of a single upstream. 43 | 44 | return resp, u, err 45 | } 46 | 47 | w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc) 48 | var errs []error 49 | for i, ok := w.Take(); ok; i, ok = w.Take() { 50 | u = ups[i] 51 | 52 | var elapsed time.Duration 53 | resp, elapsed, err = p.exchange(u, req) 54 | if err == nil { 55 | p.updateRTT(u.Address(), elapsed) 56 | 57 | return resp, u, nil 58 | } 59 | 60 | errs = append(errs, err) 61 | 62 | // TODO(e.burkov): Use the actual configured timeout or, perhaps, the 63 | // actual measured elapsed time. 64 | p.updateRTT(u.Address(), defaultTimeout) 65 | } 66 | 67 | err = fmt.Errorf("all upstreams failed to exchange request: %w", errors.Join(errs...)) 68 | 69 | return nil, nil, err 70 | } 71 | 72 | // exchange returns the result of the DNS request exchange with the given 73 | // upstream and the elapsed time in milliseconds. It uses the given clock to 74 | // measure the request duration. 75 | func (p *Proxy) exchange( 76 | u upstream.Upstream, 77 | req *dns.Msg, 78 | ) (resp *dns.Msg, dur time.Duration, err error) { 79 | startTime := p.time.Now() 80 | resp, err = u.Exchange(req) 81 | 82 | // Don't use [time.Since] because it uses [time.Now]. 83 | dur = p.time.Now().Sub(startTime) 84 | 85 | addr := u.Address() 86 | q := &req.Question[0] 87 | if err != nil { 88 | p.logger.Error( 89 | "exchange failed", 90 | "upstream", addr, 91 | "question", q, 92 | "duration", dur, 93 | slogutil.KeyError, err, 94 | ) 95 | } else { 96 | p.logger.Debug( 97 | "exchange successfully finished", 98 | "upstream", addr, 99 | "question", q, 100 | "duration", dur, 101 | ) 102 | } 103 | 104 | return resp, dur, err 105 | } 106 | 107 | // upstreamRTTStats is the statistics for a single upstream's round-trip time. 108 | type upstreamRTTStats struct { 109 | // rttSum is the sum of all the round-trip times in microseconds. The 110 | // float64 type is used since it's capable of representing about 285 years 111 | // in microseconds. 112 | rttSum float64 113 | 114 | // reqNum is the number of requests to the upstream. The float64 type is 115 | // used since to avoid unnecessary type conversions. 116 | reqNum float64 117 | } 118 | 119 | // update returns updated stats after adding given RTT. 120 | func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) { 121 | return upstreamRTTStats{ 122 | rttSum: stats.rttSum + float64(rtt.Microseconds()), 123 | reqNum: stats.reqNum + 1, 124 | } 125 | } 126 | 127 | // calcWeights returns the slice of weights, each corresponding to the upstream 128 | // with the same index in the given slice. 129 | func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) { 130 | weights = make([]float64, 0, len(ups)) 131 | 132 | p.rttLock.Lock() 133 | defer p.rttLock.Unlock() 134 | 135 | for _, u := range ups { 136 | stat := p.upstreamRTTStats[u.Address()] 137 | if stat.rttSum == 0 || stat.reqNum == 0 { 138 | // Use 1 as the default weight. 139 | weights = append(weights, 1) 140 | } else { 141 | weights = append(weights, 1/(stat.rttSum/stat.reqNum)) 142 | } 143 | } 144 | 145 | return weights 146 | } 147 | 148 | // updateRTT updates the round-trip time in [upstreamRTTStats] for given 149 | // address. 150 | func (p *Proxy) updateRTT(address string, rtt time.Duration) { 151 | p.rttLock.Lock() 152 | defer p.rttLock.Unlock() 153 | 154 | if p.upstreamRTTStats == nil { 155 | p.upstreamRTTStats = map[string]upstreamRTTStats{} 156 | } 157 | 158 | p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt) 159 | } 160 | -------------------------------------------------------------------------------- /internal/bootstrap/resolver.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net" 7 | "net/netip" 8 | "slices" 9 | 10 | "github.com/AdguardTeam/golibs/errors" 11 | "github.com/AdguardTeam/golibs/logutil/slogutil" 12 | ) 13 | 14 | // Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] 15 | // from standard library also implements this interface. 16 | type Resolver interface { 17 | // LookupNetIP looks up the IP addresses for the given host. network should 18 | // be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be 19 | // empty even if err is nil. All the addrs must be valid. 20 | LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error) 21 | } 22 | 23 | // type check 24 | var _ Resolver = (*net.Resolver)(nil) 25 | 26 | // ParallelResolver is a slice of resolvers that are queried concurrently. The 27 | // first successful response is returned. 28 | type ParallelResolver []Resolver 29 | 30 | // type check 31 | var _ Resolver = ParallelResolver(nil) 32 | 33 | // LookupNetIP implements the [Resolver] interface for ParallelResolver. 34 | func (r ParallelResolver) LookupNetIP( 35 | ctx context.Context, 36 | network Network, 37 | host string, 38 | ) (addrs []netip.Addr, err error) { 39 | resolversNum := len(r) 40 | switch resolversNum { 41 | case 0: 42 | return nil, ErrNoResolvers 43 | case 1: 44 | return r[0].LookupNetIP(ctx, network, host) 45 | default: 46 | // Go on. 47 | } 48 | 49 | // Size of channel must accommodate results of lookups from all resolvers, 50 | // sending into channel will block otherwise. 51 | ch := make(chan any, resolversNum) 52 | for _, rslv := range r { 53 | go lookupAsync(ctx, rslv, network, host, ch) 54 | } 55 | 56 | var errs []error 57 | for range r { 58 | switch result := <-ch; result := result.(type) { 59 | case error: 60 | errs = append(errs, result) 61 | case []netip.Addr: 62 | return result, nil 63 | } 64 | } 65 | 66 | return nil, errors.Join(errs...) 67 | } 68 | 69 | // recoverAndLog is a deferred helper that recovers from a panic and logs the 70 | // panic value with the logger from context or with a default one. It sends the 71 | // recovered value into resCh. resCh must not be nil. 72 | func recoverAndLog(ctx context.Context, resCh chan<- any) { 73 | err := errors.FromRecovered(recover()) 74 | if err == nil { 75 | return 76 | } 77 | 78 | l, ok := slogutil.LoggerFromContext(ctx) 79 | if !ok { 80 | l = slog.Default() 81 | } 82 | 83 | l.ErrorContext(ctx, "recovered panic", slogutil.KeyError, err) 84 | slogutil.PrintStack(ctx, l, slog.LevelError) 85 | 86 | resCh <- err 87 | } 88 | 89 | // lookupAsync performs a lookup for ip of host with r and sends the result into 90 | // resCh. It is intended to be used as a goroutine. r and resCh must not be 91 | // nil, network should be one of [NetworkIP], [NetworkIP4] or [NetworkIP6], host 92 | // should not be empty. 93 | func lookupAsync(ctx context.Context, r Resolver, network, host string, resCh chan<- any) { 94 | // TODO(d.kolyshev): Propose better solution to recover without requiring 95 | // logger in the context. 96 | defer recoverAndLog(ctx, resCh) 97 | 98 | addrs, err := r.LookupNetIP(ctx, network, host) 99 | if err != nil { 100 | resCh <- err 101 | } else { 102 | resCh <- addrs 103 | } 104 | } 105 | 106 | // ConsequentResolver is a slice of resolvers that are queried in order until 107 | // the first successful non-empty response, as opposed to just successful 108 | // response requirement in [ParallelResolver]. 109 | type ConsequentResolver []Resolver 110 | 111 | // type check 112 | var _ Resolver = ConsequentResolver(nil) 113 | 114 | // LookupNetIP implements the [Resolver] interface for ConsequentResolver. 115 | func (resolvers ConsequentResolver) LookupNetIP( 116 | ctx context.Context, 117 | network Network, 118 | host string, 119 | ) (addrs []netip.Addr, err error) { 120 | if len(resolvers) == 0 { 121 | return nil, ErrNoResolvers 122 | } 123 | 124 | var errs []error 125 | for _, r := range resolvers { 126 | addrs, err = r.LookupNetIP(ctx, network, host) 127 | if err == nil && len(addrs) > 0 { 128 | return addrs, nil 129 | } 130 | 131 | errs = append(errs, err) 132 | } 133 | 134 | return nil, errors.Join(errs...) 135 | } 136 | 137 | // StaticResolver is a resolver which always responds with an underlying slice 138 | // of IP addresses regardless of host and network. 139 | type StaticResolver []netip.Addr 140 | 141 | // type check 142 | var _ Resolver = StaticResolver(nil) 143 | 144 | // LookupNetIP implements the [Resolver] interface for StaticResolver. It 145 | // always returns the cloned underlying slice of addresses. 146 | func (r StaticResolver) LookupNetIP( 147 | _ context.Context, 148 | _ Network, 149 | _ string, 150 | ) (addrs []netip.Addr, err error) { 151 | return slices.Clone(r), nil 152 | } 153 | -------------------------------------------------------------------------------- /proxy/pending_test.go: -------------------------------------------------------------------------------- 1 | package proxy_test 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/AdguardTeam/dnsproxy/internal/dnsproxytest" 11 | "github.com/AdguardTeam/dnsproxy/proxy" 12 | "github.com/AdguardTeam/dnsproxy/upstream" 13 | "github.com/AdguardTeam/golibs/logutil/slogutil" 14 | "github.com/AdguardTeam/golibs/netutil" 15 | "github.com/AdguardTeam/golibs/testutil" 16 | "github.com/AdguardTeam/golibs/testutil/servicetest" 17 | "github.com/miekg/dns" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | ) 21 | 22 | // TODO(e.burkov): Merge those with the ones in internal tests and move to 23 | // dnsproxytest. 24 | 25 | const ( 26 | // testTimeout is the common timeout for tests and contexts. 27 | testTimeout = 1 * time.Second 28 | 29 | // testCacheSize is the default size of the cache in bytes. 30 | testCacheSize = 64 * 1024 31 | ) 32 | 33 | var ( 34 | // localhostAnyPort is a localhost address with an arbitrary port. 35 | localhostAnyPort = netip.AddrPortFrom(netutil.IPv4Localhost(), 0) 36 | 37 | // testTrustedProxies is a set of trusted proxies that includes all 38 | // addresses used in tests. 39 | testTrustedProxies = netutil.SliceSubnetSet{ 40 | netip.MustParsePrefix("0.0.0.0/0"), 41 | netip.MustParsePrefix("::0/0"), 42 | } 43 | ) 44 | 45 | // assertEqualResponses is a helper function that checks if two DNS messages are 46 | // equal, excluding their ID. 47 | // 48 | // TODO(e.burkov): Cosider using go-cmp. 49 | func assertEqualResponses(tb testing.TB, expected, actual *dns.Msg) { 50 | tb.Helper() 51 | 52 | if expected == nil { 53 | require.Nil(tb, actual) 54 | 55 | return 56 | } 57 | 58 | require.NotNil(tb, actual) 59 | 60 | expectedHdr, actualHdr := expected.MsgHdr, actual.MsgHdr 61 | expectedHdr.Id, actualHdr.Id = 0, 0 62 | assert.Equal(tb, expectedHdr, actualHdr) 63 | 64 | assert.Equal(tb, expected.Question, actual.Question) 65 | assert.Equal(tb, expected.Answer, actual.Answer) 66 | assert.Equal(tb, expected.Ns, actual.Ns) 67 | assert.Equal(tb, expected.Extra, actual.Extra) 68 | } 69 | 70 | func TestPendingRequests(t *testing.T) { 71 | t.Parallel() 72 | 73 | const reqsNum = 100 74 | 75 | // workloadWG is used to hold the upstream response until as many requests 76 | // as possible reach the [proxy.Resolve] method. This is a best-effort 77 | // approach, so it's not strictly guaranteed to hold all requests, but it 78 | // works for the test. 79 | workloadWG := &sync.WaitGroup{} 80 | workloadWG.Add(reqsNum) 81 | 82 | once := &sync.Once{} 83 | u := &dnsproxytest.Upstream{ 84 | OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { 85 | once.Do(func() { 86 | resp = (&dns.Msg{}).SetReply(req) 87 | }) 88 | 89 | // Only allow a single request to be processed. 90 | require.NotNil(testutil.PanicT{}, resp) 91 | 92 | workloadWG.Wait() 93 | 94 | return resp, nil 95 | }, 96 | OnAddress: func() (addr string) { return "" }, 97 | OnClose: func() (err error) { return nil }, 98 | } 99 | 100 | p, err := proxy.New(&proxy.Config{ 101 | Logger: slogutil.NewDiscardLogger(), 102 | UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)}, 103 | TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)}, 104 | UpstreamConfig: &proxy.UpstreamConfig{Upstreams: []upstream.Upstream{u}}, 105 | TrustedProxies: testTrustedProxies, 106 | RatelimitSubnetLenIPv4: 24, 107 | RatelimitSubnetLenIPv6: 64, 108 | Ratelimit: 0, 109 | CacheEnabled: true, 110 | CacheSizeBytes: testCacheSize, 111 | EnableEDNSClientSubnet: true, 112 | PendingRequests: &proxy.PendingRequestsConfig{ 113 | Enabled: true, 114 | }, 115 | RequestHandler: func(prx *proxy.Proxy, dctx *proxy.DNSContext) (err error) { 116 | workloadWG.Done() 117 | 118 | return prx.Resolve(dctx) 119 | }, 120 | }) 121 | require.NoError(t, err) 122 | 123 | servicetest.RequireRun(t, p, testTimeout) 124 | 125 | addr := p.Addr(proxy.ProtoTCP).String() 126 | client := &dns.Client{ 127 | Net: string(proxy.ProtoTCP), 128 | Timeout: testTimeout, 129 | } 130 | 131 | resolveWG := &sync.WaitGroup{} 132 | responses := make([]*dns.Msg, reqsNum) 133 | errs := make([]error, reqsNum) 134 | 135 | for i := range reqsNum { 136 | resolveWG.Add(1) 137 | 138 | req := (&dns.Msg{}).SetQuestion("domain.example.", dns.TypeA) 139 | 140 | go func() { 141 | defer resolveWG.Done() 142 | 143 | reqCtx := testutil.ContextWithTimeout(t, testTimeout) 144 | responses[i], _, errs[i] = client.ExchangeContext(reqCtx, req, addr) 145 | }() 146 | } 147 | 148 | resolveWG.Wait() 149 | 150 | require.NoError(t, errs[0]) 151 | 152 | for i, resp := range responses[:len(responses)-1] { 153 | assert.Equal(t, errs[i], errs[i+1]) 154 | assertEqualResponses(t, resp, responses[i+1]) 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /upstream/dnscrypt.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log/slog" 7 | "net/url" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "github.com/AdguardTeam/golibs/errors" 13 | "github.com/ameshkov/dnscrypt/v2" 14 | "github.com/miekg/dns" 15 | ) 16 | 17 | // dnsCrypt implements the [Upstream] interface for the DNSCrypt protocol. 18 | type dnsCrypt struct { 19 | // mu protects client and serverInfo. 20 | mu *sync.RWMutex 21 | 22 | // client stores the DNSCrypt client properties. 23 | client *dnscrypt.Client 24 | 25 | // resolverInfo stores the DNSCrypt server properties. 26 | resolverInfo *dnscrypt.ResolverInfo 27 | 28 | // addr is the DNSCrypt server URL. 29 | addr *url.URL 30 | 31 | // logger is used for exchange logging. It is never nil. 32 | logger *slog.Logger 33 | 34 | // verifyCert is a callback that verifies the resolver's certificate. 35 | verifyCert func(cert *dnscrypt.Cert) (err error) 36 | 37 | // timeout is the timeout for the DNS requests. 38 | timeout time.Duration 39 | } 40 | 41 | // newDNSCrypt returns a new DNSCrypt Upstream. 42 | func newDNSCrypt(addr *url.URL, opts *Options) (u *dnsCrypt) { 43 | return &dnsCrypt{ 44 | mu: &sync.RWMutex{}, 45 | addr: addr, 46 | logger: opts.Logger, 47 | verifyCert: opts.VerifyDNSCryptCertificate, 48 | timeout: opts.Timeout, 49 | } 50 | } 51 | 52 | // type check 53 | var _ Upstream = (*dnsCrypt)(nil) 54 | 55 | // Address implements the [Upstream] interface for *dnsCrypt. 56 | func (p *dnsCrypt) Address() string { return p.addr.String() } 57 | 58 | // Exchange implements the [Upstream] interface for *dnsCrypt. 59 | func (p *dnsCrypt) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { 60 | resp, err = p.exchangeDNSCrypt(req) 61 | if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) { 62 | // If request times out, it is possible that the server configuration 63 | // has been changed. It is safe to assume that the key was rotated, see 64 | // https://dnscrypt.pl/2017/02/26/how-key-rotation-is-automated. 65 | // Re-fetch the server certificate info for new requests to not fail. 66 | _, _, err = p.resetClient() 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | return p.exchangeDNSCrypt(req) 72 | } 73 | 74 | return resp, err 75 | } 76 | 77 | // Close implements the [Upstream] interface for *dnsCrypt. 78 | func (p *dnsCrypt) Close() (err error) { 79 | return nil 80 | } 81 | 82 | // exchangeDNSCrypt attempts to send the DNS query and returns the response. 83 | func (p *dnsCrypt) exchangeDNSCrypt(req *dns.Msg) (resp *dns.Msg, err error) { 84 | var client *dnscrypt.Client 85 | var resolverInfo *dnscrypt.ResolverInfo 86 | func() { 87 | p.mu.RLock() 88 | defer p.mu.RUnlock() 89 | 90 | client, resolverInfo = p.client, p.resolverInfo 91 | }() 92 | 93 | // Check the client and server info are set and the certificate is not 94 | // expired, since any of these cases require a client reset. 95 | // 96 | // TODO(ameshkov): Consider using [time.Time] for [dnscrypt.Cert.NotAfter]. 97 | switch { 98 | case 99 | client == nil, 100 | resolverInfo == nil, 101 | resolverInfo.ResolverCert.NotAfter < uint32(time.Now().Unix()): 102 | client, resolverInfo, err = p.resetClient() 103 | if err != nil { 104 | // Don't wrap the error, because it's informative enough as is. 105 | return nil, err 106 | } 107 | default: 108 | // Go on. 109 | } 110 | 111 | resp, err = client.Exchange(req, resolverInfo) 112 | if resp != nil && resp.Truncated { 113 | q := &req.Question[0] 114 | p.logger.Debug( 115 | "dnscrypt received truncated, falling back to tcp", 116 | "addr", p.addr, 117 | "question", q, 118 | ) 119 | 120 | tcpClient := &dnscrypt.Client{Timeout: p.timeout, Net: networkTCP} 121 | resp, err = tcpClient.Exchange(req, resolverInfo) 122 | } 123 | if err == nil && resp != nil && resp.Id != req.Id { 124 | err = dns.ErrId 125 | } 126 | 127 | return resp, err 128 | } 129 | 130 | // resetClient renews the DNSCrypt client and server properties and also sets 131 | // those to nil on fail. 132 | func (p *dnsCrypt) resetClient() (client *dnscrypt.Client, ri *dnscrypt.ResolverInfo, err error) { 133 | addr := p.Address() 134 | 135 | defer func() { 136 | p.mu.Lock() 137 | defer p.mu.Unlock() 138 | 139 | p.client, p.resolverInfo = client, ri 140 | }() 141 | 142 | // Use UDP for DNSCrypt upstreams by default. 143 | client = &dnscrypt.Client{Timeout: p.timeout, Net: networkUDP} 144 | ri, err = client.Dial(addr) 145 | if err != nil { 146 | // Trigger client and server info renewal on the next request. 147 | client, ri = nil, nil 148 | err = fmt.Errorf("fetching certificate info from %s: %w", addr, err) 149 | } else if p.verifyCert != nil { 150 | err = p.verifyCert(ri.ResolverCert) 151 | if err != nil { 152 | // Trigger client and server info renewal on the next request. 153 | client, ri = nil, nil 154 | err = fmt.Errorf("verifying certificate info from %s: %w", addr, err) 155 | } 156 | } 157 | 158 | return client, ri, err 159 | } 160 | -------------------------------------------------------------------------------- /internal/cmd/cmd.go: -------------------------------------------------------------------------------- 1 | // Package cmd is the dnsproxy CLI entry point. 2 | package cmd 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "log/slog" 8 | "net/http" 9 | "net/http/pprof" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | "time" 14 | 15 | "github.com/AdguardTeam/dnsproxy/internal/version" 16 | "github.com/AdguardTeam/dnsproxy/proxy" 17 | "github.com/AdguardTeam/golibs/errors" 18 | "github.com/AdguardTeam/golibs/logutil/slogutil" 19 | "github.com/AdguardTeam/golibs/osutil" 20 | ) 21 | 22 | // Main is the entrypoint of dnsproxy CLI. Main may accept arguments, such as 23 | // embedded assets and command-line arguments. 24 | func Main() { 25 | conf, exitCode, err := parseConfig() 26 | if err != nil { 27 | _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("parsing options: %w", err)) 28 | } 29 | 30 | if conf == nil { 31 | os.Exit(exitCode) 32 | } 33 | 34 | logOutput := os.Stdout 35 | if conf.LogOutput != "" { 36 | // #nosec G302 -- Trust the file path that is given in the 37 | // configuration. 38 | logOutput, err = os.OpenFile(conf.LogOutput, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644) 39 | if err != nil { 40 | _, _ = fmt.Fprintln(os.Stderr, fmt.Errorf("cannot create a log file: %s", err)) 41 | 42 | os.Exit(osutil.ExitCodeArgumentError) 43 | } 44 | 45 | defer func() { _ = logOutput.Close() }() 46 | } 47 | 48 | lvl := slog.LevelInfo 49 | if conf.Verbose { 50 | lvl = slog.LevelDebug 51 | } 52 | 53 | l := slogutil.New(&slogutil.Config{ 54 | Output: logOutput, 55 | Format: slogutil.FormatDefault, 56 | Level: lvl, 57 | // TODO(d.kolyshev): Consider making configurable. 58 | AddTimestamp: true, 59 | }) 60 | 61 | ctx := context.Background() 62 | 63 | if conf.Pprof { 64 | runPprof(ctx, l) 65 | } 66 | 67 | err = runProxy(ctx, l, conf) 68 | if err != nil { 69 | l.ErrorContext(ctx, "running dnsproxy", slogutil.KeyError, err) 70 | 71 | // As defers are skipped in case of os.Exit, close logOutput manually. 72 | // 73 | // TODO(a.garipov): Consider making logger.Close method. 74 | if logOutput != os.Stdout { 75 | _ = logOutput.Close() 76 | } 77 | 78 | os.Exit(osutil.ExitCodeFailure) 79 | } 80 | } 81 | 82 | // runProxy starts and runs the proxy. l must not be nil. 83 | // 84 | // TODO(e.burkov): Move into separate dnssvc package. 85 | func runProxy(ctx context.Context, l *slog.Logger, conf *configuration) (err error) { 86 | var ( 87 | buildVersion = version.Version() 88 | revision = version.Revision() 89 | branch = version.Branch() 90 | commitTime = version.CommitTime() 91 | ) 92 | 93 | l.InfoContext( 94 | ctx, 95 | "dnsproxy starting", 96 | "version", buildVersion, 97 | "revision", revision, 98 | "branch", branch, 99 | "commit_time", commitTime, 100 | ) 101 | 102 | // Prepare the proxy server and its configuration. 103 | proxyConf, err := createProxyConfig(ctx, l, conf) 104 | if err != nil { 105 | return fmt.Errorf("configuring proxy: %w", err) 106 | } 107 | 108 | dnsProxy, err := proxy.New(proxyConf) 109 | if err != nil { 110 | return fmt.Errorf("creating proxy: %w", err) 111 | } 112 | 113 | // Start the proxy server. 114 | err = dnsProxy.Start(ctx) 115 | if err != nil { 116 | return fmt.Errorf("starting dnsproxy: %w", err) 117 | } 118 | 119 | // TODO(e.burkov): Use [service.SignalHandler]. 120 | signalChannel := make(chan os.Signal, 1) 121 | signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) 122 | <-signalChannel 123 | 124 | // Stopping the proxy. 125 | err = dnsProxy.Shutdown(ctx) 126 | if err != nil { 127 | return fmt.Errorf("stopping dnsproxy: %w", err) 128 | } 129 | 130 | return nil 131 | } 132 | 133 | // runPprof runs pprof server on localhost:6060. 134 | // 135 | // TODO(e.burkov): Add debugsvc. 136 | func runPprof(ctx context.Context, l *slog.Logger) { 137 | mux := http.NewServeMux() 138 | mux.HandleFunc("/debug/pprof/", pprof.Index) 139 | mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 140 | mux.HandleFunc("/debug/pprof/profile", pprof.Profile) 141 | mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 142 | mux.HandleFunc("/debug/pprof/trace", pprof.Trace) 143 | mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) 144 | mux.Handle("/debug/pprof/block", pprof.Handler("block")) 145 | mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) 146 | mux.Handle("/debug/pprof/heap", pprof.Handler("heap")) 147 | mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) 148 | mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) 149 | 150 | go func() { 151 | // TODO(d.kolyshev): Consider making configurable. 152 | const pprofAddr = "localhost:6060" 153 | l.InfoContext(ctx, "starting pprof", "addr", pprofAddr) 154 | 155 | srv := &http.Server{ 156 | Addr: pprofAddr, 157 | ReadTimeout: 60 * time.Second, 158 | Handler: mux, 159 | } 160 | 161 | err := srv.ListenAndServe() 162 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 163 | l.ErrorContext(ctx, "pprof failed to listen", "addr", pprofAddr, slogutil.KeyError, err) 164 | } 165 | }() 166 | } 167 | -------------------------------------------------------------------------------- /proxy/pending.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/golibs/errors" 8 | "github.com/AdguardTeam/golibs/syncutil" 9 | ) 10 | 11 | // pendingRequests handles identical requests that are in progress. It is used 12 | // to avoid sending the same request multiple times to the upstream server. The 13 | // implementations are: 14 | // - [defaultPendingRequests]. 15 | // - [emptyPendingRequests]. 16 | type pendingRequests interface { 17 | // queue is called for each request. It returns false if there are no 18 | // identical requests in progress. Otherwise it blocks until the first 19 | // request is completed and returns the error that occurred during its 20 | // resolution. 21 | queue(ctx context.Context, dctx *DNSContext) (loaded bool, err error) 22 | 23 | // done must be called after the request is completed, if queue returned 24 | // false for it. 25 | done(ctx context.Context, dctx *DNSContext, err error) 26 | } 27 | 28 | // defaultPendingRequests is a default implementation of the [pendingRequests] 29 | // interface. It must be created with [newDefaultPendingRequests]. 30 | type defaultPendingRequests struct { 31 | storage *syncutil.Map[string, *pendingRequest] 32 | } 33 | 34 | // pendingRequest is a structure that stores the query state and result. 35 | type pendingRequest struct { 36 | // finish is a channel that is closed when the request is completed. It is 37 | // used to block request processing for any but the first one. 38 | finish chan struct{} 39 | 40 | // resolveErr is the error that occurred during the request processing. It 41 | // may be nil. It must only be accessed for reading after the finish 42 | // channel is closed. 43 | resolveErr error 44 | 45 | // cloneDNSCtx is a clone of the DNSContext that was used to create the 46 | // pendingRequest and store its result. It must only be accessed for 47 | // reading after the finish channel is closed. 48 | cloneDNSCtx *DNSContext 49 | } 50 | 51 | // newDefaultPendingRequests creates a new instance of DefaultPendingRequests. 52 | func newDefaultPendingRequests() (pr *defaultPendingRequests) { 53 | return &defaultPendingRequests{ 54 | storage: syncutil.NewMap[string, *pendingRequest](), 55 | } 56 | } 57 | 58 | // type check 59 | var _ pendingRequests = (*defaultPendingRequests)(nil) 60 | 61 | // queue implements the [pendingRequests] interface for 62 | // [defaultPendingRequests]. 63 | func (pr *defaultPendingRequests) queue( 64 | ctx context.Context, 65 | dctx *DNSContext, 66 | ) (loaded bool, err error) { 67 | var key []byte 68 | if dctx.ReqECS != nil { 69 | ones, _ := dctx.ReqECS.Mask.Size() 70 | key = msgToKeyWithSubnet(dctx.Req, dctx.ReqECS.IP, ones) 71 | } else { 72 | key = msgToKey(dctx.Req) 73 | } 74 | 75 | req := &pendingRequest{ 76 | finish: make(chan struct{}), 77 | } 78 | 79 | pending, loaded := pr.storage.LoadOrStore(string(key), req) 80 | if !loaded { 81 | return false, nil 82 | } 83 | 84 | <-pending.finish 85 | 86 | origDNSCtx := pending.cloneDNSCtx 87 | 88 | // TODO(a.garipov): Perhaps, statistics should be calculated separately for 89 | // each request. 90 | dctx.queryStatistics = origDNSCtx.queryStatistics 91 | dctx.Upstream = origDNSCtx.Upstream 92 | if origDNSCtx.Res != nil { 93 | // TODO(e.burkov): Add cloner for DNS messages. 94 | dctx.Res = origDNSCtx.Res.Copy().SetReply(dctx.Req) 95 | } 96 | 97 | return loaded, pending.resolveErr 98 | } 99 | 100 | // done implements the [pendingRequests] interface for [defaultPendingRequests]. 101 | func (pr *defaultPendingRequests) done(ctx context.Context, dctx *DNSContext, err error) { 102 | var key []byte 103 | if dctx.ReqECS != nil { 104 | ones, _ := dctx.ReqECS.Mask.Size() 105 | key = msgToKeyWithSubnet(dctx.Req, dctx.ReqECS.IP, ones) 106 | } else { 107 | key = msgToKey(dctx.Req) 108 | } 109 | 110 | pending, ok := pr.storage.Load(string(key)) 111 | if !ok { 112 | panic(fmt.Errorf("loading pending request: key %x: %w", key, errors.ErrNoValue)) 113 | } 114 | 115 | pending.resolveErr = err 116 | 117 | cloneCtx := &DNSContext{ 118 | Upstream: dctx.Upstream, 119 | queryStatistics: dctx.queryStatistics, 120 | } 121 | 122 | if dctx.Res != nil { 123 | cloneCtx.Res = dctx.Res.Copy() 124 | } 125 | 126 | pending.cloneDNSCtx = cloneCtx 127 | 128 | pr.storage.Delete(string(key)) 129 | close(pending.finish) 130 | } 131 | 132 | // emptyPendingRequests is a no-op implementation of PendingRequests. It is 133 | // used when pending requests are not needed. 134 | type emptyPendingRequests struct{} 135 | 136 | // type check 137 | var _ pendingRequests = emptyPendingRequests{} 138 | 139 | // queue implements the [pendingRequests] interface for [emptyPendingRequests]. 140 | // It always returns false and does not block. 141 | func (emptyPendingRequests) queue(_ context.Context, _ *DNSContext) (loaded bool, err error) { 142 | return false, nil 143 | } 144 | 145 | // done implements the [pendingRequests] interface for [emptyPendingRequests]. 146 | func (emptyPendingRequests) done(_ context.Context, _ *DNSContext, _ error) {} 147 | -------------------------------------------------------------------------------- /fastip/fastest_internal_test.go: -------------------------------------------------------------------------------- 1 | package fastip 2 | 3 | import ( 4 | "net/netip" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/dnsproxy/upstream" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/AdguardTeam/golibs/testutil" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestFastestAddr_ExchangeFastest(t *testing.T) { 17 | l := slogutil.NewDiscardLogger() 18 | 19 | t.Run("error", func(t *testing.T) { 20 | const errDesired errors.Error = "this is expected" 21 | 22 | u := &errUpstream{ 23 | err: errDesired, 24 | } 25 | f := New(&Config{ 26 | Logger: l, 27 | PingWaitTimeout: DefaultPingWaitTimeout, 28 | }) 29 | 30 | resp, up, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{u}) 31 | require.Error(t, err) 32 | 33 | assert.ErrorIs(t, err, errDesired) 34 | assert.Nil(t, resp) 35 | assert.Nil(t, up) 36 | }) 37 | 38 | t.Run("one_dead", func(t *testing.T) { 39 | port := listen(t, netip.IPv4Unspecified()) 40 | 41 | f := New(&Config{ 42 | Logger: l, 43 | PingWaitTimeout: DefaultPingWaitTimeout, 44 | }) 45 | f.pingPorts = []uint{port} 46 | 47 | // The alive IP is the just created local listener's address. The dead 48 | // one is known as TEST-NET-1 which shouldn't be routed at all. See 49 | // RFC-5737 (https://datatracker.ietf.org/doc/html/rfc5737). 50 | aliveAddr := netip.MustParseAddr("127.0.0.1") 51 | 52 | alive := &testAUpstream{ 53 | recs: []*dns.A{newTestRec(t, aliveAddr)}, 54 | } 55 | dead := &testAUpstream{ 56 | recs: []*dns.A{newTestRec(t, netip.MustParseAddr("192.0.2.1"))}, 57 | } 58 | 59 | rep, ups, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{dead, alive}) 60 | require.NoError(t, err) 61 | 62 | assert.Equal(t, ups, alive) 63 | 64 | require.NotNil(t, rep) 65 | require.NotEmpty(t, rep.Answer) 66 | 67 | ip := testutil.RequireTypeAssert[*dns.A](t, rep.Answer[0]).A 68 | assert.Equal(t, aliveAddr.AsSlice(), []byte(ip)) 69 | }) 70 | 71 | t.Run("all_dead", func(t *testing.T) { 72 | f := New(&Config{ 73 | Logger: l, 74 | PingWaitTimeout: DefaultPingWaitTimeout, 75 | }) 76 | f.pingPorts = []uint{getFreePort(t)} 77 | 78 | firstIP := netip.MustParseAddr("127.0.0.1") 79 | ups := &testAUpstream{ 80 | recs: []*dns.A{ 81 | newTestRec(t, firstIP), 82 | newTestRec(t, netip.MustParseAddr("127.0.0.2")), 83 | newTestRec(t, netip.MustParseAddr("127.0.0.3")), 84 | }, 85 | } 86 | 87 | resp, _, err := f.ExchangeFastest(newTestReq(t), []upstream.Upstream{ups}) 88 | require.NoError(t, err) 89 | 90 | require.NotNil(t, resp) 91 | require.NotEmpty(t, resp.Answer) 92 | 93 | ip := testutil.RequireTypeAssert[*dns.A](t, resp.Answer[0]).A 94 | assert.Equal(t, firstIP.AsSlice(), []byte(ip)) 95 | }) 96 | } 97 | 98 | // testAUpstream is a mock err upstream structure for tests. 99 | type errUpstream struct { 100 | err error 101 | closeErr error 102 | } 103 | 104 | // Address implements the [upstream.Upstream] interface for *errUpstream. 105 | func (u *errUpstream) Address() string { 106 | return "bad_upstream" 107 | } 108 | 109 | // Exchange implements the [upstream.Upstream] interface for *errUpstream. 110 | func (u *errUpstream) Exchange(_ *dns.Msg) (*dns.Msg, error) { 111 | return nil, u.err 112 | } 113 | 114 | // Close implements the [upstream.Upstream] interface for *errUpstream. 115 | func (u *errUpstream) Close() error { 116 | return u.closeErr 117 | } 118 | 119 | // testAUpstream is a mock A upstream structure for tests. 120 | type testAUpstream struct { 121 | recs []*dns.A 122 | } 123 | 124 | // type check 125 | var _ upstream.Upstream = (*testAUpstream)(nil) 126 | 127 | // Exchange implements the [upstream.Upstream] interface for *testAUpstream. 128 | func (u *testAUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { 129 | resp = &dns.Msg{} 130 | resp.SetReply(m) 131 | 132 | for _, a := range u.recs { 133 | resp.Answer = append(resp.Answer, a) 134 | } 135 | 136 | return resp, nil 137 | } 138 | 139 | // Address implements the [upstream.Upstream] interface for *testAUpstream. 140 | func (u *testAUpstream) Address() (addr string) { 141 | return "" 142 | } 143 | 144 | // Close implements the [upstream.Upstream] interface for *testAUpstream. 145 | func (u *testAUpstream) Close() (err error) { 146 | return nil 147 | } 148 | 149 | // newTestRec returns a new test A record. 150 | func newTestRec(t *testing.T, addr netip.Addr) (rr *dns.A) { 151 | return &dns.A{ 152 | Hdr: dns.RR_Header{ 153 | Rrtype: dns.TypeA, 154 | Name: dns.Fqdn(t.Name()), 155 | Ttl: 60, 156 | }, 157 | A: addr.AsSlice(), 158 | } 159 | } 160 | 161 | // newTestReq returns a new test A request. 162 | func newTestReq(t *testing.T) (req *dns.Msg) { 163 | return &dns.Msg{ 164 | MsgHdr: dns.MsgHdr{ 165 | Id: dns.Id(), 166 | RecursionDesired: true, 167 | }, 168 | Question: []dns.Question{{ 169 | Name: dns.Fqdn(t.Name()), 170 | Qtype: dns.TypeA, 171 | Qclass: dns.ClassINET, 172 | }}, 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /upstream/parallel.go: -------------------------------------------------------------------------------- 1 | package upstream 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | 7 | "github.com/AdguardTeam/golibs/errors" 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // TODO(e.burkov): Consider using wrapped [errors.ErrNoValue] and 12 | // [errors.ErrEmptyValue] instead. 13 | const ( 14 | // ErrNoUpstreams is returned from the methods that expect at least a single 15 | // upstream to work with when no upstreams specified. 16 | ErrNoUpstreams errors.Error = "no upstream specified" 17 | 18 | // ErrNoReply is returned from [ExchangeAll] when no upstreams replied. 19 | ErrNoReply errors.Error = "no reply" 20 | ) 21 | 22 | // ExchangeParallel returns the first successful response from one of u. It 23 | // returns an error if all upstreams failed to exchange the request. 24 | func ExchangeParallel(ups []Upstream, req *dns.Msg) (reply *dns.Msg, resolved Upstream, err error) { 25 | upsNum := len(ups) 26 | switch upsNum { 27 | case 0: 28 | return nil, nil, ErrNoUpstreams 29 | case 1: 30 | return exchangeSingle(ups[0], req) 31 | default: 32 | // Go on. 33 | } 34 | 35 | resCh := make(chan any, upsNum) 36 | for _, f := range ups { 37 | // Use a copy to prevent data races, as [dns.Client] can modify the DNS 38 | // request during the exchange. 39 | // 40 | // TODO(s.chzhen): Consider using buffer pool. 41 | copyReq := req.Copy() 42 | go exchangeAsync(f, copyReq, resCh) 43 | } 44 | 45 | errs := []error{} 46 | for range ups { 47 | var r *ExchangeAllResult 48 | r, err = receiveAsyncResult(resCh) 49 | if err != nil { 50 | if !errors.Is(err, ErrNoReply) { 51 | errs = append(errs, err) 52 | } 53 | } else { 54 | return r.Resp, r.Upstream, nil 55 | } 56 | } 57 | 58 | // TODO(e.burkov): Probably it's better to return the joined error from 59 | // each upstream that returned no response, and get rid of multiple 60 | // [errors.Is] calls. This will change the behavior though. 61 | if len(errs) == 0 { 62 | return nil, nil, errors.Error("none of upstream servers responded") 63 | } 64 | 65 | return nil, nil, errors.Join(errs...) 66 | } 67 | 68 | // exchangeSingle returns a successful response and resolver if a DNS lookup was 69 | // successful. 70 | func exchangeSingle( 71 | ups Upstream, 72 | req *dns.Msg, 73 | ) (resp *dns.Msg, resolved Upstream, err error) { 74 | resp, err = ups.Exchange(req) 75 | if err != nil { 76 | return nil, nil, err 77 | } 78 | 79 | return resp, ups, err 80 | } 81 | 82 | // ExchangeAllResult is the successful result of [ExchangeAll] for a single 83 | // upstream. 84 | type ExchangeAllResult struct { 85 | // Resp is the response DNS request resolved into. 86 | Resp *dns.Msg 87 | 88 | // Upstream is the upstream that successfully resolved the request. 89 | Upstream Upstream 90 | } 91 | 92 | // ExchangeAll returns the responses from all of u. It returns an error only if 93 | // all upstreams failed to exchange the request. 94 | func ExchangeAll(ups []Upstream, req *dns.Msg) (res []ExchangeAllResult, err error) { 95 | upsNum := len(ups) 96 | switch upsNum { 97 | case 0: 98 | return nil, ErrNoUpstreams 99 | case 1: 100 | var reply *dns.Msg 101 | reply, err = ups[0].Exchange(req) 102 | if err != nil { 103 | return nil, err 104 | } else if reply == nil { 105 | return nil, ErrNoReply 106 | } 107 | 108 | return []ExchangeAllResult{{Upstream: ups[0], Resp: reply}}, nil 109 | default: 110 | // Go on. 111 | } 112 | 113 | res = make([]ExchangeAllResult, 0, upsNum) 114 | var errs []error 115 | 116 | resCh := make(chan any, upsNum) 117 | 118 | // Start exchanging concurrently. 119 | for _, u := range ups { 120 | // Use a copy to prevent data races, as [dns.Client] can modify the DNS 121 | // request during the exchange. 122 | // 123 | // TODO(s.chzhen): Consider using buffer pool. 124 | copyReq := req.Copy() 125 | go exchangeAsync(u, copyReq, resCh) 126 | } 127 | 128 | // Wait for all exchanges to finish. 129 | for range ups { 130 | var r *ExchangeAllResult 131 | r, err = receiveAsyncResult(resCh) 132 | if err != nil { 133 | errs = append(errs, err) 134 | } else { 135 | res = append(res, *r) 136 | } 137 | } 138 | 139 | if len(errs) == upsNum { 140 | return res, fmt.Errorf("all upstreams failed: %w", errors.Join(errs...)) 141 | } 142 | 143 | return slices.Clip(res), nil 144 | } 145 | 146 | // receiveAsyncResult receives a single result from resCh or an error from 147 | // errCh. It returns either a non-nil result or an error. 148 | func receiveAsyncResult(resCh chan any) (res *ExchangeAllResult, err error) { 149 | switch res := (<-resCh).(type) { 150 | case error: 151 | return nil, res 152 | case *ExchangeAllResult: 153 | if res.Resp == nil { 154 | return nil, ErrNoReply 155 | } 156 | 157 | return res, nil 158 | default: 159 | return nil, fmt.Errorf("unexpected type %T of result", res) 160 | } 161 | } 162 | 163 | // exchangeAsync tries to resolve DNS request with one upstream and sends the 164 | // result to respCh. 165 | func exchangeAsync(u Upstream, req *dns.Msg, resCh chan any) { 166 | reply, err := u.Exchange(req) 167 | if err != nil { 168 | resCh <- err 169 | } else { 170 | resCh <- &ExchangeAllResult{Resp: reply, Upstream: u} 171 | } 172 | } 173 | --------------------------------------------------------------------------------