├── .github └── ISSUE_TEMPLATE │ ├── config.yaml │ ├── feature_request_template.yaml │ └── issue_template.yaml ├── .gitignore ├── .markdownlint.json ├── CHANGELOG.md ├── COPYING ├── HACKING.md ├── Makefile ├── README.md ├── config.dist.yaml ├── doc ├── configuration.md ├── debugdns.md ├── debughttp.md ├── development.md ├── environment.md ├── externalhttp.md ├── http.md ├── metrics.md └── querylog.md ├── go.mod ├── go.sum ├── go.work ├── go.work.sum ├── internal ├── access │ ├── access.go │ ├── access_test.go │ ├── blocker.go │ ├── engine.go │ ├── engine_internal_test.go │ ├── metrics.go │ ├── profile.go │ ├── profile_test.go │ ├── profileconstructor.go │ ├── standardaccess.go │ └── standardaccess_test.go ├── agd │ ├── account.go │ ├── agd.go │ ├── agd_test.go │ ├── context.go │ ├── customdomain.go │ ├── device.go │ ├── device_test.go │ ├── devicefinder.go │ ├── devicetype.go │ ├── devicetype_test.go │ ├── dns.go │ ├── error.go │ ├── filteringgroup.go │ ├── humanid.go │ ├── humanid_test.go │ ├── os.go │ ├── profile.go │ ├── ratelimit.go │ ├── requestid.go │ ├── requestid_test.go │ ├── server.go │ ├── server_test.go │ ├── tls.go │ └── tls_test.go ├── agdcache │ ├── agdcache.go │ ├── agdcache_test.go │ ├── default.go │ ├── default_test.go │ ├── lru.go │ ├── lru_test.go │ ├── manager.go │ └── manager_test.go ├── agdhttp │ ├── agdhttp.go │ ├── agdhttp_test.go │ ├── client.go │ ├── error.go │ ├── error_test.go │ ├── url.go │ └── url_test.go ├── agdnet │ ├── agdnet.go │ ├── agdnet_example_test.go │ ├── agdnet_test.go │ ├── prefixaddr.go │ ├── prefixaddr_example_test.go │ └── prefixaddr_test.go ├── agdpasswd │ ├── authenticator.go │ └── authenticator_test.go ├── agdprotobuf │ └── pbutil.go ├── agdtest │ ├── agdtest.go │ ├── interface.go │ └── profile.go ├── agdtime │ ├── agdtime.go │ ├── agdtime_example_test.go │ ├── schedule.go │ └── schedule_example_test.go ├── agdurlflt │ ├── agdurlflt.go │ └── agdurlflt_test.go ├── agdvalidate │ └── agdvalidate.go ├── backendpb │ ├── backendpb.go │ ├── backendpb_internal_test.go │ ├── backendpb_test.go │ ├── billstat.go │ ├── billstat_test.go │ ├── customdomain.go │ ├── customdomain_test.go │ ├── device.go │ ├── devicechange.go │ ├── devicechange_internal_test.go │ ├── dns.pb.go │ ├── dns.proto │ ├── dns_grpc.pb.go │ ├── error.go │ ├── metrics.go │ ├── profile.go │ ├── profilestorage.go │ ├── profilestorage_internal_test.go │ ├── profilestorage_test.go │ ├── ratelimiter.go │ ├── ratelimiter_test.go │ ├── remotekv.go │ ├── remotekv_test.go │ ├── standardaccess.go │ ├── stats.go │ ├── ticket.go │ ├── ticketstorage.go │ ├── ticketstorage_internal_test.go │ └── ticketstorage_test.go ├── billstat │ ├── billstat.go │ ├── billstat_test.go │ ├── metrics.go │ ├── runtime.go │ └── runtime_test.go ├── bindtodevice │ ├── bindtodevice.go │ ├── bindtodevice_internal_test.go │ ├── bindtodevice_linux_internal_test.go │ ├── bindtodevice_test.go │ ├── chanlistener_linux.go │ ├── chanlistener_linux_internal_test.go │ ├── chanpacketconn_linux.go │ ├── chanpacketconn_linux_internal_test.go │ ├── connindex_linux.go │ ├── connindex_linux_internal_test.go │ ├── interfacelistener_linux.go │ ├── interfacestorage.go │ ├── listenconfig_linux.go │ ├── listenconfig_linux_internal_test.go │ ├── listenconfig_others.go │ ├── manager.go │ ├── manager_linux.go │ ├── manager_linux_test.go │ ├── manager_others.go │ ├── metrics.go │ ├── packetsession_linux.go │ ├── socket_linux.go │ └── socket_linux_internal_test.go ├── cmd │ ├── access.go │ ├── additional.go │ ├── backend.go │ ├── builder.go │ ├── cache.go │ ├── check.go │ ├── cmd.go │ ├── config.go │ ├── conncheck.go │ ├── crash.go │ ├── ddr.go │ ├── dns.go │ ├── dnscrypt.go │ ├── dnsdb.go │ ├── env.go │ ├── error.go │ ├── filter.go │ ├── filteringgroup.go │ ├── geoip.go │ ├── ifacelistener.go │ ├── network.go │ ├── plugin │ │ └── plugin.go │ ├── querylog.go │ ├── ratelimit.go │ ├── runtime.go │ ├── safebrowsing.go │ ├── server.go │ ├── servergroup.go │ ├── tls.go │ ├── upstream.go │ └── websvc.go ├── connlimiter │ ├── conn.go │ ├── counter.go │ ├── counter_internal_test.go │ ├── limiter.go │ ├── limiter_test.go │ ├── listenconfig.go │ ├── listenconfig_test.go │ ├── listener.go │ └── metrics.go ├── consul │ ├── allowlist.go │ ├── allowlist_test.go │ └── metrics.go ├── debugsvc │ ├── cache.go │ ├── debugsvc.go │ ├── debugsvc_test.go │ ├── refresh.go │ └── route.go ├── dnscheck │ ├── dnscheck.go │ ├── dnscheck_test.go │ ├── metrics.go │ ├── remotekv.go │ └── remotekv_test.go ├── dnsdb │ ├── buffer.go │ ├── dnsdb.go │ ├── http.go │ ├── http_test.go │ ├── metrics.go │ └── record.go ├── dnsmsg │ ├── blockingmode.go │ ├── cloner.go │ ├── cloner_test.go │ ├── clonerstat.go │ ├── constructor.go │ ├── constructor_test.go │ ├── dnsmsg.go │ ├── dnsmsg_test.go │ ├── error.go │ ├── error_test.go │ ├── httpscloner.go │ ├── optcloner.go │ ├── response.go │ ├── response_test.go │ ├── rrconstructor.go │ ├── structurederror.go │ ├── svcbmsg.go │ └── svcbmsg_test.go ├── dnsserver │ ├── cache │ │ ├── cache.go │ │ ├── cache_test.go │ │ └── metrics.go │ ├── context.go │ ├── context_test.go │ ├── disposer.go │ ├── dnsserver.go │ ├── dnsserver_test.go │ ├── dnsservertest │ │ ├── dnsservertest.go │ │ ├── error_unix.go │ │ ├── error_windows.go │ │ ├── handler.go │ │ ├── msg.go │ │ ├── msg_test.go │ │ ├── quictracer.go │ │ ├── server.go │ │ └── tls.go │ ├── doc.go │ ├── error.go │ ├── example_test.go │ ├── forward │ │ ├── context.go │ │ ├── error.go │ │ ├── example_test.go │ │ ├── forward.go │ │ ├── forward_test.go │ │ ├── healthcheck.go │ │ ├── healthcheck_test.go │ │ ├── metrics.go │ │ ├── network.go │ │ ├── upstream.go │ │ ├── upstreamplain.go │ │ └── upstreamplain_test.go │ ├── go.mod │ ├── go.sum │ ├── handler.go │ ├── metrics.go │ ├── middleware.go │ ├── msg.go │ ├── netext │ │ ├── listenconfig.go │ │ ├── listenconfig_unix.go │ │ ├── listenconfig_unix_test.go │ │ ├── listenconfig_windows.go │ │ ├── packetconn.go │ │ ├── packetconn_linux.go │ │ ├── packetconn_linux_internal_test.go │ │ ├── packetconn_linux_test.go │ │ └── packetconn_others.go │ ├── nonwriter.go │ ├── normalize.go │ ├── pool │ │ ├── conn.go │ │ ├── example_test.go │ │ ├── pool.go │ │ └── pool_test.go │ ├── prometheus │ │ ├── cache.go │ │ ├── cache_test.go │ │ ├── dns.go │ │ ├── forward.go │ │ ├── forward_test.go │ │ ├── helper.go │ │ ├── prometheus.go │ │ ├── prometheus_test.go │ │ ├── ratelimit.go │ │ ├── ratelimit_test.go │ │ ├── server.go │ │ └── server_test.go │ ├── protocol.go │ ├── querylog │ │ ├── querylog.go │ │ └── querylog_test.go │ ├── ratelimit │ │ ├── allowlist.go │ │ ├── backoff.go │ │ ├── counter.go │ │ ├── metrics.go │ │ ├── ratelimit.go │ │ └── ratelimit_test.go │ ├── responsewriter.go │ ├── serverbase.go │ ├── serverbench_test.go │ ├── serverdns.go │ ├── serverdns_test.go │ ├── serverdnscrypt.go │ ├── serverdnscrypt_test.go │ ├── serverdnstcp.go │ ├── serverdnsudp.go │ ├── serverhttps.go │ ├── serverhttps_test.go │ ├── serverhttpsjson.go │ ├── serverquic.go │ ├── serverquic_test.go │ ├── servertls.go │ ├── servertls_test.go │ ├── staticcheck.conf │ ├── task.go │ ├── tls.go │ └── ttl.go ├── dnssvc │ ├── config.go │ ├── context.go │ ├── dnssvc.go │ ├── dnssvc_test.go │ ├── errcoll.go │ ├── handler.go │ ├── handler_test.go │ ├── integration_test.go │ ├── internal │ │ ├── devicefinder │ │ │ ├── customdomain.go │ │ │ ├── device.go │ │ │ ├── device_test.go │ │ │ ├── devicedata.go │ │ │ ├── devicedata_test.go │ │ │ ├── devicefinder.go │ │ │ ├── devicefinder_test.go │ │ │ ├── error.go │ │ │ ├── humanid.go │ │ │ ├── humanid_test.go │ │ │ └── metrics.go │ │ ├── dnssvctest │ │ │ └── dnssvctest.go │ │ ├── initial │ │ │ ├── initial.go │ │ │ ├── initial_test.go │ │ │ ├── metrics.go │ │ │ ├── specialdomain.go │ │ │ └── specialdomain_test.go │ │ ├── internal.go │ │ ├── mainmw │ │ │ ├── debug.go │ │ │ ├── debug_internal_test.go │ │ │ ├── error.go │ │ │ ├── filter.go │ │ │ ├── filter_internal_test.go │ │ │ ├── mainmw.go │ │ │ ├── mainmw_test.go │ │ │ ├── metrics.go │ │ │ ├── record.go │ │ │ └── record_internal_test.go │ │ ├── preservice │ │ │ ├── preservice.go │ │ │ └── preservice_test.go │ │ ├── preupstream │ │ │ ├── preupstream.go │ │ │ └── preupstream_test.go │ │ └── ratelimitmw │ │ │ ├── access.go │ │ │ ├── access_test.go │ │ │ ├── limit.go │ │ │ ├── metrics.go │ │ │ ├── ratelimitmw.go │ │ │ └── requestinfo.go │ └── reexport.go ├── ecscache │ ├── cache.go │ ├── cache_internal_test.go │ ├── ecsblocklist.go │ ├── ecsblocklist_generate.go │ ├── ecscache.go │ ├── ecscache_internal_test.go │ ├── ecscache_test.go │ ├── metrics.go │ └── msg.go ├── errcoll │ ├── errcoll.go │ ├── sentry.go │ ├── sentry_test.go │ ├── writer.go │ └── writer_test.go ├── experiment │ └── experiment.go ├── filter │ ├── config.go │ ├── custom │ │ ├── custom.go │ │ └── custom_test.go │ ├── filter.go │ ├── filter_test.go │ ├── filterstorage │ │ ├── config.go │ │ ├── default.go │ │ ├── default_test.go │ │ ├── filterstorage.go │ │ ├── filterstorage_test.go │ │ ├── index.go │ │ ├── index_internal_test.go │ │ ├── refresh.go │ │ ├── refresh_test.go │ │ ├── standardaccess.go │ │ ├── standardaccess_test.go │ │ └── testdata │ │ │ └── TestStandardAccess_cache │ │ │ ├── bad_version │ │ │ └── standard_profile_access.json │ │ │ └── success │ │ │ └── standard_profile_access.json │ ├── hashprefix │ │ ├── filter.go │ │ ├── filter_test.go │ │ ├── hashprefix.go │ │ ├── hashprefix_test.go │ │ ├── matcher.go │ │ ├── matcher_test.go │ │ ├── metrics.go │ │ ├── storage.go │ │ └── storage_test.go │ ├── id.go │ ├── id_test.go │ ├── internal │ │ ├── composite │ │ │ ├── composite.go │ │ │ ├── composite_internal_test.go │ │ │ ├── composite_test.go │ │ │ ├── request.go │ │ │ └── result.go │ │ ├── filtertest │ │ │ ├── filtertest.go │ │ │ ├── hashprefix.go │ │ │ ├── refresh.go │ │ │ └── result.go │ │ ├── refreshable │ │ │ ├── refreshable.go │ │ │ └── refreshable_test.go │ │ ├── rulelist │ │ │ ├── cache.go │ │ │ ├── dnsrewrite.go │ │ │ ├── immutable.go │ │ │ ├── refreshable.go │ │ │ ├── refreshable_test.go │ │ │ ├── rulelist.go │ │ │ └── rulelist_internal_test.go │ │ ├── safesearch │ │ │ ├── safesearch.go │ │ │ └── safesearch_test.go │ │ └── serviceblock │ │ │ ├── index.go │ │ │ ├── serviceblock.go │ │ │ └── serviceblock_test.go │ ├── metrics.go │ ├── result.go │ ├── schedule.go │ ├── schedule_test.go │ └── storage.go ├── geoip │ ├── asntops.go │ ├── asntops_generate.go │ ├── country.go │ ├── country_generate.go │ ├── error.go │ ├── file.go │ ├── file_test.go │ ├── filescanner.go │ ├── geoip.go │ ├── geoip_test.go │ ├── location.go │ ├── metrics.go │ └── testdata │ │ ├── GeoIP2-City-Test.mmdb │ │ ├── GeoIP2-Country-Test.mmdb │ │ └── GeoIP2-ISP-Test.mmdb ├── metrics │ ├── access.go │ ├── backend.go │ ├── billstat.go │ ├── bindtodevice.go │ ├── connlimiter.go │ ├── consul.go │ ├── customdomain.go │ ├── devicefinder.go │ ├── dnscheck.go │ ├── dnsdb.go │ ├── dnsmsg.go │ ├── dnssvc.go │ ├── ecscache.go │ ├── filter.go │ ├── geoip.go │ ├── hashprefix.go │ ├── mainmw.go │ ├── metrics.go │ ├── metrics_test.go │ ├── profiledb.go │ ├── querylog.go │ ├── ratelimitmw.go │ ├── remotekv.go │ ├── research.go │ ├── rulestat.go │ ├── standardaccess.go │ ├── tlsconfig.go │ ├── tlsconfig_test.go │ ├── tlstickets.go │ ├── usercount.go │ ├── usercount_test.go │ └── websvc.go ├── profiledb │ ├── config.go │ ├── customdomaindb.go │ ├── default.go │ ├── default_test.go │ ├── error.go │ ├── internal │ │ ├── filecachepb │ │ │ ├── filecache.pb.go │ │ │ ├── filecache.proto │ │ │ ├── filecachepb.go │ │ │ ├── filecachepb_internal_test.go │ │ │ ├── storage.go │ │ │ ├── storage_test.go │ │ │ └── unsafe.go │ │ ├── internal.go │ │ └── profiledbtest │ │ │ └── profiledbtest.go │ ├── metrics.go │ ├── profiledb.go │ ├── profiledb_test.go │ └── storage.go ├── querylog │ ├── entry.go │ ├── fs.go │ ├── fs_test.go │ ├── metrics.go │ ├── querylog.go │ └── querylog_test.go ├── remotekv │ ├── cachekv.go │ ├── cachekv_test.go │ ├── consulkv │ │ ├── consulkv.go │ │ ├── consulkv_test.go │ │ └── error.go │ ├── keynamespace.go │ ├── keynamespace_test.go │ ├── rediskv │ │ ├── rediskv.go │ │ └── rediskv_test.go │ └── remotekv.go ├── rulestat │ ├── http.go │ ├── http_test.go │ ├── metrics.go │ ├── rulestat.go │ └── rulestat_test.go ├── tlsconfig │ ├── certindex.go │ ├── certindex_internal_test.go │ ├── customdomaindb.go │ ├── customdomaindb_test.go │ ├── customdomainindex.go │ ├── customdomainstorage.go │ ├── manager.go │ ├── manager_test.go │ ├── metrics.go │ ├── ticket.go │ ├── ticket_test.go │ ├── ticketdb.go │ ├── ticketdb_test.go │ ├── tlsconfig.go │ └── tlsconfig_test.go ├── version │ ├── norace.go │ ├── race.go │ └── version.go └── websvc │ ├── blockpage.go │ ├── blockpage_test.go │ ├── config.go │ ├── linkip.go │ ├── linkip_test.go │ ├── metrics.go │ ├── nondoh.go │ ├── nondoh_test.go │ ├── server.go │ ├── servergroup.go │ ├── static.go │ ├── static_test.go │ ├── testdata │ └── block_page.html │ ├── websvc.go │ ├── websvc_internal_test.go │ └── websvc_test.go ├── main.go ├── scripts ├── backend │ ├── dns.go │ ├── main.go │ ├── ratelimit.go │ ├── remotekv.go │ └── ticketstorage.go ├── hooks │ └── pre-commit ├── make │ ├── go-bench.sh │ ├── go-build.sh │ ├── go-deps.sh │ ├── go-fuzz.sh │ ├── go-gen.sh │ ├── go-lint.sh │ ├── go-test.sh │ ├── go-tools.sh │ ├── go-upd-tools.sh │ ├── helper.sh │ ├── md-lint.sh │ ├── sh-lint.sh │ └── txt-lint.sh └── test │ ├── bindtodevice.docker │ └── bindtodevice.sh └── staticcheck.conf /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | 'blank_issues_enabled': false 2 | 'contact_links': 3 | - 'about': > 4 | Please report filtering issues, for example advertising filters 5 | misfiring or safe browsing false positives, using the form on our 6 | website 7 | 'name': 'AdGuard filters issues' 8 | 'url': 'https://link.adtidy.org/forward.html?action=report&app=home&from=github' 9 | - 'about': > 10 | Please send requests for new blocked services and vetted filtering lists 11 | to the Hostlists Registry repository 12 | 'name': 'Blocked services and vetted filtering rule lists: AdGuard Hostlists Registry' 13 | 'url': 'https://github.com/AdguardTeam/HostlistsRegistry' 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request_template.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | name: 🌱 Feature request 4 | description: Create a feature request to help us improve AdGuard DNS. 5 | labels: [ 'feature request' ] 6 | body: 7 | - type: textarea 8 | id: what-happened 9 | attributes: 10 | label: Issue Details 11 | description: What happened? 12 | placeholder: Is your feature request related to a problem? Please add a clear and concise description of what the problem is. 13 | validations: 14 | required: false 15 | - type: textarea 16 | id: how_it_should_be 17 | attributes: 18 | label: Proposed solution 19 | description: 20 | placeholder: Describe the solution you'd like in a clear and concise manner. 21 | validations: 22 | required: false 23 | - type: textarea 24 | id: additional 25 | attributes: 26 | label: Alternative solution 27 | description: 28 | placeholder: A clear and concise description of any alternative solutions or features you've considered. 29 | validations: 30 | required: false 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # This comment is used to simplify checking local copies of the file. Bump 2 | # this number every time a significant change is made to this file. 3 | # 4 | # AdGuard-Project-Version: 3 5 | 6 | # Please, DO NOT put your text editors' temporary files here. The more are 7 | # added, the harder it gets to maintain and manage projects' gitignores. Put 8 | # them into your global gitignore file instead. 9 | # 10 | # See https://stackoverflow.com/a/7335487/1892060. 11 | # 12 | # Only build, run, and test outputs here. Sorted. With negations at the 13 | # bottom to make sure they take effect. 14 | *.exe 15 | *.out 16 | *.test 17 | /bin/ 18 | /filters/ 19 | /github-mirror/ 20 | /test-reports/ 21 | /test/ 22 | /tmp/ 23 | AdGuardDNS 24 | agdns 25 | asn.mmdb 26 | config.yaml 27 | country.mmdb 28 | profilecache.json 29 | querylog.jsonl 30 | -------------------------------------------------------------------------------- /.markdownlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "ul-indent": { 3 | "indent": 4 4 | }, 5 | "ul-style": { 6 | "style": "dash" 7 | }, 8 | "emphasis-style": { 9 | "style": "asterisk" 10 | }, 11 | "no-duplicate-heading": { 12 | "siblings_only": true 13 | }, 14 | "no-inline-html": { 15 | "allowed_elements": [ 16 | "a" 17 | ] 18 | }, 19 | "no-trailing-spaces": { 20 | "br_spaces": 0 21 | }, 22 | "line-length": false, 23 | "no-bare-urls": false, 24 | "no-emphasis-as-heading": false, 25 | "link-fragments": false 26 | } 27 | -------------------------------------------------------------------------------- /HACKING.md: -------------------------------------------------------------------------------- 1 | # Code guidelines 2 | 3 | See the [Adguard Code Guidelines](https://github.com/AdguardTeam/CodeGuidelines/). 4 | -------------------------------------------------------------------------------- /doc/metrics.md: -------------------------------------------------------------------------------- 1 | # AdGuard DNS Prometheus metrics 2 | 3 | **TODO(a.garipov):** Describe the metrics. 4 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.25.1 2 | 3 | use ( 4 | . 5 | ./internal/dnsserver 6 | ) 7 | -------------------------------------------------------------------------------- /internal/access/blocker.go: -------------------------------------------------------------------------------- 1 | package access 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/geoip" 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // Blocker is the interface to control DNS resolution access. 12 | type Blocker interface { 13 | // IsBlocked returns true if the req should be blocked. req must not be 14 | // nil, and req.Question must have one item. 15 | IsBlocked( 16 | ctx context.Context, 17 | req *dns.Msg, 18 | rAddr netip.AddrPort, 19 | l *geoip.Location, 20 | ) (isBlocked bool) 21 | } 22 | 23 | // EmptyBlocker is an empty [Blocker] implementation that does nothing. 24 | type EmptyBlocker struct{} 25 | 26 | // type check 27 | var _ Blocker = EmptyBlocker{} 28 | 29 | // IsBlocked implements the [Blocker] interface for EmptyBlocker. It always 30 | // returns false. 31 | func (EmptyBlocker) IsBlocked( 32 | _ context.Context, 33 | _ *dns.Msg, 34 | _ netip.AddrPort, 35 | _ *geoip.Location, 36 | ) (isBlocked bool) { 37 | return false 38 | } 39 | -------------------------------------------------------------------------------- /internal/access/metrics.go: -------------------------------------------------------------------------------- 1 | package access 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // ProfileMetrics is an interface used for collecting statistics related to the 9 | // profile access engine. 10 | type ProfileMetrics interface { 11 | // ObserveProfileInit records the duration taken for the initialization of 12 | // the profile access engine. 13 | ObserveProfileInit(ctx context.Context, dur time.Duration) 14 | } 15 | 16 | // EmptyProfileMetrics is the implementation of the [ProfileMetrics] interface 17 | // that does nothing. 18 | type EmptyProfileMetrics struct{} 19 | 20 | // type check 21 | var _ ProfileMetrics = EmptyProfileMetrics{} 22 | 23 | // ObserveProfileInit implements the [ProfileMetrics] interface for 24 | // EmptyProfileMetrics. 25 | func (EmptyProfileMetrics) ObserveProfileInit(_ context.Context, _ time.Duration) {} 26 | -------------------------------------------------------------------------------- /internal/access/profileconstructor.go: -------------------------------------------------------------------------------- 1 | package access 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/syncutil" 5 | "github.com/AdguardTeam/urlfilter" 6 | ) 7 | 8 | // ProfileConstructorConfig is the configuration for the [ProfileConstructor]. 9 | type ProfileConstructorConfig struct { 10 | // Metrics is used for the collection of the statistics of profile access 11 | // managers. It must not be nil. 12 | Metrics ProfileMetrics 13 | 14 | // Standard is the standard blocker for all profiles which have enabled this 15 | // feature. It must not be nil. 16 | Standard Blocker 17 | } 18 | 19 | // ProfileConstructor creates default access managers for profiles. 20 | type ProfileConstructor struct { 21 | reqPool *syncutil.Pool[urlfilter.DNSRequest] 22 | resPool *syncutil.Pool[urlfilter.DNSResult] 23 | metrics ProfileMetrics 24 | standard Blocker 25 | } 26 | 27 | // NewProfileConstructor returns a properly initialized *ProfileConstructor. 28 | // conf must not be nil. 29 | func NewProfileConstructor(conf *ProfileConstructorConfig) (c *ProfileConstructor) { 30 | return &ProfileConstructor{ 31 | reqPool: syncutil.NewPool(func() (req *urlfilter.DNSRequest) { 32 | return &urlfilter.DNSRequest{} 33 | }), 34 | resPool: syncutil.NewPool(func() (v *urlfilter.DNSResult) { 35 | return &urlfilter.DNSResult{} 36 | }), 37 | metrics: conf.Metrics, 38 | standard: conf.Standard, 39 | } 40 | } 41 | 42 | // New creates a new access manager for a profile based on the configuration. 43 | // conf must not be nil and must be valid. 44 | func (c *ProfileConstructor) New(conf *ProfileConfig) (p *DefaultProfile) { 45 | var standard Blocker = EmptyBlocker{} 46 | if conf.StandardEnabled { 47 | standard = c.standard 48 | } 49 | 50 | return newDefaultProfile(&defaultProfileConfig{ 51 | conf: conf, 52 | reqPool: c.reqPool, 53 | resPool: c.resPool, 54 | metrics: c.metrics, 55 | standard: standard, 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /internal/access/standardaccess_test.go: -------------------------------------------------------------------------------- 1 | package access_test 2 | 3 | import ( 4 | "net/netip" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/access" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" 9 | "github.com/AdguardTeam/golibs/testutil" 10 | "github.com/miekg/dns" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func BenchmarkStandardBlocker_IsBlocked(b *testing.B) { 15 | blocker := access.NewStandardBlocker(&access.StandardBlockerConfig{ 16 | BlocklistDomainRules: []string{ 17 | "block.test", 18 | }, 19 | }) 20 | 21 | ctx := testutil.ContextWithTimeout(b, testTimeout) 22 | remoteAddr := netip.AddrPort{} 23 | 24 | b.Run("pass", func(b *testing.B) { 25 | req := dnsservertest.NewReq("pass.test", dns.TypeA, dns.ClassINET) 26 | 27 | // Warmup to fill the pools and the slices. 28 | blocked := blocker.IsBlocked(ctx, req, remoteAddr, nil) 29 | require.False(b, blocked) 30 | 31 | b.ReportAllocs() 32 | for b.Loop() { 33 | blocked = blocker.IsBlocked(ctx, req, remoteAddr, nil) 34 | } 35 | 36 | require.False(b, blocked) 37 | }) 38 | 39 | b.Run("block", func(b *testing.B) { 40 | req := dnsservertest.NewReq("block.test", dns.TypeA, dns.ClassINET) 41 | 42 | // Warmup to fill the pools and the slices. 43 | blocked := blocker.IsBlocked(ctx, req, remoteAddr, nil) 44 | require.True(b, blocked) 45 | 46 | b.ReportAllocs() 47 | for b.Loop() { 48 | blocked = blocker.IsBlocked(ctx, req, remoteAddr, nil) 49 | } 50 | 51 | require.True(b, blocked) 52 | }) 53 | 54 | // Most recent results: 55 | // goos: linux 56 | // goarch: amd64 57 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/access 58 | // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics 59 | // BenchmarkStandardBlocker_IsBlocked/pass-16 3568975 335.4 ns/op 16 B/op 1 allocs/op 60 | // BenchmarkStandardBlocker_IsBlocked/block-16 3286392 364.1 ns/op 24 B/op 1 allocs/op 61 | } 62 | -------------------------------------------------------------------------------- /internal/agd/account.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agdvalidate" 7 | ) 8 | 9 | // AccountID is the ID of an account containing multiple profiles (a.k.a. DNS 10 | // servers). It is an opaque string. 11 | type AccountID string 12 | 13 | // NewAccountID converts a simple string into an AccountID and makes sure that 14 | // it's valid. This should be preferred to a simple type conversion. 15 | func NewAccountID(s string) (id AccountID, err error) { 16 | // For now, allow only the printable, non-whitespace ASCII characters. 17 | // Technically we only need to exclude carriage return and line feed 18 | // characters, but let's be more strict just in case. 19 | if i, r := agdvalidate.FirstNonIDRune(s, false); i != -1 { 20 | return "", fmt.Errorf("bad account id: bad char %q at index %d", r, i) 21 | } 22 | 23 | return AccountID(s), nil 24 | } 25 | -------------------------------------------------------------------------------- /internal/agd/agd.go: -------------------------------------------------------------------------------- 1 | // Package agd contains common entities and interfaces of AdGuard DNS. 2 | package agd 3 | -------------------------------------------------------------------------------- /internal/agd/agd_test.go: -------------------------------------------------------------------------------- 1 | package agd_test 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | ) 8 | 9 | // Common long strings for tests. 10 | // 11 | // TODO(a.garipov): Move to a new validation package. 12 | var ( 13 | testLongStr = strings.Repeat("a", 200) 14 | testLongStrUnicode = strings.Repeat("ы", 200) 15 | ) 16 | 17 | // Common IDs for tests. 18 | const ( 19 | testHumanIDStr = "My-Device-X--10" 20 | testHumanIDLowerStr = "my-device-x--10" 21 | 22 | testHumanID agd.HumanID = testHumanIDStr 23 | ) 24 | -------------------------------------------------------------------------------- /internal/agd/device_test.go: -------------------------------------------------------------------------------- 1 | package agd_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/AdguardTeam/golibs/testutil" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewDeviceName(t *testing.T) { 12 | t.Parallel() 13 | 14 | testCases := []struct { 15 | name string 16 | in string 17 | wantErrMsg string 18 | }{{ 19 | name: "empty", 20 | in: "", 21 | wantErrMsg: "", 22 | }, { 23 | name: "normal", 24 | in: "Normal name", 25 | wantErrMsg: "", 26 | }, { 27 | name: "normal_unicode", 28 | in: "Нормальное имя", 29 | wantErrMsg: "", 30 | }, { 31 | name: "too_long", 32 | in: testLongStr, 33 | wantErrMsg: `bad device name "` + testLongStr + `": too long: got 200 runes, max 128`, 34 | }, { 35 | name: "too_long_unicode", 36 | in: testLongStrUnicode, 37 | wantErrMsg: `bad device name "` + testLongStrUnicode + 38 | `": too long: got 200 runes, max 128`, 39 | }} 40 | 41 | for _, tc := range testCases { 42 | t.Run(tc.name, func(t *testing.T) { 43 | t.Parallel() 44 | 45 | n, err := agd.NewDeviceName(tc.in) 46 | testutil.AssertErrorMsg(t, tc.wantErrMsg, err) 47 | if tc.wantErrMsg == "" && tc.in != "" { 48 | assert.NotEmpty(t, n) 49 | } else { 50 | assert.Empty(t, n) 51 | } 52 | }) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /internal/agd/devicetype_test.go: -------------------------------------------------------------------------------- 1 | package agd_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/AdguardTeam/golibs/testutil" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestDeviceTypeFromDNS(t *testing.T) { 12 | t.Parallel() 13 | 14 | testCases := []struct { 15 | name string 16 | in string 17 | wantErrMsg string 18 | want agd.DeviceType 19 | }{{ 20 | name: "success", 21 | in: "adr", 22 | wantErrMsg: "", 23 | want: agd.DeviceTypeAndroid, 24 | }, { 25 | name: "success_case", 26 | in: "Adr", 27 | wantErrMsg: "", 28 | want: agd.DeviceTypeAndroid, 29 | }, { 30 | name: "too_long", 31 | in: "windows", 32 | wantErrMsg: `bad device type "windows": too long: got 7 bytes, max 3`, 33 | want: agd.DeviceTypeNone, 34 | }, { 35 | name: "too_small", 36 | in: "x", 37 | wantErrMsg: `bad device type "x": too short: got 1 bytes, min 3`, 38 | want: agd.DeviceTypeNone, 39 | }, { 40 | name: "none", 41 | in: "(none)", 42 | wantErrMsg: `bad device type "(none)": too long: got 6 bytes, max 3`, 43 | want: agd.DeviceTypeNone, 44 | }, { 45 | name: "unknown", 46 | in: "xxx", 47 | wantErrMsg: `bad device type "xxx": unknown device type`, 48 | want: agd.DeviceTypeNone, 49 | }} 50 | 51 | for _, tc := range testCases { 52 | t.Run(tc.name, func(t *testing.T) { 53 | t.Parallel() 54 | 55 | got, err := agd.DeviceTypeFromDNS(tc.in) 56 | assert.Equal(t, tc.want, got) 57 | testutil.AssertErrorMsg(t, tc.wantErrMsg, err) 58 | }) 59 | } 60 | } 61 | 62 | func TestDeviceType_String(t *testing.T) { 63 | t.Parallel() 64 | 65 | assert.Equal(t, "(none)", agd.DeviceTypeNone.String()) 66 | assert.Equal(t, "adr", agd.DeviceTypeAndroid.String()) 67 | assert.Equal(t, "!bad_device_type_42", agd.DeviceType(42).String()) 68 | } 69 | -------------------------------------------------------------------------------- /internal/agd/dns.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 4 | 5 | // Common DNS Message Constants, Types, And Utilities 6 | 7 | // Protocol is a DNS protocol. It is reexported here to lower the degree of 8 | // dependency on the dnsserver module. 9 | type Protocol = dnsserver.Protocol 10 | 11 | // Protocol value constants. They are reexported here to lower the degree of 12 | // dependency on the dnsserver module. 13 | const ( 14 | // NOTE: DO NOT change the numerical values or use iota, because other 15 | // packages and modules may depend on the numerical values. These numerical 16 | // values are a part of the API. 17 | 18 | ProtoInvalid = dnsserver.ProtoInvalid 19 | ProtoDNS = dnsserver.ProtoDNS 20 | ProtoDoH = dnsserver.ProtoDoH 21 | ProtoDoQ = dnsserver.ProtoDoQ 22 | ProtoDoT = dnsserver.ProtoDoT 23 | ProtoDNSCrypt = dnsserver.ProtoDNSCrypt 24 | ) 25 | -------------------------------------------------------------------------------- /internal/agd/error.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Common Errors 8 | 9 | // ArgumentError is returned by functions when a value of an argument is 10 | // invalid. 11 | type ArgumentError struct { 12 | // Name is the name of the argument. 13 | Name string 14 | 15 | // Message is an optional additional message. 16 | Message string 17 | } 18 | 19 | // Error implements the error interface for *ArgumentError. 20 | func (err *ArgumentError) Error() (msg string) { 21 | if err.Message == "" { 22 | return fmt.Sprintf("argument %s is invalid", err.Name) 23 | } 24 | 25 | return fmt.Sprintf("argument %s is invalid: %s", err.Name, err.Message) 26 | } 27 | -------------------------------------------------------------------------------- /internal/agd/filteringgroup.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import "github.com/AdguardTeam/AdGuardDNS/internal/filter" 4 | 5 | // FilteringGroup represents a set of filtering settings. 6 | // 7 | // TODO(a.garipov): Extract the pre-filtering booleans and logic into a new 8 | // package. 9 | type FilteringGroup struct { 10 | // FilterConfig is the configuration of the filters used for this filtering 11 | // group. It must not be nil. 12 | FilterConfig *filter.ConfigGroup 13 | 14 | // ID is the unique ID of this filtering group. It must be set. 15 | ID FilteringGroupID 16 | 17 | // BlockChromePrefetch shows if the Chrome prefetch proxy feature should be 18 | // disabled for requests using this filtering group. 19 | BlockChromePrefetch bool 20 | 21 | // BlockFirefoxCanary shows if Firefox canary domain is blocked for 22 | // requests using this filtering group. 23 | BlockFirefoxCanary bool 24 | 25 | // BlockPrivateRelay shows if Apple Private Relay is blocked for requests 26 | // using this filtering group. 27 | BlockPrivateRelay bool 28 | } 29 | 30 | // FilteringGroupID is the ID of a filter group. It is an opaque string. 31 | type FilteringGroupID string 32 | -------------------------------------------------------------------------------- /internal/agd/os.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import ( 4 | "io/fs" 5 | "os" 6 | ) 7 | 8 | // OS-Related Constants 9 | 10 | // DefaultWOFlags is the default set of flags for opening a write-only files. 11 | const DefaultWOFlags = os.O_APPEND | os.O_CREATE | os.O_WRONLY 12 | 13 | // DefaultPerm is the default set of permissions for non-executable files. Be 14 | // strict and allow only reading and writing for the file, and only to the user. 15 | const DefaultPerm fs.FileMode = 0o600 16 | 17 | // DefaultDirPerm is the default set of permissions for directories. 18 | const DefaultDirPerm fs.FileMode = 0o700 19 | -------------------------------------------------------------------------------- /internal/agd/requestid.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/golibs/mathutil/randutil" 8 | ) 9 | 10 | // RequestIDLen is the length of a [RequestID] in bytes. A RequestID is 11 | // currently a random 16-byte (128-bit) number. 12 | const RequestIDLen = 16 13 | 14 | // RequestID is the ID of a request. It is an opaque, randomly generated 15 | // string. API users should not rely on it being pseudorandom or 16 | // cryptographically random. 17 | type RequestID [RequestIDLen]byte 18 | 19 | // requestIDRand is used to create [RequestID]s. 20 | // 21 | // TODO(a.garipov): Consider making a struct instead of using one global source. 22 | var requestIDRand = randutil.NewReader(randutil.MustNewSeed()) 23 | 24 | // NewRequestID returns a new pseudorandom RequestID. Prefer this to manual 25 | // conversion from other string types. 26 | func NewRequestID() (id RequestID) { 27 | _, err := requestIDRand.Read(id[:]) 28 | if err != nil { 29 | panic(fmt.Errorf("generating random request id: %w", err)) 30 | } 31 | 32 | return id 33 | } 34 | 35 | // type check 36 | var _ fmt.Stringer = RequestID{} 37 | 38 | // String implements the [fmt.Stringer] interface for RequestID. 39 | func (id RequestID) String() (s string) { 40 | enc := base64.URLEncoding.WithPadding(base64.NoPadding) 41 | n := enc.EncodedLen(RequestIDLen) 42 | idData64 := make([]byte, n) 43 | enc.Encode(idData64, id[:]) 44 | 45 | return string(idData64) 46 | } 47 | -------------------------------------------------------------------------------- /internal/agd/requestid_test.go: -------------------------------------------------------------------------------- 1 | package agd_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func BenchmarkNewRequestID(b *testing.B) { 11 | var reqID agd.RequestID 12 | 13 | b.ReportAllocs() 14 | for b.Loop() { 15 | reqID = agd.NewRequestID() 16 | } 17 | 18 | require.NotEmpty(b, reqID) 19 | 20 | // Most recent results: 21 | // 22 | // goos: darwin 23 | // goarch: amd64 24 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/agd 25 | // cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz 26 | // BenchmarkNewRequestID-12 41978553 27.96 ns/op 0 B/op 0 allocs/op 27 | } 28 | -------------------------------------------------------------------------------- /internal/agd/tls.go: -------------------------------------------------------------------------------- 1 | package agd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | "github.com/AdguardTeam/golibs/validate" 8 | ) 9 | 10 | // maxCertificateNameLen is the maximum length of a [CertificateName]. 11 | const maxCertificateNameLen = 32 12 | 13 | // CertificateName is the unique name identifying the TLS certificate. 14 | type CertificateName string 15 | 16 | // NewCertificateName creates a new CertificateName from the given string. 17 | func NewCertificateName(str string) (name CertificateName, err error) { 18 | if str == "" { 19 | return "", errors.ErrEmptyValue 20 | } 21 | 22 | err = validate.InRange("length", len(str), 1, maxCertificateNameLen) 23 | if err != nil { 24 | // Don't wrap the error, since it's informative enough as is. 25 | return "", err 26 | } 27 | 28 | for i, r := range str { 29 | // Don't use [agdvalidate.FirstNonIDRune] as it allows invalid symbols 30 | // for file names. 31 | if !isValidCertNameRune(r) { 32 | return "", fmt.Errorf("at index %d: bad symbol: %q", i, r) 33 | } 34 | } 35 | 36 | return CertificateName(str), nil 37 | } 38 | 39 | // isValidCertNameRune returns true if the given rune is valid to be used in a 40 | // [CertificateName]. It essentially allows alphanumeric symbols, underscores, 41 | // and hyphens. 42 | func isValidCertNameRune(r rune) (ok bool) { 43 | switch { 44 | case 45 | r >= 'a' && r <= 'z', 46 | r >= 'A' && r <= 'Z', 47 | r >= '0' && r <= '9', 48 | r == '_', r == '-': 49 | return true 50 | default: 51 | return false 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /internal/agd/tls_test.go: -------------------------------------------------------------------------------- 1 | package agd_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/AdguardTeam/golibs/testutil" 8 | ) 9 | 10 | func TestNewCertificateName(t *testing.T) { 11 | t.Parallel() 12 | 13 | testCases := []struct { 14 | name string 15 | value string 16 | wantErrMsg string 17 | }{{ 18 | name: "empty", 19 | value: "", 20 | wantErrMsg: "empty value", 21 | }, { 22 | name: "bad_symbol", 23 | value: "not valid", 24 | wantErrMsg: "at index 3: bad symbol: ' '", 25 | }, { 26 | name: "bad_base_name", 27 | value: "bad/base_name", 28 | wantErrMsg: "at index 3: bad symbol: '/'", 29 | }, { 30 | name: "too_long", 31 | value: "this_is_a_very_long_certificate_name", 32 | wantErrMsg: "length: out of range: must be no greater than 32, got 36", 33 | }, { 34 | name: "ok", 35 | value: "ok_cert_name", 36 | wantErrMsg: "", 37 | }, { 38 | name: "ok_numeric", 39 | value: "1234567890", 40 | wantErrMsg: "", 41 | }} 42 | 43 | for _, tc := range testCases { 44 | t.Run(tc.name, func(t *testing.T) { 45 | t.Parallel() 46 | 47 | _, err := agd.NewCertificateName(tc.value) 48 | testutil.AssertErrorMsg(t, tc.wantErrMsg, err) 49 | }) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /internal/agdcache/agdcache.go: -------------------------------------------------------------------------------- 1 | // Package agdcache contains cache interfaces, helpers, and implementations. 2 | package agdcache 3 | 4 | import ( 5 | "time" 6 | ) 7 | 8 | // Interface is the cache interface. 9 | type Interface[K, T any] interface { 10 | // Set sets key and val as cache pair. 11 | Set(key K, val T) 12 | 13 | // SetWithExpire sets key and val as cache pair with expiration time. 14 | SetWithExpire(key K, val T, expiration time.Duration) 15 | 16 | // Get gets val from the cache using key. 17 | Get(key K) (val T, ok bool) 18 | 19 | // Clearer completely clears cache. 20 | Clearer 21 | 22 | // Len returns the number of items in the cache. 23 | Len() (n int) 24 | } 25 | 26 | // Clearer is a partial cache interface. 27 | type Clearer interface { 28 | // Clear completely clears cache. 29 | Clear() 30 | } 31 | 32 | // Empty is an [Interface] implementation that does nothing. 33 | type Empty[K, T any] struct{} 34 | 35 | // type check 36 | var _ Interface[any, any] = Empty[any, any]{} 37 | 38 | // Set implements the [Interface] interface for Empty. 39 | func (c Empty[K, T]) Set(key K, val T) {} 40 | 41 | // SetWithExpire implements the [Interface] interface for Empty. 42 | func (c Empty[K, T]) SetWithExpire(key K, val T, expiration time.Duration) {} 43 | 44 | // Get implements the [Interface] interface for Empty. 45 | func (c Empty[K, T]) Get(key K) (val T, ok bool) { 46 | return val, false 47 | } 48 | 49 | // type check 50 | var _ Clearer = Empty[any, any]{} 51 | 52 | // Clear implements the [Interface] interface for Empty. 53 | func (c Empty[K, T]) Clear() {} 54 | 55 | // Len implements the [Interface] interface for Empty. n may include items that 56 | // have expired, but have not yet been cleaned up. 57 | func (c Empty[K, T]) Len() (n int) { 58 | return 0 59 | } 60 | -------------------------------------------------------------------------------- /internal/agdcache/agdcache_test.go: -------------------------------------------------------------------------------- 1 | package agdcache_test 2 | 3 | import "time" 4 | 5 | // Constants used in tests. 6 | const ( 7 | key = "key" 8 | val = 123 9 | 10 | nonExistingKey = "nonExistingKey" 11 | 12 | expDuration = 100 * time.Millisecond 13 | ) 14 | -------------------------------------------------------------------------------- /internal/agdcache/manager_test.go: -------------------------------------------------------------------------------- 1 | package agdcache_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agdcache" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestManager(t *testing.T) { 11 | const ( 12 | cacheID = "cacheID" 13 | cacheIDNonExisting = "non_existing_cache_id" 14 | ) 15 | 16 | isCleared := false 17 | mc := &mockClearer{ 18 | onClear: func() { 19 | isCleared = true 20 | }, 21 | } 22 | 23 | m := agdcache.NewDefaultManager() 24 | m.Add(cacheID, mc) 25 | m.ClearByID(cacheID) 26 | 27 | assert.True(t, isCleared) 28 | 29 | assert.NotPanics(t, func() { m.ClearByID(cacheIDNonExisting) }) 30 | } 31 | 32 | // mockClearer is the mock implementation of the [agdcache.Clearer] for tests. 33 | type mockClearer struct { 34 | onClear func() 35 | } 36 | 37 | // type check 38 | var _ agdcache.Clearer = (*mockClearer)(nil) 39 | 40 | // Clear implements the [agdcache.Clearer] interface for *mockClearer. 41 | func (mc *mockClearer) Clear() { 42 | mc.onClear() 43 | } 44 | -------------------------------------------------------------------------------- /internal/agdhttp/agdhttp.go: -------------------------------------------------------------------------------- 1 | // Package agdhttp contains common constants, functions, and types for working 2 | // with HTTP. 3 | // 4 | // TODO(a.garipov): Consider moving all or some of this stuff to module golibs. 5 | package agdhttp 6 | 7 | import "github.com/AdguardTeam/AdGuardDNS/internal/version" 8 | 9 | // Common Constants, Functions And Types 10 | 11 | // HTTP header value constants. 12 | const ( 13 | HdrValApplicationJSON = "application/json" 14 | HdrValApplicationOctetStream = "application/octet-stream" 15 | HdrValGzip = "gzip" 16 | HdrValTextCSV = "text/csv" 17 | HdrValTextHTML = "text/html" 18 | HdrValTextPlain = "text/plain" 19 | HdrValWildcard = "*" 20 | ) 21 | 22 | // RobotsDisallowAll is a predefined robots disallow all content. 23 | const RobotsDisallowAll = "User-agent: *\nDisallow: /\n" 24 | 25 | // NotFoundString is the text used by the standard library in 26 | // [http.NotFoundHandler]. 27 | // 28 | // TODO(a.garipov): Move to golibs. 29 | const NotFoundString = "404 page not found\n" 30 | 31 | // userAgent is the cached User-Agent string for AdGuardDNS. 32 | var userAgent = version.Name() + "/" + version.Version() 33 | 34 | // UserAgent returns the ID of the service as a User-Agent string. It can also 35 | // be used as the value of the Server HTTP header. 36 | func UserAgent() (ua string) { 37 | return userAgent 38 | } 39 | -------------------------------------------------------------------------------- /internal/agdhttp/agdhttp_test.go: -------------------------------------------------------------------------------- 1 | package agdhttp_test 2 | 3 | import "github.com/AdguardTeam/golibs/errors" 4 | 5 | // Common Testing Constants And Variables 6 | 7 | // testSrv is the common Server header value for tests. 8 | const testSrv = "testServer/1.0" 9 | 10 | // testError is the common error for tests. 11 | const testError errors.Error = "test error" 12 | -------------------------------------------------------------------------------- /internal/agdhttp/url.go: -------------------------------------------------------------------------------- 1 | package agdhttp 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | 7 | "github.com/AdguardTeam/golibs/errors" 8 | "github.com/AdguardTeam/golibs/netutil/urlutil" 9 | ) 10 | 11 | // ParseHTTPURL parses an absolute URL and makes sure that it is a valid HTTP(S) 12 | // URL. All returned errors will have the underlying type [*url.Error]. 13 | // 14 | // TODO(a.garipov): Define as a type? 15 | func ParseHTTPURL(s string) (u *url.URL, err error) { 16 | u, err = url.Parse(s) 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | switch { 22 | case u.Host == "": 23 | return nil, &url.Error{ 24 | Op: "parse", 25 | URL: s, 26 | Err: errors.Error("empty host"), 27 | } 28 | case !urlutil.IsValidHTTPURLScheme(u.Scheme): 29 | return nil, &url.Error{ 30 | Op: "parse", 31 | URL: s, 32 | Err: fmt.Errorf("bad scheme %q", u.Scheme), 33 | } 34 | default: 35 | return u, nil 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /internal/agdnet/agdnet_example_test.go: -------------------------------------------------------------------------------- 1 | package agdnet_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 7 | ) 8 | 9 | func ExampleAndroidMetricDomainReplacement() { 10 | printResult := func(input string) { 11 | fmt.Printf("%-42q: %q\n", input, agdnet.AndroidMetricDomainReplacement(input)) 12 | } 13 | 14 | anAndroidDomain := "12345678-dnsotls-ds.metric.gstatic.com." 15 | printResult(anAndroidDomain) 16 | 17 | anAndroidDomain = "123456-dnsohttps-ds.metric.gstatic.com." 18 | printResult(anAndroidDomain) 19 | 20 | notAndroidDomain := "example.com" 21 | printResult(notAndroidDomain) 22 | 23 | // Output: 24 | // "12345678-dnsotls-ds.metric.gstatic.com." : "00000000-dnsotls-ds.metric.gstatic.com." 25 | // "123456-dnsohttps-ds.metric.gstatic.com." : "000000-dnsohttps-ds.metric.gstatic.com." 26 | // "example.com" : "" 27 | } 28 | -------------------------------------------------------------------------------- /internal/agdnet/agdnet_test.go: -------------------------------------------------------------------------------- 1 | package agdnet_test 2 | 3 | import "net/netip" 4 | 5 | // Common subnets for tests. 6 | var ( 7 | testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24") 8 | testSubnetIPv6 = netip.MustParsePrefix("1234:5678::/64") 9 | ) 10 | -------------------------------------------------------------------------------- /internal/agdnet/prefixaddr.go: -------------------------------------------------------------------------------- 1 | package agdnet 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/netip" 7 | ) 8 | 9 | // PrefixNetAddr is a wrapper around netip.Prefix that makes it a [net.Addr]. 10 | type PrefixNetAddr struct { 11 | Prefix netip.Prefix 12 | Net string 13 | Port uint16 14 | } 15 | 16 | // type check 17 | var _ net.Addr = (*PrefixNetAddr)(nil) 18 | 19 | // String implements the [net.Addr] interface for *PrefixNetAddr. It returns 20 | // either a simple IP:port address or one with the prefix length appended after 21 | // a slash, depending on whether or not subnet is a single-address subnet. This 22 | // is done to make using the IP:port part easier to split off using functions 23 | // like [strings.Cut]. 24 | func (addr *PrefixNetAddr) String() (n string) { 25 | p := addr.Prefix 26 | addrPort := netip.AddrPortFrom(p.Addr(), addr.Port) 27 | if p.IsSingleIP() { 28 | return addrPort.String() 29 | } 30 | 31 | return fmt.Sprintf("%s/%d", addrPort, p.Bits()) 32 | } 33 | 34 | // Network implements the [net.Addr] interface for *PrefixNetAddr. 35 | func (addr *PrefixNetAddr) Network() (n string) { return addr.Net } 36 | -------------------------------------------------------------------------------- /internal/agdnet/prefixaddr_example_test.go: -------------------------------------------------------------------------------- 1 | package agdnet_test 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 8 | ) 9 | 10 | func ExamplePrefixNetAddr_string() { 11 | fmt.Println(&agdnet.PrefixNetAddr{ 12 | Prefix: netip.MustParsePrefix("1.2.3.4/32"), 13 | Net: "", 14 | Port: 5678, 15 | }) 16 | fmt.Println(&agdnet.PrefixNetAddr{ 17 | Prefix: netip.MustParsePrefix("1.2.3.0/24"), 18 | Net: "", 19 | Port: 5678, 20 | }) 21 | 22 | // Output: 23 | // 1.2.3.4:5678 24 | // 1.2.3.0:5678/24 25 | } 26 | -------------------------------------------------------------------------------- /internal/agdnet/prefixaddr_test.go: -------------------------------------------------------------------------------- 1 | package agdnet_test 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestPrefixAddr(t *testing.T) { 13 | const ( 14 | port = 56789 15 | network = "tcp" 16 | ) 17 | 18 | fullPrefix := netip.MustParsePrefix("1.2.3.4/32") 19 | 20 | testCases := []struct { 21 | in *agdnet.PrefixNetAddr 22 | want string 23 | name string 24 | }{{ 25 | in: &agdnet.PrefixNetAddr{ 26 | Prefix: testSubnetIPv4, 27 | Net: network, 28 | Port: port, 29 | }, 30 | want: fmt.Sprintf( 31 | "%s/%d", 32 | netip.AddrPortFrom(testSubnetIPv4.Addr(), port), testSubnetIPv4.Bits(), 33 | ), 34 | name: "ipv4", 35 | }, { 36 | in: &agdnet.PrefixNetAddr{ 37 | Prefix: testSubnetIPv6, 38 | Net: network, 39 | Port: port, 40 | }, 41 | want: fmt.Sprintf( 42 | "%s/%d", 43 | netip.AddrPortFrom(testSubnetIPv6.Addr(), port), testSubnetIPv6.Bits(), 44 | ), 45 | name: "ipv6", 46 | }, { 47 | in: &agdnet.PrefixNetAddr{ 48 | Prefix: fullPrefix, 49 | Net: network, 50 | Port: port, 51 | }, 52 | want: netip.AddrPortFrom(fullPrefix.Addr(), port).String(), 53 | name: "ipv4_full", 54 | }} 55 | 56 | for _, tc := range testCases { 57 | t.Run(tc.name, func(t *testing.T) { 58 | assert.Equal(t, tc.want, tc.in.String()) 59 | assert.Equal(t, network, tc.in.Network()) 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /internal/agdpasswd/authenticator.go: -------------------------------------------------------------------------------- 1 | // Package agdpasswd contains authentication utils. 2 | package agdpasswd 3 | 4 | import ( 5 | "context" 6 | 7 | "golang.org/x/crypto/bcrypt" 8 | ) 9 | 10 | // Authenticator represents a password authenticator. 11 | type Authenticator interface { 12 | // Authenticate returns true if the given passwd is allowed. 13 | Authenticate(ctx context.Context, passwd []byte) (ok bool) 14 | } 15 | 16 | // AllowAuthenticator is an empty authenticator implementation that always 17 | // grants access, regardless of any restrictions. 18 | type AllowAuthenticator struct{} 19 | 20 | // type check 21 | var _ Authenticator = AllowAuthenticator{} 22 | 23 | // Authenticate implements the [Authenticator] interface for AllowAuthenticator. 24 | func (AllowAuthenticator) Authenticate(_ context.Context, _ []byte) (ok bool) { 25 | return true 26 | } 27 | 28 | // PasswordHashBcrypt is the Bcrypt implementation of [Authenticator]. 29 | type PasswordHashBcrypt struct { 30 | // bytes contains the password hash. 31 | bytes []byte 32 | } 33 | 34 | // NewPasswordHashBcrypt returns a new bcrypt hashed password authenticator. 35 | func NewPasswordHashBcrypt(hashedPassword []byte) (p *PasswordHashBcrypt) { 36 | return &PasswordHashBcrypt{bytes: hashedPassword} 37 | } 38 | 39 | // PasswordHash returns password hash bytes slice. 40 | func (p *PasswordHashBcrypt) PasswordHash() (b []byte) { 41 | return p.bytes 42 | } 43 | 44 | // type check 45 | var _ Authenticator = (*PasswordHashBcrypt)(nil) 46 | 47 | // Authenticate implements the [Authenticator] interface for 48 | // *PasswordHashBcrypt. 49 | func (p *PasswordHashBcrypt) Authenticate(_ context.Context, passwd []byte) (ok bool) { 50 | return bcrypt.CompareHashAndPassword(p.bytes, passwd) == nil 51 | } 52 | -------------------------------------------------------------------------------- /internal/agdpasswd/authenticator_test.go: -------------------------------------------------------------------------------- 1 | package agdpasswd_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdpasswd" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | "golang.org/x/crypto/bcrypt" 11 | ) 12 | 13 | func TestPasswordHashBcrypt_Authenticate(t *testing.T) { 14 | t.Parallel() 15 | 16 | const passwd = "mypassword" 17 | 18 | hash, err := bcrypt.GenerateFromPassword([]byte(passwd), 0) 19 | require.NoError(t, err) 20 | 21 | authenticator := agdpasswd.NewPasswordHashBcrypt(hash) 22 | 23 | testCases := []struct { 24 | want assert.BoolAssertionFunc 25 | name string 26 | pass string 27 | }{{ 28 | want: assert.True, 29 | name: "success", 30 | pass: passwd, 31 | }, { 32 | want: assert.False, 33 | name: "fail", 34 | pass: "an-other-passwd", 35 | }} 36 | 37 | for _, tc := range testCases { 38 | t.Run(tc.name, func(t *testing.T) { 39 | t.Parallel() 40 | 41 | tc.want(t, authenticator.Authenticate(context.Background(), []byte(tc.pass))) 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/agdprotobuf/pbutil.go: -------------------------------------------------------------------------------- 1 | // Package agdprotobuf contains protobuf utils. 2 | package agdprotobuf 3 | 4 | import ( 5 | "fmt" 6 | "net/netip" 7 | ) 8 | 9 | // ByteSlicesToIPs converts a slice of byte slices into a slice of netip.Addrs. 10 | func ByteSlicesToIPs(data [][]byte) (ips []netip.Addr, err error) { 11 | if data == nil { 12 | return nil, nil 13 | } 14 | 15 | ips = make([]netip.Addr, 0, len(data)) 16 | for i, ipData := range data { 17 | var ip netip.Addr 18 | err = ip.UnmarshalBinary(ipData) 19 | if err != nil { 20 | return nil, fmt.Errorf("ip at index %d: %w", i, err) 21 | } 22 | 23 | ips = append(ips, ip) 24 | } 25 | 26 | return ips, nil 27 | } 28 | -------------------------------------------------------------------------------- /internal/agdtest/profile.go: -------------------------------------------------------------------------------- 1 | package agdtest 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/access" 8 | gocmp "github.com/google/go-cmp/cmp" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // AssertEqualProfile compares two values while ignoring internal details of 13 | // some fields of profiles, such as pools. 14 | func AssertEqualProfile(tb testing.TB, want, got any) (ok bool) { 15 | tb.Helper() 16 | 17 | exportAll := gocmp.Exporter(func(_ reflect.Type) (ok bool) { return true }) 18 | 19 | defAccCmp := gocmp.Comparer(func(want, got *access.DefaultProfile) (ok bool) { 20 | return gocmp.Equal(want.Config(), got.Config(), exportAll) 21 | }) 22 | 23 | diff := gocmp.Diff(want, got, defAccCmp, exportAll) 24 | if diff == "" { 25 | return true 26 | } 27 | 28 | // Use assert.Failf instead of tb.Errorf to get a more consistent error 29 | // message. 30 | return assert.Failf(tb, "not equal", "got: %+v\nwant: %+v\ndiff: %s", got, want, diff) 31 | } 32 | -------------------------------------------------------------------------------- /internal/agdtime/agdtime.go: -------------------------------------------------------------------------------- 1 | // Package agdtime contains time-related utilities. 2 | package agdtime 3 | 4 | import ( 5 | "encoding" 6 | "time" 7 | 8 | "github.com/AdguardTeam/golibs/errors" 9 | ) 10 | 11 | // Location is a wrapper around time.Location that can de/serialize itself from 12 | // and to JSON. 13 | // 14 | // TODO(a.garipov): Move to timeutil. 15 | type Location struct { 16 | time.Location 17 | } 18 | 19 | // LoadLocation is a wrapper around [time.LoadLocation] that returns a 20 | // *Location instead. 21 | func LoadLocation(name string) (l *Location, err error) { 22 | tl, err := time.LoadLocation(name) 23 | if err != nil { 24 | // Don't wrap the error, because this function is a wrapper. 25 | return nil, err 26 | } 27 | 28 | return &Location{ 29 | Location: *tl, 30 | }, nil 31 | } 32 | 33 | // UTC returns [time.UTC] as *Location. 34 | func UTC() (l *Location) { 35 | return &Location{ 36 | Location: *time.UTC, 37 | } 38 | } 39 | 40 | // type check 41 | var _ encoding.TextMarshaler = Location{} 42 | 43 | // MarshalText implements the [encoding.TextMarshaler] interface for Location. 44 | func (l Location) MarshalText() (text []byte, err error) { 45 | return []byte(l.String()), nil 46 | } 47 | 48 | var _ encoding.TextUnmarshaler = (*Location)(nil) 49 | 50 | // UnmarshalText implements the [encoding.TextUnmarshaler] interface for 51 | // *Location. 52 | func (l *Location) UnmarshalText(b []byte) (err error) { 53 | defer func() { err = errors.Annotate(err, "unmarshaling location: %w") }() 54 | 55 | tl, err := time.LoadLocation(string(b)) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | l.Location = *tl 61 | 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /internal/agdtime/agdtime_example_test.go: -------------------------------------------------------------------------------- 1 | package agdtime_test 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" 9 | ) 10 | 11 | func ExampleLocation() { 12 | var req struct { 13 | TimeZone *agdtime.Location `json:"tmz"` 14 | } 15 | 16 | l, err := agdtime.LoadLocation("Europe/Brussels") 17 | if err != nil { 18 | panic(err) 19 | } 20 | 21 | req.TimeZone = l 22 | buf := &bytes.Buffer{} 23 | err = json.NewEncoder(buf).Encode(req) 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | fmt.Print(buf) 29 | 30 | req.TimeZone = nil 31 | err = json.NewDecoder(buf).Decode(&req) 32 | if err != nil { 33 | panic(err) 34 | } 35 | 36 | fmt.Printf("%+v\n", req) 37 | 38 | // Output: 39 | // {"tmz":"Europe/Brussels"} 40 | // {TimeZone:Europe/Brussels} 41 | } 42 | -------------------------------------------------------------------------------- /internal/agdtime/schedule.go: -------------------------------------------------------------------------------- 1 | package agdtime 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/AdguardTeam/golibs/timeutil" 7 | ) 8 | 9 | // ExponentialSchedule is a [timeutil.Schedule] that exponentially increases the 10 | // time until the next event until it reaches the maximum. 11 | // 12 | // TODO(a.garipov): Consider moving to golibs. 13 | type ExponentialSchedule struct { 14 | current time.Duration 15 | max time.Duration 16 | base uint64 17 | } 18 | 19 | // NewExponentialSchedule returns a new properly initialized 20 | // *ExponentialSchedule. 21 | func NewExponentialSchedule(initial, max time.Duration, base uint64) (s *ExponentialSchedule) { 22 | return &ExponentialSchedule{ 23 | current: initial, 24 | max: max, 25 | base: base, 26 | } 27 | } 28 | 29 | // type check 30 | var _ timeutil.Schedule = (*ExponentialSchedule)(nil) 31 | 32 | // UntilNext implements the [timeutil.Schedule] interface for 33 | // *ExponentialSchedule. 34 | func (s *ExponentialSchedule) UntilNext(_ time.Time) (d time.Duration) { 35 | d = s.current 36 | 37 | // A negative s.current means that the previous call has overflown 38 | // time.Duration, which means it's above max. 39 | if d >= s.max || d < 0 { 40 | return s.max 41 | } 42 | 43 | // #nosec G115 -- The overflow is processed above. 44 | s.current = s.current * time.Duration(s.base) 45 | 46 | return d 47 | } 48 | -------------------------------------------------------------------------------- /internal/agdtime/schedule_example_test.go: -------------------------------------------------------------------------------- 1 | package agdtime_test 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdtime" 8 | ) 9 | 10 | func ExampleExponentialSchedule() { 11 | s := agdtime.NewExponentialSchedule(1*time.Second, 1*time.Minute, 2) 12 | 13 | for range 10 { 14 | fmt.Println(s.UntilNext(time.Time{})) 15 | } 16 | 17 | // Output: 18 | // 1s 19 | // 2s 20 | // 4s 21 | // 8s 22 | // 16s 23 | // 32s 24 | // 1m0s 25 | // 1m0s 26 | // 1m0s 27 | // 1m0s 28 | } 29 | -------------------------------------------------------------------------------- /internal/agdurlflt/agdurlflt.go: -------------------------------------------------------------------------------- 1 | // Package agdurlflt contains utilities for the urlfilter module. 2 | package agdurlflt 3 | 4 | import ( 5 | "bytes" 6 | "unicode" 7 | ) 8 | 9 | // RulesLen returns the length of the byte buffer necessary to write ruleStrs, 10 | // separated by a newline, to it. 11 | func RulesLen[S ~string](ruleStrs []S) (l int) { 12 | if len(ruleStrs) == 0 { 13 | return 0 14 | } 15 | 16 | for _, s := range ruleStrs { 17 | l += len(s) + len("\n") 18 | } 19 | 20 | return l 21 | } 22 | 23 | // RulesToBytes writes ruleStrs to a byte slice and returns it. 24 | // 25 | // TODO(a.garipov): Consider moving to golibs or urlfilter. 26 | func RulesToBytes[S ~string](ruleStrs []S) (b []byte) { 27 | l := RulesLen(ruleStrs) 28 | if l == 0 { 29 | return nil 30 | } 31 | 32 | buf := bytes.NewBuffer(make([]byte, 0, l)) 33 | for _, s := range ruleStrs { 34 | _, _ = buf.WriteString(string(s)) 35 | _ = buf.WriteByte('\n') 36 | } 37 | 38 | return buf.Bytes() 39 | } 40 | 41 | // RulesToBytesLower writes lowercase versions of ruleStrs to a byte slice and 42 | // returns it. 43 | // 44 | // NOTE: Do not use this for rules that can include dnsrewrite modifiers, since 45 | // their DNS types are case-sensitive. 46 | // 47 | // TODO(a.garipov): Consider moving to golibs or urlfilter. 48 | func RulesToBytesLower(ruleStrs []string) (b []byte) { 49 | l := RulesLen(ruleStrs) 50 | if l == 0 { 51 | return nil 52 | } 53 | 54 | buf := bytes.NewBuffer(make([]byte, 0, l)) 55 | for _, s := range ruleStrs { 56 | for _, c := range s { 57 | // NOTE: Theoretically there might be cases where a lowercase 58 | // version of a rune takes up more space or less space than an 59 | // uppercase one, but that doesn't matter since we're using a 60 | // bytes.Buffer and rules generally are ASCII-only. 61 | _, _ = buf.WriteRune(unicode.ToLower(c)) 62 | } 63 | 64 | _ = buf.WriteByte('\n') 65 | } 66 | 67 | return buf.Bytes() 68 | } 69 | -------------------------------------------------------------------------------- /internal/agdurlflt/agdurlflt_test.go: -------------------------------------------------------------------------------- 1 | package agdurlflt_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdurlflt" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // testRulesStrs are the common filtering rules for tests. 12 | var testRulesStrs = []string{ 13 | `||blocked.example^`, 14 | `@@||allowed.example^`, 15 | `||dnsrewrite.example^$dnsrewrite=192.0.2.1`, 16 | } 17 | 18 | // testRulesData is the data of [testRulesStrs] as bytes. 19 | var testRulesData = []byte(strings.Join(testRulesStrs, "\n") + "\n") 20 | 21 | func BenchmarkRulesToBytes(b *testing.B) { 22 | var got []byte 23 | 24 | b.ReportAllocs() 25 | for b.Loop() { 26 | got = agdurlflt.RulesToBytes(testRulesStrs) 27 | } 28 | 29 | require.Equal(b, testRulesData, got) 30 | 31 | // Most recent results: 32 | // 33 | // goos: linux 34 | // goarch: amd64 35 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/agdurlflt 36 | // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics 37 | // BenchmarkRulesToBytes-16 7925872 145.3 ns/op 96 B/op 1 allocs/op 38 | } 39 | 40 | func BenchmarkRulesToBytesLower(b *testing.B) { 41 | var got []byte 42 | 43 | b.ReportAllocs() 44 | for b.Loop() { 45 | got = agdurlflt.RulesToBytesLower(testRulesStrs) 46 | } 47 | 48 | require.Equal(b, testRulesData, got) 49 | 50 | // Most recent results: 51 | // 52 | // goos: linux 53 | // goarch: amd64 54 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/agdurlflt 55 | // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics 56 | // BenchmarkRulesToBytesLower-16 1000000 1188 ns/op 96 B/op 1 allocs/op 57 | } 58 | -------------------------------------------------------------------------------- /internal/agdvalidate/agdvalidate.go: -------------------------------------------------------------------------------- 1 | // Package agdvalidate contains validation utilities. 2 | package agdvalidate 3 | 4 | import "fmt" 5 | 6 | // FirstNonIDRune returns the first non-printable or non-ASCII rune and its 7 | // index. If includeSlashes is true, it also looks for slashes. If there are 8 | // no such runes, i is -1. 9 | func FirstNonIDRune(s string, excludeSlashes bool) (i int, r rune) { 10 | for i, r = range s { 11 | if r < '!' || r > '~' || (excludeSlashes && r == '/') { 12 | return i, r 13 | } 14 | } 15 | 16 | return -1, 0 17 | } 18 | 19 | // Unit name constants. 20 | const ( 21 | UnitByte = "bytes" 22 | UnitRune = "runes" 23 | ) 24 | 25 | // Inclusion returns an error if n is greater than maxVal or less than minVal. 26 | // unitName is used for error messages, see [UnitByte] and the related 27 | // constants. 28 | func Inclusion(n, minVal, maxVal int, unitName string) (err error) { 29 | switch { 30 | case n > maxVal: 31 | return fmt.Errorf("too long: got %d %s, max %d", n, unitName, maxVal) 32 | case n < minVal: 33 | return fmt.Errorf("too short: got %d %s, min %d", n, unitName, minVal) 34 | default: 35 | return nil 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /internal/backendpb/backendpb.go: -------------------------------------------------------------------------------- 1 | // Package backendpb contains the protobuf structures for the backend API. 2 | // 3 | // TODO(a.garipov): Move the generated code into a separate package. 4 | package backendpb 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "net/url" 10 | 11 | "github.com/AdguardTeam/golibs/httphdr" 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/credentials" 14 | "google.golang.org/grpc/credentials/insecure" 15 | "google.golang.org/grpc/metadata" 16 | ) 17 | 18 | // newClient returns new properly initialized gRPC connection to the API server. 19 | func newClient(apiURL *url.URL) (client *grpc.ClientConn, err error) { 20 | var creds credentials.TransportCredentials 21 | switch s := apiURL.Scheme; s { 22 | case "grpc": 23 | creds = insecure.NewCredentials() 24 | case "grpcs": 25 | // Use a nil [tls.Config] to get the default TLS configuration. 26 | creds = credentials.NewTLS(nil) 27 | default: 28 | return nil, fmt.Errorf("bad grpc url scheme %q", s) 29 | } 30 | 31 | conn, err := grpc.NewClient( 32 | apiURL.Host, 33 | grpc.WithDisableServiceConfig(), 34 | grpc.WithTransportCredentials(creds), 35 | ) 36 | if err != nil { 37 | return nil, fmt.Errorf("dialing: %w", err) 38 | } 39 | 40 | // Immediately make a connection attempt, since the constructor is often 41 | // called right before the initial refresh. 42 | conn.Connect() 43 | 44 | return conn, nil 45 | } 46 | 47 | // ctxWithAuthentication adds the API key authentication header to the outgoing 48 | // request context if apiKey is not empty. If it is empty, ctx is parent. 49 | func ctxWithAuthentication(parent context.Context, apiKey string) (ctx context.Context) { 50 | ctx = parent 51 | if apiKey == "" { 52 | return ctx 53 | } 54 | 55 | // TODO(a.garipov): Better validations for the key. 56 | md := metadata.Pairs(httphdr.Authorization, fmt.Sprintf("Bearer %s", apiKey)) 57 | 58 | return metadata.NewOutgoingContext(ctx, md) 59 | } 60 | -------------------------------------------------------------------------------- /internal/backendpb/stats.go: -------------------------------------------------------------------------------- 1 | package backendpb 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "time" 7 | ) 8 | 9 | // profilesCallStats is a stateful structure that collects and reports 10 | // statistics about a [ProfileStorage.Profiles] call. 11 | type profilesCallStats struct { 12 | logger *slog.Logger 13 | 14 | recvStart time.Time 15 | decStart time.Time 16 | 17 | initRecv time.Duration 18 | totalRecv time.Duration 19 | totalDec time.Duration 20 | 21 | numRecv int 22 | 23 | isFullSync bool 24 | } 25 | 26 | // startRecv starts the receive timer. 27 | func (s *profilesCallStats) startRecv() { 28 | s.recvStart = time.Now() 29 | } 30 | 31 | // endRecv ends the receive timer and records the results. 32 | func (s *profilesCallStats) endRecv() { 33 | d := time.Since(s.recvStart) 34 | if s.numRecv == 0 { 35 | // Count the initial receive separately, since it is often not 36 | // representative of an average receive, because this is when gRPC 37 | // actually performs the call. 38 | s.initRecv = d 39 | } else { 40 | s.totalRecv += d 41 | } 42 | 43 | s.numRecv++ 44 | } 45 | 46 | // startDec starts the decoding timer. 47 | func (s *profilesCallStats) startDec() { 48 | s.decStart = time.Now() 49 | } 50 | 51 | // endDec ends the decoding timer and records the results. 52 | func (s *profilesCallStats) endDec() { 53 | s.totalDec += time.Since(s.decStart) 54 | } 55 | 56 | // report writes the statistics to the log and the metrics. 57 | func (s *profilesCallStats) report(ctx context.Context, mtrc ProfileDBMetrics) { 58 | lvl := slog.LevelDebug 59 | if s.isFullSync { 60 | lvl = slog.LevelInfo 61 | } 62 | 63 | if s.numRecv == 0 { 64 | s.logger.Log(ctx, lvl, "no recv") 65 | 66 | return 67 | } 68 | 69 | n := time.Duration(s.numRecv) 70 | avgRecv := s.totalRecv / n 71 | avgDec := s.totalDec / n 72 | 73 | s.logger.Log(ctx, lvl, "recv stats", "total", s.totalRecv, "avg", avgRecv, "init", s.initRecv) 74 | s.logger.Log(ctx, lvl, "decode stats", "total", s.totalDec, "avg", avgDec) 75 | 76 | mtrc.UpdateStats(ctx, avgRecv, avgDec) 77 | } 78 | -------------------------------------------------------------------------------- /internal/backendpb/ticket.go: -------------------------------------------------------------------------------- 1 | package backendpb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/validate" 10 | ) 11 | 12 | // ticketsToInternal converts received session tickets to internal format, 13 | // mapping each ticket to its name. 14 | func (ts *TicketStorage) ticketsToInternal( 15 | ctx context.Context, 16 | received []*SessionTicket, 17 | ) (tickets map[tlsconfig.SessionTicketName]tlsconfig.SessionTicket, err error) { 18 | err = validate.NotEmptySlice("received", received) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | tickets = make(map[tlsconfig.SessionTicketName]tlsconfig.SessionTicket, len(received)) 24 | 25 | var errs []error 26 | for i, recTicket := range received { 27 | name, ticket, convErr := recTicket.toInternal() 28 | ts.metrics.SetTicketStatus(ctx, string(name), ts.clock.Now(), convErr) 29 | if convErr != nil { 30 | convErr = fmt.Errorf("loaded session ticket: at index %d: %w", i, convErr) 31 | errs = append(errs, convErr) 32 | 33 | continue 34 | } 35 | 36 | tickets[name] = ticket 37 | } 38 | 39 | return tickets, errors.Join(errs...) 40 | } 41 | 42 | // toInternal converts the received session ticket to internal format. It 43 | // always returns non-nil nt, but it may be invalid if the conversion fails. 44 | func (x *SessionTicket) toInternal() ( 45 | name tlsconfig.SessionTicketName, 46 | ticket tlsconfig.SessionTicket, 47 | err error, 48 | ) { 49 | var errs []error 50 | 51 | name, err = tlsconfig.NewSessionTicketName(x.GetName()) 52 | if err != nil { 53 | // Don't wrap the error, since it's informative enough as is. 54 | errs = append(errs, err) 55 | } 56 | 57 | ticket, err = tlsconfig.NewSessionTicket(x.GetData()) 58 | if err != nil { 59 | errs = append(errs, fmt.Errorf("ticket: %w", err)) 60 | } 61 | 62 | return name, ticket, errors.Join(errs...) 63 | } 64 | -------------------------------------------------------------------------------- /internal/backendpb/ticketstorage_internal_test.go: -------------------------------------------------------------------------------- 1 | package backendpb 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/tlsconfig" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestTicketStorage_CalcTicketsHash(t *testing.T) { 11 | t.Parallel() 12 | 13 | testCases := []struct { 14 | tickets tlsconfig.NamedTickets 15 | name string 16 | want float64 17 | }{{ 18 | tickets: tlsconfig.NamedTickets{ 19 | "foo": tlsconfig.SessionTicket{1, 2, 3, 4}, 20 | "bar": tlsconfig.SessionTicket{5, 6, 7, 8}, 21 | }, 22 | name: "data", 23 | want: 2.5599110696847e+14, 24 | }, { 25 | tickets: tlsconfig.NamedTickets{"foo": tlsconfig.SessionTicket{}}, 26 | name: "no_data", 27 | want: 1.76700443131662e+14, 28 | }, { 29 | tickets: tlsconfig.NamedTickets{}, 30 | name: "no_tickets", 31 | want: 0, 32 | }} 33 | 34 | for _, tc := range testCases { 35 | t.Run(tc.name, func(t *testing.T) { 36 | t.Parallel() 37 | 38 | assert.Equal(t, tc.want, calcTicketsHash(tc.tickets)) 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /internal/billstat/billstat_test.go: -------------------------------------------------------------------------------- 1 | package billstat_test 2 | 3 | import "time" 4 | 5 | // testTimeout is the timeout for common test operations. 6 | const testTimeout = 1 * time.Second 7 | -------------------------------------------------------------------------------- /internal/billstat/metrics.go: -------------------------------------------------------------------------------- 1 | package billstat 2 | 3 | import "context" 4 | 5 | // Metrics is an interface that is used for the collection of the billing 6 | // statistics. 7 | type Metrics interface { 8 | // SetRecordCount sets the total number of records stored. 9 | SetRecordCount(ctx context.Context, count int) 10 | 11 | // HandleUploadDuration handles the upload duration of billing statistics. 12 | HandleUploadDuration(ctx context.Context, dur float64, err error) 13 | } 14 | 15 | // EmptyMetrics is the implementation of the [Metrics] interface that does 16 | // nothing. 17 | type EmptyMetrics struct{} 18 | 19 | // type check 20 | var _ Metrics = EmptyMetrics{} 21 | 22 | // SetRecordCount implements the [Metrics] interface for EmptyMetrics. 23 | func (EmptyMetrics) SetRecordCount(_ context.Context, _ int) {} 24 | 25 | // HandleUploadDuration implements the [Metrics] interface for EmptyMetrics. 26 | func (EmptyMetrics) HandleUploadDuration(_ context.Context, _ float64, _ error) {} 27 | -------------------------------------------------------------------------------- /internal/bindtodevice/bindtodevice.go: -------------------------------------------------------------------------------- 1 | // Package bindtodevice contains an implementation of the [netext.ListenConfig] 2 | // interface that uses Linux's SO_BINDTODEVICE socket option to be able to bind 3 | // to a device. 4 | package bindtodevice 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | ) 10 | 11 | // ID is the unique identifier of an interface listener. 12 | type ID string 13 | 14 | // unit is a convenient alias for struct{}. 15 | type unit = struct{} 16 | 17 | // Convenient constants containing type names for error reporting using 18 | // [wrapConnError]. 19 | const ( 20 | tnChanPConn = "chanPacketConn" 21 | tnChanLsnr = "chanListener" 22 | ) 23 | 24 | // wrapConnError is a helper for creating informative errors. 25 | func wrapConnError(typeName, methodName string, laddr net.Addr, err error) (wrapped error) { 26 | return fmt.Errorf("bindtodevice: %s %s: %s: %w", typeName, laddr, methodName, err) 27 | } 28 | -------------------------------------------------------------------------------- /internal/bindtodevice/bindtodevice_internal_test.go: -------------------------------------------------------------------------------- 1 | package bindtodevice 2 | 3 | import ( 4 | "net" 5 | "net/netip" 6 | "time" 7 | ) 8 | 9 | // testTimeout is a common timeout for tests. 10 | const testTimeout = 1 * time.Second 11 | 12 | // Common addresses for tests. 13 | var ( 14 | testLAddr = &net.UDPAddr{ 15 | IP: net.IP{1, 2, 3, 4}, 16 | Port: 53, 17 | } 18 | testRAddr = &net.UDPAddr{ 19 | IP: net.IP{5, 6, 7, 8}, 20 | Port: 1234, 21 | } 22 | ) 23 | 24 | // Common subnets for tests. 25 | var ( 26 | testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24") 27 | ) 28 | -------------------------------------------------------------------------------- /internal/bindtodevice/bindtodevice_linux_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "net" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // newTestChanListener is a helper for creating a *chanListener for tests. 13 | func newTestChanListener(tb testing.TB, conns chan net.Conn) (l *chanListener) { 14 | tb.Helper() 15 | 16 | l = newChanListener(EmptyMetrics{}, conns, testSubnetIPv4, testLAddr) 17 | require.NotNil(tb, l) 18 | 19 | return l 20 | } 21 | 22 | // newTestChanPacketConn is a helper for creating a *chanPacketConn for tests. 23 | func newTestChanPacketConn( 24 | tb testing.TB, 25 | sessions chan *packetSession, 26 | writeReqs chan *packetConnWriteReq, 27 | ) (c *chanPacketConn) { 28 | tb.Helper() 29 | 30 | c = newChanPacketConn( 31 | EmptyMetrics{}, 32 | sessions, 33 | testSubnetIPv4, 34 | writeReqs, 35 | "", 36 | testLAddr, 37 | ) 38 | require.NotNil(tb, c) 39 | 40 | return c 41 | } 42 | -------------------------------------------------------------------------------- /internal/bindtodevice/bindtodevice_test.go: -------------------------------------------------------------------------------- 1 | package bindtodevice_test 2 | 3 | import ( 4 | "net/netip" 5 | "time" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" 8 | ) 9 | 10 | // testTimeout is a common timeout for tests. 11 | const testTimeout = 1 * time.Second 12 | 13 | // Common interface listener IDs for tests 14 | const ( 15 | testID1 bindtodevice.ID = "id1" 16 | testID2 bindtodevice.ID = "id2" 17 | ) 18 | 19 | // Common port numbers for tests. 20 | // 21 | // TODO(a.garipov): Figure a way to use 0 in most real tests. 22 | const ( 23 | testPort1 uint16 = 12345 24 | testPort2 uint16 = 12346 25 | ) 26 | 27 | // testIfaceName is the common network interface name for tests. 28 | const testIfaceName = "not_a_real_iface0" 29 | 30 | // testSubnetIPv4 is a common subnet for tests. 31 | var testSubnetIPv4 = netip.MustParsePrefix("1.2.3.0/24") 32 | -------------------------------------------------------------------------------- /internal/bindtodevice/chanlistener_linux_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "net" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestChanListener_Accept(t *testing.T) { 14 | conns := make(chan net.Conn, 1) 15 | l := newTestChanListener(t, conns) 16 | 17 | // A simple way to have a distinct net.Conn without actually implementing 18 | // the entire interface. 19 | c := struct { 20 | net.Conn 21 | Value int 22 | }{ 23 | Value: 1, 24 | } 25 | 26 | conns <- c 27 | 28 | got, err := l.Accept() 29 | require.NoError(t, err) 30 | 31 | assert.Equal(t, c, got) 32 | } 33 | 34 | func TestChanListener_Addr(t *testing.T) { 35 | l := newTestChanListener(t, nil) 36 | got := l.Addr() 37 | assert.Equal(t, testLAddr, got) 38 | } 39 | 40 | func TestChanListener_Close(t *testing.T) { 41 | conns := make(chan net.Conn) 42 | l := newTestChanListener(t, conns) 43 | err := l.Close() 44 | assert.NoError(t, err) 45 | 46 | err = l.Close() 47 | assert.Error(t, err) 48 | } 49 | -------------------------------------------------------------------------------- /internal/bindtodevice/connindex_linux_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "net/netip" 7 | "slices" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestSubnetCompare(t *testing.T) { 14 | want := []netip.Prefix{ 15 | netip.MustParsePrefix("1.0.0.0/24"), 16 | netip.MustParsePrefix("1.2.3.0/24"), 17 | netip.MustParsePrefix("1.0.0.0/16"), 18 | netip.MustParsePrefix("1.2.0.0/16"), 19 | } 20 | got := []netip.Prefix{ 21 | netip.MustParsePrefix("1.0.0.0/16"), 22 | netip.MustParsePrefix("1.0.0.0/24"), 23 | netip.MustParsePrefix("1.2.0.0/16"), 24 | netip.MustParsePrefix("1.2.3.0/24"), 25 | } 26 | 27 | slices.SortFunc(got, subnetCompare) 28 | assert.Equalf(t, want, got, "got (as strings): %q", got) 29 | } 30 | -------------------------------------------------------------------------------- /internal/bindtodevice/listenconfig_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "context" 7 | "net" 8 | 9 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 10 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" 11 | ) 12 | 13 | // ListenConfig is a [netext.ListenConfig] implementation that uses the 14 | // provided channel-based packet connection and listener to implement the 15 | // methods of the interface. 16 | // 17 | // netext.ListenConfig instances of this type are the ones that are going to be 18 | // set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic 19 | // work. 20 | type ListenConfig struct { 21 | packetConn *chanPacketConn 22 | listener *chanListener 23 | addr *agdnet.PrefixNetAddr 24 | } 25 | 26 | // type check 27 | var _ netext.ListenConfig = (*ListenConfig)(nil) 28 | 29 | // Listen implements the [netext.ListenConfig] interface for *ListenConfig. 30 | func (lc *ListenConfig) Listen( 31 | ctx context.Context, 32 | network string, 33 | address string, 34 | ) (l net.Listener, err error) { 35 | return lc.listener, nil 36 | } 37 | 38 | // ListenPacket implements the [netext.ListenConfig] interface for 39 | // *ListenConfig. 40 | func (lc *ListenConfig) ListenPacket( 41 | ctx context.Context, 42 | network string, 43 | address string, 44 | ) (c net.PacketConn, err error) { 45 | return lc.packetConn, nil 46 | } 47 | 48 | // Addr returns the address on which lc accepts connections. addr.Net is empty. 49 | func (lc *ListenConfig) Addr() (addr *agdnet.PrefixNetAddr) { 50 | return lc.addr 51 | } 52 | -------------------------------------------------------------------------------- /internal/bindtodevice/listenconfig_linux_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "context" 7 | "testing" 8 | 9 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestListenConfig(t *testing.T) { 15 | pc := newTestChanPacketConn(t, nil, nil) 16 | lsnr := newTestChanListener(t, nil) 17 | addr := &agdnet.PrefixNetAddr{ 18 | Prefix: testSubnetIPv4, 19 | Net: "", 20 | Port: 1234, 21 | } 22 | c := &ListenConfig{ 23 | packetConn: pc, 24 | listener: lsnr, 25 | addr: addr, 26 | } 27 | 28 | ctx := context.Background() 29 | 30 | gotPC, err := c.ListenPacket(ctx, "", "") 31 | require.NoError(t, err) 32 | 33 | assert.Equal(t, pc, gotPC) 34 | 35 | gotLsnr, err := c.Listen(ctx, "", "") 36 | require.NoError(t, err) 37 | 38 | assert.Equal(t, lsnr, gotLsnr) 39 | 40 | gotAddr := c.Addr() 41 | assert.Equal(t, addr, gotAddr) 42 | } 43 | -------------------------------------------------------------------------------- /internal/bindtodevice/listenconfig_others.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | 10 | "github.com/AdguardTeam/AdGuardDNS/internal/agdnet" 11 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" 12 | "github.com/AdguardTeam/golibs/errors" 13 | ) 14 | 15 | // ListenConfig is a [netext.ListenConfig] implementation that uses the 16 | // provided channel-based packet connection and listener to implement the 17 | // methods of the interface. 18 | // 19 | // netext.ListenConfig instances of this type are the ones that are going to be 20 | // set as [dnsserver.ConfigBase.ListenConfig] to make the bind-to-device logic 21 | // work. 22 | // 23 | // It is only supported on Linux. 24 | type ListenConfig struct{} 25 | 26 | // type check 27 | var _ netext.ListenConfig = (*ListenConfig)(nil) 28 | 29 | // Listen implements the [netext.ListenConfig] interface for *ListenConfig. 30 | // 31 | // It is only supported on Linux. 32 | func (lc *ListenConfig) Listen( 33 | ctx context.Context, 34 | network string, 35 | address string, 36 | ) (l net.Listener, err error) { 37 | return nil, fmt.Errorf( 38 | "bindtodevice: listen: %w; only supported on linux", 39 | errors.ErrUnsupported, 40 | ) 41 | } 42 | 43 | // ListenPacket implements the [netext.ListenConfig] interface for 44 | // *ListenConfig. 45 | // 46 | // It is only supported on Linux. 47 | func (lc *ListenConfig) ListenPacket( 48 | ctx context.Context, 49 | network string, 50 | address string, 51 | ) (c net.PacketConn, err error) { 52 | return nil, fmt.Errorf( 53 | "bindtodevice: listenpacket: %w; only supported on linux", 54 | errors.ErrUnsupported, 55 | ) 56 | } 57 | 58 | // Addr returns the address on which lc accepts connections. 59 | // 60 | // It is only supported on Linux. 61 | func (lc *ListenConfig) Addr() (addr *agdnet.PrefixNetAddr) { 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /internal/bindtodevice/manager.go: -------------------------------------------------------------------------------- 1 | package bindtodevice 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 7 | ) 8 | 9 | // ManagerConfig is the configuration structure for [NewManager]. All fields 10 | // must be set. 11 | type ManagerConfig struct { 12 | // Logger is used to log the operation of the manager. 13 | Logger *slog.Logger 14 | 15 | // InterfaceStorage is used to get the information about the system's 16 | // network interfaces. Normally, this is [DefaultInterfaceStorage]. 17 | InterfaceStorage InterfaceStorage 18 | 19 | // ErrColl is the error collector that is used to collect non-critical 20 | // errors. 21 | ErrColl errcoll.Interface 22 | 23 | // Metrics collects bindtodevice-related statistics. It must not be nil. 24 | Metrics Metrics 25 | 26 | // ChannelBufferSize is the size of the buffers of the channels used to 27 | // dispatch TCP connections and UDP sessions. 28 | ChannelBufferSize int 29 | } 30 | 31 | // ControlConfig is the configuration of socket options. 32 | type ControlConfig struct { 33 | // RcvBufSize defines the size of socket receive buffer in bytes. Default 34 | // is zero (uses system settings). 35 | RcvBufSize int 36 | 37 | // SndBufSize defines the size of socket send buffer in bytes. Default is 38 | // zero (uses system settings). 39 | SndBufSize int 40 | } 41 | -------------------------------------------------------------------------------- /internal/bindtodevice/packetsession_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package bindtodevice 4 | 5 | import ( 6 | "net" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" 9 | ) 10 | 11 | // packetSession is a [netext.PacketSession] that contains additional 12 | // information about the packet read from a UDP connection that has the 13 | // SO_BINDTODEVICE option set. 14 | type packetSession struct { 15 | laddr *net.UDPAddr 16 | raddr *net.UDPAddr 17 | readBody []byte 18 | respOOB []byte 19 | } 20 | 21 | // type check 22 | var _ netext.PacketSession = (*packetSession)(nil) 23 | 24 | // LocalAddr implements the [netext.PacketSession] interface for *packetSession. 25 | func (s *packetSession) LocalAddr() (addr net.Addr) { return s.laddr } 26 | 27 | // RemoteAddr implements the [netext.PacketSession] interface for 28 | // *packetSession. 29 | func (s *packetSession) RemoteAddr() (addr net.Addr) { return s.raddr } 30 | -------------------------------------------------------------------------------- /internal/cmd/access.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/errors" 5 | "github.com/AdguardTeam/golibs/netutil" 6 | "github.com/AdguardTeam/golibs/validate" 7 | ) 8 | 9 | // Possible values of the STANDARD_ACCESS_TYPE environment variable. 10 | const ( 11 | standardAccessOff = "off" 12 | standardAccessBackend = "backend" 13 | ) 14 | 15 | // accessConfig is the configuration that controls IP and hosts blocking. 16 | type accessConfig struct { 17 | // BlockedQuestionDomains is a list of AdBlock rules used to block access. 18 | BlockedQuestionDomains []string `yaml:"blocked_question_domains"` 19 | 20 | // BlockedClientSubnets is a list of IP addresses or subnets to block. 21 | BlockedClientSubnets []netutil.Prefix `yaml:"blocked_client_subnets"` 22 | } 23 | 24 | // type check 25 | var _ validate.Interface = (*accessConfig)(nil) 26 | 27 | // Validate implements the [validate.Interface] interface for *accessConfig. 28 | func (c *accessConfig) Validate() (err error) { 29 | if c == nil { 30 | return errors.ErrNoValue 31 | } 32 | 33 | return nil 34 | } 35 | -------------------------------------------------------------------------------- /internal/cmd/additional.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | "maps" 6 | "slices" 7 | 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/validate" 10 | "github.com/prometheus/common/model" 11 | ) 12 | 13 | // additionalInfo is a extra info configuration. 14 | type additionalInfo map[string]string 15 | 16 | // type check 17 | var _ validate.Interface = additionalInfo(nil) 18 | 19 | // Validate implements the [validate.Interface] interface for additionalInfo. 20 | func (c additionalInfo) Validate() (err error) { 21 | var errs []error 22 | for _, k := range slices.Sorted(maps.Keys(c)) { 23 | if !model.LegacyValidation.IsValidLabelName(k) { 24 | errs = append(errs, fmt.Errorf( 25 | "prometheus labels must match %s, got %q", 26 | model.LabelNameRE, 27 | k, 28 | )) 29 | } 30 | } 31 | 32 | return errors.Join(errs...) 33 | } 34 | -------------------------------------------------------------------------------- /internal/cmd/dnsdb.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/errors" 5 | "github.com/AdguardTeam/golibs/validate" 6 | ) 7 | 8 | // dnsDBConfig is the configuration of the DNSDB module. 9 | type dnsDBConfig struct { 10 | // MaxSize is the maximum amount of records in the memory buffer. 11 | MaxSize int `yaml:"max_size"` 12 | 13 | // Enabled describes if the DNSDB memory buffer is enabled. 14 | Enabled bool `yaml:"enabled"` 15 | } 16 | 17 | // type check 18 | var _ validate.Interface = (*dnsDBConfig)(nil) 19 | 20 | // Validate implements the [validate.Interface] interface for *dnsDBConfig. 21 | func (c *dnsDBConfig) Validate() (err error) { 22 | if c == nil { 23 | return errors.ErrNoValue 24 | } else if !c.Enabled { 25 | return nil 26 | } 27 | 28 | return validate.Positive("max_size", c.MaxSize) 29 | } 30 | -------------------------------------------------------------------------------- /internal/cmd/error.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | ) 11 | 12 | // reportPanics reports all panics in Main using the Sentry client, logs them, 13 | // and repanics. It should be called in a defer. 14 | // 15 | // TODO(a.garipov): Consider switching to pure Sentry. 16 | func reportPanics(ctx context.Context, errColl errcoll.Interface, l *slog.Logger) { 17 | v := recover() 18 | if v == nil { 19 | return 20 | } 21 | 22 | slogutil.PrintRecovered(ctx, l, v) 23 | 24 | err := errors.FromRecovered(v) 25 | errColl.Collect(ctx, err) 26 | errFlushColl, ok := errColl.(errcoll.ErrorFlushCollector) 27 | if ok { 28 | errFlushColl.Flush() 29 | } 30 | 31 | panic(v) 32 | } 33 | -------------------------------------------------------------------------------- /internal/cmd/geoip.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/errors" 5 | "github.com/AdguardTeam/golibs/timeutil" 6 | "github.com/AdguardTeam/golibs/validate" 7 | ) 8 | 9 | // geoIPConfig is the GeoIP database configuration. 10 | type geoIPConfig struct { 11 | // HostCacheSize is the size of the hostname lookup cache, in entries. 12 | // 13 | // TODO(a.garipov): Rename to "host_cache_count"? 14 | HostCacheSize int `yaml:"host_cache_size"` 15 | 16 | // IPCacheSize is the size of the IP lookup cache, in entries. 17 | // 18 | // TODO(a.garipov): Rename to "ip_cache_count"? 19 | IPCacheSize int `yaml:"ip_cache_size"` 20 | 21 | // RefreshIvl defines how often AdGuard DNS reopens the GeoIP database 22 | // files. 23 | RefreshIvl timeutil.Duration `yaml:"refresh_interval"` 24 | } 25 | 26 | // type check 27 | var _ validate.Interface = (*geoIPConfig)(nil) 28 | 29 | // Validate implements the [validate.Interface] interface for *geoIPConfig. 30 | func (c *geoIPConfig) Validate() (err error) { 31 | if c == nil { 32 | return errors.ErrNoValue 33 | } 34 | 35 | return errors.Join( 36 | // NOTE: While a [geoip.File] can work with an empty host cache, that 37 | // feature is only used for tests. 38 | validate.Positive("host_cache_size", c.HostCacheSize), 39 | validate.Positive("ip_cache_size", c.IPCacheSize), 40 | validate.Positive("refresh_interval", c.RefreshIvl), 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /internal/cmd/network.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "math" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/bindtodevice" 7 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/AdguardTeam/golibs/validate" 10 | "github.com/c2h5oh/datasize" 11 | ) 12 | 13 | // network defines the network settings. 14 | type network struct { 15 | // SndBufSize defines the size of socket send buffer. Default is zero (uses 16 | // system settings). 17 | SndBufSize datasize.ByteSize `yaml:"so_sndbuf"` 18 | 19 | // RcvBufSize defines the size of socket receive buffer. Default is zero 20 | // (uses system settings). 21 | RcvBufSize datasize.ByteSize `yaml:"so_rcvbuf"` 22 | } 23 | 24 | // type check 25 | var _ validate.Interface = (*network)(nil) 26 | 27 | // Validate implements the [validate.Interface] interface for *network. 28 | func (n *network) Validate() (err error) { 29 | if n == nil { 30 | return errors.ErrNoValue 31 | } 32 | 33 | const maxBufSize datasize.ByteSize = math.MaxInt32 34 | 35 | return errors.Join( 36 | validate.NoGreaterThan("so_sndbuf", n.SndBufSize, maxBufSize), 37 | validate.NoGreaterThan("so_rcvbuf", n.RcvBufSize, maxBufSize), 38 | ) 39 | } 40 | 41 | // toInternal converts n to the bindtodevice control configuration and network 42 | // extension control configuration. n must be valid. 43 | func (n *network) toInternal() (bc *bindtodevice.ControlConfig, nc *netext.ControlConfig) { 44 | bc = &bindtodevice.ControlConfig{ 45 | // #nosec G115 -- Validated in [network.validate]. 46 | SndBufSize: int(n.SndBufSize.Bytes()), 47 | // #nosec G115 -- Validated in [network.validate]. 48 | RcvBufSize: int(n.RcvBufSize.Bytes()), 49 | } 50 | nc = &netext.ControlConfig{ 51 | // #nosec G115 -- Validated in [network.validate]. 52 | SndBufSize: int(n.SndBufSize.Bytes()), 53 | // #nosec G115 -- Validated in [network.validate]. 54 | RcvBufSize: int(n.RcvBufSize.Bytes()), 55 | } 56 | 57 | return bc, nc 58 | } 59 | -------------------------------------------------------------------------------- /internal/cmd/querylog.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/errors" 5 | "github.com/AdguardTeam/golibs/validate" 6 | ) 7 | 8 | // queryLogConfig is the query log configuration. 9 | type queryLogConfig struct { 10 | // File contains the JSONL file query log configuration. 11 | File *queryLogFileConfig `yaml:"file"` 12 | } 13 | 14 | // type check 15 | var _ validate.Interface = (*queryLogConfig)(nil) 16 | 17 | // Validate implements the [validate.Interface] interface for *queryLogConfig. 18 | func (c *queryLogConfig) Validate() (err error) { 19 | if c == nil { 20 | return errors.ErrNoValue 21 | } 22 | 23 | return validate.NotNil("file", c.File) 24 | } 25 | 26 | // queryLogFileConfig is the JSONL file query log configuration. 27 | type queryLogFileConfig struct { 28 | Enabled bool `yaml:"enabled"` 29 | } 30 | -------------------------------------------------------------------------------- /internal/cmd/runtime.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "runtime/debug" 7 | 8 | "github.com/AdguardTeam/golibs/logutil/slogutil" 9 | ) 10 | 11 | // setMaxThreads sets the maximum number of threads for the Go runtime, if 12 | // necessary. l must not be nil, envs must not be negative. 13 | func setMaxThreads(ctx context.Context, l *slog.Logger, n int) { 14 | if n == 0 { 15 | l.Log(ctx, slogutil.LevelTrace, "go max threads not set") 16 | 17 | return 18 | } 19 | 20 | debug.SetMaxThreads(n) 21 | 22 | l.InfoContext(ctx, "set go max threads", "n", n) 23 | } 24 | -------------------------------------------------------------------------------- /internal/cmd/safebrowsing.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/AdguardTeam/golibs/errors" 5 | "github.com/AdguardTeam/golibs/timeutil" 6 | "github.com/AdguardTeam/golibs/validate" 7 | ) 8 | 9 | // safeBrowsingConfig is the configuration for one of the safe browsing filters. 10 | type safeBrowsingConfig struct { 11 | // BlockHost is the hostname with which to respond to any requests that 12 | // match the filter. 13 | // 14 | // TODO(a.garipov): Consider replacing with a list of IPv4 and IPv6 15 | // addresses. 16 | BlockHost string `yaml:"block_host"` 17 | 18 | // CacheSize is the size of the response cache, in entries. 19 | CacheSize int `yaml:"cache_size"` 20 | 21 | // CacheTTL is the TTL of the response cache. 22 | CacheTTL timeutil.Duration `yaml:"cache_ttl"` 23 | 24 | // RefreshIvl defines how often AdGuard DNS refreshes the filter. 25 | RefreshIvl timeutil.Duration `yaml:"refresh_interval"` 26 | 27 | // RefreshTimeout is the timeout for the filter update operation. 28 | RefreshTimeout timeutil.Duration `yaml:"refresh_timeout"` 29 | } 30 | 31 | // type check 32 | var _ validate.Interface = (*safeBrowsingConfig)(nil) 33 | 34 | // Validate implements the [validate.Interface] interface for 35 | // *safeBrowsingConfig. 36 | func (c *safeBrowsingConfig) Validate() (err error) { 37 | if c == nil { 38 | return errors.ErrNoValue 39 | } 40 | 41 | return errors.Join( 42 | validate.NotEmpty("block_host", c.BlockHost), 43 | validate.Positive("cache_size", c.CacheSize), 44 | validate.Positive("cache_ttl", c.CacheTTL), 45 | validate.Positive("refresh_interval", c.RefreshIvl), 46 | validate.Positive("refresh_timeout", c.RefreshTimeout), 47 | ) 48 | } 49 | -------------------------------------------------------------------------------- /internal/connlimiter/conn.go: -------------------------------------------------------------------------------- 1 | package connlimiter 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/AdguardTeam/golibs/errors" 11 | "github.com/AdguardTeam/golibs/logutil/optslog" 12 | ) 13 | 14 | // limitConn is a wrapper for a stream connection that decreases the counter 15 | // value on close. 16 | // 17 | // See https://pkg.go.dev/golang.org/x/net/netutil#LimitListener. 18 | type limitConn struct { 19 | net.Conn 20 | 21 | connInfo *ConnMetricsData 22 | logger *slog.Logger 23 | metrics Metrics 24 | decrement func(ctx context.Context) 25 | start time.Time 26 | isClosed atomic.Bool 27 | } 28 | 29 | // Close closes the underlying connection and decrements the counter. 30 | func (c *limitConn) Close() (err error) { 31 | defer func() { err = errors.Annotate(err, "limit conn: %w") }() 32 | 33 | if !c.isClosed.CompareAndSwap(false, true) { 34 | return net.ErrClosed 35 | } 36 | 37 | // Close the connection immediately and wait for the counter decrement and 38 | // metrics later. 39 | err = c.Conn.Close() 40 | 41 | ctx := context.Background() 42 | connLife := time.Since(c.start) 43 | optslog.Trace2(ctx, c.logger, "closed conn", "raddr", c.RemoteAddr(), "conn_life", connLife) 44 | 45 | c.metrics.ObserveLifeDuration(ctx, c.connInfo, connLife) 46 | 47 | c.decrement(ctx) 48 | 49 | return err 50 | } 51 | -------------------------------------------------------------------------------- /internal/connlimiter/counter.go: -------------------------------------------------------------------------------- 1 | package connlimiter 2 | 3 | // counter is the simultaneous stream-connection counter. It stops accepting 4 | // new connections once it reaches stop and resumes when the number of active 5 | // connections goes back to resume. 6 | // 7 | // Note that current is the number of both active stream-connections as well as 8 | // goroutines that are currently in the process of accepting a new connection 9 | // but haven't accepted one yet. 10 | type counter struct { 11 | current uint64 12 | stop uint64 13 | resume uint64 14 | isAccepting bool 15 | } 16 | 17 | // increment tries to add the connection to the current active connection count. 18 | // If the counter does not accept new connections, shouldAccept is false. 19 | func (c *counter) increment() (shouldAccept bool) { 20 | if !c.isAccepting { 21 | return false 22 | } 23 | 24 | c.current++ 25 | c.isAccepting = c.current < c.stop 26 | 27 | return true 28 | } 29 | 30 | // decrement decreases the number of current active connections. 31 | func (c *counter) decrement() { 32 | c.current-- 33 | 34 | c.isAccepting = c.isAccepting || c.current <= c.resume 35 | } 36 | -------------------------------------------------------------------------------- /internal/connlimiter/counter_internal_test.go: -------------------------------------------------------------------------------- 1 | package connlimiter 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCounter(t *testing.T) { 10 | t.Run("same", func(t *testing.T) { 11 | c := &counter{ 12 | current: 0, 13 | stop: 1, 14 | resume: 1, 15 | isAccepting: true, 16 | } 17 | 18 | assert.True(t, c.increment()) 19 | assert.False(t, c.increment()) 20 | 21 | c.decrement() 22 | assert.True(t, c.increment()) 23 | assert.False(t, c.increment()) 24 | }) 25 | 26 | t.Run("more", func(t *testing.T) { 27 | c := &counter{ 28 | current: 0, 29 | stop: 2, 30 | resume: 1, 31 | isAccepting: true, 32 | } 33 | 34 | assert.True(t, c.increment()) 35 | assert.True(t, c.increment()) 36 | assert.False(t, c.increment()) 37 | 38 | c.decrement() 39 | assert.True(t, c.increment()) 40 | assert.False(t, c.increment()) 41 | }) 42 | } 43 | -------------------------------------------------------------------------------- /internal/connlimiter/listenconfig.go: -------------------------------------------------------------------------------- 1 | package connlimiter 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/netext" 9 | ) 10 | 11 | // type check 12 | var _ netext.ListenConfig = (*ListenConfig)(nil) 13 | 14 | // ListenConfig is a [netext.ListenConfig] that uses a [*Limiter] to limit the 15 | // number of active stream-connections. 16 | type ListenConfig struct { 17 | listenConfig netext.ListenConfig 18 | limiter *Limiter 19 | } 20 | 21 | // NewListenConfig returns a new netext.ListenConfig that uses l to limit the 22 | // number of active stream-connections. 23 | func NewListenConfig(c netext.ListenConfig, l *Limiter) (limited *ListenConfig) { 24 | return &ListenConfig{ 25 | listenConfig: c, 26 | limiter: l, 27 | } 28 | } 29 | 30 | // ListenPacket implements the [netext.ListenConfig] interface for 31 | // *ListenConfig. 32 | func (c *ListenConfig) ListenPacket( 33 | ctx context.Context, 34 | network string, 35 | address string, 36 | ) (conn net.PacketConn, err error) { 37 | return c.listenConfig.ListenPacket(ctx, network, address) 38 | } 39 | 40 | // Listen implements the [netext.ListenConfig] interface for *ListenConfig. 41 | // Listen returns a net.Listener wrapped by c's limiter. ctx must contain a 42 | // [dnsserver.ServerInfo]. 43 | func (c *ListenConfig) Listen( 44 | ctx context.Context, 45 | network string, 46 | address string, 47 | ) (l net.Listener, err error) { 48 | l, err = c.listenConfig.Listen(ctx, network, address) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return c.limiter.Limit(l, dnsserver.MustServerInfoFromContext(ctx)), nil 54 | } 55 | -------------------------------------------------------------------------------- /internal/consul/metrics.go: -------------------------------------------------------------------------------- 1 | package consul 2 | 3 | import "context" 4 | 5 | // Metrics is an interface that is used for the collection of the allowlist 6 | // statistics. 7 | type Metrics interface { 8 | // SetSize sets the number of received subnets. 9 | SetSize(ctx context.Context, n int) 10 | 11 | // SetStatus sets the status and time of the allowlist refresh attempt. 12 | SetStatus(ctx context.Context, err error) 13 | } 14 | 15 | // EmptyMetrics is the implementation of the [Metrics] interface that does 16 | // nothing. 17 | type EmptyMetrics struct{} 18 | 19 | // type check 20 | var _ Metrics = EmptyMetrics{} 21 | 22 | // SetSize implements the [Metrics] interface for EmptyMetrics. 23 | func (EmptyMetrics) SetSize(_ context.Context, _ int) {} 24 | 25 | // SetStatus plements the [Metrics] interface for EmptyMetrics. 26 | func (EmptyMetrics) SetStatus(_ context.Context, _ error) {} 27 | -------------------------------------------------------------------------------- /internal/dnscheck/dnscheck_test.go: -------------------------------------------------------------------------------- 1 | package dnscheck_test 2 | 3 | import ( 4 | "net/netip" 5 | ) 6 | 7 | // Test data. 8 | var ( 9 | testRemoteIP = netip.MustParseAddr("1.2.3.4") 10 | ) 11 | -------------------------------------------------------------------------------- /internal/dnsdb/http.go: -------------------------------------------------------------------------------- 1 | package dnsdb 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/csv" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strings" 10 | 11 | "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" 12 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 13 | "github.com/AdguardTeam/golibs/errors" 14 | "github.com/AdguardTeam/golibs/httphdr" 15 | ) 16 | 17 | // type check 18 | var _ http.Handler = (*Default)(nil) 19 | 20 | // ServeHTTP implements the [http.Handler] interface for *Default. 21 | func (db *Default) ServeHTTP(w http.ResponseWriter, r *http.Request) { 22 | var err error 23 | ctx := r.Context() 24 | 25 | records := db.reset(ctx) 26 | 27 | h := w.Header() 28 | h.Add(httphdr.ContentType, agdhttp.HdrValTextCSV) 29 | 30 | h.Set(httphdr.Trailer, httphdr.XError) 31 | defer func() { 32 | if err != nil { 33 | h.Set(httphdr.XError, err.Error()) 34 | errcoll.Collect(ctx, db.errColl, db.logger, "handling http", err) 35 | } 36 | }() 37 | 38 | var rw io.Writer = w 39 | 40 | // TODO(a.garipov): Parse the quality value. 41 | // 42 | // TODO(a.garipov): Support other compression algorithms. 43 | if strings.Contains(r.Header.Get(httphdr.AcceptEncoding), agdhttp.HdrValGzip) { 44 | h.Set(httphdr.ContentEncoding, agdhttp.HdrValGzip) 45 | gw := gzip.NewWriter(w) 46 | defer func() { err = errors.WithDeferred(err, gw.Close()) }() 47 | 48 | rw = gw 49 | } 50 | 51 | w.WriteHeader(http.StatusOK) 52 | 53 | csvw := csv.NewWriter(rw) 54 | defer csvw.Flush() 55 | 56 | err = writeCSVRecs(csvw, records) 57 | } 58 | 59 | // writeCSVRecs writes the CSV representation of recs into w. 60 | func writeCSVRecs(w *csv.Writer, recs []*record) (err error) { 61 | for i, r := range recs { 62 | err = w.Write(r.csv()) 63 | if err != nil { 64 | return fmt.Errorf("record at index %d: %w", i, err) 65 | } 66 | } 67 | 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /internal/dnsdb/metrics.go: -------------------------------------------------------------------------------- 1 | package dnsdb 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // Metrics is an interface that is used for the collection of the DNS database 9 | // statistics. 10 | type Metrics interface { 11 | // SetRecordCount sets the number of records that have not yet been 12 | // uploaded. 13 | SetRecordCount(ctx context.Context, count int) 14 | 15 | // ObserveRotation updates the time of the database rotation and stores the 16 | // duration of the rotation. 17 | ObserveRotation(ctx context.Context, dur time.Duration) 18 | } 19 | 20 | // EmptyMetrics is the implementation of the [Metrics] interface that does 21 | // nothing. 22 | type EmptyMetrics struct{} 23 | 24 | // type check 25 | var _ Metrics = EmptyMetrics{} 26 | 27 | // SetRecordCount implements the [Metrics] interface for EmptyMetrics. 28 | func (EmptyMetrics) SetRecordCount(_ context.Context, _ int) {} 29 | 30 | // ObserveRotation implements the [Metrics] interface for EmptyMetrics. 31 | func (EmptyMetrics) ObserveRotation(_ context.Context, dur time.Duration) {} 32 | -------------------------------------------------------------------------------- /internal/dnsmsg/clonerstat.go: -------------------------------------------------------------------------------- 1 | package dnsmsg 2 | 3 | // ClonerStat is an interface for entities that collect statistics about a 4 | // [Cloner]. 5 | // 6 | // All methods must be safe for concurrent use. 7 | type ClonerStat interface { 8 | // OnClone is called on [Cloner.Clone] calls. isFull is true if the clone 9 | // was full. 10 | OnClone(isFull bool) 11 | } 12 | 13 | // EmptyClonerStat is a [ClonerStat] implementation that does nothing. 14 | type EmptyClonerStat struct{} 15 | 16 | // type check 17 | var _ ClonerStat = EmptyClonerStat{} 18 | 19 | // OnClone implements the [ClonerStat] interface for EmptyClonerStat. 20 | func (EmptyClonerStat) OnClone(_ bool) {} 21 | -------------------------------------------------------------------------------- /internal/dnsmsg/error.go: -------------------------------------------------------------------------------- 1 | package dnsmsg 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | ) 8 | 9 | // BadECSError is returned by functions that work with EDNS Client Subnet 10 | // option when the data in the option is invalid. 11 | type BadECSError struct { 12 | Err error 13 | } 14 | 15 | // type check 16 | var _ error = BadECSError{} 17 | 18 | // Error implements the error interface for BadECSError. 19 | func (err BadECSError) Error() (msg string) { 20 | return fmt.Sprintf("bad ecs: %s", err.Err) 21 | } 22 | 23 | // type check 24 | var _ errors.Wrapper = BadECSError{} 25 | 26 | // Unwrap implements the errors.Wrapper interface for BadECSError. 27 | func (err BadECSError) Unwrap() (unwrapped error) { 28 | return err.Err 29 | } 30 | 31 | // IsSentryReportable implements the [errcoll.SentryReportableError] interface 32 | // for BadECSError. 33 | func (err BadECSError) IsSentryReportable() (ok bool) { return false } 34 | -------------------------------------------------------------------------------- /internal/dnsmsg/error_test.go: -------------------------------------------------------------------------------- 1 | package dnsmsg_test 2 | 3 | import ( 4 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsmsg" 5 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 6 | ) 7 | 8 | // type check 9 | var _ errcoll.SentryReportableError = dnsmsg.BadECSError{} 10 | -------------------------------------------------------------------------------- /internal/dnsserver/cache/metrics.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // MetricsListener is an interface that is used for monitoring the 10 | // cache.Middleware state. The middleware user may opt to supply a metrics 11 | // interface implementation that would increment different kinds of metrics (for 12 | // instance, prometheus metrics). 13 | type MetricsListener interface { 14 | // OnCacheItemAdded is called when an item has been added to the cache. 15 | OnCacheItemAdded(ctx context.Context, resp *dns.Msg, cacheLen int) 16 | 17 | // OnCacheHit is called when a response for the specified request has been 18 | // found in the cache. 19 | OnCacheHit(ctx context.Context, req *dns.Msg) 20 | 21 | // OnCacheMiss is called when a response for the specified request has not 22 | // been found in the cache. 23 | OnCacheMiss(ctx context.Context, req *dns.Msg) 24 | } 25 | 26 | // EmptyMetricsListener implements MetricsListener with empty functions. This 27 | // implementation is used by default if the user does not supply a custom one. 28 | type EmptyMetricsListener struct{} 29 | 30 | // type check 31 | var _ MetricsListener = EmptyMetricsListener{} 32 | 33 | // OnCacheItemAdded implements the MetricsListener interface for 34 | // EmptyMetricsListener. 35 | func (EmptyMetricsListener) OnCacheItemAdded(_ context.Context, _ *dns.Msg, _ int) {} 36 | 37 | // OnCacheHit implements the MetricsListener interface for EmptyMetricsListener. 38 | func (EmptyMetricsListener) OnCacheHit(_ context.Context, _ *dns.Msg) {} 39 | 40 | // OnCacheMiss implements the MetricsListener interface for 41 | // EmptyMetricsListener. 42 | func (EmptyMetricsListener) OnCacheMiss(_ context.Context, _ *dns.Msg) {} 43 | -------------------------------------------------------------------------------- /internal/dnsserver/context_test.go: -------------------------------------------------------------------------------- 1 | package dnsserver_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestServerInfoFromContext(t *testing.T) { 12 | ctx := context.Background() 13 | _, ok := dnsserver.ServerInfoFromContext(ctx) 14 | require.False(t, ok) 15 | 16 | serverInfo := &dnsserver.ServerInfo{ 17 | Name: "test", 18 | Addr: "127.0.0.1", 19 | Proto: dnsserver.ProtoDNS, 20 | } 21 | ctx = dnsserver.ContextWithServerInfo(ctx, serverInfo) 22 | 23 | s, ok := dnsserver.ServerInfoFromContext(ctx) 24 | require.True(t, ok) 25 | require.Equal(t, serverInfo, s) 26 | } 27 | 28 | func TestMustServerInfoFromContext(t *testing.T) { 29 | require.Panics(t, func() { 30 | ctx := context.Background() 31 | _ = dnsserver.MustServerInfoFromContext(ctx) 32 | }) 33 | } 34 | -------------------------------------------------------------------------------- /internal/dnsserver/disposer.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import "github.com/miekg/dns" 4 | 5 | // Disposer is an interface for pools that can save parts of DNS response 6 | // messages for later reuse. 7 | // 8 | // TODO(a.garipov): Think of ways of extending [ResponseWriter] to do this 9 | // instead. 10 | // 11 | // TODO(a.garipov): Think of a better name. Recycle? Scrap? 12 | type Disposer interface { 13 | // Dispose saves parts of resp for later reuse. resp may be nil. 14 | // Implementations must be safe for concurrent use. 15 | Dispose(resp *dns.Msg) 16 | } 17 | 18 | // EmptyDisposer is a [Disposer] that does nothing. 19 | type EmptyDisposer struct{} 20 | 21 | // type check 22 | var _ Disposer = EmptyDisposer{} 23 | 24 | // Dispose implements the [Disposer] interface for EmptyDisposer. 25 | func (EmptyDisposer) Dispose(_ *dns.Msg) {} 26 | -------------------------------------------------------------------------------- /internal/dnsserver/dnsserver_test.go: -------------------------------------------------------------------------------- 1 | package dnsserver_test 2 | 3 | import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 4 | 5 | // testTimeout is a common timeout for tests. 6 | const testTimeout = dnsserver.DefaultReadTimeout 7 | -------------------------------------------------------------------------------- /internal/dnsserver/dnsservertest/dnsservertest.go: -------------------------------------------------------------------------------- 1 | // Package dnsservertest provides convenient helper functions for unit-tests 2 | // in packages related to dnsserver. 3 | package dnsservertest 4 | -------------------------------------------------------------------------------- /internal/dnsserver/dnsservertest/error_unix.go: -------------------------------------------------------------------------------- 1 | //go:build unix 2 | 3 | package dnsservertest 4 | 5 | import ( 6 | "github.com/AdguardTeam/golibs/errors" 7 | "golang.org/x/sys/unix" 8 | ) 9 | 10 | // errorIsAddrInUse returns true if err is an address already in use error. 11 | func errorIsAddrInUse(err error) (ok bool) { 12 | return errors.Is(err, unix.EADDRINUSE) 13 | } 14 | -------------------------------------------------------------------------------- /internal/dnsserver/dnsservertest/error_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package dnsservertest 4 | 5 | import ( 6 | "github.com/AdguardTeam/golibs/errors" 7 | "golang.org/x/sys/windows" 8 | ) 9 | 10 | // errorIsAddrInUse returns true if err is an address already in use error. 11 | func errorIsAddrInUse(err error) (ok bool) { 12 | return errors.Is(err, windows.WSAEADDRINUSE) 13 | } 14 | -------------------------------------------------------------------------------- /internal/dnsserver/forward/context.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/golibs/errors" 8 | ) 9 | 10 | // ctxKey is the type for context keys within this package. 11 | type ctxKey uint8 12 | 13 | const ( 14 | ctxKeyNetworkOverride ctxKey = iota 15 | ) 16 | 17 | // type check 18 | var _ fmt.Stringer = ctxKey(0) 19 | 20 | // String implements the [fmt.Stringer] interface for ctxKey. 21 | func (k ctxKey) String() (s string) { 22 | switch k { 23 | case ctxKeyNetworkOverride: 24 | return "ctxKeyNetworkOverride" 25 | default: 26 | panic(fmt.Errorf("ctx key: %w: %d", errors.ErrBadEnumValue, k)) 27 | } 28 | } 29 | 30 | // panicBadType is a helper that panics with a message about the context key and 31 | // the expected type. 32 | func panicBadType(key ctxKey, v any) { 33 | panic(fmt.Errorf("bad type for %s: %T(%[2]v)", key, v)) 34 | } 35 | 36 | // withNetworkOverride returns a copy of the parent context with the network 37 | // override added. 38 | func withNetworkOverride(ctx context.Context, network Network) (withNet context.Context) { 39 | return context.WithValue(ctx, ctxKeyNetworkOverride, network) 40 | } 41 | 42 | // networkOverrideFromContext returns the network override from the context, if 43 | // any. 44 | func networkOverrideFromContext(ctx context.Context) (network Network, ok bool) { 45 | const key = ctxKeyNetworkOverride 46 | 47 | v := ctx.Value(key) 48 | if v == nil { 49 | return NetworkAny, false 50 | } 51 | 52 | network, ok = v.(Network) 53 | if !ok { 54 | panicBadType(key, v) 55 | } 56 | 57 | return network, true 58 | } 59 | -------------------------------------------------------------------------------- /internal/dnsserver/forward/error.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | ) 8 | 9 | // Common Errors 10 | 11 | // Error is the forwarding error. 12 | type Error struct { 13 | Err error 14 | Main Upstream 15 | Fallback Upstream 16 | } 17 | 18 | // type check 19 | var _ error = (*Error)(nil) 20 | 21 | // Error implements the error interface for *Error. 22 | func (err *Error) Error() (msg string) { 23 | if err.Fallback == nil { 24 | return fmt.Sprintf("forwarding to %s: %s", err.Main, err.Err) 25 | } else if err.Main == nil { 26 | return fmt.Sprintf("forwarding to fallback %s: %s", err.Fallback, err.Err) 27 | } 28 | 29 | return fmt.Sprintf( 30 | "forwarding to %s with fallback %s: %s", 31 | err.Main, 32 | err.Fallback, 33 | err.Err, 34 | ) 35 | } 36 | 37 | // type check 38 | var _ errors.Wrapper = (*Error)(nil) 39 | 40 | // Unwrap implements the errors.Wrapper interface for *Error. 41 | func (err *Error) Unwrap() (unwrapped error) { 42 | return err.Err 43 | } 44 | 45 | // annotate is a deferrable helper for forwarding errors. 46 | func annotate(err error, ups, fallbackUps Upstream) (wrapped error) { 47 | if err == nil { 48 | return nil 49 | } 50 | 51 | return &Error{ 52 | Err: err, 53 | Main: ups, 54 | Fallback: fallbackUps, 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /internal/dnsserver/forward/example_test.go: -------------------------------------------------------------------------------- 1 | package forward_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/netip" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 9 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/forward" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | ) 12 | 13 | func ExampleNewHandler() { 14 | conf := &dnsserver.ConfigDNS{ 15 | Base: &dnsserver.ConfigBase{ 16 | BaseLogger: slogutil.NewDiscardLogger(), 17 | Name: "srv", 18 | Addr: "127.0.0.1:0", 19 | Handler: forward.NewHandler(&forward.HandlerConfig{ 20 | UpstreamsAddresses: []*forward.UpstreamPlainConfig{{ 21 | Network: forward.NetworkAny, 22 | Address: netip.MustParseAddrPort("8.8.8.8:53"), 23 | Timeout: testTimeout, 24 | }}, 25 | FallbackAddresses: []*forward.UpstreamPlainConfig{{ 26 | Network: forward.NetworkAny, 27 | Address: netip.MustParseAddrPort("1.1.1.1:53"), 28 | Timeout: testTimeout, 29 | }}, 30 | }), 31 | }, 32 | } 33 | 34 | srv := dnsserver.NewServerDNS(conf) 35 | err := srv.Start(context.Background()) 36 | if err != nil { 37 | panic("failed to start the server") 38 | } 39 | 40 | fmt.Println("started server") 41 | 42 | defer func() { 43 | err = srv.Shutdown(context.Background()) 44 | if err != nil { 45 | panic("failed to shutdown the server") 46 | } 47 | 48 | fmt.Println("stopped server") 49 | }() 50 | 51 | // Output: 52 | // 53 | // started server 54 | // stopped server 55 | } 56 | -------------------------------------------------------------------------------- /internal/dnsserver/forward/network.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/golibs/errors" 7 | ) 8 | 9 | // Network is an enumeration of networks [UpstreamPlain] supports. 10 | type Network string 11 | 12 | const ( 13 | // NetworkAny means that [UpstreamPlain] will use the regular way of sending 14 | // a DNS query. First, it will send it over UDP. If for the response will 15 | // be truncated, it will automatically switch to using TCP. 16 | NetworkAny Network = "" 17 | 18 | // NetworkUDP means that [UpstreamPlain] will only use UDP. 19 | NetworkUDP Network = "udp" 20 | 21 | // NetworkTCP means that [UpstreamPlain] will only use TCP. 22 | NetworkTCP Network = "tcp" 23 | ) 24 | 25 | // NewNetwork parses the string and returns the corresponding Network value. 26 | func NewNetwork(networkStr string) (network Network, err error) { 27 | switch network = Network(networkStr); network { 28 | case NetworkAny, NetworkUDP, NetworkTCP: 29 | return network, nil 30 | default: 31 | return "", fmt.Errorf("networkStr: %w: %q", errors.ErrBadEnumValue, networkStr) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/dnsserver/forward/upstream.go: -------------------------------------------------------------------------------- 1 | package forward 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // Upstream is the interface for a DNS client. 12 | type Upstream interface { 13 | // Exchange processes the given request. Returns a response, network type 14 | // over which the request has been processed and an error if happened. 15 | // 16 | // TODO(a.garipov): Make it more extensible. Either metrics through context, 17 | // or returning some interface value, similar to [netext.PacketSession]. 18 | Exchange(ctx context.Context, req *dns.Msg) (resp *dns.Msg, nw Network, err error) 19 | 20 | io.Closer 21 | fmt.Stringer 22 | } 23 | -------------------------------------------------------------------------------- /internal/dnsserver/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/AdguardTeam/AdGuardDNS/internal/dnsserver 2 | 3 | go 1.25.1 4 | 5 | require ( 6 | github.com/AdguardTeam/golibs v0.34.1 7 | github.com/ameshkov/dnscrypt/v2 v2.4.0 8 | github.com/ameshkov/dnsstamps v1.0.3 9 | github.com/bluele/gcache v0.0.2 10 | github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 11 | github.com/miekg/dns v1.1.68 12 | github.com/panjf2000/ants/v2 v2.11.3 13 | github.com/patrickmn/go-cache v2.1.1-0.20191004192108-46f407853014+incompatible 14 | github.com/prometheus/client_golang v1.23.1 15 | github.com/quic-go/quic-go v0.54.0 16 | github.com/stretchr/testify v1.11.1 17 | golang.org/x/net v0.44.0 18 | golang.org/x/sys v0.36.0 19 | ) 20 | 21 | require ( 22 | github.com/beorn7/perks v1.0.1 // indirect 23 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 24 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 25 | github.com/grafana/regexp v0.0.0-20240518133315-a468a5bfb3bc // indirect 26 | github.com/kr/text v0.2.0 // indirect 27 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 28 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 29 | github.com/prometheus/client_model v0.6.2 // indirect 30 | github.com/prometheus/common v0.66.0 // indirect 31 | github.com/prometheus/procfs v0.17.0 // indirect 32 | github.com/quic-go/qpack v0.5.1 // indirect 33 | github.com/robfig/cron/v3 v3.0.1 // indirect 34 | go.uber.org/mock v0.6.0 // indirect 35 | golang.org/x/crypto v0.42.0 // indirect 36 | golang.org/x/exp v0.0.0-20250911091902-df9299821621 // indirect 37 | golang.org/x/mod v0.28.0 // indirect 38 | golang.org/x/sync v0.17.0 // indirect 39 | golang.org/x/text v0.29.0 // indirect 40 | golang.org/x/tools v0.37.0 // indirect 41 | google.golang.org/protobuf v1.36.9 // indirect 42 | gopkg.in/yaml.v2 v2.4.0 // indirect 43 | gopkg.in/yaml.v3 v3.0.1 // indirect 44 | ) 45 | -------------------------------------------------------------------------------- /internal/dnsserver/handler.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // Handler is an interface that defines how the DNS server would process DNS 10 | // queries. Inspired by net/http.Server and it's Handler. 11 | type Handler interface { 12 | // ServeDNS processes the request and writes a DNS response to rw. ctx must 13 | // contain [*ServerInfo] and [*RequestInfo]. rw and req must not be nil. 14 | // req must have exactly one question. 15 | ServeDNS(ctx context.Context, rw ResponseWriter, req *dns.Msg) (err error) 16 | } 17 | 18 | // The HandlerFunc type is an adapter to allow the use of ordinary functions 19 | // as DNS handlers. If f is a function with the appropriate signature, 20 | // HandlerFunc(f) is a Handler that calls f. 21 | type HandlerFunc func(context.Context, ResponseWriter, *dns.Msg) (err error) 22 | 23 | // type check 24 | var _ Handler = HandlerFunc(nil) 25 | 26 | // ServeDNS implements the [Handler] interface for HandlerFunc. 27 | func (f HandlerFunc) ServeDNS(ctx context.Context, rw ResponseWriter, req *dns.Msg) (err error) { 28 | return f(ctx, rw, req) 29 | } 30 | 31 | // notImplementedHandlerFunc is used if no Handler is configured for a server. 32 | var notImplementedHandlerFunc HandlerFunc = func( 33 | ctx context.Context, 34 | w ResponseWriter, 35 | r *dns.Msg, 36 | ) (err error) { 37 | res := (&dns.Msg{}).SetRcode(r, dns.RcodeNotImplemented) 38 | 39 | return w.WriteMsg(ctx, r, res) 40 | } 41 | -------------------------------------------------------------------------------- /internal/dnsserver/middleware.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | // Middleware is a general interface for dnsserver.Server middlewares. 4 | type Middleware interface { 5 | // Wrap wraps the specified Handler and returns a new handler. This 6 | // handler may call the underlying one and implement additional logic. 7 | Wrap(h Handler) (wrapped Handler) 8 | } 9 | 10 | // WithMiddlewares is a helper function that attaches the specified middlewares 11 | // to the Handler. Middlewares will be called in the same order in which they 12 | // were specified. 13 | func WithMiddlewares(h Handler, middlewares ...Middleware) (wrapped Handler) { 14 | wrapped = h 15 | 16 | // Go through middlewares in the reverse order. This way the middleware 17 | // that was specified first will be called first. 18 | for i := len(middlewares) - 1; i >= 0; i-- { 19 | m := middlewares[i] 20 | wrapped = m.Wrap(wrapped) 21 | } 22 | 23 | return wrapped 24 | } 25 | -------------------------------------------------------------------------------- /internal/dnsserver/msg.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "slices" 7 | 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | // genErrorResponse creates a short DNS message with the specified rcode. 12 | // it is supposed to be used for generating errors (server failure, bad format, 13 | // etc.) 14 | func genErrorResponse(req *dns.Msg, code int) (m *dns.Msg) { 15 | m = &dns.Msg{} 16 | m.SetRcode(req, code) 17 | 18 | return m 19 | } 20 | 21 | // packWithPrefix packs a DNS message with a 2-byte prefix with the message 22 | // length by appending it into buf and returns it. 23 | func packWithPrefix(m *dns.Msg, buf []byte) (packed []byte, err error) { 24 | buf, err = m.PackBuffer(buf) 25 | if err != nil { 26 | return nil, fmt.Errorf("packing buffer: %w", err) 27 | } 28 | 29 | l := len(buf) 30 | if l > dns.MaxMsgSize { 31 | // Generally shouldn't happen. 32 | return nil, fmt.Errorf("buffer too large: %d bytes", l) 33 | } 34 | 35 | // Try to reuse the slice if there is already space there. 36 | packed = slices.Grow(buf, 2)[:l+2] 37 | 38 | copy(packed[2:], buf) 39 | binary.BigEndian.PutUint16(packed[:2], uint16(l)) 40 | 41 | return packed, nil 42 | } 43 | -------------------------------------------------------------------------------- /internal/dnsserver/netext/listenconfig_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package netext 4 | 5 | import ( 6 | "net" 7 | "syscall" 8 | ) 9 | 10 | // listenControlWithSO is nil on Windows, because it doesn't support socket 11 | // options. 12 | var listenControlWithSO func(_ *ControlConfig, _ syscall.RawConn) (_ error) 13 | 14 | // setIPOpts sets the IPv4 and IPv6 options on a packet connection. 15 | func setIPOpts(c net.PacketConn) (err error) { 16 | return nil 17 | } 18 | -------------------------------------------------------------------------------- /internal/dnsserver/netext/packetconn_linux_internal_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | package netext 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "golang.org/x/net/ipv4" 10 | "golang.org/x/net/ipv6" 11 | ) 12 | 13 | func TestUDPOOBSize(t *testing.T) { 14 | // See https://github.com/miekg/dns/blob/v1.1.50/udp.go. 15 | 16 | len4 := len(ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)) 17 | len6 := len(ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)) 18 | 19 | max := len4 20 | if len6 > max { 21 | max = len6 22 | } 23 | 24 | assert.Equal(t, max, IPDstOOBSize) 25 | } 26 | -------------------------------------------------------------------------------- /internal/dnsserver/netext/packetconn_others.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package netext 4 | 5 | import "net" 6 | 7 | // wrapPacketConn wraps c to make it a [SessionPacketConn], if the OS supports 8 | // that. 9 | func wrapPacketConn(c net.PacketConn) (wrapped net.PacketConn) { 10 | return c 11 | } 12 | -------------------------------------------------------------------------------- /internal/dnsserver/nonwriter.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // NonWriterResponseWriter saves the response that has been written but doesn't 11 | // actually send it to the client. 12 | type NonWriterResponseWriter struct { 13 | localAddr net.Addr 14 | remoteAddr net.Addr 15 | req *dns.Msg // request (should be supplied in the WriteMsg method) 16 | res *dns.Msg // message that has been written (if any) 17 | } 18 | 19 | // type check 20 | var _ ResponseWriter = (*NonWriterResponseWriter)(nil) 21 | 22 | // NewNonWriterResponseWriter creates a new instance of the NonWriterResponseWriter. 23 | func NewNonWriterResponseWriter(localAddr, remoteAddr net.Addr) (nrw *NonWriterResponseWriter) { 24 | return &NonWriterResponseWriter{ 25 | localAddr: localAddr, 26 | remoteAddr: remoteAddr, 27 | } 28 | } 29 | 30 | // LocalAddr implements the ResponseWriter interface for *NonWriterResponseWriter. 31 | func (r *NonWriterResponseWriter) LocalAddr() (addr net.Addr) { 32 | return r.localAddr 33 | } 34 | 35 | // RemoteAddr implements the ResponseWriter interface for *NonWriterResponseWriter. 36 | func (r *NonWriterResponseWriter) RemoteAddr() (addr net.Addr) { 37 | return r.remoteAddr 38 | } 39 | 40 | // WriteMsg implements the ResponseWriter interface for *NonWriterResponseWriter. 41 | func (r *NonWriterResponseWriter) WriteMsg(_ context.Context, req, resp *dns.Msg) (err error) { 42 | // Just save the response, we'll use it later (see httpHandler for instance) 43 | r.req = req 44 | r.res = resp 45 | 46 | return nil 47 | } 48 | 49 | // Msg returns the message that has been written. 50 | func (r *NonWriterResponseWriter) Msg() (m *dns.Msg) { 51 | return r.res 52 | } 53 | -------------------------------------------------------------------------------- /internal/dnsserver/pool/conn.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "net" 5 | "time" 6 | ) 7 | 8 | // Conn wraps a net.Conn and contains additional info that could be required 9 | // by the Pool instance. It can be used directly instead of a net.Conn or you 10 | // may choose to use the underlying Conn.Conn instead. 11 | type Conn struct { 12 | net.Conn 13 | 14 | // lastTimeUsed is the last time when this connection was used, i.e. 15 | // requested from the pool. 16 | lastTimeUsed time.Time 17 | } 18 | 19 | // wrapConn wraps a net.Conn in a Conn instance. 20 | func wrapConn(conn net.Conn) (c *Conn) { 21 | return &Conn{ 22 | Conn: conn, 23 | lastTimeUsed: time.Now(), 24 | } 25 | } 26 | 27 | // isExpired checks if the connection has expired. 28 | func isExpired(conn *Conn, timeout time.Duration) (exp bool) { 29 | return timeout > 0 && 30 | time.Since(conn.lastTimeUsed) > timeout 31 | } 32 | -------------------------------------------------------------------------------- /internal/dnsserver/pool/example_test.go: -------------------------------------------------------------------------------- 1 | package pool_test 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/pool" 8 | ) 9 | 10 | func ExampleNewPool() { 11 | f := pool.Factory(func(_ context.Context) (net.Conn, error) { 12 | return net.Dial("udp", "8.8.8.8:53") 13 | }) 14 | p := pool.NewPool(10, f) 15 | 16 | // Create a new connection or get it from the pool 17 | conn, err := p.Get(context.Background()) 18 | if err != nil { 19 | panic("cannot create a new connection") 20 | } 21 | 22 | // Put the connection back to the pool when it's not needed anymore 23 | err = p.Put(conn) 24 | if err != nil { 25 | panic("cannot put connection back to the pool") 26 | } 27 | 28 | // Close the pool when you don't need it anymore 29 | err = p.Close() 30 | if err != nil { 31 | panic("cannot close the pool") 32 | } 33 | 34 | // Output: 35 | } 36 | -------------------------------------------------------------------------------- /internal/dnsserver/prometheus/dns.go: -------------------------------------------------------------------------------- 1 | package prometheus 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | // typeToString converts query type to a human-readable string. 10 | func typeToString(req *dns.Msg) string { 11 | var qType uint16 12 | if len(req.Question) == 1 { 13 | // NOTE: req can be invalid here, so check if the question is okay. 14 | qType = req.Question[0].Qtype 15 | } 16 | 17 | switch qType { 18 | case 19 | dns.TypeA, 20 | dns.TypeAAAA, 21 | dns.TypeCNAME, 22 | dns.TypeDNSKEY, 23 | dns.TypeDS, 24 | dns.TypeHTTPS, 25 | dns.TypeMX, 26 | dns.TypeNS, 27 | dns.TypeNSEC, 28 | dns.TypeNSEC3, 29 | dns.TypePTR, 30 | dns.TypeRRSIG, 31 | dns.TypeSOA, 32 | dns.TypeSRV, 33 | dns.TypeSVCB, 34 | dns.TypeTXT, 35 | // Meta Qtypes: 36 | dns.TypeANY, 37 | dns.TypeAXFR, 38 | dns.TypeIXFR: 39 | return dns.Type(qType).String() 40 | } 41 | 42 | // Sometimes people prefer to log something like "TYPE{qtype}". However, 43 | // practice shows that this creates quite a huge cardinality. 44 | return "OTHER" 45 | } 46 | 47 | // rCodeToString converts response code to a human-readable string. 48 | func rCodeToString(rCode int) string { 49 | rc, ok := dns.RcodeToString[rCode] 50 | if !ok { 51 | rc = strconv.Itoa(rCode) 52 | } 53 | 54 | return rc 55 | } 56 | -------------------------------------------------------------------------------- /internal/dnsserver/prometheus/prometheus_test.go: -------------------------------------------------------------------------------- 1 | package prometheus_test 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | "time" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 9 | "github.com/AdguardTeam/golibs/logutil/slogutil" 10 | "github.com/prometheus/client_golang/prometheus" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // testTimeout is the common timeout for tests. 15 | const testTimeout = 1 * time.Second 16 | 17 | // testLogger is the common logger for tests. 18 | var testLogger = slogutil.NewDiscardLogger() 19 | 20 | // testNamespace is a test namespace for metrics. 21 | const testNamespace = "dns" 22 | 23 | // testReqDomain is the common request domain for tests. 24 | const testReqDomain = "request-domain.example" 25 | 26 | // testServerInfo is the common server information structure for tests. 27 | var testServerInfo = &dnsserver.ServerInfo{ 28 | Name: "test_server", 29 | Addr: "127.0.0.1:80", 30 | Proto: dnsserver.ProtoDNS, 31 | } 32 | 33 | // testUDPAddr is the common UDP address for tests. 34 | var testUDPAddr = &net.UDPAddr{ 35 | IP: net.IP{1, 2, 3, 4}, 36 | Port: 53, 37 | } 38 | 39 | // requireMetrics accepts a list of metrics names and checks that they exist in 40 | // reg. 41 | func requireMetrics(t testing.TB, reg *prometheus.Registry, args ...string) { 42 | t.Helper() 43 | 44 | mf, err := reg.Gather() 45 | require.NoError(t, err) 46 | require.NotNil(t, mf) 47 | 48 | // Check that metrics were incremented. If they're present in the collection 49 | // return by Gatherer, it means that they were used. 50 | metricsToCheck := map[string]bool{} 51 | for _, m := range args { 52 | metricsToCheck[m] = true 53 | } 54 | 55 | // Delete from metricsToCheck if the metric was found. 56 | // metricsToCheck must be empty in the end. 57 | for _, m := range mf { 58 | delete(metricsToCheck, m.GetName()) 59 | } 60 | 61 | require.Len(t, metricsToCheck, 0, "Some metrics weren't reported: %v", metricsToCheck) 62 | } 63 | -------------------------------------------------------------------------------- /internal/dnsserver/ratelimit/allowlist.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "sync" 7 | ) 8 | 9 | // IP Address And Network Allowlist 10 | 11 | // Allowlist decides whether ip should be excluded from rate limiting. All 12 | // methods bust be safe for concurrent use. 13 | type Allowlist interface { 14 | IsAllowed(ctx context.Context, ip netip.Addr) (ok bool, err error) 15 | } 16 | 17 | // DynamicAllowlist is an allowlist that has a dynamic and a persistent list of 18 | // IP networks to allow. 19 | type DynamicAllowlist struct { 20 | // mu protects dynamic. 21 | mu *sync.RWMutex 22 | dynamic []netip.Prefix 23 | 24 | persistent []netip.Prefix 25 | } 26 | 27 | // NewDynamicAllowlist returns a new dynamic allow list. 28 | func NewDynamicAllowlist(persistent, dynamic []netip.Prefix) (l *DynamicAllowlist) { 29 | l = &DynamicAllowlist{ 30 | mu: &sync.RWMutex{}, 31 | dynamic: dynamic, 32 | persistent: persistent, 33 | } 34 | 35 | return l 36 | } 37 | 38 | // IsAllowed implements the Allowlist interface for *DynamicAllowlist. 39 | func (l *DynamicAllowlist) IsAllowed(_ context.Context, ip netip.Addr) (ok bool, err error) { 40 | for _, n := range l.persistent { 41 | if n.Contains(ip) { 42 | return true, nil 43 | } 44 | } 45 | 46 | l.mu.RLock() 47 | defer l.mu.RUnlock() 48 | 49 | for _, n := range l.dynamic { 50 | if n.Contains(ip) { 51 | return true, nil 52 | } 53 | } 54 | 55 | return false, nil 56 | } 57 | 58 | // Update replaces the previous list of dynamic subnets with nets. 59 | func (l *DynamicAllowlist) Update(subnets []netip.Prefix) { 60 | l.mu.Lock() 61 | defer l.mu.Unlock() 62 | 63 | l.dynamic = subnets 64 | } 65 | -------------------------------------------------------------------------------- /internal/dnsserver/ratelimit/counter.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/AdguardTeam/golibs/container" 8 | ) 9 | 10 | // RequestCounter is a single request-per-interval counter. 11 | // 12 | // TODO(a.garipov): Add clock inteface. 13 | type RequestCounter struct { 14 | // mu protects all fields. 15 | mu *sync.Mutex 16 | 17 | // ring is a container with requests records. It is never nil. 18 | ring *container.RingBuffer[int64] 19 | 20 | // ivl is a time duration in which the requests are counted. 21 | ivl time.Duration 22 | } 23 | 24 | // NewRequestCounter returns a new requests-per-interval counter. 25 | func NewRequestCounter(num uint, ivl time.Duration) (r *RequestCounter) { 26 | return &RequestCounter{ 27 | mu: &sync.Mutex{}, 28 | // Add one, because we need to always keep track of the previous 29 | // request. For example, consider num == 1. 30 | ring: container.NewRingBuffer[int64](num + 1), 31 | ivl: ivl, 32 | } 33 | } 34 | 35 | // Add adds another request to r. isAbove is true if the request goes above the 36 | // counter value. It is safe for concurrent use. 37 | func (r *RequestCounter) Add(t time.Time) (isAbove bool) { 38 | r.mu.Lock() 39 | defer r.mu.Unlock() 40 | 41 | ts := t.UnixNano() 42 | 43 | r.ring.Push(ts) 44 | tail := r.ring.Current() 45 | 46 | return tail > 0 && ts-tail <= int64(1*r.ivl) 47 | } 48 | -------------------------------------------------------------------------------- /internal/dnsserver/ratelimit/metrics.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | // Metrics is an interface for monitoring the [ratelimit.Middleware] state. The 11 | // middleware user may opt to supply a metrics interface implementation that 12 | // would increment different kinds of metrics (for instance, Prometheus 13 | // metrics). 14 | type Metrics interface { 15 | // OnRateLimited is called when the DNS query is dropped. 16 | OnRateLimited(ctx context.Context, req *dns.Msg, rw dnsserver.ResponseWriter) 17 | 18 | // OnAllowlisted is called when the DNS query is allowlisted. 19 | OnAllowlisted(ctx context.Context, req *dns.Msg, rw dnsserver.ResponseWriter) 20 | } 21 | 22 | // EmptyMetrics implements [Metrics] with empty functions. This implementation 23 | // is used by default if the user does not supply a custom one. 24 | type EmptyMetrics struct{} 25 | 26 | // type check 27 | var _ Metrics = EmptyMetrics{} 28 | 29 | // OnRateLimited implements the [Metrics] interface for *EmptyMetrics. 30 | func (EmptyMetrics) OnRateLimited(context.Context, *dns.Msg, dnsserver.ResponseWriter) {} 31 | 32 | // OnAllowlisted implements the [Metrics] interface for EmptyMetrics. 33 | func (EmptyMetrics) OnAllowlisted(context.Context, *dns.Msg, dnsserver.ResponseWriter) {} 34 | -------------------------------------------------------------------------------- /internal/dnsserver/staticcheck.conf: -------------------------------------------------------------------------------- 1 | checks = ["all"] 2 | initialisms = [ 3 | # See https://github.com/dominikh/go-tools/blob/master/config/config.go. 4 | # 5 | # Do not add "PTR" since we use "Ptr" as a suffix. 6 | "inherit" 7 | , "DNSSEC" 8 | , "EDNS" 9 | , "MX" 10 | , "QUIC" 11 | , "SDNS" 12 | , "SVCB" 13 | , "TLD" 14 | ] 15 | dot_import_whitelist = [] 16 | http_status_code_whitelist = [] 17 | -------------------------------------------------------------------------------- /internal/dnsserver/tls.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | ) 7 | 8 | // tlsListener is the implementation of net.Listener that accepts tls.Conn. 9 | // The only point of using our own implementation is to close underlying TCP 10 | // connections gracefully. 11 | // The bug itself is described here: https://github.com/golang/go/issues/45709. 12 | type tlsListener struct { 13 | tcp net.Listener 14 | tlsConfig *tls.Config 15 | } 16 | 17 | // newTLSListener creates a new instance of tlsListener. 18 | func newTLSListener(l net.Listener, tlsConfig *tls.Config) (tlsListen *tlsListener) { 19 | return &tlsListener{ 20 | tcp: l, 21 | tlsConfig: tlsConfig, 22 | } 23 | } 24 | 25 | // type check 26 | var _ net.Listener = (*tlsListener)(nil) 27 | 28 | // Accept implements the net.Listener interface for *tlsListener. 29 | func (l *tlsListener) Accept() (conn net.Conn, err error) { 30 | var c net.Conn 31 | c, err = l.tcp.Accept() 32 | if err != nil { 33 | return nil, err 34 | } 35 | conn = &tlsConn{ 36 | Conn: tls.Server(c, l.tlsConfig), 37 | baseConn: c, 38 | } 39 | return conn, nil 40 | } 41 | 42 | // Close implements the net.Listener interface for *tlsListener. 43 | func (l *tlsListener) Close() (err error) { 44 | return l.tcp.Close() 45 | } 46 | 47 | // Addr implements the net.Listener interface for *tlsListener. 48 | func (l *tlsListener) Addr() (addr net.Addr) { 49 | return l.tcp.Addr() 50 | } 51 | 52 | // tlsConn is the implementation of net.Conn with a minuscule change 53 | // When "Close" method is called, it closes underlying connection instead 54 | // of sending the TLS close_notify alert. 55 | type tlsConn struct { 56 | *tls.Conn 57 | baseConn net.Conn // underlying TCP connection 58 | } 59 | 60 | // type check 61 | var _ net.Conn = (*tlsConn)(nil) 62 | 63 | // Close implements the net.Conn interface for *tlsConn. 64 | // It changes the basic logic in order to fix this issue: 65 | // https://github.com/golang/go/issues/45709 66 | func (c *tlsConn) Close() (err error) { 67 | return c.baseConn.Close() 68 | } 69 | -------------------------------------------------------------------------------- /internal/dnsserver/ttl.go: -------------------------------------------------------------------------------- 1 | package dnsserver 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/miekg/dns" 7 | ) 8 | 9 | const ( 10 | // minimalDefaultTTL is the absolute lowest TTL we can use. 11 | minimalDefaultTTL = 5 * time.Second 12 | // maximumDefaultTTL is the maximum TTL was use on RRsets. 13 | maximumDefaultTTL = 1 * time.Hour 14 | ) 15 | 16 | // minimalTTL scans the message and returns the lowest TTL found. 17 | func minimalTTL(m *dns.Msg) (d time.Duration) { 18 | if m.Rcode != dns.RcodeSuccess && m.Rcode != dns.RcodeNameError { 19 | return minimalDefaultTTL 20 | } 21 | 22 | // If message is empty, i.e. there are no records with TTL 23 | // return a short ttl as a fail safe. 24 | if isEmptyMessage(m) { 25 | return minimalDefaultTTL 26 | } 27 | 28 | return minimalTTLMsgRRs(m) 29 | } 30 | 31 | // isEmptyRequest returns true if the message has no records at all 32 | // or if it has just an OPT record. We consider it an "empty" message 33 | // in this case. 34 | func isEmptyMessage(m *dns.Msg) (empty bool) { 35 | return len(m.Answer) == 0 && len(m.Ns) == 0 && 36 | (len(m.Extra) == 0 || 37 | (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) 38 | } 39 | 40 | // minimalTTLMsgRRs gets minimal TTL from all message RRs. 41 | func minimalTTLMsgRRs(m *dns.Msg) (d time.Duration) { 42 | minTTL32 := uint32(maximumDefaultTTL.Seconds()) 43 | 44 | for _, r := range m.Answer { 45 | minTTL32 = min(minTTL32, r.Header().Ttl) 46 | } 47 | 48 | for _, r := range m.Ns { 49 | minTTL32 = min(minTTL32, r.Header().Ttl) 50 | } 51 | 52 | for _, r := range m.Extra { 53 | // OPT records use TTL field for extended rcode and flags. 54 | if h := r.Header(); h.Rrtype != dns.TypeOPT { 55 | minTTL32 = min(minTTL32, h.Ttl) 56 | } 57 | } 58 | 59 | return time.Duration(minTTL32) * time.Second 60 | } 61 | -------------------------------------------------------------------------------- /internal/dnssvc/context.go: -------------------------------------------------------------------------------- 1 | package dnssvc 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 8 | "github.com/AdguardTeam/golibs/contextutil" 9 | ) 10 | 11 | // contextConstructor is a [contextutil.Constructor] implementation that returns 12 | // a context with the given timeout as well as a new [agd.RequestID]. 13 | type contextConstructor struct { 14 | timeout time.Duration 15 | } 16 | 17 | // newContextConstructor returns a new properly initialized *contextConstructor. 18 | func newContextConstructor(timeout time.Duration) (c *contextConstructor) { 19 | return &contextConstructor{ 20 | timeout: timeout, 21 | } 22 | } 23 | 24 | // type check 25 | var _ contextutil.Constructor = (*contextConstructor)(nil) 26 | 27 | // New implements the [contextutil.Constructor] interface for 28 | // *contextConstructor. It returns a context with a new [agd.RequestID] as well 29 | // as its timeout and the corresponding cancellation function. 30 | func (c *contextConstructor) New( 31 | parent context.Context, 32 | ) (ctx context.Context, cancel context.CancelFunc) { 33 | ctx, cancel = context.WithTimeout(parent, c.timeout) 34 | ctx = agd.WithRequestID(ctx, agd.NewRequestID()) 35 | 36 | return ctx, cancel 37 | } 38 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/devicefinder/customdomain.go: -------------------------------------------------------------------------------- 1 | package devicefinder 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | ) 8 | 9 | // CustomDomainDB contains information about custom domains and matches domains. 10 | type CustomDomainDB interface { 11 | // Match returns the domain name or wildcard that matches the client-sent 12 | // server name. cliSrvName must be lowercased. 13 | // 14 | // If there is a match, matchedDomain must be a valid domain name or 15 | // wildcard, and profIDs must not be empty and its items must be valid. 16 | // Otherwise, matchedDomain must be empty and profIDs must be nil. 17 | // 18 | // TODO(a.garipov, e.burkov): Reduce allocations of profIDs. 19 | Match(ctx context.Context, cliSrvName string) (matchedDomain string, profIDs []agd.ProfileID) 20 | } 21 | 22 | // EmptyCustomDomainDB is an [CustomDomainDB] that does nothing. 23 | type EmptyCustomDomainDB struct{} 24 | 25 | // type check 26 | var _ CustomDomainDB = EmptyCustomDomainDB{} 27 | 28 | // Match implements the [CustomDomainDB] interface for EmptyCustomDomainDB. 29 | // matchedDomain and profID are always empty. 30 | func (EmptyCustomDomainDB) Match( 31 | _ context.Context, 32 | _ string, 33 | ) (matchedDomain string, profIDs []agd.ProfileID) { 34 | return "", nil 35 | } 36 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/devicefinder/error.go: -------------------------------------------------------------------------------- 1 | package devicefinder 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 7 | "github.com/AdguardTeam/golibs/errors" 8 | ) 9 | 10 | // Authentication errors. 11 | // 12 | // TODO(a.garipov): Consider using errors from package [errors] instead of some 13 | // of these. 14 | const ( 15 | ErrAuthenticationFailed errors.Error = "basic authentication failed" 16 | ErrNoPassword errors.Error = "no password" 17 | ErrNoUserInfo errors.Error = "no userinfo" 18 | ErrNotDoH errors.Error = "not doh" 19 | ) 20 | 21 | // deviceDataError is an error about bad device data or other issues found 22 | // during device data checking. 23 | type deviceDataError struct { 24 | err error 25 | typ string 26 | } 27 | 28 | // type check 29 | var _ error = (*deviceDataError)(nil) 30 | 31 | // newDeviceDataError is a helper constructor for device-data errors. 32 | func newDeviceDataError(orig error, typ string) (err error) { 33 | return &deviceDataError{ 34 | err: orig, 35 | typ: typ, 36 | } 37 | } 38 | 39 | // Error implements the error interface for *deviceDataError. 40 | func (err *deviceDataError) Error() (msg string) { 41 | return fmt.Sprintf("%s device id check: %s", err.typ, err.err) 42 | } 43 | 44 | // type check 45 | var _ errors.Wrapper = (*deviceDataError)(nil) 46 | 47 | // Unwrap implements the [errors.Wrapper] interface for *deviceDataError. 48 | func (err *deviceDataError) Unwrap() (unwrapped error) { return err.err } 49 | 50 | // type check 51 | var _ errcoll.SentryReportableError = (*deviceDataError)(nil) 52 | 53 | // IsSentryReportable implements the [errcoll.SentryReportableError] interface 54 | // for *deviceDataError. 55 | func (*deviceDataError) IsSentryReportable() (ok bool) { return false } 56 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/devicefinder/metrics.go: -------------------------------------------------------------------------------- 1 | package devicefinder 2 | 3 | import "context" 4 | 5 | // Metrics is an interface for collection of the statistics of the default 6 | // device finder. 7 | type Metrics interface { 8 | // IncrementCustomDomainMismatches is called when a detected device does not 9 | // belong to the profile which the custom domain belongs to. 10 | IncrementCustomDomainMismatches(ctx context.Context, domain string) 11 | 12 | // IncrementCustomDomainRequests is called when a request is recognized as 13 | // being to a custom domain belonging to a profile. 14 | IncrementCustomDomainRequests(ctx context.Context, domain string) 15 | 16 | // IncrementDoHAuthenticationFails is called when a request fails DoH 17 | // authentication. 18 | IncrementDoHAuthenticationFails(ctx context.Context) 19 | 20 | // IncrementUnknownDedicated is called when the DNS request is sent to an 21 | // unknown local address. 22 | IncrementUnknownDedicated(ctx context.Context) 23 | } 24 | 25 | // EmptyMetrics is an empty [Metrics] implementation that does nothing. 26 | type EmptyMetrics struct{} 27 | 28 | // type check 29 | var _ Metrics = EmptyMetrics{} 30 | 31 | // IncrementCustomDomainRequests implements the [Metrics] interface for 32 | // EmptyMetrics. 33 | func (EmptyMetrics) IncrementCustomDomainRequests(_ context.Context, _ string) {} 34 | 35 | // IncrementCustomDomainMismatches implements the [Metrics] interface for 36 | // EmptyMetrics. 37 | func (EmptyMetrics) IncrementCustomDomainMismatches(_ context.Context, _ string) {} 38 | 39 | // IncrementDoHAuthenticationFails implements the [Metrics] interface for 40 | // EmptyMetrics. 41 | func (EmptyMetrics) IncrementDoHAuthenticationFails(_ context.Context) {} 42 | 43 | // IncrementUnknownDedicated implements the [Metrics] interface for 44 | // EmptyMetrics. 45 | func (EmptyMetrics) IncrementUnknownDedicated(_ context.Context) {} 46 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/initial/initial_test.go: -------------------------------------------------------------------------------- 1 | package initial_test 2 | 3 | // TODO(a.garipov): Rewrite tests. 4 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/internal.go: -------------------------------------------------------------------------------- 1 | // Package internal contains common utilities for DNS middlewares. 2 | package internal 3 | 4 | import "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver" 5 | 6 | // MakeNonWriter makes rw a *dnsserver.NonWriterResponseWriter unless it already 7 | // is one, in which case it just returns it. 8 | func MakeNonWriter(rw dnsserver.ResponseWriter) (nwrw *dnsserver.NonWriterResponseWriter) { 9 | nwrw, ok := rw.(*dnsserver.NonWriterResponseWriter) 10 | if ok { 11 | return nwrw 12 | } 13 | 14 | return dnsserver.NewNonWriterResponseWriter(rw.LocalAddr(), rw.RemoteAddr()) 15 | } 16 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/mainmw/error.go: -------------------------------------------------------------------------------- 1 | package mainmw 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 8 | "github.com/AdguardTeam/golibs/errors" 9 | ) 10 | 11 | // afterFilteringError is returned by the handler function of [Middleware.Wrap] 12 | // in case there is an error after filtering. 13 | type afterFilteringError struct { 14 | err error 15 | } 16 | 17 | // type check 18 | var _ error = afterFilteringError{} 19 | 20 | // Error implements the error interface for afterFilteringError. 21 | func (err afterFilteringError) Error() (msg string) { 22 | return fmt.Sprintf("after filtering: %s", err.err) 23 | } 24 | 25 | // type check 26 | var _ errors.Wrapper = afterFilteringError{} 27 | 28 | // Unwrap implements the [errors.Wrapper] interface for afterFilteringError. 29 | func (err afterFilteringError) Unwrap() (unwrapped error) { 30 | return err.err 31 | } 32 | 33 | // type check 34 | var _ errcoll.SentryReportableError = afterFilteringError{} 35 | 36 | // IsSentryReportable implements the [errcoll.SentryReportableError] interface 37 | // for afterFilteringError. 38 | func (err afterFilteringError) IsSentryReportable() (ok bool) { 39 | return !errors.Is(err.err, context.DeadlineExceeded) && 40 | !errors.Is(err.err, context.Canceled) 41 | } 42 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/mainmw/metrics.go: -------------------------------------------------------------------------------- 1 | package mainmw 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "time" 7 | ) 8 | 9 | // Metrics is an interface for collection of the statistics of the main 10 | // filtering middleware. 11 | type Metrics interface { 12 | // OnRequest records the request metrics. m must not be nil. 13 | OnRequest(ctx context.Context, m *RequestMetrics) 14 | } 15 | 16 | // RequestMetrics is an alias for a structure that contains the information 17 | // about a request that has reached the filtering middleware. 18 | // 19 | // NOTE: This is an alias to reduce the amount of dependencies required of 20 | // implementations. This is also the reason why only built-in or stdlib types 21 | // are used. 22 | type RequestMetrics = struct { 23 | // RemoteIP is the IP address of the client. 24 | RemoteIP netip.Addr 25 | 26 | // Continent is the continent code, if any. 27 | Continent string 28 | 29 | // Country is the country code, if any. 30 | Country string 31 | 32 | // FilterListID is the ID of the filtering-rule list affecting this query, 33 | // if any. 34 | FilterListID string 35 | 36 | // FilteringDuration is the total amount of time spent filtering the query. 37 | FilteringDuration time.Duration 38 | 39 | // ASN is the autonomous-system number, if any. 40 | ASN uint32 41 | 42 | // IsAnonymous is true if the request does not have a profile associated 43 | // with it. 44 | IsAnonymous bool 45 | 46 | // IsBlocked is true if the request is blocked or rewritten. 47 | IsBlocked bool 48 | } 49 | 50 | // EmptyMetrics is an implementation of the [Metrics] interface that does 51 | // nothing. 52 | type EmptyMetrics struct{} 53 | 54 | // type check 55 | var _ Metrics = EmptyMetrics{} 56 | 57 | // OnRequest implements the [Metrics] interface for EmptyMetrics. 58 | func (EmptyMetrics) OnRequest(_ context.Context, _ *RequestMetrics) {} 59 | -------------------------------------------------------------------------------- /internal/dnssvc/internal/ratelimitmw/access.go: -------------------------------------------------------------------------------- 1 | package ratelimitmw 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 8 | "github.com/AdguardTeam/golibs/logutil/optslog" 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // isBlockedByAccess returns true if req is blocked by global or profile access 13 | // settings. 14 | func (mw *Middleware) isBlockedByAccess( 15 | ctx context.Context, 16 | ri *agd.RequestInfo, 17 | req *dns.Msg, 18 | raddr netip.AddrPort, 19 | ) (isBlocked bool) { 20 | // NOTE: Global access has priority over the profile one. 21 | if mw.accessManager.IsBlockedIP(raddr.Addr()) { 22 | mw.metrics.IncrementAccessBlockedBySubnet(ctx) 23 | optslog.Debug1(ctx, mw.logger, "access denied globally by ip", "remote_ip", ri.RemoteIP) 24 | 25 | return true 26 | } else if mw.accessManager.IsBlockedHost(ri.Host, ri.QType) { 27 | mw.metrics.IncrementAccessBlockedByHost(ctx) 28 | optslog.Debug2( 29 | ctx, 30 | mw.logger, 31 | "access denied globally by rule", 32 | "remote_ip", ri.RemoteIP, 33 | "host", ri.Host, 34 | ) 35 | 36 | return true 37 | } 38 | 39 | p, _ := ri.DeviceData() 40 | if p == nil { 41 | return false 42 | } 43 | 44 | if p.Access.IsBlocked(ctx, req, raddr, ri.Location) { 45 | mw.metrics.IncrementAccessBlockedByProfile(ctx) 46 | optslog.Debug2( 47 | ctx, 48 | mw.logger, 49 | "access denied by profile", 50 | "remote_ip", ri.RemoteIP, 51 | "profile_id", p.ID, 52 | ) 53 | 54 | return true 55 | } 56 | 57 | return false 58 | } 59 | -------------------------------------------------------------------------------- /internal/dnssvc/reexport.go: -------------------------------------------------------------------------------- 1 | package dnssvc 2 | 3 | import ( 4 | "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/devicefinder" 5 | "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/initial" 6 | "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/mainmw" 7 | "github.com/AdguardTeam/AdGuardDNS/internal/dnssvc/internal/ratelimitmw" 8 | ) 9 | 10 | // Re-exports related to configuration. 11 | type ( 12 | // DDRConfig is the configuration for the server group's Discovery Of 13 | // Designated Resolvers (DDR) handlers. 14 | DDRConfig = initial.DDRConfig 15 | ) 16 | 17 | // Re-exports related to custom domains. 18 | type ( 19 | // CustomDomainDB contains information about custom domains and matches domains. 20 | CustomDomainDB = devicefinder.CustomDomainDB 21 | 22 | // EmptyCustomDomainDB is an [CustomDomainDB] that does nothing. 23 | EmptyCustomDomainDB = devicefinder.EmptyCustomDomainDB 24 | ) 25 | 26 | // Re-exports related to metrics. 27 | type ( 28 | // DeviceFinderMetrics is an interface for collection of the statistics of 29 | // the default device finder. 30 | DeviceFinderMetrics = devicefinder.Metrics 31 | 32 | // InitialMiddlewareMetrics is an interface for monitoring the initial 33 | // middleware state. 34 | InitialMiddlewareMetrics = initial.Metrics 35 | 36 | // MainMiddlewareMetrics is an interface for collection of the statistics of 37 | // the main filtering middleware. 38 | MainMiddlewareMetrics = mainmw.Metrics 39 | 40 | // RatelimitMiddlewareMetrics is an interface for monitoring the ratelimit 41 | // middleware state. 42 | RatelimitMiddlewareMetrics = ratelimitmw.Metrics 43 | ) 44 | -------------------------------------------------------------------------------- /internal/ecscache/cache_internal_test.go: -------------------------------------------------------------------------------- 1 | package ecscache 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/agdcache" 9 | "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" 10 | "github.com/AdguardTeam/AdGuardDNS/internal/dnsserver/dnsservertest" 11 | "github.com/AdguardTeam/golibs/logutil/slogutil" 12 | "github.com/AdguardTeam/golibs/timeutil" 13 | "github.com/miekg/dns" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func BenchmarkMiddleware(b *testing.B) { 18 | mw := NewMiddleware(&MiddlewareConfig{ 19 | Metrics: EmptyMetrics{}, 20 | Clock: timeutil.SystemClock{}, 21 | Cloner: agdtest.NewCloner(), 22 | Logger: slogutil.NewDiscardLogger(), 23 | CacheManager: agdcache.EmptyManager{}, 24 | GeoIP: agdtest.NewGeoIP(), 25 | NoECSCount: 100, 26 | ECSCount: 100, 27 | }) 28 | 29 | const ( 30 | host = "benchmark.example" 31 | fqdn = host + "." 32 | 33 | defaultTTL uint32 = 3600 34 | ) 35 | 36 | reqAddr := netip.MustParseAddr("1.2.3.4") 37 | 38 | req := dnsservertest.NewReq(fqdn, dns.TypeA, dns.ClassINET) 39 | cr := &cacheRequest{ 40 | host: host, 41 | subnet: netip.MustParsePrefix("1.2.3.0/24"), 42 | qType: dns.TypeA, 43 | qClass: dns.ClassINET, 44 | reqDO: true, 45 | } 46 | resp := dnsservertest.NewResp(dns.RcodeSuccess, req, dnsservertest.SectionAnswer{ 47 | dnsservertest.NewA(host, defaultTTL, reqAddr), 48 | }) 49 | 50 | ctx := context.Background() 51 | 52 | var msg *dns.Msg 53 | 54 | b.ReportAllocs() 55 | for b.Loop() { 56 | mw.set(resp, cr, true) 57 | 58 | msg, _ = mw.get(ctx, req, cr) 59 | } 60 | 61 | assert.NotNil(b, msg) 62 | 63 | // Most recent results: 64 | // 65 | // goos: darwin 66 | // goarch: arm64 67 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/ecscache 68 | // cpu: Apple M1 Pro 69 | // BenchmarkMiddleware_Get-8 1647064 726.8 ns/op 568 B/op 12 allocs/op 70 | } 71 | -------------------------------------------------------------------------------- /internal/ecscache/ecscache_internal_test.go: -------------------------------------------------------------------------------- 1 | package ecscache 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | "testing/quick" 7 | "time" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // isSafeFloatInt returns true if d can be safely represented inside a float64. 13 | func isSafeFloatInt(d time.Duration) (ok bool) { 14 | const ( 15 | maxSafeFloatInt = 1<<53 - 1 16 | minSafeFloatInt = -maxSafeFloatInt 17 | ) 18 | 19 | return d > minSafeFloatInt && d < maxSafeFloatInt 20 | } 21 | 22 | func TestRoundDiv(t *testing.T) { 23 | roundDivCheck := func(a, b time.Duration) (res time.Duration) { 24 | if !isSafeFloatInt(a) || !isSafeFloatInt(b) { 25 | return 0 26 | } 27 | 28 | return roundDiv(a, b) 29 | } 30 | 31 | mathRoundCheck := func(a, b time.Duration) (res time.Duration) { 32 | if !isSafeFloatInt(a) || !isSafeFloatInt(b) { 33 | return 0 34 | } 35 | 36 | return time.Duration(math.Round(float64(a) / float64(b))) 37 | } 38 | 39 | assert.NoError(t, quick.CheckEqual(roundDivCheck, mathRoundCheck, &quick.Config{ 40 | MaxCount: 100_000, 41 | })) 42 | } 43 | -------------------------------------------------------------------------------- /internal/ecscache/metrics.go: -------------------------------------------------------------------------------- 1 | package ecscache 2 | 3 | import "context" 4 | 5 | // Metrics is an interface that is used for the collection of the ECS cache 6 | // statistics. 7 | type Metrics interface { 8 | // SetElementsCount sets the total number of items in the cache for domain 9 | // names that support or do not support ECS. 10 | SetElementsCount(ctx context.Context, supportsECS bool, count int) 11 | 12 | // IncrementLookups increments the number of ECS cache lookups for hosts 13 | // that does or doesn't support ECS. 14 | IncrementLookups(ctx context.Context, supportsECS, hit bool) 15 | } 16 | 17 | // EmptyMetrics is the implementation of the [Metrics] interface that does 18 | // nothing. 19 | type EmptyMetrics struct{} 20 | 21 | // type check 22 | var _ Metrics = EmptyMetrics{} 23 | 24 | // SetElementsCount implements the [Metrics] interface for EmptyMetrics. 25 | func (EmptyMetrics) SetElementsCount(_ context.Context, _ bool, _ int) {} 26 | 27 | // IncrementLookups implements the [Metrics] interface for EmptyMetrics. 28 | func (EmptyMetrics) IncrementLookups(_ context.Context, _, _ bool) {} 29 | -------------------------------------------------------------------------------- /internal/errcoll/errcoll.go: -------------------------------------------------------------------------------- 1 | // Package errcoll contains implementations of error collectors, most notably 2 | // Sentry. 3 | package errcoll 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "log/slog" 9 | 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/service" 12 | ) 13 | 14 | // Interface is the interface for error collectors that process information 15 | // about errors, possibly sending them to a remote location. 16 | type Interface interface { 17 | Collect(ctx context.Context, err error) 18 | } 19 | 20 | // Collect is a helper method for reporting non-critical errors. It writes the 21 | // resulting error into the log and also into errColl. 22 | // 23 | // TODO(a.garipov): Find a way to extract the prefix from l and add to err. 24 | func Collect(ctx context.Context, errColl Interface, l *slog.Logger, msg string, err error) { 25 | l.ErrorContext(ctx, msg, slogutil.KeyError, err) 26 | errColl.Collect(ctx, fmt.Errorf("%s: %w", msg, err)) 27 | } 28 | 29 | // RefreshErrorHandler is a [service.ErrorHandler] that can be used whenever a 30 | // [service.Refresher] cannot report its own errors for some reason. 31 | type RefreshErrorHandler struct { 32 | logger *slog.Logger 33 | errColl Interface 34 | } 35 | 36 | // NewRefreshErrorHandler returns a properly initialized *RefreshErrorHandler. 37 | // All arguments must not be nil. 38 | func NewRefreshErrorHandler(logger *slog.Logger, errColl Interface) (h *RefreshErrorHandler) { 39 | return &RefreshErrorHandler{ 40 | logger: logger, 41 | errColl: errColl, 42 | } 43 | } 44 | 45 | // type check 46 | var _ service.ErrorHandler = (*RefreshErrorHandler)(nil) 47 | 48 | // Handle implements the [service.ErrorHandler] interface for 49 | // *RefreshErrorHandler. 50 | func (h *RefreshErrorHandler) Handle(ctx context.Context, err error) { 51 | Collect(ctx, h.errColl, h.logger, "refreshing", err) 52 | } 53 | -------------------------------------------------------------------------------- /internal/errcoll/writer.go: -------------------------------------------------------------------------------- 1 | package errcoll 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "time" 8 | 9 | "github.com/AdguardTeam/golibs/errors" 10 | ) 11 | 12 | // WriterErrorCollector is an [Interface] implementation that writes errors to 13 | // an [io.Writer]. 14 | type WriterErrorCollector struct { 15 | w io.Writer 16 | } 17 | 18 | // NewWriterErrorCollector returns a new properly initialized 19 | // *WriterErrorCollector. 20 | func NewWriterErrorCollector(w io.Writer) (c *WriterErrorCollector) { 21 | return &WriterErrorCollector{ 22 | w: w, 23 | } 24 | } 25 | 26 | // type check 27 | var _ Interface = (*WriterErrorCollector)(nil) 28 | 29 | // Collect implements the [Interface] interface for *WriterErrorCollector. 30 | func (c *WriterErrorCollector) Collect(ctx context.Context, err error) { 31 | var ( 32 | sentryRepErr SentryReportableError 33 | isIface bool 34 | isReportable bool 35 | ) 36 | if isIface = errors.As(err, &sentryRepErr); isIface { 37 | isReportable = sentryRepErr.IsSentryReportable() 38 | } 39 | 40 | _, _ = fmt.Fprintf( 41 | c.w, 42 | "%s: caught error: %s (sentry iface: %t, reportable: %t)\n", 43 | time.Now(), 44 | err, 45 | isIface, 46 | isReportable, 47 | ) 48 | } 49 | -------------------------------------------------------------------------------- /internal/errcoll/writer_test.go: -------------------------------------------------------------------------------- 1 | package errcoll_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | 8 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 9 | "github.com/AdguardTeam/golibs/errors" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestWriterErrorCollector(t *testing.T) { 14 | buf := &bytes.Buffer{} 15 | c := errcoll.NewWriterErrorCollector(buf) 16 | c.Collect(context.Background(), errors.Error("test error")) 17 | 18 | wantRx := `.*: caught error: test error.*` 19 | got := buf.String() 20 | assert.Regexp(t, wantRx, got) 21 | } 22 | -------------------------------------------------------------------------------- /internal/experiment/experiment.go: -------------------------------------------------------------------------------- 1 | // Package experiment occasionally contains code for one-off experiments. 2 | // Experiments can be enabled using the EXPERIMENTS environment variable, which 3 | // is a comma-separated list of experiment IDs. 4 | // 5 | // Please keep every experiment in its own file. 6 | // 7 | // Since the code living here is short-living, the following requirements do not 8 | // apply: 9 | // 10 | // - Comments may be skipped. 11 | // - Some errors may be logged or ignored. 12 | // - Tests may be lacking. 13 | // - The environment may be read here as opposed to package cmd. 14 | package experiment 15 | 16 | import ( 17 | "log/slog" 18 | "os" 19 | 20 | "github.com/AdguardTeam/AdGuardDNS/internal/metrics" 21 | "github.com/AdguardTeam/golibs/stringutil" 22 | "github.com/prometheus/client_golang/prometheus" 23 | ) 24 | 25 | func Init(l *slog.Logger, reg prometheus.Registerer) (err error) { 26 | expStr := os.Getenv("EXPERIMENTS") 27 | if expStr == "" { 28 | return nil 29 | } 30 | 31 | expIDs := stringutil.SplitTrimmed(expStr, ",") 32 | for _, id := range expIDs { 33 | switch id { 34 | // NOTE: Add experiments here in the following format: 35 | // case idMyExp: 36 | // enableMyExp() 37 | default: 38 | l.Error("no such experiment", "id", id) 39 | } 40 | } 41 | 42 | return metrics.SetExperimentGauge(reg, prometheus.Labels{ 43 | // NOTE: Add experiments here in the following format: 44 | // idMyExp: metrics.BoolString(expMyExpEnabled), 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /internal/filter/filter_test.go: -------------------------------------------------------------------------------- 1 | package filter_test 2 | 3 | import "strings" 4 | 5 | // Common long strings for tests. 6 | var ( 7 | testLongStr = strings.Repeat("a", 200) 8 | ) 9 | -------------------------------------------------------------------------------- /internal/filter/filterstorage/filterstorage.go: -------------------------------------------------------------------------------- 1 | // Package filterstorage defines an interface for a storage of filters as well 2 | // as the default implementation and the filter configuration. 3 | package filterstorage 4 | 5 | import ( 6 | "github.com/AdguardTeam/AdGuardDNS/internal/filter" 7 | ) 8 | 9 | // Additional synthetic filter IDs for refreshable indexes. 10 | // 11 | // TODO(a.garipov): Consider using a separate type. 12 | const ( 13 | FilterIDBlockedServiceIndex filter.ID = "blocked_service_index" 14 | FilterIDRuleListIndex filter.ID = "rule_list_index" 15 | FilterIDStandardProfileAccess filter.ID = "standard_profile_access" 16 | ) 17 | 18 | // Filenames for filter indexes. 19 | const ( 20 | indexFileNameBlockedServices = "services.json" 21 | indexFileNameRuleLists = "filters.json" 22 | indexFileNameStandardProfileAccess = "standard_profile_access.json" 23 | ) 24 | 25 | // Constants that define cache identifiers for the cache manager. 26 | const ( 27 | // cachePrefixSafeSearch is used as a cache prefix for safe-search filters. 28 | cachePrefixSafeSearch = "filters/safe_search" 29 | 30 | // cachePrefixRuleList is used a cache prefix for rule-list filters. 31 | cachePrefixRuleList = "filters/rulelist" 32 | ) 33 | -------------------------------------------------------------------------------- /internal/filter/filterstorage/index_internal_test.go: -------------------------------------------------------------------------------- 1 | package filterstorage 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestIndexRespFilter_compare(t *testing.T) { 11 | var ( 12 | fltA = &indexRespFilter{ 13 | Key: "a", 14 | } 15 | fltB = &indexRespFilter{ 16 | Key: "b", 17 | } 18 | ) 19 | 20 | want := []*indexRespFilter{ 21 | fltA, 22 | fltB, 23 | nil, 24 | nil, 25 | } 26 | 27 | got := []*indexRespFilter{ 28 | fltB, 29 | nil, 30 | fltA, 31 | nil, 32 | } 33 | 34 | slices.SortStableFunc(got, (*indexRespFilter).compare) 35 | 36 | assert.Equal(t, want, got) 37 | } 38 | -------------------------------------------------------------------------------- /internal/filter/filterstorage/testdata/TestStandardAccess_cache/bad_version/standard_profile_access.json: -------------------------------------------------------------------------------- 1 | { 2 | "unknown_field": "value", 3 | "schema_version": 0 4 | } 5 | -------------------------------------------------------------------------------- /internal/filter/filterstorage/testdata/TestStandardAccess_cache/success/standard_profile_access.json: -------------------------------------------------------------------------------- 1 | { 2 | "allowed_nets": [ 3 | "192.0.2.1/32" 4 | ], 5 | "blocked_nets": [ 6 | "192.0.2.2/32" 7 | ], 8 | "allowed_asns": [ 9 | 10 10 | ], 11 | "blocked_asns": [ 12 | 20 13 | ], 14 | "rules": [ 15 | "blocked.std.test", 16 | "@@allowed.std.test" 17 | ], 18 | "schema_version": 1 19 | } 20 | -------------------------------------------------------------------------------- /internal/filter/hashprefix/hashprefix.go: -------------------------------------------------------------------------------- 1 | // Package hashprefix defines a storage of hashes of domain names used for 2 | // filtering and serving TXT records with domain-name hashes. 3 | package hashprefix 4 | 5 | import "crypto/sha256" 6 | 7 | // Hash and hash part length constants. 8 | const ( 9 | // PrefixLen is the length of the hash prefix of the filtered hostname. 10 | PrefixLen = 2 11 | 12 | // PrefixEncLen is the encoded length of the hash prefix. Two text 13 | // bytes per one binary byte. 14 | PrefixEncLen = PrefixLen * 2 15 | 16 | // hashLen is the length of the whole hash of the checked hostname. 17 | hashLen = sha256.Size 18 | 19 | // suffixLen is the length of the hash suffix of the filtered hostname. 20 | suffixLen = hashLen - PrefixLen 21 | 22 | // hashEncLen is the encoded length of the hash. Two text bytes per one 23 | // binary byte. 24 | hashEncLen = hashLen * 2 25 | ) 26 | 27 | // Prefix is the type of the SHA256 hash prefix used to match against the 28 | // domain-name database. 29 | type Prefix [PrefixLen]byte 30 | 31 | // suffix is the type of the rest of a SHA256 hash of the filtered domain names. 32 | type suffix [suffixLen]byte 33 | -------------------------------------------------------------------------------- /internal/filter/hashprefix/hashprefix_test.go: -------------------------------------------------------------------------------- 1 | package hashprefix_test 2 | 3 | import ( 4 | "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" 5 | ) 6 | 7 | // testHashes is the host data for tests. 8 | const testHashes = filtertest.HostAdultContent + "\n" 9 | 10 | // testHashesData is the host data for tests. 11 | var testHashesData = []byte(testHashes) 12 | -------------------------------------------------------------------------------- /internal/filter/hashprefix/metrics.go: -------------------------------------------------------------------------------- 1 | package hashprefix 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // Metrics is an interface used for collection if the hashprefix filter 8 | // statistics. 9 | type Metrics interface { 10 | // IncrementLookups increments the number of lookups. hit is true if the 11 | // lookup returned a value. 12 | IncrementLookups(ctx context.Context, hit bool) 13 | 14 | // UpdateCacheSize is called when the cache size is updated. 15 | UpdateCacheSize(ctx context.Context, cacheLen int) 16 | } 17 | 18 | // EmptyMetrics is the implementation of the [Metrics] interface that does nothing. 19 | type EmptyMetrics struct{} 20 | 21 | // type check 22 | var _ Metrics = EmptyMetrics{} 23 | 24 | // IncrementLookups implements the [Metrics] interface for EmptyMetrics. 25 | func (EmptyMetrics) IncrementLookups(_ context.Context, _ bool) {} 26 | 27 | // UpdateCacheSize implements the [Metrics] interface for EmptyMetrics. 28 | func (EmptyMetrics) UpdateCacheSize(_ context.Context, _ int) {} 29 | -------------------------------------------------------------------------------- /internal/filter/internal/composite/composite_internal_test.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/filter" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/filtertest" 9 | "github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/rulelist" 10 | "github.com/AdguardTeam/urlfilter" 11 | "github.com/miekg/dns" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func BenchmarkFilter_FilterReqWithRuleLists(b *testing.B) { 16 | blockingRL := rulelist.NewFromString( 17 | filtertest.RuleBlockStr+"\n", 18 | "test", 19 | "", 20 | rulelist.EmptyResultCache{}, 21 | ) 22 | 23 | f := New(&Config{ 24 | URLFilterRequest: &urlfilter.DNSRequest{}, 25 | URLFilterResult: &urlfilter.DNSResult{}, 26 | RuleLists: []*rulelist.Refreshable{blockingRL}, 27 | }) 28 | 29 | ctx := context.Background() 30 | req := filtertest.NewRequest(b, "", filtertest.HostBlocked, filtertest.IPv4Client, dns.TypeA) 31 | 32 | var result filter.Result 33 | 34 | b.ReportAllocs() 35 | for b.Loop() { 36 | result = f.filterReqWithRuleLists(ctx, req) 37 | } 38 | 39 | assert.NotNil(b, result) 40 | 41 | // Most recent results: 42 | // goos: linux 43 | // goarch: amd64 44 | // pkg: github.com/AdguardTeam/AdGuardDNS/internal/filter/internal/composite 45 | // cpu: AMD Ryzen 7 PRO 4750U with Radeon Graphics 46 | // BenchmarkFilter_FilterReqWithRuleLists-16 748243 1394 ns/op 468 B/op 8 allocs/op 47 | } 48 | -------------------------------------------------------------------------------- /internal/filter/internal/composite/request.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/filter" 7 | "github.com/AdguardTeam/urlfilter" 8 | ) 9 | 10 | // RequestFilter can filter a request based on the request info. 11 | type RequestFilter interface { 12 | // FilterRequest filters a DNS request based on the information provided 13 | // about the request. req must be valid. 14 | FilterRequest(ctx context.Context, req *filter.Request) (r filter.Result, err error) 15 | } 16 | 17 | // RequestFilterUF can filter a request based on the request info and using 18 | // URLFilter data to optimize allocations. 19 | type RequestFilterUF interface { 20 | // FilterRequestUF filters a DNS request based on the information provided 21 | // about the request and using URLFilter data to optimize allocations. req 22 | // must be valid. ufReq and ufRes must not be nil and must be reset. 23 | FilterRequestUF( 24 | ctx context.Context, 25 | req *filter.Request, 26 | ufReq *urlfilter.DNSRequest, 27 | ufRes *urlfilter.DNSResult, 28 | ) (r filter.Result, err error) 29 | } 30 | 31 | // ufRequestFilter is a wrapper around a [RequestFilterUF] that uses the 32 | // provided URLFilter data. 33 | type ufRequestFilter struct { 34 | flt RequestFilterUF 35 | req *urlfilter.DNSRequest 36 | res *urlfilter.DNSResult 37 | } 38 | 39 | // type check 40 | var _ RequestFilter = (*ufRequestFilter)(nil) 41 | 42 | // FilterRequest implements the [RequestFilter] interface for *ufRequestFilter. 43 | func (f *ufRequestFilter) FilterRequest( 44 | ctx context.Context, 45 | req *filter.Request, 46 | ) (r filter.Result, err error) { 47 | f.req.Reset() 48 | f.res.Reset() 49 | 50 | return f.flt.FilterRequestUF(ctx, req, f.req, f.res) 51 | } 52 | -------------------------------------------------------------------------------- /internal/filter/internal/filtertest/refresh.go: -------------------------------------------------------------------------------- 1 | package filtertest 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "net/url" 8 | "os" 9 | "path/filepath" 10 | "testing" 11 | 12 | "github.com/AdguardTeam/AdGuardDNS/internal/agdhttp" 13 | "github.com/AdguardTeam/golibs/httphdr" 14 | "github.com/AdguardTeam/golibs/testutil" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | // PrepareRefreshable launches an HTTP server serving the given text and code, 19 | // as well as creates a cache file. If reqCh not nil, a signal is sent every 20 | // time the server is called. The server uses [ServerName] as the value of the 21 | // Server header. 22 | // 23 | // TODO(a.garipov): Rewrite to use []byte for text. 24 | func PrepareRefreshable( 25 | tb testing.TB, 26 | reqCh chan<- struct{}, 27 | text string, 28 | code int, 29 | ) (cachePath string, srvURL *url.URL) { 30 | tb.Helper() 31 | 32 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 33 | pt := testutil.PanicT{} 34 | if reqCh != nil { 35 | testutil.RequireSend(pt, reqCh, struct{}{}, Timeout) 36 | } 37 | 38 | w.Header().Set(httphdr.Server, ServerName) 39 | 40 | w.WriteHeader(code) 41 | 42 | _, writeErr := io.WriteString(w, text) 43 | require.NoError(pt, writeErr) 44 | })) 45 | tb.Cleanup(srv.Close) 46 | 47 | srvURL, err := agdhttp.ParseHTTPURL(srv.URL) 48 | require.NoError(tb, err) 49 | 50 | cacheDir := tb.TempDir() 51 | cacheFile, err := os.CreateTemp(cacheDir, filepath.Base(tb.Name())) 52 | require.NoError(tb, err) 53 | require.NoError(tb, cacheFile.Close()) 54 | 55 | return cacheFile.Name(), srvURL 56 | } 57 | -------------------------------------------------------------------------------- /internal/filter/internal/rulelist/immutable.go: -------------------------------------------------------------------------------- 1 | package rulelist 2 | 3 | import "github.com/AdguardTeam/AdGuardDNS/internal/filter" 4 | 5 | // Immutable is a rule-list filter that doesn't refresh or change. It is used 6 | // for users' custom rule-lists as well as in service blocking. 7 | // 8 | // TODO(a.garipov): Consider not using rule-list engines for service and custom 9 | // filters at all. It could be faster to simply go through all enabled rules 10 | // sequentially instead. Alternatively, rework the [urlfilter.DNSEngine] and 11 | // make it use the sequential scan if the number of rules is less than some 12 | // constant value. 13 | // 14 | // See AGDNS-342. 15 | type Immutable struct { 16 | // TODO(a.garipov): Find ways to embed it in a way that shows the methods, 17 | // doesn't result in double dereferences, and doesn't cause naming issues. 18 | *baseFilter 19 | } 20 | 21 | // NewImmutable returns a new immutable DNS request and response filter using 22 | // the provided rule text and IDs. 23 | func NewImmutable( 24 | rulesData []byte, 25 | id filter.ID, 26 | svcID filter.BlockedServiceID, 27 | cache ResultCache, 28 | ) (f *Immutable) { 29 | return &Immutable{ 30 | baseFilter: newBaseFilter(rulesData, id, svcID, cache), 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/filter/metrics.go: -------------------------------------------------------------------------------- 1 | package filter 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // TODO(a.garipov): Consider re-adding some metrics for custom filters after 9 | // AGDNS-1519. 10 | 11 | // Metrics is the interface for metrics of filters. 12 | type Metrics interface { 13 | // SetFilterStatus sets the status of a filter by its id. If err is not 14 | // nil, updTime and ruleCount are ignored. 15 | SetFilterStatus( 16 | ctx context.Context, 17 | id string, 18 | updTime time.Time, 19 | ruleCount int, 20 | err error, 21 | ) 22 | } 23 | 24 | // EmptyMetrics is the implementation of the [Metrics] interface that does 25 | // nothing. 26 | type EmptyMetrics struct{} 27 | 28 | // type check 29 | var _ Metrics = EmptyMetrics{} 30 | 31 | // SetFilterStatus implements the [Metrics] interface for EmptyMetrics. 32 | func (EmptyMetrics) SetFilterStatus(_ context.Context, _ string, _ time.Time, _ int, _ error) {} 33 | -------------------------------------------------------------------------------- /internal/filter/storage.go: -------------------------------------------------------------------------------- 1 | package filter 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // StoragePrefix is a common prefix for logging and refreshes of the filter 8 | // storage. 9 | // 10 | // TODO(a.garipov): Consider extracting these kinds of IDs to agdcache or some 11 | // other package. 12 | const StoragePrefix = "filters/storage" 13 | 14 | // Storage is the interface for filter storages that can build a filter based 15 | // on a configuration. 16 | type Storage interface { 17 | // ForConfig returns a filter created from the configuration. If c is nil, 18 | // f is [filter.Empty]. 19 | ForConfig(ctx context.Context, c Config) (f Interface) 20 | 21 | // HasListID returns true if id is known to the storage. 22 | HasListID(id ID) (ok bool) 23 | } 24 | -------------------------------------------------------------------------------- /internal/geoip/error.go: -------------------------------------------------------------------------------- 1 | package geoip 2 | 3 | import "fmt" 4 | 5 | // NotACountryError is returned from NewCountry when the string doesn't 6 | // represent a valid country. 7 | type NotACountryError struct { 8 | // Code is the code presented to NewCountry. 9 | Code string 10 | } 11 | 12 | // Error implements the error interface for *NotACountryError. 13 | func (err *NotACountryError) Error() (msg string) { 14 | return fmt.Sprintf("%q is not a valid iso 3166-1 alpha-2 code", err.Code) 15 | } 16 | 17 | // NotAContinentError is returned from NewContinent when the string doesn't 18 | // represent a valid continent. 19 | type NotAContinentError struct { 20 | // Code is the code presented to NewContinent. 21 | Code string 22 | } 23 | 24 | // Error implements the error interface for *NotAContinentError. 25 | func (err *NotAContinentError) Error() (msg string) { 26 | return fmt.Sprintf("%q is not a valid continent code", err.Code) 27 | } 28 | -------------------------------------------------------------------------------- /internal/geoip/geoip.go: -------------------------------------------------------------------------------- 1 | // Package geoip contains implementations of the GeoIP database for AdGuard DNS. 2 | package geoip 3 | 4 | import ( 5 | "context" 6 | "net/netip" 7 | 8 | "github.com/AdguardTeam/golibs/netutil" 9 | ) 10 | 11 | // Interface is the interface for the GeoIP database that stores the geographic 12 | // data about an IP address. 13 | type Interface interface { 14 | // SubnetByLocation returns the default subnet for location, if there is 15 | // one. If there isn't, n is an unspecified subnet. fam must be either 16 | // [netutil.AddrFamilyIPv4] or [netutil.AddrFamilyIPv6]. 17 | SubnetByLocation( 18 | ctx context.Context, 19 | l *Location, 20 | fam netutil.AddrFamily, 21 | ) (n netip.Prefix, err error) 22 | 23 | // Data returns the GeoIP data for ip. It may use host to get cached GeoIP 24 | // data if ip is netip.Addr{}. 25 | Data(ctx context.Context, host string, ip netip.Addr) (l *Location, err error) 26 | } 27 | -------------------------------------------------------------------------------- /internal/geoip/location.go: -------------------------------------------------------------------------------- 1 | package geoip 2 | 3 | // Location Types And Constants 4 | 5 | // Location represents the GeoIP location data about an IP address. 6 | type Location struct { 7 | // Country is the country whose subnets contain the IP address. 8 | Country Country 9 | 10 | // Continent is the continent whose subnets contain the IP address. 11 | Continent Continent 12 | 13 | // TopSubdivision is the ISO-code of the political subdivision of a country 14 | // whose subnets contain the IP address. This field may be empty. 15 | TopSubdivision string 16 | 17 | // ASN is the number of the autonomous system whose subnets contain the IP 18 | // address. 19 | ASN ASN 20 | } 21 | 22 | // ASN is the autonomous system number of an IP address. 23 | // 24 | // See also https://datatracker.ietf.org/doc/html/rfc7300. 25 | type ASN uint32 26 | 27 | // Continent represents a continent code used by MaxMind. 28 | type Continent string 29 | 30 | // Continent code constants. 31 | const ( 32 | // ContinentNone is an unknown continent code. 33 | ContinentNone Continent = "" 34 | // ContinentAF is Africa. 35 | ContinentAF Continent = "AF" 36 | // ContinentAN is Antarctica. 37 | ContinentAN Continent = "AN" 38 | // ContinentAS is Asia. 39 | ContinentAS Continent = "AS" 40 | // ContinentEU is Europe. 41 | ContinentEU Continent = "EU" 42 | // ContinentNA is North America. 43 | ContinentNA Continent = "NA" 44 | // ContinentOC is Oceania. 45 | ContinentOC Continent = "OC" 46 | // ContinentSA is South America. 47 | ContinentSA Continent = "SA" 48 | ) 49 | 50 | // NewContinent converts s into a Continent while also validating it. Prefer to 51 | // use this instead of a plain conversion. 52 | func NewContinent(s string) (c Continent, err error) { 53 | switch c = Continent(s); c { 54 | case 55 | ContinentAF, 56 | ContinentAN, 57 | ContinentAS, 58 | ContinentEU, 59 | ContinentNA, 60 | ContinentOC, 61 | ContinentSA, 62 | ContinentNone: 63 | return c, nil 64 | default: 65 | return ContinentNone, &NotAContinentError{Code: s} 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/geoip/metrics.go: -------------------------------------------------------------------------------- 1 | package geoip 2 | 3 | import "context" 4 | 5 | // Metrics is an interface that is used for the collection of the GeoIP database 6 | // statistics. 7 | type Metrics interface { 8 | // HandleASNUpdateStatus updates the GeoIP ASN database update status. 9 | HandleASNUpdateStatus(ctx context.Context, err error) 10 | 11 | // HandleCountryUpdateStatus updates the GeoIP countries database update 12 | // status. 13 | HandleCountryUpdateStatus(ctx context.Context, err error) 14 | 15 | // IncrementHostCacheLookups increments the number of GeoIP cache lookups 16 | // for hostnames. 17 | IncrementHostCacheLookups(ctx context.Context, hit bool) 18 | 19 | // IncrementIPCacheLookups increments the number of GeoIP cache lookups for 20 | // IP addresses. 21 | IncrementIPCacheLookups(ctx context.Context, hit bool) 22 | } 23 | 24 | // EmptyMetrics is the implementation of the [Metrics] interface that does 25 | // nothing. 26 | type EmptyMetrics struct{} 27 | 28 | // type check 29 | var _ Metrics = EmptyMetrics{} 30 | 31 | // HandleASNUpdateStatus implements the [Metrics] interface for EmptyMetrics. 32 | func (EmptyMetrics) HandleASNUpdateStatus(_ context.Context, _ error) {} 33 | 34 | // HandleCountryUpdateStatus implements the [Metrics] interface for 35 | // EmptyMetrics. 36 | func (EmptyMetrics) HandleCountryUpdateStatus(_ context.Context, _ error) {} 37 | 38 | // IncrementHostCacheLookups implements the [Metrics] interface for 39 | // EmptyMetrics. 40 | func (EmptyMetrics) IncrementHostCacheLookups(_ context.Context, _ bool) {} 41 | 42 | // IncrementIPCacheLookups implements the [Metrics] interface for EmptyMetrics. 43 | func (EmptyMetrics) IncrementIPCacheLookups(_ context.Context, _ bool) {} 44 | -------------------------------------------------------------------------------- /internal/geoip/testdata/GeoIP2-City-Test.mmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdguardTeam/AdGuardDNS/5da2a9fd26b9b4ebda83f5b20cdbff2fb772487b/internal/geoip/testdata/GeoIP2-City-Test.mmdb -------------------------------------------------------------------------------- /internal/geoip/testdata/GeoIP2-Country-Test.mmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdguardTeam/AdGuardDNS/5da2a9fd26b9b4ebda83f5b20cdbff2fb772487b/internal/geoip/testdata/GeoIP2-Country-Test.mmdb -------------------------------------------------------------------------------- /internal/geoip/testdata/GeoIP2-ISP-Test.mmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdguardTeam/AdGuardDNS/5da2a9fd26b9b4ebda83f5b20cdbff2fb772487b/internal/geoip/testdata/GeoIP2-ISP-Test.mmdb -------------------------------------------------------------------------------- /internal/metrics/access.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/prometheus/client_golang/prometheus" 9 | ) 10 | 11 | // AccessProfile is the Prometheus-based implementation of the 12 | // [access.ProfileMetrics] interface. 13 | type AccessProfile struct { 14 | // accessProfileInitDuration is a histogram with the duration of a profile 15 | // access internal engine initialization. 16 | accessProfileInitDuration prometheus.Histogram 17 | } 18 | 19 | // NewAccessProfile registers the profile access engine metrics in reg and 20 | // returns a properly initialized [AccessProfile]. 21 | func NewAccessProfile(namespace string, reg prometheus.Registerer) (m *AccessProfile, err error) { 22 | const ( 23 | accessProfileInitDuration = "profile_init_engine_duration_seconds" 24 | ) 25 | 26 | m = &AccessProfile{ 27 | accessProfileInitDuration: prometheus.NewHistogram(prometheus.HistogramOpts{ 28 | Name: accessProfileInitDuration, 29 | Namespace: namespace, 30 | Subsystem: subsystemAccess, 31 | Help: "Time elapsed on profile access engine initialization.", 32 | Buckets: []float64{0.001, 0.01, 0.1, 1}, 33 | }), 34 | } 35 | 36 | err = reg.Register(m.accessProfileInitDuration) 37 | if err != nil { 38 | return nil, fmt.Errorf("registering metrics %q: %w", accessProfileInitDuration, err) 39 | } 40 | 41 | return m, nil 42 | } 43 | 44 | // ObserveProfileInit implements the [access.Metrics] interface for 45 | // *AccessProfile. 46 | func (m *AccessProfile) ObserveProfileInit(_ context.Context, dur time.Duration) { 47 | m.accessProfileInitDuration.Observe(dur.Seconds()) 48 | } 49 | -------------------------------------------------------------------------------- /internal/metrics/dnsmsg.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/prometheus/client_golang/prometheus" 7 | ) 8 | 9 | // ClonerStat is the Prometheus-based implementation of the [dnsmsg.ClonerStat] 10 | // interface. 11 | type ClonerStat struct { 12 | // dnsMsgFullClones is a counter with the total number of ECS cache full 13 | // clones. 14 | dnsMsgFullClones prometheus.Counter 15 | 16 | // dnsMsgPartialClones is a counter with the total number of ECS cache 17 | // partial clones. 18 | dnsMsgPartialClones prometheus.Counter 19 | } 20 | 21 | // NewClonerStat registers the Redis KV metrics in reg and returns a properly 22 | // initialized [ClonerStat]. 23 | func NewClonerStat(namespace string, reg prometheus.Registerer) (m *ClonerStat, err error) { 24 | const ( 25 | fullClonesTotal = "total_full_clones" 26 | ) 27 | 28 | // fullClones is a counter with the total number of cloned messages using 29 | // our custom cloner. "full" is either "1" (cloned entirely using the 30 | // cloner) or "0" (cloned using miekg/dns.Copy). 31 | fullClones := prometheus.NewCounterVec(prometheus.CounterOpts{ 32 | Name: fullClonesTotal, 33 | Subsystem: subsystemDNSMsg, 34 | Namespace: namespace, 35 | Help: "Total number of (not) full clones using the cloner. " + 36 | "full=1 means that a message was cloned fully using the cloner.", 37 | }, []string{"full"}) 38 | 39 | m = &ClonerStat{ 40 | dnsMsgFullClones: fullClones.With(prometheus.Labels{ 41 | "full": "1", 42 | }), 43 | dnsMsgPartialClones: fullClones.With(prometheus.Labels{ 44 | "full": "0", 45 | }), 46 | } 47 | 48 | err = reg.Register(fullClones) 49 | if err != nil { 50 | return nil, fmt.Errorf("registering metrics %q: %w", fullClonesTotal, err) 51 | } 52 | 53 | return m, nil 54 | } 55 | 56 | // The type check is performed in the test file to prevent a dependency. 57 | 58 | // OnClone implements the [dnsmsg.ClonerStat] interface for ClonerStat. 59 | func (m *ClonerStat) OnClone(isFull bool) { 60 | IncrementCond(isFull, m.dnsMsgFullClones, m.dnsMsgPartialClones) 61 | } 62 | -------------------------------------------------------------------------------- /internal/metrics/dnssvc.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/AdguardTeam/golibs/container" 8 | "github.com/AdguardTeam/golibs/errors" 9 | "github.com/prometheus/client_golang/prometheus" 10 | ) 11 | 12 | // InitialMiddleware is the Prometheus-based implementation of the 13 | // [dnssvc.InitialMiddlewareMetrics] interface. 14 | type InitialMiddleware struct { 15 | specialRequestsTotal *prometheus.CounterVec 16 | } 17 | 18 | // NewInitialMiddleware registers the filtering-middleware metrics in reg and 19 | // returns a properly initialized *InitialMiddleware. All arguments must be 20 | // set. 21 | func NewInitialMiddleware( 22 | namespace string, 23 | reg prometheus.Registerer, 24 | ) (m *InitialMiddleware, err error) { 25 | const ( 26 | specialRequestsTotal = "special_requests_total" 27 | ) 28 | 29 | m = &InitialMiddleware{ 30 | specialRequestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ 31 | Name: specialRequestsTotal, 32 | Namespace: namespace, 33 | Subsystem: subsystemDNSSvc, 34 | Help: "The number of DNS requests for special domain names.", 35 | }, []string{"kind"}), 36 | } 37 | 38 | var errs []error 39 | collectors := container.KeyValues[string, prometheus.Collector]{{ 40 | Key: specialRequestsTotal, 41 | Value: m.specialRequestsTotal, 42 | }} 43 | 44 | for _, c := range collectors { 45 | err = reg.Register(c.Value) 46 | if err != nil { 47 | errs = append(errs, fmt.Errorf("registering metrics %q: %w", c.Key, err)) 48 | } 49 | } 50 | 51 | if err = errors.Join(errs...); err != nil { 52 | return nil, err 53 | } 54 | 55 | return m, nil 56 | } 57 | 58 | // IncrementRequestsTotal implements the [Metrics] interface for 59 | // *InitialMiddleware. 60 | func (m *InitialMiddleware) IncrementRequestsTotal(_ context.Context, kind string) { 61 | m.specialRequestsTotal.WithLabelValues(kind).Inc() 62 | } 63 | -------------------------------------------------------------------------------- /internal/metrics/research.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/prometheus/client_golang/prometheus" 7 | ) 8 | 9 | // SetExperimentGauge the gauge used to inform about running experiments. reg 10 | // must not be nil. 11 | func SetExperimentGauge(reg prometheus.Registerer, constLabels prometheus.Labels) (err error) { 12 | gauge := prometheus.NewGauge(prometheus.GaugeOpts{ 13 | Name: "experiment_enabled", 14 | Namespace: namespace, 15 | Subsystem: subsystemResearch, 16 | Help: `A metric with a constant value of 1 labeled by experiments that are available ` + 17 | `and enabled.`, 18 | ConstLabels: constLabels, 19 | }) 20 | 21 | err = reg.Register(gauge) 22 | if err != nil { 23 | return fmt.Errorf("registering experiment_enabled metric: %w", err) 24 | } 25 | 26 | gauge.Set(1) 27 | 28 | return nil 29 | } 30 | -------------------------------------------------------------------------------- /internal/profiledb/customdomaindb.go: -------------------------------------------------------------------------------- 1 | package profiledb 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | ) 8 | 9 | // CustomDomainDB is a database of custom-domain data. All methods must be safe 10 | // for concurrent use. 11 | type CustomDomainDB interface { 12 | // AddCertificate adds information about a current certificate. domains 13 | // must contain only valid domain names and wildcards like 14 | // "*.domain.example". s must not be nil and must be valid. 15 | AddCertificate( 16 | ctx context.Context, 17 | profID agd.ProfileID, 18 | domains []string, 19 | state *agd.CustomDomainStateCurrent, 20 | ) 21 | 22 | // DeleteAllWellKnownPaths removes all data about well-known paths for 23 | // certificate validation. 24 | DeleteAllWellKnownPaths(ctx context.Context) 25 | 26 | // SetWellKnownPath adds a well-known path for certificate validation to the 27 | // database and sets the expiration time. s must not be nil and must be 28 | // valid. 29 | SetWellKnownPath(ctx context.Context, s *agd.CustomDomainStatePending) 30 | } 31 | 32 | // EmptyCustomDomainDB is the implementation of the [CustomDomainDB] interface 33 | // that does nothing. 34 | type EmptyCustomDomainDB struct{} 35 | 36 | // type check 37 | var _ CustomDomainDB = EmptyCustomDomainDB{} 38 | 39 | // AddCertificate implements the [CustomDomainDB] interface for 40 | // EmptyCustomDomainDB 41 | func (EmptyCustomDomainDB) AddCertificate( 42 | _ context.Context, 43 | _ agd.ProfileID, 44 | _ []string, 45 | _ *agd.CustomDomainStateCurrent, 46 | ) { 47 | } 48 | 49 | // DeleteAllWellKnownPaths implements the [CustomDomainDB] interface for 50 | // EmptyCustomDomainDB. 51 | func (EmptyCustomDomainDB) DeleteAllWellKnownPaths(_ context.Context) {} 52 | 53 | // SetWellKnownPath implements the [CustomDomainDB] interface for 54 | // EmptyCustomDomainDB. 55 | func (EmptyCustomDomainDB) SetWellKnownPath(_ context.Context, _ *agd.CustomDomainStatePending) {} 56 | -------------------------------------------------------------------------------- /internal/profiledb/internal/filecachepb/unsafe.go: -------------------------------------------------------------------------------- 1 | package filecachepb 2 | 3 | import "unsafe" 4 | 5 | // unsafelyConvertStrSlice checks if []T1 can be converted to []T2 at compile 6 | // time and, if so, converts the slice using package unsafe. 7 | // 8 | // Slices resulting from this conversion must not be mutated. 9 | func unsafelyConvertStrSlice[T1, T2 ~string](s []T1) (res []T2) { 10 | if s == nil { 11 | return nil 12 | } 13 | 14 | // #nosec G103 -- Conversion between two slices with the same underlying 15 | // element type is safe. 16 | return *(*[]T2)(unsafe.Pointer(&s)) 17 | } 18 | -------------------------------------------------------------------------------- /internal/querylog/metrics.go: -------------------------------------------------------------------------------- 1 | package querylog 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/c2h5oh/datasize" 8 | ) 9 | 10 | // Metrics is an interface that is used for the collection of the query log 11 | // statistics. 12 | type Metrics interface { 13 | // ObserveItemSize stores the size of written query log entry. 14 | ObserveItemSize(ctx context.Context, size datasize.ByteSize) 15 | 16 | // ObserveWrite stores the duration of the write operation and increments 17 | // the write counter. 18 | ObserveWrite(ctx context.Context, dur time.Duration) 19 | } 20 | 21 | // EmptyMetrics is the implementation of the [Metrics] interface that does 22 | // nothing. 23 | type EmptyMetrics struct{} 24 | 25 | // type check 26 | var _ Metrics = EmptyMetrics{} 27 | 28 | // ObserveItemSize implements the [Metrics] interface for EmptyMetrics. 29 | func (EmptyMetrics) ObserveItemSize(_ context.Context, _ datasize.ByteSize) {} 30 | 31 | // ObserveWrite implements the [Metrics] interface for EmptyMetrics. 32 | func (EmptyMetrics) ObserveWrite(_ context.Context, _ time.Duration) {} 33 | -------------------------------------------------------------------------------- /internal/querylog/querylog.go: -------------------------------------------------------------------------------- 1 | // Package querylog defines the AdGuard DNS query log constants and types and 2 | // provides implementations of the log. 3 | package querylog 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | // Interface is the query log interface. All methods must be safe for 10 | // concurrent use. 11 | type Interface interface { 12 | // Write writes the entry into the query log. e must not be nil. 13 | Write(ctx context.Context, e *Entry) (err error) 14 | } 15 | 16 | // Empty is a query log does nothing and returns nil values. 17 | type Empty struct{} 18 | 19 | // type check 20 | var _ Interface = Empty{} 21 | 22 | // Write implements the Interface interface for Empty. It does nothing and 23 | // returns nil. 24 | func (Empty) Write(_ context.Context, _ *Entry) (err error) { 25 | return nil 26 | } 27 | -------------------------------------------------------------------------------- /internal/querylog/querylog_test.go: -------------------------------------------------------------------------------- 1 | package querylog_test 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/AdguardTeam/AdGuardDNS/internal/filter" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/geoip" 9 | "github.com/AdguardTeam/AdGuardDNS/internal/querylog" 10 | "github.com/miekg/dns" 11 | ) 12 | 13 | // testRequestID is the common request ID for tests. 14 | var testRequestID = agd.NewRequestID() 15 | 16 | // testEntry returns an entry for tests. 17 | func testEntry() (e *querylog.Entry) { 18 | return &querylog.Entry{ 19 | RequestResult: &filter.ResultBlocked{ 20 | List: "adguard_dns_filter", 21 | Rule: "||example.com^", 22 | }, 23 | ResponseResult: nil, 24 | Time: time.Unix(123, 0), 25 | RequestID: testRequestID, 26 | ProfileID: "prof1234", 27 | DeviceID: "dev1234", 28 | ClientCountry: geoip.CountryRU, 29 | ResponseCountry: geoip.CountryUS, 30 | DomainFQDN: "example.com.", 31 | Protocol: agd.ProtoDNS, 32 | ClientASN: 1234, 33 | Elapsed: 5 * time.Millisecond, 34 | RequestType: dns.TypeA, 35 | ResponseCode: dns.RcodeSuccess, 36 | DNSSEC: true, 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/remotekv/cachekv.go: -------------------------------------------------------------------------------- 1 | package remotekv 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agdcache" 7 | ) 8 | 9 | // Cache is a local cache implementation of the [Interface] interface. 10 | type Cache struct { 11 | cache agdcache.Interface[string, []byte] 12 | } 13 | 14 | // CacheConfig is the configuration for the local cache [Interface] 15 | // implementation. All fields must not be empty. 16 | type CacheConfig struct { 17 | // Cache is the underlying cache. 18 | Cache agdcache.Interface[string, []byte] 19 | } 20 | 21 | // NewCache returns a new *Cache. c must not be nil. 22 | func NewCache(c *CacheConfig) (kv *Cache) { 23 | return &Cache{ 24 | cache: c.Cache, 25 | } 26 | } 27 | 28 | // type check 29 | var _ Interface = (*Cache)(nil) 30 | 31 | // Get implements the [Interface] interface for *Cache. 32 | func (kv *Cache) Get(ctx context.Context, key string) (val []byte, ok bool, err error) { 33 | val, ok = kv.cache.Get(key) 34 | 35 | return val, ok, nil 36 | } 37 | 38 | // Set implements the [Interface] interface for *Cache. 39 | func (kv *Cache) Set(ctx context.Context, key string, val []byte) (err error) { 40 | kv.cache.Set(key, val) 41 | 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /internal/remotekv/cachekv_test.go: -------------------------------------------------------------------------------- 1 | package remotekv_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdcache" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/remotekv" 9 | "github.com/AdguardTeam/golibs/testutil" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // testTimeout is the common timeout for tests and contexts. 15 | const testTimeout = 1 * time.Second 16 | 17 | func TestNewCache(t *testing.T) { 18 | const testKey = "key" 19 | 20 | testVal := []byte{1, 2, 3} 21 | 22 | cache := remotekv.NewCache(&remotekv.CacheConfig{ 23 | Cache: agdcache.NewLRU[string, []byte](&agdcache.LRUConfig{ 24 | Count: 1, 25 | }), 26 | }) 27 | 28 | ctx := testutil.ContextWithTimeout(t, testTimeout) 29 | err := cache.Set(ctx, testKey, testVal) 30 | require.NoError(t, err) 31 | 32 | got, ok, err := cache.Get(ctx, testKey) 33 | require.NoError(t, err) 34 | require.True(t, ok) 35 | 36 | assert.Equal(t, got, testVal) 37 | } 38 | -------------------------------------------------------------------------------- /internal/remotekv/consulkv/error.go: -------------------------------------------------------------------------------- 1 | package consulkv 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/errcoll" 7 | "github.com/AdguardTeam/golibs/errors" 8 | ) 9 | 10 | // ErrRateLimited is returned by [KV.Get] when the request is rate 11 | // limited. 12 | const ErrRateLimited errors.Error = "rate limited" 13 | 14 | // httpError is an error returned by the Consul KV database HTTP client. 15 | type httpError struct { 16 | err error 17 | } 18 | 19 | // type check 20 | var _ error = httpError{} 21 | 22 | // Error implements the error interface for httpError. 23 | func (err httpError) Error() (msg string) { 24 | return err.err.Error() 25 | } 26 | 27 | // type check 28 | var _ errors.Wrapper = httpError{} 29 | 30 | // Unwrap implements the [errors.Wrapper] interface for httpError. 31 | func (err httpError) Unwrap() (unwrapped error) { 32 | return err.err 33 | } 34 | 35 | // type check 36 | var _ errcoll.SentryReportableError = httpError{} 37 | 38 | // IsSentryReportable implements the [errcoll.SentryReportableError] interface 39 | // for httpError. 40 | func (err httpError) IsSentryReportable() (ok bool) { 41 | return !errors.Is(err.err, ErrRateLimited) && 42 | !errors.Is(err.err, context.Canceled) && 43 | !errors.Is(err.err, context.DeadlineExceeded) 44 | } 45 | -------------------------------------------------------------------------------- /internal/remotekv/keynamespace.go: -------------------------------------------------------------------------------- 1 | package remotekv 2 | 3 | import "context" 4 | 5 | // KeyNamespaceConfig is the configuration structure for [KeyNamespace]. 6 | type KeyNamespaceConfig struct { 7 | // KV is the key-value storage to be wrapped. It must not be nil. 8 | KV Interface 9 | 10 | // Prefix is the custom prefix to be added to the keys. Prefix should be in 11 | // accordance with the wrapped KV storage keys. 12 | Prefix string 13 | } 14 | 15 | // KeyNamespace is wrapper around [Interface] that adds a custom prefix to the 16 | // keys. 17 | type KeyNamespace struct { 18 | // kv is the key-value storage to be wrapped. 19 | kv Interface 20 | 21 | // prefix is the custom prefix to be added to the keys. prefix should be in 22 | // accordance with the wrapped KV storage keys. 23 | prefix string 24 | } 25 | 26 | // NewKeyNamespace returns a properly initialized *KeyNamespace. conf must not 27 | // be nil. 28 | func NewKeyNamespace(conf *KeyNamespaceConfig) (n *KeyNamespace) { 29 | return &KeyNamespace{ 30 | kv: conf.KV, 31 | prefix: conf.Prefix, 32 | } 33 | } 34 | 35 | // type check 36 | var _ Interface = (*KeyNamespace)(nil) 37 | 38 | // Get implements the [Interface] interface for *KeyNamespace. 39 | func (n *KeyNamespace) Get(ctx context.Context, key string) (val []byte, ok bool, err error) { 40 | // TODO(s.chzhen): Improve memory allocation. 41 | prefixed := n.prefix + key 42 | 43 | return n.kv.Get(ctx, prefixed) 44 | } 45 | 46 | // Set implements the [Interface] interface for *KeyNamespace. 47 | func (n *KeyNamespace) Set(ctx context.Context, key string, val []byte) (err error) { 48 | // TODO(s.chzhen): Improve memory allocation. 49 | prefixed := n.prefix + key 50 | 51 | return n.kv.Set(ctx, prefixed, val) 52 | } 53 | -------------------------------------------------------------------------------- /internal/remotekv/keynamespace_test.go: -------------------------------------------------------------------------------- 1 | package remotekv_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/remotekv" 9 | "github.com/AdguardTeam/golibs/testutil" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestNewKeyNamespace(t *testing.T) { 15 | const ( 16 | testKey = "key" 17 | testPrefix = "test" 18 | ) 19 | 20 | kv := &agdtest.RemoteKV{ 21 | OnSet: func(_ context.Context, key string, _ []byte) (_ error) { 22 | require.Equal(t, testPrefix+testKey, key) 23 | 24 | return assert.AnError 25 | }, 26 | OnGet: func(_ context.Context, key string) (_ []byte, _ bool, _ error) { 27 | require.Equal(t, testPrefix+testKey, key) 28 | 29 | return nil, false, assert.AnError 30 | }, 31 | } 32 | 33 | n := remotekv.NewKeyNamespace(&remotekv.KeyNamespaceConfig{ 34 | KV: kv, 35 | Prefix: testPrefix, 36 | }) 37 | 38 | assert.NotPanics(t, func() { 39 | ctx := testutil.ContextWithTimeout(t, testTimeout) 40 | err := n.Set(ctx, testKey, nil) 41 | assert.ErrorIs(t, err, assert.AnError) 42 | 43 | _, _, err = n.Get(ctx, testKey) 44 | assert.ErrorIs(t, err, assert.AnError) 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /internal/remotekv/remotekv.go: -------------------------------------------------------------------------------- 1 | // Package remotekv contains remote key-value storage interfaces, helpers, and 2 | // implementations. 3 | package remotekv 4 | 5 | import ( 6 | "context" 7 | ) 8 | 9 | // Interface is the remote key-value storage interface. 10 | type Interface interface { 11 | // Get returns val by key from the storage. ok is true if val by key 12 | // exists. 13 | Get(ctx context.Context, key string) (val []byte, ok bool, err error) 14 | 15 | // Set sets val into the storage by key. 16 | Set(ctx context.Context, key string, val []byte) (err error) 17 | } 18 | 19 | // Empty is the [Interface] implementation that does nothing. 20 | type Empty struct{} 21 | 22 | // type check 23 | var _ Interface = Empty{} 24 | 25 | // Get implements the [Interface] interface for Empty. ok is always false. 26 | func (Empty) Get(_ context.Context, _ string) (val []byte, ok bool, err error) { 27 | return val, false, nil 28 | } 29 | 30 | // Set implements the [Interface] interface for Empty. 31 | func (Empty) Set(_ context.Context, _ string, _ []byte) (err error) { 32 | return nil 33 | } 34 | -------------------------------------------------------------------------------- /internal/rulestat/metrics.go: -------------------------------------------------------------------------------- 1 | package rulestat 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // Metrics is an interface that is used for the collection of the filtering rule 8 | // statistics. 9 | type Metrics interface { 10 | // SetHitCount the number of rule hits that have not yet been uploaded. 11 | SetHitCount(ctx context.Context, count int64) 12 | 13 | // HandleUploadStatus handles the upload status of the filtering rule 14 | // statistics. 15 | HandleUploadStatus(ctx context.Context, err error) 16 | } 17 | 18 | // EmptyMetrics is the implementation of the [Metrics] interface that does 19 | // nothing. 20 | type EmptyMetrics struct{} 21 | 22 | // type check 23 | var _ Metrics = EmptyMetrics{} 24 | 25 | // SetHitCount implements the [Metrics] interface for EmptyMetrics. 26 | func (EmptyMetrics) SetHitCount(_ context.Context, _ int64) {} 27 | 28 | // HandleUploadStatus implements the [Metrics] interface for EmptyMetrics. 29 | func (EmptyMetrics) HandleUploadStatus(_ context.Context, _ error) {} 30 | -------------------------------------------------------------------------------- /internal/rulestat/rulestat.go: -------------------------------------------------------------------------------- 1 | // Package rulestat contains the filtering rule statistics collector and API. 2 | package rulestat 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/filter" 8 | ) 9 | 10 | // Interface is an ephemeral storage of the filtering rule list statistics 11 | // interface. 12 | // 13 | // All methods must be safe for concurrent use. 14 | type Interface interface { 15 | Collect(ctx context.Context, id filter.ID, r filter.RuleText) 16 | } 17 | 18 | // type check 19 | var _ Interface = Empty{} 20 | 21 | // Empty is an Interface implementation that does nothing. 22 | type Empty struct{} 23 | 24 | // Collect implements the Interface interface for Empty. 25 | func (Empty) Collect(_ context.Context, _ filter.ID, _ filter.RuleText) {} 26 | -------------------------------------------------------------------------------- /internal/rulestat/rulestat_test.go: -------------------------------------------------------------------------------- 1 | package rulestat_test 2 | 3 | import "time" 4 | 5 | // testTimeout is the common timeout for tests. 6 | const testTimeout = 1 * time.Second 7 | -------------------------------------------------------------------------------- /internal/tlsconfig/customdomainstorage.go: -------------------------------------------------------------------------------- 1 | package tlsconfig 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/AdguardTeam/AdGuardDNS/internal/agd" 7 | "github.com/AdguardTeam/golibs/errors" 8 | ) 9 | 10 | // CustomDomainStorage retrieves certificate data for a custom domain by the 11 | // certificate name. 12 | type CustomDomainStorage interface { 13 | // CertificateData returns the certificate data for the name. If err is 14 | // nil, cert and key must not be nil. If the certificate could not be 15 | // found, err must contain [ErrCertificateNotFound]. 16 | CertificateData(ctx context.Context, name agd.CertificateName) (cert, key []byte, err error) 17 | } 18 | 19 | // ErrCertificateNotFound is returned (optionally wrapped) by 20 | // [CustomDomainStorage.CertificateData] when a certificate with that name 21 | // was not found. 22 | const ErrCertificateNotFound errors.Error = "certificate not found" 23 | 24 | // EmptyCustomDomainStorage is the implementation of the [CustomDomainStorage] 25 | // interface that does nothing. 26 | type EmptyCustomDomainStorage struct{} 27 | 28 | // type check 29 | var _ CustomDomainStorage = EmptyCustomDomainStorage{} 30 | 31 | // CertificateData implements the [CustomDomainStorage] interface for 32 | // EmptyCustomDomainStorage 33 | func (EmptyCustomDomainStorage) CertificateData( 34 | _ context.Context, 35 | _ agd.CertificateName, 36 | ) (_, _ []byte, _ error) { 37 | return nil, nil, nil 38 | } 39 | -------------------------------------------------------------------------------- /internal/tlsconfig/tlsconfig.go: -------------------------------------------------------------------------------- 1 | // Package tlsconfig contains TLS-related interfaces, helpers, and 2 | // implementations. 3 | package tlsconfig 4 | -------------------------------------------------------------------------------- /internal/tlsconfig/tlsconfig_test.go: -------------------------------------------------------------------------------- 1 | package tlsconfig_test 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/pem" 8 | "math/big" 9 | "os" 10 | "testing" 11 | "time" 12 | 13 | "github.com/AdguardTeam/golibs/logutil/slogutil" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | // testTimeout is the common timeout for tests and contexts. 18 | const testTimeout = 1 * time.Second 19 | 20 | // testLogger is the common logger for tests. 21 | var testLogger = slogutil.NewDiscardLogger() 22 | 23 | // newCertAndKey is a helper function that generates certificate and key. 24 | func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) { 25 | tb.Helper() 26 | 27 | key, err := rsa.GenerateKey(rand.Reader, 2048) 28 | require.NoError(tb, err) 29 | 30 | certTmpl := &x509.Certificate{ 31 | SerialNumber: big.NewInt(n), 32 | } 33 | 34 | certDER, err = x509.CreateCertificate(rand.Reader, certTmpl, certTmpl, &key.PublicKey, key) 35 | require.NoError(tb, err) 36 | 37 | return certDER, key 38 | } 39 | 40 | // writeCertAndKey is a helper function that writes certificate and key to 41 | // specified paths. 42 | func writeCertAndKey( 43 | tb testing.TB, 44 | certDER []byte, 45 | certPath string, 46 | key *rsa.PrivateKey, 47 | keyPath string, 48 | ) { 49 | tb.Helper() 50 | 51 | certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE, 0o600) 52 | require.NoError(tb, err) 53 | 54 | defer func() { 55 | err = certFile.Close() 56 | require.NoError(tb, err) 57 | }() 58 | 59 | err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) 60 | require.NoError(tb, err) 61 | 62 | keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE, 0o600) 63 | require.NoError(tb, err) 64 | 65 | defer func() { 66 | err = keyFile.Close() 67 | require.NoError(tb, err) 68 | }() 69 | 70 | err = pem.Encode(keyFile, &pem.Block{ 71 | Type: "RSA PRIVATE KEY", 72 | Bytes: x509.MarshalPKCS1PrivateKey(key), 73 | }) 74 | require.NoError(tb, err) 75 | } 76 | -------------------------------------------------------------------------------- /internal/version/norace.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | package version 4 | 5 | // RaceEnabled is true if the current binary has been built with --race. 6 | const RaceEnabled = false 7 | -------------------------------------------------------------------------------- /internal/version/race.go: -------------------------------------------------------------------------------- 1 | //go:build race 2 | 3 | package version 4 | 5 | // RaceEnabled is true if the current binary has been built with --race. 6 | const RaceEnabled = true 7 | -------------------------------------------------------------------------------- /internal/version/version.go: -------------------------------------------------------------------------------- 1 | // Package version contains AdGuardDNS version information. 2 | package version 3 | 4 | // These can be set by the linker. Unfortunately, we cannot set constants 5 | // during linking, and Go doesn't have a concept of immutable variables, so to 6 | // be thorough we have to only export them through getters. 7 | var ( 8 | branch string 9 | committime string 10 | revision string 11 | version string 12 | 13 | name = "AdGuardDNS" 14 | ) 15 | 16 | // Branch returns the compiled-in value of the Git branch. 17 | func Branch() (b string) { 18 | return branch 19 | } 20 | 21 | // CommitTime returns the compiled-in value of the commit time as a string. 22 | func CommitTime() (t string) { 23 | return committime 24 | } 25 | 26 | // Revision returns the compiled-in value of the Git revision. 27 | func Revision() (r string) { 28 | return revision 29 | } 30 | 31 | // Version returns the compiled-in value of the AdGuardDNS version as a 32 | // string. 33 | func Version() (v string) { 34 | return version 35 | } 36 | 37 | // Name returns the compiled-in value of the AdGuardDNS name. 38 | func Name() (n string) { 39 | return name 40 | } 41 | -------------------------------------------------------------------------------- /internal/websvc/metrics.go: -------------------------------------------------------------------------------- 1 | package websvc 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // RequestType is a type alias for string that represents the request type 8 | // for web service metrics. 9 | type RequestType = string 10 | 11 | // List of web service requests of type RequestType. 12 | // 13 | // NOTE: Keep in sync with [metrics.RequestType]. 14 | const ( 15 | RequestTypeError404 RequestType = "error404" 16 | RequestTypeError500 RequestType = "error500" 17 | RequestTypeStaticContent RequestType = "static_content" 18 | RequestTypeDNSCheckTest RequestType = "dnscheck_test" 19 | RequestTypeRobotsTxt RequestType = "robots_txt" 20 | RequestTypeRootRedirect RequestType = "root_redirect" 21 | RequestTypeLinkedIPProxy RequestType = "linkip" 22 | RequestTypeAdultBlockingPage RequestType = "adult_blocking_page" 23 | RequestTypeGeneralBlockingPage RequestType = "general_blocking_page" 24 | RequestTypeSafeBrowsingPage RequestType = "safe_browsing_page" 25 | ) 26 | 27 | // Metrics is an interface for collecting web service request statistics. 28 | type Metrics interface { 29 | // IncrementReqCount increments the web service request count for a given 30 | // RequestType. reqType must be one of the RequestType values. 31 | IncrementReqCount(ctx context.Context, reqType RequestType) 32 | } 33 | 34 | // EmptyMetrics is the implementation of the [Metrics] interface that does 35 | // nothing. 36 | type EmptyMetrics struct{} 37 | 38 | // type check 39 | var _ Metrics = EmptyMetrics{} 40 | 41 | // IncrementReqCount implements the [Metrics] interface for EmptyMetrics. 42 | func (EmptyMetrics) IncrementReqCount(_ context.Context, _ RequestType) {} 43 | -------------------------------------------------------------------------------- /internal/websvc/servergroup.go: -------------------------------------------------------------------------------- 1 | package websvc 2 | 3 | // ServerGroup is a semantic alias for names of server groups. 4 | type ServerGroup = string 5 | 6 | // Valid server groups. 7 | const ( 8 | ServerGroupAdultBlockingPage ServerGroup = "adult_blocking_page" 9 | ServerGroupGeneralBlockingPage ServerGroup = "general_blocking_page" 10 | ServerGroupLinkedIP ServerGroup = "linked_ip" 11 | ServerGroupNonDoH ServerGroup = "non_doh" 12 | ServerGroupSafeBrowsingPage ServerGroup = "safe_browsing_page" 13 | ) 14 | 15 | // loggerKeyGroup is the key used by server groups 16 | const loggerKeyGroup = "group" 17 | -------------------------------------------------------------------------------- /internal/websvc/static.go: -------------------------------------------------------------------------------- 1 | package websvc 2 | 3 | import ( 4 | "maps" 5 | "net/http" 6 | 7 | "github.com/AdguardTeam/golibs/logutil/slogutil" 8 | ) 9 | 10 | // StaticContent serves static content with the given content type. Elements 11 | // must not be nil. 12 | type StaticContent map[string]*StaticFile 13 | 14 | // StaticFile is a single file in a [StaticFS]. 15 | type StaticFile struct { 16 | // Headers contains headers of the HTTP response. 17 | Headers http.Header 18 | 19 | // Content is the file content. 20 | Content []byte 21 | } 22 | 23 | // type check 24 | var _ http.Handler = StaticContent(nil) 25 | 26 | // ServeHTTP implements the [http.Handler] interface for StaticContent. 27 | func (sc StaticContent) ServeHTTP(w http.ResponseWriter, r *http.Request) { 28 | p := r.URL.Path 29 | f, ok := sc[p] 30 | if !ok { 31 | http.NotFound(w, r) 32 | 33 | return 34 | } 35 | 36 | respHdr := w.Header() 37 | maps.Copy(respHdr, f.Headers) 38 | 39 | w.WriteHeader(http.StatusOK) 40 | _, err := w.Write(f.Content) 41 | if err != nil { 42 | ctx := r.Context() 43 | l := slogutil.MustLoggerFromContext(ctx) 44 | l.Log(ctx, levelForError(err), "writing static content", slogutil.KeyError, err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /internal/websvc/static_test.go: -------------------------------------------------------------------------------- 1 | package websvc_test 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/AdguardTeam/AdGuardDNS/internal/agdtest" 8 | "github.com/AdguardTeam/AdGuardDNS/internal/websvc" 9 | "github.com/AdguardTeam/golibs/httphdr" 10 | "github.com/AdguardTeam/golibs/testutil/servicetest" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestService_ServeHTTP_static(t *testing.T) { 16 | t.Parallel() 17 | 18 | staticContent := websvc.StaticContent{ 19 | "/favicon.ico": { 20 | Content: []byte{}, 21 | Headers: http.Header{ 22 | httphdr.ContentType: []string{"image/x-icon"}, 23 | }, 24 | }, 25 | } 26 | 27 | c := &websvc.Config{ 28 | Logger: testLogger, 29 | CertificateValidator: testCertValidator, 30 | StaticContent: staticContent, 31 | DNSCheck: http.NotFoundHandler(), 32 | ErrColl: agdtest.NewErrorCollector(), 33 | Metrics: websvc.EmptyMetrics{}, 34 | Timeout: testTimeout, 35 | } 36 | 37 | svc := websvc.New(c) 38 | require.NotNil(t, svc) 39 | 40 | servicetest.RequireRun(t, svc, testTimeout) 41 | 42 | respHdr := http.Header{ 43 | httphdr.ContentType: []string{"image/x-icon"}, 44 | } 45 | assertResponseWithHeaders(t, svc, "/favicon.ico", http.StatusOK, respHdr) 46 | } 47 | 48 | // assertResponseWithHeaders is a helper function that checks status code and 49 | // headers of HTTP response. 50 | func assertResponseWithHeaders( 51 | t *testing.T, 52 | svc *websvc.Service, 53 | path string, 54 | statusCode int, 55 | respHdr http.Header, 56 | ) { 57 | t.Helper() 58 | 59 | rw := assertResponse(t, svc, path, statusCode) 60 | 61 | assert.Equal(t, respHdr, rw.Header()) 62 | } 63 | -------------------------------------------------------------------------------- /internal/websvc/testdata/block_page.html: -------------------------------------------------------------------------------- 1 | Block page 2 | -------------------------------------------------------------------------------- /internal/websvc/websvc_internal_test.go: -------------------------------------------------------------------------------- 1 | package websvc 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/AdguardTeam/golibs/errors" 8 | ) 9 | 10 | // LocalAddrs returns the local addresses of the servers in group g. Addrs may 11 | // contain nils. 12 | func (svc *Service) LocalAddrs(g ServerGroup) (addrs []net.Addr) { 13 | switch g { 14 | case ServerGroupAdultBlockingPage: 15 | return serverAddrs(svc.adultBlocking) 16 | case ServerGroupGeneralBlockingPage: 17 | return serverAddrs(svc.generalBlocking) 18 | case ServerGroupLinkedIP: 19 | return serverAddrs(svc.linkedIP) 20 | case ServerGroupNonDoH: 21 | return serverAddrs(svc.nonDoH) 22 | case ServerGroupSafeBrowsingPage: 23 | return serverAddrs(svc.safeBrowsing) 24 | default: 25 | panic(fmt.Errorf("server group: %w: %q", errors.ErrBadEnumValue, g)) 26 | } 27 | } 28 | 29 | // serverAddrs collects the addresses of the servers. 30 | func serverAddrs(srvs []*server) (addrs []net.Addr) { 31 | for _, s := range srvs { 32 | addrs = append(addrs, s.localAddr()) 33 | } 34 | 35 | return addrs 36 | } 37 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2022-2025 AdGuard Software Ltd. 2 | // 3 | // This program is free software: you can redistribute it and/or modify it under 4 | // the terms of the GNU Affero General Public License as published by the Free 5 | // Software Foundation, version 3. 6 | 7 | package main 8 | 9 | import "github.com/AdguardTeam/AdGuardDNS/internal/cmd" 10 | 11 | func main() { 12 | cmd.Main(nil) 13 | } 14 | -------------------------------------------------------------------------------- /scripts/backend/main.go: -------------------------------------------------------------------------------- 1 | // main implements a single mock GRPC server for backend services defined by 2 | // BILLSTAT_URL, PROFILES_URL, and REMOTE_KV_URL environment variables. 3 | package main 4 | 5 | import ( 6 | "net" 7 | "os" 8 | 9 | "github.com/AdguardTeam/AdGuardDNS/internal/backendpb" 10 | "github.com/AdguardTeam/golibs/logutil/slogutil" 11 | "github.com/AdguardTeam/golibs/osutil" 12 | "google.golang.org/grpc" 13 | ) 14 | 15 | func main() { 16 | l := slogutil.New(nil) 17 | 18 | const listenAddr = "localhost:6062" 19 | 20 | lsnr, err := net.Listen("tcp", listenAddr) 21 | if err != nil { 22 | l.Error("getting listener", slogutil.KeyError, err) 23 | 24 | os.Exit(osutil.ExitCodeFailure) 25 | } 26 | 27 | grpcSrv := grpc.NewServer() 28 | dnsSrv := newMockDNSServiceServer(l.With(slogutil.KeyPrefix, "dns")) 29 | backendpb.RegisterDNSServiceServer(grpcSrv, dnsSrv) 30 | 31 | kvSrv := newMockRemoteKVServiceServer(l.With(slogutil.KeyPrefix, "remote_kv")) 32 | backendpb.RegisterRemoteKVServiceServer(grpcSrv, kvSrv) 33 | 34 | rateLimitSrv := newMockRateLimitServiceServer(l.With(slogutil.KeyPrefix, "rate_limiter")) 35 | backendpb.RegisterRateLimitServiceServer(grpcSrv, rateLimitSrv) 36 | 37 | sessTickSrv := newMockSessionTicketServiceServer(l.With(slogutil.KeyPrefix, "session_ticket")) 38 | backendpb.RegisterSessionTicketServiceServer(grpcSrv, sessTickSrv) 39 | 40 | l.Info("starting serving", "laddr", listenAddr) 41 | err = grpcSrv.Serve(lsnr) 42 | if err != nil { 43 | l.Error("serving grpc", slogutil.KeyError, err) 44 | 45 | os.Exit(osutil.ExitCodeFailure) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /scripts/make/go-bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | verbose="${VERBOSE:-0}" 4 | readonly verbose 5 | 6 | # Verbosity levels: 7 | # 0 = Don't print anything except for errors. 8 | # 1 = Print commands, but not nested commands. 9 | # 2 = Print everything. 10 | if [ "$verbose" -gt '1' ]; then 11 | set -x 12 | v_flags='-v=1' 13 | x_flags='-x=1' 14 | elif [ "$verbose" -gt '0' ]; then 15 | set -x 16 | v_flags='-v=1' 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | v_flags='-v=0' 21 | x_flags='-x=0' 22 | fi 23 | readonly v_flags x_flags 24 | 25 | set -e -f -u 26 | 27 | if [ "${RACE:-1}" -eq '0' ]; then 28 | race_flags='--race=0' 29 | else 30 | race_flags='--race=1' 31 | fi 32 | readonly race_flags 33 | 34 | go="${GO:-go}" 35 | 36 | count_flags='--count=1' 37 | shuffle_flags='--shuffle=on' 38 | timeout_flags="${TIMEOUT_FLAGS:---timeout=30s}" 39 | readonly go count_flags shuffle_flags timeout_flags 40 | 41 | "$go" test \ 42 | "$count_flags" \ 43 | "$shuffle_flags" \ 44 | "$race_flags" \ 45 | "$timeout_flags" \ 46 | "$x_flags" \ 47 | "$v_flags" \ 48 | --bench='.' \ 49 | --benchmem \ 50 | --benchtime='1s' \ 51 | --run='^$' \ 52 | work 53 | -------------------------------------------------------------------------------- /scripts/make/go-deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 2 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" mod download "$x_flags" 30 | -------------------------------------------------------------------------------- /scripts/make/go-fuzz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | verbose="${VERBOSE:-0}" 4 | readonly verbose 5 | 6 | # Verbosity levels: 7 | # 0 = Don't print anything except for errors. 8 | # 1 = Print commands, but not nested commands. 9 | # 2 = Print everything. 10 | if [ "$verbose" -gt '1' ]; then 11 | set -x 12 | v_flags='-v=1' 13 | x_flags='-x=1' 14 | elif [ "$verbose" -gt '0' ]; then 15 | set -x 16 | v_flags='-v=1' 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | v_flags='-v=0' 21 | x_flags='-x=0' 22 | fi 23 | readonly v_flags x_flags 24 | 25 | set -e -f -u 26 | 27 | if [ "${RACE:-1}" -eq '0' ]; then 28 | race_flags='--race=0' 29 | else 30 | race_flags='--race=1' 31 | fi 32 | readonly race_flags 33 | 34 | go="${GO:-go}" 35 | 36 | count_flags='--count=1' 37 | shuffle_flags='--shuffle=on' 38 | timeout_flags="${TIMEOUT_FLAGS:---timeout=30s}" 39 | fuzztime_flags="${FUZZTIME_FLAGS:---fuzztime=20s}" 40 | 41 | readonly go count_flags shuffle_flags timeout_flags fuzztime_flags 42 | 43 | # TODO(a.garipov): File an issue about using --fuzz with multiple packages. 44 | "$go" test \ 45 | "$count_flags" \ 46 | "$shuffle_flags" \ 47 | "$race_flags" \ 48 | "$timeout_flags" \ 49 | "$x_flags" \ 50 | "$v_flags" \ 51 | "$fuzztime_flags" \ 52 | --fuzz='FuzzCloner_Clone' \ 53 | ./internal/dnsmsg/ \ 54 | ; 55 | 56 | "$go" test \ 57 | "$count_flags" \ 58 | "$shuffle_flags" \ 59 | "$race_flags" \ 60 | "$timeout_flags" \ 61 | "$x_flags" \ 62 | "$v_flags" \ 63 | "$fuzztime_flags" \ 64 | --fuzz='FuzzHumanIDParser_ParseNormalized' \ 65 | ./internal/agd/ \ 66 | ; 67 | 68 | "$go" test \ 69 | "$count_flags" \ 70 | "$shuffle_flags" \ 71 | "$race_flags" \ 72 | "$timeout_flags" \ 73 | "$x_flags" \ 74 | "$v_flags" \ 75 | "$fuzztime_flags" \ 76 | --fuzz='FuzzDefault' \ 77 | ./internal/agdcache/ \ 78 | ; 79 | -------------------------------------------------------------------------------- /scripts/make/go-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 7 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | set -x 13 | v_flags='-v=1' 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | v_flags='-v=1' 18 | x_flags='-x=0' 19 | else 20 | set +x 21 | v_flags='-v=0' 22 | x_flags='-x=0' 23 | fi 24 | readonly v_flags x_flags 25 | 26 | set -e -f -u 27 | 28 | # Reset GOARCH and GOOS to make sure we install the tools for the native 29 | # architecture even when we're cross-compiling the main binary, and also to 30 | # prevent the "cannot install cross-compiled binaries when GOBIN is set" error. 31 | env \ 32 | GOARCH="" \ 33 | GOBIN="${PWD}/bin" \ 34 | GOOS="" \ 35 | GOWORK='off' \ 36 | "${GO:-go}" install "$v_flags" "$x_flags" tool \ 37 | ; 38 | -------------------------------------------------------------------------------- /scripts/make/go-upd-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 4 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | if [ "$verbose" -gt '1' ]; then 12 | env 13 | set -x 14 | x_flags='-x=1' 15 | elif [ "$verbose" -gt '0' ]; then 16 | set -x 17 | x_flags='-x=0' 18 | else 19 | set +x 20 | x_flags='-x=0' 21 | fi 22 | readonly x_flags 23 | 24 | set -e -f -u 25 | 26 | go="${GO:-go}" 27 | readonly go 28 | 29 | "$go" get -u "$x_flags" tool 30 | "$go" mod tidy "$x_flags" 31 | -------------------------------------------------------------------------------- /scripts/make/md-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Don't use -f, because we use globs in this script. 12 | set -e -u 13 | 14 | if [ "$verbose" -gt '0' ]; then 15 | set -x 16 | fi 17 | 18 | markdownlint \ 19 | ./*.md \ 20 | ./doc/*.md \ 21 | ; 22 | -------------------------------------------------------------------------------- /scripts/make/sh-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This comment is used to simplify checking local copies of the script. Bump 4 | # this number every time a significant change is made to this script. 5 | # 6 | # AdGuard-Project-Version: 3 7 | 8 | verbose="${VERBOSE:-0}" 9 | readonly verbose 10 | 11 | # Don't use -f, because we use globs in this script. 12 | set -e -u 13 | 14 | if [ "$verbose" -gt '0' ]; then 15 | set -x 16 | fi 17 | 18 | # Source the common helpers, including not_found and run_linter. 19 | . ./scripts/make/helper.sh 20 | 21 | run_linter -e shfmt --binary-next-line -d -p -s \ 22 | ./scripts/hooks/* \ 23 | ./scripts/make/*.sh \ 24 | ; 25 | 26 | shellcheck -e 'SC2250' -f 'gcc' -o 'all' -x -- \ 27 | ./scripts/hooks/* \ 28 | ./scripts/make/*.sh \ 29 | ; 30 | -------------------------------------------------------------------------------- /scripts/test/bindtodevice.docker: -------------------------------------------------------------------------------- 1 | # Use the golang:alpine as the base image as it already has most of the 2 | # necessary ip(8) tooling installed. 3 | FROM golang:alpine 4 | 5 | RUN apk add bind-tools bmake gcc git libc-dev &&\ 6 | ln /usr/bin/bmake /usr/bin/make &&\ 7 | mkdir /test/ &&\ 8 | git config --global --add safe.directory /test 9 | 10 | WORKDIR /test/ 11 | 12 | ENV ADGUARD_DNS_TEST_NET_INTERFACE='eth0' 13 | 14 | # The ip route operations must be here and not in the RUN instruction above, 15 | # because they require --cap-add='NET_ADMIN', which is unavailable during build 16 | # time. See ./bindtodevice.sh. 17 | ENTRYPOINT ip route del '172.17.0.0/16' dev 'eth0' &&\ 18 | ip route add local '172.17.0.0/16' dev 'eth0' &&\ 19 | exec sh 20 | -------------------------------------------------------------------------------- /scripts/test/bindtodevice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e -f -u -x 4 | 5 | use_sudo="${USE_SUDO:-0}" 6 | readonly use_sudo 7 | 8 | maybe_sudo() { 9 | if [ "$use_sudo" -eq 0 ] 10 | then 11 | "$@" 12 | else 13 | sudo "$@" 14 | fi 15 | } 16 | 17 | maybe_sudo docker build\ 18 | -t agdns_bindtodevice_test\ 19 | -\ 20 | < ./scripts/test/bindtodevice.docker 21 | 22 | maybe_sudo docker run\ 23 | --cap-add='NET_ADMIN'\ 24 | --name='agdns_bindtodevice_test'\ 25 | --rm\ 26 | -i\ 27 | -t\ 28 | -v "$PWD":'/test'\ 29 | -v "$( go env GOMODCACHE )":'/go/pkg/mod'\ 30 | agdns_bindtodevice_test 31 | -------------------------------------------------------------------------------- /staticcheck.conf: -------------------------------------------------------------------------------- 1 | # This comment is used to simplify checking local copies of the staticcheck 2 | # configuration. Bump this number every time a significant change is made to 3 | # this file. 4 | # 5 | # AdGuard-Project-Version: 1 6 | checks = ["all"] 7 | initialisms = [ 8 | # See https://github.com/dominikh/go-tools/blob/master/config/config.go. 9 | # 10 | # Do not add "PTR" since we use "Ptr" as a suffix. 11 | "inherit" 12 | , "ASN" 13 | , "DHCP" 14 | , "DNSSEC" 15 | # E.g. SentryDSN. 16 | , "DSN" 17 | , "ECS" 18 | , "EDNS" 19 | , "MX" 20 | , "QUIC" 21 | , "RA" 22 | , "RRSIG" 23 | , "RTT" 24 | , "SDNS" 25 | , "SLAAC" 26 | , "SOA" 27 | , "SVCB" 28 | , "TLD" 29 | , "WHOIS" 30 | ] 31 | dot_import_whitelist = [] 32 | http_status_code_whitelist = [] 33 | --------------------------------------------------------------------------------