├── .codecov.yml ├── .github ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── codeql-analysis.yml │ └── go.yml ├── .gitignore ├── AUTHORS ├── CODEOWNERS ├── CONTRIBUTORS ├── COPYRIGHT ├── LICENSE ├── Makefile.fuzz ├── Makefile.release ├── README.md ├── acceptfunc.go ├── acceptfunc_test.go ├── client.go ├── client_test.go ├── clientconfig.go ├── clientconfig_test.go ├── dane.go ├── defaults.go ├── dns.go ├── dns_bench_test.go ├── dns_test.go ├── dnssec.go ├── dnssec_keygen.go ├── dnssec_keyscan.go ├── dnssec_privkey.go ├── dnssec_test.go ├── dnsutil ├── util.go └── util_test.go ├── doc.go ├── duplicate.go ├── duplicate_generate.go ├── duplicate_test.go ├── dyn_test.go ├── edns.go ├── edns_test.go ├── example_test.go ├── format.go ├── format_test.go ├── fuzz.go ├── fuzz_test.go ├── generate.go ├── generate_test.go ├── go.mod ├── go.sum ├── hash.go ├── issue_test.go ├── labels.go ├── labels_test.go ├── leak_test.go ├── length_test.go ├── listen_no_socket_options.go ├── listen_socket_options.go ├── msg.go ├── msg_generate.go ├── msg_helpers.go ├── msg_helpers_test.go ├── msg_test.go ├── msg_truncate.go ├── msg_truncate_test.go ├── nsecx.go ├── nsecx_test.go ├── parse_test.go ├── privaterr.go ├── privaterr_test.go ├── reverse.go ├── rr_test.go ├── sanitize.go ├── sanitize_test.go ├── scan.go ├── scan_rr.go ├── scan_test.go ├── serve_mux.go ├── serve_mux_test.go ├── server.go ├── server_test.go ├── sig0.go ├── sig0_test.go ├── smimea.go ├── svcb.go ├── svcb_test.go ├── tlsa.go ├── tmpdir_darwin_test.go ├── tmpdir_test.go ├── tools.go ├── tsig.go ├── tsig_test.go ├── types.go ├── types_generate.go ├── types_test.go ├── udp.go ├── udp_no_control.go ├── udp_test.go ├── update.go ├── update_test.go ├── version.go ├── version_test.go ├── xfr.go ├── xfr_test.go ├── zduplicate.go ├── zmsg.go └── ztypes.go /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 40% 6 | threshold: null 7 | patch: false 8 | changes: false 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | groups: 8 | all: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Thanks for you pull request, do note the following: 2 | 3 | * If your PR introduces backward incompatible changes it will very likely not be merged. 4 | 5 | * We support the last two major Go versions, if your PR uses features from a too new Go version, it 6 | will not be merged. 7 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: "Code scanning - action" 2 | 3 | on: 4 | push: 5 | branches: [master, ] 6 | pull_request: 7 | branches: [master] 8 | schedule: 9 | - cron: '0 23 * * 5' 10 | 11 | jobs: 12 | CodeQL-Build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@v3 19 | with: 20 | fetch-depth: 2 21 | 22 | - run: git checkout HEAD^2 23 | if: ${{ github.event_name == 'pull_request' }} 24 | 25 | - name: Initialize CodeQL 26 | uses: github/codeql-action/init@v2 27 | 28 | - name: Autobuild 29 | uses: github/codeql-action/autobuild@v2 30 | 31 | - name: Perform CodeQL Analysis 32 | uses: github/codeql-action/analyze@v2 33 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: [push, pull_request] 3 | jobs: 4 | 5 | build: 6 | name: Build and Test 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | go: [ 1.22.x, 1.23.x, 1.24.x ] 11 | steps: 12 | 13 | - name: Set up Go 14 | uses: actions/setup-go@v3 15 | with: 16 | go-version: ${{ matrix.go }} 17 | 18 | - name: Check out code 19 | uses: actions/checkout@v3 20 | 21 | - name: Build 22 | run: go build -v ./... 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.6 2 | tags 3 | test.out 4 | a.out 5 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Miek Gieben 2 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @miekg @tmthrgd 2 | -------------------------------------------------------------------------------- /CONTRIBUTORS: -------------------------------------------------------------------------------- 1 | Alex A. Skinner 2 | Andrew Tunnell-Jones 3 | Ask Bjørn Hansen 4 | Dave Cheney 5 | Dusty Wilson 6 | Marek Majkowski 7 | Peter van Dijk 8 | Omri Bahumi 9 | Alex Sergeyev 10 | James Hartig 11 | -------------------------------------------------------------------------------- /COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright 2009 The Go Authors. All rights reserved. Use of this source code 2 | is governed by a BSD-style license that can be found in the LICENSE file. 3 | Extensions of the original work are copyright (c) 2011 Miek Gieben 4 | 5 | Copyright 2011 Miek Gieben. All rights reserved. Use of this source code is 6 | governed by a BSD-style license that can be found in the LICENSE file. 7 | 8 | Copyright 2014 CloudFlare. All rights reserved. Use of this source code is 9 | governed by a BSD-style license that can be found in the LICENSE file. 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2009, The Go Authors. Extensions copyright (c) 2011, Miek Gieben. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Makefile.fuzz: -------------------------------------------------------------------------------- 1 | # Makefile for fuzzing 2 | # 3 | # Use go-fuzz and needs the tools installed. 4 | # See https://blog.cloudflare.com/dns-parser-meet-go-fuzzer/ 5 | # 6 | # Installing go-fuzz: 7 | # $ make -f Makefile.fuzz get 8 | # Installs: 9 | # * github.com/dvyukov/go-fuzz/go-fuzz 10 | # * get github.com/dvyukov/go-fuzz/go-fuzz-build 11 | 12 | all: build 13 | 14 | .PHONY: build 15 | build: 16 | go-fuzz-build -tags fuzz github.com/miekg/dns 17 | 18 | .PHONY: build-newrr 19 | build-newrr: 20 | go-fuzz-build -func FuzzNewRR -tags fuzz github.com/miekg/dns 21 | 22 | .PHONY: fuzz 23 | fuzz: 24 | go-fuzz -bin=dns-fuzz.zip -workdir=fuzz 25 | 26 | .PHONY: get 27 | get: 28 | go get github.com/dvyukov/go-fuzz/go-fuzz 29 | go get github.com/dvyukov/go-fuzz/go-fuzz-build 30 | 31 | .PHONY: clean 32 | clean: 33 | rm *-fuzz.zip 34 | -------------------------------------------------------------------------------- /Makefile.release: -------------------------------------------------------------------------------- 1 | # Makefile for releasing. 2 | # 3 | # The release is controlled from version.go. The version found there is 4 | # used to tag the git repo, we're not building any artifacts so there is nothing 5 | # to upload to github. 6 | # 7 | # * Up the version in version.go 8 | # * Run: make -f Makefile.release release 9 | # * will *commit* your change with 'Release $VERSION' 10 | # * push to github 11 | # 12 | 13 | define GO 14 | //+build ignore 15 | 16 | package main 17 | 18 | import ( 19 | "fmt" 20 | 21 | "github.com/miekg/dns" 22 | ) 23 | 24 | func main() { 25 | fmt.Println(dns.Version.String()) 26 | } 27 | endef 28 | 29 | $(file > version_release.go,$(GO)) 30 | VERSION:=$(shell go run version_release.go) 31 | TAG="v$(VERSION)" 32 | 33 | all: 34 | @echo Use the \'release\' target to start a release $(VERSION) 35 | rm -f version_release.go 36 | 37 | .PHONY: release 38 | release: commit push 39 | @echo Released $(VERSION) 40 | rm -f version_release.go 41 | 42 | .PHONY: commit 43 | commit: 44 | @echo Committing release $(VERSION) 45 | git commit -am"Release $(VERSION)" 46 | git tag $(TAG) 47 | 48 | .PHONY: push 49 | push: 50 | @echo Pushing release $(VERSION) to master 51 | git push --tags 52 | git push 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/miekg/dns.svg?branch=master)](https://travis-ci.org/miekg/dns) 2 | [![Code Coverage](https://img.shields.io/codecov/c/github/miekg/dns/master.svg)](https://codecov.io/github/miekg/dns?branch=master) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/miekg/dns)](https://goreportcard.com/report/miekg/dns) 4 | [![](https://godoc.org/github.com/miekg/dns?status.svg)](https://godoc.org/github.com/miekg/dns) 5 | 6 | # Alternative (more granular) approach to a DNS library 7 | 8 | > Less is more. 9 | 10 | Complete and usable DNS library. All Resource Records are supported, including the DNSSEC types. 11 | It follows a lean and mean philosophy. If there is stuff you should know as a DNS programmer there 12 | isn't a convenience function for it. Server side and client side programming is supported, i.e. you 13 | can build servers and resolvers with it. 14 | 15 | We try to keep the "master" branch as sane as possible and at the bleeding edge of standards, 16 | avoiding breaking changes wherever reasonable. We support the last two versions of Go. 17 | 18 | # Goals 19 | 20 | * KISS; 21 | * Fast; 22 | * Small API. If it's easy to code in Go, don't make a function for it. 23 | 24 | # Users 25 | 26 | A not-so-up-to-date-list-that-may-be-actually-current: 27 | 28 | * https://github.com/coredns/coredns 29 | * https://github.com/abh/geodns 30 | * https://github.com/baidu/bfe 31 | * http://www.statdns.com/ 32 | * http://www.dnsinspect.com/ 33 | * https://github.com/chuangbo/jianbing-dictionary-dns 34 | * http://www.dns-lg.com/ 35 | * https://github.com/fcambus/rrda 36 | * https://github.com/kenshinx/godns 37 | * https://github.com/skynetservices/skydns 38 | * https://github.com/hashicorp/consul 39 | * https://github.com/DevelopersPL/godnsagent 40 | * https://github.com/duedil-ltd/discodns 41 | * https://github.com/StalkR/dns-reverse-proxy 42 | * https://github.com/tianon/rawdns 43 | * https://mesosphere.github.io/mesos-dns/ 44 | * https://github.com/fcambus/statzone 45 | * https://github.com/benschw/dns-clb-go 46 | * https://github.com/corny/dnscheck for 47 | * https://github.com/miekg/unbound 48 | * https://github.com/miekg/exdns 49 | * https://dnslookup.org 50 | * https://github.com/looterz/grimd 51 | * https://github.com/phamhongviet/serf-dns 52 | * https://github.com/mehrdadrad/mylg 53 | * https://github.com/bamarni/dockness 54 | * https://github.com/fffaraz/microdns 55 | * https://github.com/ipdcode/hades 56 | * https://github.com/StackExchange/dnscontrol/ 57 | * https://www.dnsperf.com/ 58 | * https://dnssectest.net/ 59 | * https://github.com/oif/apex 60 | * https://github.com/jedisct1/dnscrypt-proxy 61 | * https://github.com/jedisct1/rpdns 62 | * https://github.com/xor-gate/sshfp 63 | * https://github.com/rs/dnstrace 64 | * https://blitiri.com.ar/p/dnss ([github mirror](https://github.com/albertito/dnss)) 65 | * https://render.com 66 | * https://github.com/peterzen/goresolver 67 | * https://github.com/folbricht/routedns 68 | * https://domainr.com/ 69 | * https://zonedb.org/ 70 | * https://router7.org/ 71 | * https://github.com/fortio/dnsping 72 | * https://github.com/Luzilla/dnsbl_exporter 73 | * https://github.com/bodgit/tsig 74 | * https://github.com/v2fly/v2ray-core (test only) 75 | * https://kuma.io/ 76 | * https://www.misaka.io/services/dns 77 | * https://ping.sx/dig 78 | * https://fleetdeck.io/ 79 | * https://github.com/markdingo/autoreverse 80 | * https://github.com/slackhq/nebula 81 | * https://addr.tools/ 82 | * https://dnscheck.tools/ 83 | * https://github.com/egbakou/domainverifier 84 | * https://github.com/semihalev/sdns 85 | * https://github.com/wintbiit/NineDNS 86 | * https://linuxcontainers.org/incus/ 87 | * https://ifconfig.es 88 | * https://github.com/zmap/zdns 89 | * https://framagit.org/bortzmeyer/check-soa 90 | 91 | Send pull request if you want to be listed here. 92 | 93 | # Features 94 | 95 | * UDP/TCP queries, IPv4 and IPv6 96 | * RFC 1035 zone file parsing ($INCLUDE, $ORIGIN, $TTL and $GENERATE (for all record types) are supported 97 | * Fast 98 | * Server side programming (mimicking the net/http package) 99 | * Client side programming 100 | * DNSSEC: signing, validating and key generation for DSA, RSA, ECDSA and Ed25519 101 | * EDNS0, NSID, Cookies 102 | * AXFR/IXFR 103 | * TSIG, SIG(0) 104 | * DNS over TLS (DoT): encrypted connection between client and server over TCP 105 | * DNS name compression 106 | 107 | Have fun! 108 | 109 | Miek Gieben - 2010-2012 - 110 | DNS Authors 2012- 111 | 112 | # Building 113 | 114 | This library uses Go modules and uses semantic versioning. Building is done with the `go` tool, so 115 | the following should work: 116 | 117 | go get github.com/miekg/dns 118 | go build github.com/miekg/dns 119 | 120 | ## Examples 121 | 122 | A short "how to use the API" is at the beginning of doc.go (this also will show when you call `godoc 123 | github.com/miekg/dns`). 124 | 125 | Example programs can be found in the `github.com/miekg/exdns` repository. 126 | 127 | ## Supported RFCs 128 | 129 | *all of them* 130 | 131 | * 103{4,5} - DNS standard 132 | * 1183 - ISDN, X25 and other deprecated records 133 | * 1348 - NSAP record (removed the record) 134 | * 1982 - Serial Arithmetic 135 | * 1876 - LOC record 136 | * 1995 - IXFR 137 | * 1996 - DNS notify 138 | * 2136 - DNS Update (dynamic updates) 139 | * 2181 - RRset definition - there is no RRset type though, just []RR 140 | * 2537 - RSAMD5 DNS keys 141 | * 2065 - DNSSEC (updated in later RFCs) 142 | * 2671 - EDNS record 143 | * 2782 - SRV record 144 | * 2845 - TSIG record 145 | * 2915 - NAPTR record 146 | * 2929 - DNS IANA Considerations 147 | * 3110 - RSASHA1 DNS keys 148 | * 3123 - APL record 149 | * 3225 - DO bit (DNSSEC OK) 150 | * 340{1,2,3} - NAPTR record 151 | * 3445 - Limiting the scope of (DNS)KEY 152 | * 3596 - AAAA record 153 | * 3597 - Unknown RRs 154 | * 4025 - A Method for Storing IPsec Keying Material in DNS 155 | * 403{3,4,5} - DNSSEC + validation functions 156 | * 4255 - SSHFP record 157 | * 4343 - Case insensitivity 158 | * 4408 - SPF record 159 | * 4509 - SHA256 Hash in DS 160 | * 4592 - Wildcards in the DNS 161 | * 4635 - HMAC SHA TSIG 162 | * 4701 - DHCID 163 | * 4892 - id.server 164 | * 5001 - NSID 165 | * 5155 - NSEC3 record 166 | * 5205 - HIP record 167 | * 5702 - SHA2 in the DNS 168 | * 5936 - AXFR 169 | * 5966 - TCP implementation recommendations 170 | * 6605 - ECDSA 171 | * 6725 - IANA Registry Update 172 | * 6742 - ILNP DNS 173 | * 6840 - Clarifications and Implementation Notes for DNS Security 174 | * 6844 - CAA record 175 | * 6891 - EDNS0 update 176 | * 6895 - DNS IANA considerations 177 | * 6944 - DNSSEC DNSKEY Algorithm Status 178 | * 6975 - Algorithm Understanding in DNSSEC 179 | * 7043 - EUI48/EUI64 records 180 | * 7314 - DNS (EDNS) EXPIRE Option 181 | * 7477 - CSYNC RR 182 | * 7828 - edns-tcp-keepalive EDNS0 Option 183 | * 7553 - URI record 184 | * 7858 - DNS over TLS: Initiation and Performance Considerations 185 | * 7871 - EDNS0 Client Subnet 186 | * 7873 - Domain Name System (DNS) Cookies 187 | * 8080 - EdDSA for DNSSEC 188 | * 8490 - DNS Stateful Operations 189 | * 8499 - DNS Terminology 190 | * 8659 - DNS Certification Authority Authorization (CAA) Resource Record 191 | * 8777 - DNS Reverse IP Automatic Multicast Tunneling (AMT) Discovery 192 | * 8914 - Extended DNS Errors 193 | * 8976 - Message Digest for DNS Zones (ZONEMD RR) 194 | * 9460 - Service Binding and Parameter Specification via the DNS 195 | * 9461 - Service Binding Mapping for DNS Servers 196 | * 9462 - Discovery of Designated Resolvers 197 | * 9460 - SVCB and HTTPS Records 198 | * 9606 - DNS Resolver Information 199 | * Draft - Compact Denial of Existence in DNSSEC 200 | 201 | ## Loosely Based Upon 202 | 203 | * ldns - 204 | * NSD - 205 | * Net::DNS - 206 | * GRONG - 207 | -------------------------------------------------------------------------------- /acceptfunc.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // MsgAcceptFunc is used early in the server code to accept or reject a message with RcodeFormatError. 4 | // It returns a MsgAcceptAction to indicate what should happen with the message. 5 | type MsgAcceptFunc func(dh Header) MsgAcceptAction 6 | 7 | // DefaultMsgAcceptFunc checks the request and will reject if: 8 | // 9 | // * isn't a request (don't respond in that case) 10 | // 11 | // * opcode isn't OpcodeQuery or OpcodeNotify 12 | // 13 | // * does not have exactly 1 question in the question section 14 | // 15 | // * has more than 1 RR in the Answer section 16 | // 17 | // * has more than 0 RRs in the Authority section 18 | // 19 | // * has more than 2 RRs in the Additional section 20 | var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc 21 | 22 | // MsgAcceptAction represents the action to be taken. 23 | type MsgAcceptAction int 24 | 25 | // Allowed returned values from a MsgAcceptFunc. 26 | const ( 27 | MsgAccept MsgAcceptAction = iota // Accept the message 28 | MsgReject // Reject the message with a RcodeFormatError 29 | MsgIgnore // Ignore the error and send nothing back. 30 | MsgRejectNotImplemented // Reject the message with a RcodeNotImplemented 31 | ) 32 | 33 | func defaultMsgAcceptFunc(dh Header) MsgAcceptAction { 34 | if isResponse := dh.Bits&_QR != 0; isResponse { 35 | return MsgIgnore 36 | } 37 | 38 | // Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs. 39 | opcode := int(dh.Bits>>11) & 0xF 40 | if opcode != OpcodeQuery && opcode != OpcodeNotify { 41 | return MsgRejectNotImplemented 42 | } 43 | 44 | if dh.Qdcount != 1 { 45 | return MsgReject 46 | } 47 | // NOTIFY requests can have a SOA in the ANSWER section. See RFC 1996 Section 3.7 and 3.11. 48 | if dh.Ancount > 1 { 49 | return MsgReject 50 | } 51 | // IXFR request could have one SOA RR in the NS section. See RFC 1995, section 3. 52 | if dh.Nscount > 1 { 53 | return MsgReject 54 | } 55 | if dh.Arcount > 2 { 56 | return MsgReject 57 | } 58 | return MsgAccept 59 | } 60 | -------------------------------------------------------------------------------- /acceptfunc_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "encoding/binary" 5 | "net" 6 | "testing" 7 | ) 8 | 9 | func TestAcceptNotify(t *testing.T) { 10 | HandleFunc("example.org.", handleNotify) 11 | s, addrstr, _, err := RunLocalUDPServer(":0") 12 | if err != nil { 13 | t.Fatalf("unable to run test server: %v", err) 14 | } 15 | defer s.Shutdown() 16 | 17 | m := new(Msg) 18 | m.SetNotify("example.org.") 19 | // Set a SOA hint in the answer section, this is allowed according to RFC 1996. 20 | soa, _ := NewRR("example.org. IN SOA sns.dns.icann.org. noc.dns.icann.org. 2018112827 7200 3600 1209600 3600") 21 | m.Answer = []RR{soa} 22 | 23 | c := new(Client) 24 | resp, _, err := c.Exchange(m, addrstr) 25 | if err != nil { 26 | t.Errorf("failed to exchange: %v", err) 27 | } 28 | if resp.Rcode != RcodeSuccess { 29 | t.Errorf("expected %s, got %s", RcodeToString[RcodeSuccess], RcodeToString[resp.Rcode]) 30 | } 31 | } 32 | 33 | func handleNotify(w ResponseWriter, req *Msg) { 34 | m := new(Msg) 35 | m.SetReply(req) 36 | w.WriteMsg(m) 37 | } 38 | 39 | func TestInvalidMsg(t *testing.T) { 40 | HandleFunc("example.org.", func(ResponseWriter, *Msg) { 41 | t.Fatal("the handler must not be called in any of these tests") 42 | }) 43 | s, addrstr, _, err := RunLocalTCPServer(":0") 44 | if err != nil { 45 | t.Fatalf("unable to run test server: %v", err) 46 | } 47 | defer s.Shutdown() 48 | 49 | s.MsgAcceptFunc = func(dh Header) MsgAcceptAction { 50 | switch dh.Id { 51 | case 0x0001: 52 | return MsgAccept 53 | case 0x0002: 54 | return MsgReject 55 | case 0x0003: 56 | return MsgIgnore 57 | case 0x0004: 58 | return MsgRejectNotImplemented 59 | default: 60 | t.Errorf("unexpected ID %x", dh.Id) 61 | return -1 62 | } 63 | } 64 | 65 | invalidErrors := make(chan error) 66 | s.MsgInvalidFunc = func(m []byte, err error) { 67 | invalidErrors <- err 68 | } 69 | 70 | c, err := net.Dial("tcp", addrstr) 71 | if err != nil { 72 | t.Fatalf("cannot connect to test server: %v", err) 73 | } 74 | 75 | write := func(m []byte) { 76 | var length [2]byte 77 | binary.BigEndian.PutUint16(length[:], uint16(len(m))) 78 | _, err := c.Write(length[:]) 79 | if err != nil { 80 | t.Fatalf("length write failed: %v", err) 81 | } 82 | _, err = c.Write(m) 83 | if err != nil { 84 | t.Fatalf("content write failed: %v", err) 85 | } 86 | } 87 | 88 | /* Message is too short, so there is no header to accept or reject. */ 89 | 90 | tooShortMessage := make([]byte, 11) 91 | tooShortMessage[1] = 0x3 // ID = 3, would be ignored if it were parsable. 92 | 93 | write(tooShortMessage) 94 | // Expect an error to be reported. 95 | <-invalidErrors 96 | 97 | /* Message is accepted but is actually invalid. */ 98 | 99 | badMessage := make([]byte, 13) 100 | badMessage[1] = 0x1 // ID = 1, Accept. 101 | badMessage[5] = 1 // QDCOUNT = 1 102 | badMessage[12] = 99 // Bad question section. Invalid! 103 | 104 | write(badMessage) 105 | // Expect an error to be reported. 106 | <-invalidErrors 107 | 108 | /* Message is rejected before it can be determined to be invalid. */ 109 | 110 | close(invalidErrors) // A call to InvalidMsgFunc would panic due to the closed chan. 111 | 112 | badMessage[1] = 0x2 // ID = 2, Reject 113 | write(badMessage) 114 | 115 | badMessage[1] = 0x3 // ID = 3, Ignore 116 | write(badMessage) 117 | 118 | badMessage[1] = 0x4 // ID = 4, RejectNotImplemented 119 | write(badMessage) 120 | } 121 | -------------------------------------------------------------------------------- /clientconfig.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "os" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // ClientConfig wraps the contents of the /etc/resolv.conf file. 12 | type ClientConfig struct { 13 | Servers []string // servers to use 14 | Search []string // suffixes to append to local name 15 | Port string // what port to use 16 | Ndots int // number of dots in name to trigger absolute lookup 17 | Timeout int // seconds before giving up on packet 18 | Attempts int // lost packets before giving up on server, not used in the package dns 19 | } 20 | 21 | // ClientConfigFromFile parses a resolv.conf(5) like file and returns 22 | // a *ClientConfig. 23 | func ClientConfigFromFile(resolvconf string) (*ClientConfig, error) { 24 | file, err := os.Open(resolvconf) 25 | if err != nil { 26 | return nil, err 27 | } 28 | defer file.Close() 29 | return ClientConfigFromReader(file) 30 | } 31 | 32 | // ClientConfigFromReader works like ClientConfigFromFile but takes an io.Reader as argument 33 | func ClientConfigFromReader(resolvconf io.Reader) (*ClientConfig, error) { 34 | c := new(ClientConfig) 35 | scanner := bufio.NewScanner(resolvconf) 36 | c.Servers = make([]string, 0) 37 | c.Search = make([]string, 0) 38 | c.Port = "53" 39 | c.Ndots = 1 40 | c.Timeout = 5 41 | c.Attempts = 2 42 | 43 | for scanner.Scan() { 44 | if err := scanner.Err(); err != nil { 45 | return nil, err 46 | } 47 | line := scanner.Text() 48 | f := strings.Fields(line) 49 | if len(f) < 1 { 50 | continue 51 | } 52 | switch f[0] { 53 | case "nameserver": // add one name server 54 | if len(f) > 1 { 55 | // One more check: make sure server name is 56 | // just an IP address. Otherwise we need DNS 57 | // to look it up. 58 | name := f[1] 59 | c.Servers = append(c.Servers, name) 60 | } 61 | 62 | case "domain": // set search path to just this domain 63 | if len(f) > 1 { 64 | c.Search = make([]string, 1) 65 | c.Search[0] = f[1] 66 | } else { 67 | c.Search = make([]string, 0) 68 | } 69 | 70 | case "search": // set search path to given servers 71 | c.Search = cloneSlice(f[1:]) 72 | 73 | case "options": // magic options 74 | for _, s := range f[1:] { 75 | switch { 76 | case len(s) >= 6 && s[:6] == "ndots:": 77 | n, _ := strconv.Atoi(s[6:]) 78 | if n < 0 { 79 | n = 0 80 | } else if n > 15 { 81 | n = 15 82 | } 83 | c.Ndots = n 84 | case len(s) >= 8 && s[:8] == "timeout:": 85 | n, _ := strconv.Atoi(s[8:]) 86 | if n < 1 { 87 | n = 1 88 | } 89 | c.Timeout = n 90 | case len(s) >= 9 && s[:9] == "attempts:": 91 | n, _ := strconv.Atoi(s[9:]) 92 | if n < 1 { 93 | n = 1 94 | } 95 | c.Attempts = n 96 | case s == "rotate": 97 | /* not imp */ 98 | } 99 | } 100 | } 101 | } 102 | return c, nil 103 | } 104 | 105 | // NameList returns all of the names that should be queried based on the 106 | // config. It is based off of go's net/dns name building, but it does not 107 | // check the length of the resulting names. 108 | func (c *ClientConfig) NameList(name string) []string { 109 | // if this domain is already fully qualified, no append needed. 110 | if IsFqdn(name) { 111 | return []string{name} 112 | } 113 | 114 | // Check to see if the name has more labels than Ndots. Do this before making 115 | // the domain fully qualified. 116 | hasNdots := CountLabel(name) > c.Ndots 117 | // Make the domain fully qualified. 118 | name = Fqdn(name) 119 | 120 | // Make a list of names based off search. 121 | names := []string{} 122 | 123 | // If name has enough dots, try that first. 124 | if hasNdots { 125 | names = append(names, name) 126 | } 127 | for _, s := range c.Search { 128 | names = append(names, Fqdn(name+s)) 129 | } 130 | // If we didn't have enough dots, try after suffixes. 131 | if !hasNdots { 132 | names = append(names, name) 133 | } 134 | return names 135 | } 136 | -------------------------------------------------------------------------------- /clientconfig_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | const normal string = ` 11 | # Comment 12 | domain somedomain.com 13 | nameserver 10.28.10.2 14 | nameserver 11.28.10.1 15 | ` 16 | 17 | const missingNewline string = ` 18 | domain somedomain.com 19 | nameserver 10.28.10.2 20 | nameserver 11.28.10.1` // <- NOTE: NO newline. 21 | 22 | func testConfig(t *testing.T, data string) { 23 | cc, err := ClientConfigFromReader(strings.NewReader(data)) 24 | if err != nil { 25 | t.Errorf("error parsing resolv.conf: %v", err) 26 | } 27 | if l := len(cc.Servers); l != 2 { 28 | t.Errorf("incorrect number of nameservers detected: %d", l) 29 | } 30 | if l := len(cc.Search); l != 1 { 31 | t.Errorf("domain directive not parsed correctly: %v", cc.Search) 32 | } else { 33 | if cc.Search[0] != "somedomain.com" { 34 | t.Errorf("domain is unexpected: %v", cc.Search[0]) 35 | } 36 | } 37 | } 38 | 39 | func TestNameserver(t *testing.T) { testConfig(t, normal) } 40 | func TestMissingFinalNewLine(t *testing.T) { testConfig(t, missingNewline) } 41 | 42 | func TestNdots(t *testing.T) { 43 | ndotsVariants := map[string]int{ 44 | "options ndots:0": 0, 45 | "options ndots:1": 1, 46 | "options ndots:15": 15, 47 | "options ndots:16": 15, 48 | "options ndots:-1": 0, 49 | "": 1, 50 | } 51 | 52 | for data := range ndotsVariants { 53 | cc, err := ClientConfigFromReader(strings.NewReader(data)) 54 | if err != nil { 55 | t.Errorf("error parsing resolv.conf: %v", err) 56 | } 57 | if cc.Ndots != ndotsVariants[data] { 58 | t.Errorf("Ndots not properly parsed: (Expected: %d / Was: %d)", ndotsVariants[data], cc.Ndots) 59 | } 60 | } 61 | } 62 | 63 | func TestClientConfigFromReaderAttempts(t *testing.T) { 64 | testCases := []struct { 65 | data string 66 | expected int 67 | }{ 68 | {data: "options attempts:0", expected: 1}, 69 | {data: "options attempts:1", expected: 1}, 70 | {data: "options attempts:15", expected: 15}, 71 | {data: "options attempts:16", expected: 16}, 72 | {data: "options attempts:-1", expected: 1}, 73 | {data: "options attempt:", expected: 2}, 74 | } 75 | 76 | for _, test := range testCases { 77 | test := test 78 | t.Run(strings.Replace(test.data, ":", " ", -1), func(t *testing.T) { 79 | t.Parallel() 80 | 81 | cc, err := ClientConfigFromReader(strings.NewReader(test.data)) 82 | if err != nil { 83 | t.Errorf("error parsing resolv.conf: %v", err) 84 | } 85 | if cc.Attempts != test.expected { 86 | t.Errorf("A attempts not properly parsed: (Expected: %d / Was: %d)", test.expected, cc.Attempts) 87 | } 88 | }) 89 | } 90 | } 91 | 92 | func TestReadFromFile(t *testing.T) { 93 | tempDir := t.TempDir() 94 | 95 | path := filepath.Join(tempDir, "resolv.conf") 96 | if err := os.WriteFile(path, []byte(normal), 0o644); err != nil { 97 | t.Fatalf("writeFile: %v", err) 98 | } 99 | cc, err := ClientConfigFromFile(path) 100 | if err != nil { 101 | t.Errorf("error parsing resolv.conf: %v", err) 102 | } 103 | if l := len(cc.Servers); l != 2 { 104 | t.Errorf("incorrect number of nameservers detected: %d", l) 105 | } 106 | if l := len(cc.Search); l != 1 { 107 | t.Errorf("domain directive not parsed correctly: %v", cc.Search) 108 | } else { 109 | if cc.Search[0] != "somedomain.com" { 110 | t.Errorf("domain is unexpected: %v", cc.Search[0]) 111 | } 112 | } 113 | } 114 | 115 | func TestNameListNdots1(t *testing.T) { 116 | cfg := ClientConfig{ 117 | Ndots: 1, 118 | } 119 | // fqdn should be only result returned 120 | names := cfg.NameList("miek.nl.") 121 | if len(names) != 1 { 122 | t.Errorf("NameList returned != 1 names: %v", names) 123 | } else if names[0] != "miek.nl." { 124 | t.Errorf("NameList didn't return sent fqdn domain: %v", names[0]) 125 | } 126 | 127 | cfg.Search = []string{ 128 | "test", 129 | } 130 | // Sent domain has NDots and search 131 | names = cfg.NameList("miek.nl") 132 | if len(names) != 2 { 133 | t.Errorf("NameList returned != 2 names: %v", names) 134 | } else if names[0] != "miek.nl." { 135 | t.Errorf("NameList didn't return sent domain first: %v", names[0]) 136 | } else if names[1] != "miek.nl.test." { 137 | t.Errorf("NameList didn't return search last: %v", names[1]) 138 | } 139 | } 140 | 141 | func TestNameListNdots2(t *testing.T) { 142 | cfg := ClientConfig{ 143 | Ndots: 2, 144 | } 145 | 146 | // Sent domain has less than NDots and search 147 | cfg.Search = []string{ 148 | "test", 149 | } 150 | names := cfg.NameList("miek.nl") 151 | 152 | if len(names) != 2 { 153 | t.Errorf("NameList returned != 2 names: %v", names) 154 | } else if names[0] != "miek.nl.test." { 155 | t.Errorf("NameList didn't return search first: %v", names[0]) 156 | } else if names[1] != "miek.nl." { 157 | t.Errorf("NameList didn't return sent domain last: %v", names[1]) 158 | } 159 | } 160 | 161 | func TestNameListNdots0(t *testing.T) { 162 | cfg := ClientConfig{ 163 | Ndots: 0, 164 | } 165 | cfg.Search = []string{ 166 | "test", 167 | } 168 | // Sent domain has less than NDots and search 169 | names := cfg.NameList("miek") 170 | if len(names) != 2 { 171 | t.Errorf("NameList returned != 2 names: %v", names) 172 | } else if names[0] != "miek." { 173 | t.Errorf("NameList didn't return search first: %v", names[0]) 174 | } else if names[1] != "miek.test." { 175 | t.Errorf("NameList didn't return sent domain last: %v", names[1]) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /dane.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/sha256" 5 | "crypto/sha512" 6 | "crypto/x509" 7 | "encoding/hex" 8 | "errors" 9 | ) 10 | 11 | // CertificateToDANE converts a certificate to a hex string as used in the TLSA or SMIMEA records. 12 | func CertificateToDANE(selector, matchingType uint8, cert *x509.Certificate) (string, error) { 13 | switch matchingType { 14 | case 0: 15 | switch selector { 16 | case 0: 17 | return hex.EncodeToString(cert.Raw), nil 18 | case 1: 19 | return hex.EncodeToString(cert.RawSubjectPublicKeyInfo), nil 20 | } 21 | case 1: 22 | h := sha256.New() 23 | switch selector { 24 | case 0: 25 | h.Write(cert.Raw) 26 | return hex.EncodeToString(h.Sum(nil)), nil 27 | case 1: 28 | h.Write(cert.RawSubjectPublicKeyInfo) 29 | return hex.EncodeToString(h.Sum(nil)), nil 30 | } 31 | case 2: 32 | h := sha512.New() 33 | switch selector { 34 | case 0: 35 | h.Write(cert.Raw) 36 | return hex.EncodeToString(h.Sum(nil)), nil 37 | case 1: 38 | h.Write(cert.RawSubjectPublicKeyInfo) 39 | return hex.EncodeToString(h.Sum(nil)), nil 40 | } 41 | } 42 | return "", errors.New("dns: bad MatchingType or Selector") 43 | } 44 | -------------------------------------------------------------------------------- /dns.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "encoding/hex" 5 | "strconv" 6 | ) 7 | 8 | const ( 9 | year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits. 10 | defaultTtl = 3600 // Default internal TTL. 11 | 12 | // DefaultMsgSize is the standard default for messages larger than 512 bytes. 13 | DefaultMsgSize = 4096 14 | // MinMsgSize is the minimal size of a DNS packet. 15 | MinMsgSize = 512 16 | // MaxMsgSize is the largest possible DNS packet. 17 | MaxMsgSize = 65535 18 | ) 19 | 20 | // Error represents a DNS error. 21 | type Error struct{ err string } 22 | 23 | func (e *Error) Error() string { 24 | if e == nil { 25 | return "dns: " 26 | } 27 | return "dns: " + e.err 28 | } 29 | 30 | // An RR represents a resource record. 31 | type RR interface { 32 | // Header returns the header of an resource record. The header contains 33 | // everything up to the rdata. 34 | Header() *RR_Header 35 | // String returns the text representation of the resource record. 36 | String() string 37 | 38 | // copy returns a copy of the RR 39 | copy() RR 40 | 41 | // len returns the length (in octets) of the compressed or uncompressed RR in wire format. 42 | // 43 | // If compression is nil, the uncompressed size will be returned, otherwise the compressed 44 | // size will be returned and domain names will be added to the map for future compression. 45 | len(off int, compression map[string]struct{}) int 46 | 47 | // pack packs the records RDATA into wire format. The header will 48 | // already have been packed into msg. 49 | pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) 50 | 51 | // unpack unpacks an RR from wire format. 52 | // 53 | // This will only be called on a new and empty RR type with only the header populated. It 54 | // will only be called if the record's RDATA is non-empty. 55 | unpack(msg []byte, off int) (off1 int, err error) 56 | 57 | // parse parses an RR from zone file format. 58 | // 59 | // This will only be called on a new and empty RR type with only the header populated. 60 | parse(c *zlexer, origin string) *ParseError 61 | 62 | // isDuplicate returns whether the two RRs are duplicates. 63 | isDuplicate(r2 RR) bool 64 | } 65 | 66 | // RR_Header is the header all DNS resource records share. 67 | type RR_Header struct { 68 | Name string `dns:"cdomain-name"` 69 | Rrtype uint16 70 | Class uint16 71 | Ttl uint32 72 | Rdlength uint16 // Length of data after header. 73 | } 74 | 75 | // Header returns itself. This is here to make RR_Header implements the RR interface. 76 | func (h *RR_Header) Header() *RR_Header { return h } 77 | 78 | // Just to implement the RR interface. 79 | func (h *RR_Header) copy() RR { return nil } 80 | 81 | func (h *RR_Header) String() string { 82 | var s string 83 | 84 | if h.Rrtype == TypeOPT { 85 | s = ";" 86 | // and maybe other things 87 | } 88 | 89 | s += sprintName(h.Name) + "\t" 90 | s += strconv.FormatInt(int64(h.Ttl), 10) + "\t" 91 | s += Class(h.Class).String() + "\t" 92 | s += Type(h.Rrtype).String() + "\t" 93 | return s 94 | } 95 | 96 | func (h *RR_Header) len(off int, compression map[string]struct{}) int { 97 | l := domainNameLen(h.Name, off, compression, true) 98 | l += 10 // rrtype(2) + class(2) + ttl(4) + rdlength(2) 99 | return l 100 | } 101 | 102 | func (h *RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) { 103 | // RR_Header has no RDATA to pack. 104 | return off, nil 105 | } 106 | 107 | func (h *RR_Header) unpack(msg []byte, off int) (int, error) { 108 | panic("dns: internal error: unpack should never be called on RR_Header") 109 | } 110 | 111 | func (h *RR_Header) parse(c *zlexer, origin string) *ParseError { 112 | panic("dns: internal error: parse should never be called on RR_Header") 113 | } 114 | 115 | // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597. 116 | func (rr *RFC3597) ToRFC3597(r RR) error { 117 | buf := make([]byte, Len(r)) 118 | headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false) 119 | if err != nil { 120 | return err 121 | } 122 | buf = buf[:off] 123 | 124 | *rr = RFC3597{Hdr: *r.Header()} 125 | rr.Hdr.Rdlength = uint16(off - headerEnd) 126 | 127 | if noRdata(rr.Hdr) { 128 | return nil 129 | } 130 | 131 | _, err = rr.unpack(buf, headerEnd) 132 | return err 133 | } 134 | 135 | // fromRFC3597 converts an unknown RR representation from RFC 3597 to the known RR type. 136 | func (rr *RFC3597) fromRFC3597(r RR) error { 137 | hdr := r.Header() 138 | *hdr = rr.Hdr 139 | 140 | // Can't overflow uint16 as the length of Rdata is validated in (*RFC3597).parse. 141 | // We can only get here when rr was constructed with that method. 142 | hdr.Rdlength = uint16(hex.DecodedLen(len(rr.Rdata))) 143 | 144 | if noRdata(*hdr) { 145 | // Dynamic update. 146 | return nil 147 | } 148 | 149 | // rr.pack requires an extra allocation and a copy so we just decode Rdata 150 | // manually, it's simpler anyway. 151 | msg, err := hex.DecodeString(rr.Rdata) 152 | if err != nil { 153 | return err 154 | } 155 | 156 | _, err = r.unpack(msg, 0) 157 | return err 158 | } 159 | -------------------------------------------------------------------------------- /dnssec_keygen.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto" 5 | "crypto/ecdsa" 6 | "crypto/ed25519" 7 | "crypto/elliptic" 8 | "crypto/rand" 9 | "crypto/rsa" 10 | "math/big" 11 | ) 12 | 13 | // Generate generates a DNSKEY of the given bit size. 14 | // The public part is put inside the DNSKEY record. 15 | // The Algorithm in the key must be set as this will define 16 | // what kind of DNSKEY will be generated. 17 | // The ECDSA algorithms imply a fixed keysize, in that case 18 | // bits should be set to the size of the algorithm. 19 | func (k *DNSKEY) Generate(bits int) (crypto.PrivateKey, error) { 20 | switch k.Algorithm { 21 | case RSASHA1, RSASHA256, RSASHA1NSEC3SHA1: 22 | if bits < 512 || bits > 4096 { 23 | return nil, ErrKeySize 24 | } 25 | case RSASHA512: 26 | if bits < 1024 || bits > 4096 { 27 | return nil, ErrKeySize 28 | } 29 | case ECDSAP256SHA256: 30 | if bits != 256 { 31 | return nil, ErrKeySize 32 | } 33 | case ECDSAP384SHA384: 34 | if bits != 384 { 35 | return nil, ErrKeySize 36 | } 37 | case ED25519: 38 | if bits != 256 { 39 | return nil, ErrKeySize 40 | } 41 | default: 42 | return nil, ErrAlg 43 | } 44 | 45 | switch k.Algorithm { 46 | case RSASHA1, RSASHA256, RSASHA512, RSASHA1NSEC3SHA1: 47 | priv, err := rsa.GenerateKey(rand.Reader, bits) 48 | if err != nil { 49 | return nil, err 50 | } 51 | k.setPublicKeyRSA(priv.PublicKey.E, priv.PublicKey.N) 52 | return priv, nil 53 | case ECDSAP256SHA256, ECDSAP384SHA384: 54 | var c elliptic.Curve 55 | switch k.Algorithm { 56 | case ECDSAP256SHA256: 57 | c = elliptic.P256() 58 | case ECDSAP384SHA384: 59 | c = elliptic.P384() 60 | } 61 | priv, err := ecdsa.GenerateKey(c, rand.Reader) 62 | if err != nil { 63 | return nil, err 64 | } 65 | k.setPublicKeyECDSA(priv.PublicKey.X, priv.PublicKey.Y) 66 | return priv, nil 67 | case ED25519: 68 | pub, priv, err := ed25519.GenerateKey(rand.Reader) 69 | if err != nil { 70 | return nil, err 71 | } 72 | k.setPublicKeyED25519(pub) 73 | return priv, nil 74 | default: 75 | return nil, ErrAlg 76 | } 77 | } 78 | 79 | // Set the public key (the value E and N) 80 | func (k *DNSKEY) setPublicKeyRSA(_E int, _N *big.Int) bool { 81 | if _E == 0 || _N == nil { 82 | return false 83 | } 84 | buf := exponentToBuf(_E) 85 | buf = append(buf, _N.Bytes()...) 86 | k.PublicKey = toBase64(buf) 87 | return true 88 | } 89 | 90 | // Set the public key for Elliptic Curves 91 | func (k *DNSKEY) setPublicKeyECDSA(_X, _Y *big.Int) bool { 92 | if _X == nil || _Y == nil { 93 | return false 94 | } 95 | var intlen int 96 | switch k.Algorithm { 97 | case ECDSAP256SHA256: 98 | intlen = 32 99 | case ECDSAP384SHA384: 100 | intlen = 48 101 | } 102 | k.PublicKey = toBase64(curveToBuf(_X, _Y, intlen)) 103 | return true 104 | } 105 | 106 | // Set the public key for Ed25519 107 | func (k *DNSKEY) setPublicKeyED25519(_K ed25519.PublicKey) bool { 108 | if _K == nil { 109 | return false 110 | } 111 | k.PublicKey = toBase64(_K) 112 | return true 113 | } 114 | 115 | // Set the public key (the values E and N) for RSA 116 | // RFC 3110: Section 2. RSA Public KEY Resource Records 117 | func exponentToBuf(_E int) []byte { 118 | var buf []byte 119 | i := big.NewInt(int64(_E)).Bytes() 120 | if len(i) < 256 { 121 | buf = make([]byte, 1, 1+len(i)) 122 | buf[0] = uint8(len(i)) 123 | } else { 124 | buf = make([]byte, 3, 3+len(i)) 125 | buf[0] = 0 126 | buf[1] = uint8(len(i) >> 8) 127 | buf[2] = uint8(len(i)) 128 | } 129 | buf = append(buf, i...) 130 | return buf 131 | } 132 | 133 | // Set the public key for X and Y for Curve. The two 134 | // values are just concatenated. 135 | func curveToBuf(_X, _Y *big.Int, intlen int) []byte { 136 | buf := intToBytes(_X, intlen) 137 | buf = append(buf, intToBytes(_Y, intlen)...) 138 | return buf 139 | } 140 | -------------------------------------------------------------------------------- /dnssec_keyscan.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "bufio" 5 | "crypto" 6 | "crypto/ecdsa" 7 | "crypto/ed25519" 8 | "crypto/rsa" 9 | "io" 10 | "math/big" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | // NewPrivateKey returns a PrivateKey by parsing the string s. 16 | // s should be in the same form of the BIND private key files. 17 | func (k *DNSKEY) NewPrivateKey(s string) (crypto.PrivateKey, error) { 18 | if s == "" || s[len(s)-1] != '\n' { // We need a closing newline 19 | return k.ReadPrivateKey(strings.NewReader(s+"\n"), "") 20 | } 21 | return k.ReadPrivateKey(strings.NewReader(s), "") 22 | } 23 | 24 | // ReadPrivateKey reads a private key from the io.Reader q. The string file is 25 | // only used in error reporting. 26 | // The public key must be known, because some cryptographic algorithms embed 27 | // the public inside the privatekey. 28 | func (k *DNSKEY) ReadPrivateKey(q io.Reader, file string) (crypto.PrivateKey, error) { 29 | m, err := parseKey(q, file) 30 | if m == nil { 31 | return nil, err 32 | } 33 | if _, ok := m["private-key-format"]; !ok { 34 | return nil, ErrPrivKey 35 | } 36 | if m["private-key-format"] != "v1.2" && m["private-key-format"] != "v1.3" { 37 | return nil, ErrPrivKey 38 | } 39 | // TODO(mg): check if the pubkey matches the private key 40 | algoStr, _, _ := strings.Cut(m["algorithm"], " ") 41 | algo, err := strconv.ParseUint(algoStr, 10, 8) 42 | if err != nil { 43 | return nil, ErrPrivKey 44 | } 45 | switch uint8(algo) { 46 | case RSASHA1, RSASHA1NSEC3SHA1, RSASHA256, RSASHA512: 47 | priv, err := readPrivateKeyRSA(m) 48 | if err != nil { 49 | return nil, err 50 | } 51 | pub := k.publicKeyRSA() 52 | if pub == nil { 53 | return nil, ErrKey 54 | } 55 | priv.PublicKey = *pub 56 | return priv, nil 57 | case ECDSAP256SHA256, ECDSAP384SHA384: 58 | priv, err := readPrivateKeyECDSA(m) 59 | if err != nil { 60 | return nil, err 61 | } 62 | pub := k.publicKeyECDSA() 63 | if pub == nil { 64 | return nil, ErrKey 65 | } 66 | priv.PublicKey = *pub 67 | return priv, nil 68 | case ED25519: 69 | return readPrivateKeyED25519(m) 70 | default: 71 | return nil, ErrAlg 72 | } 73 | } 74 | 75 | // Read a private key (file) string and create a public key. Return the private key. 76 | func readPrivateKeyRSA(m map[string]string) (*rsa.PrivateKey, error) { 77 | p := new(rsa.PrivateKey) 78 | p.Primes = []*big.Int{nil, nil} 79 | for k, v := range m { 80 | switch k { 81 | case "modulus", "publicexponent", "privateexponent", "prime1", "prime2": 82 | v1, err := fromBase64([]byte(v)) 83 | if err != nil { 84 | return nil, err 85 | } 86 | switch k { 87 | case "modulus": 88 | p.PublicKey.N = new(big.Int).SetBytes(v1) 89 | case "publicexponent": 90 | i := new(big.Int).SetBytes(v1) 91 | p.PublicKey.E = int(i.Int64()) // int64 should be large enough 92 | case "privateexponent": 93 | p.D = new(big.Int).SetBytes(v1) 94 | case "prime1": 95 | p.Primes[0] = new(big.Int).SetBytes(v1) 96 | case "prime2": 97 | p.Primes[1] = new(big.Int).SetBytes(v1) 98 | } 99 | case "exponent1", "exponent2", "coefficient": 100 | // not used in Go (yet) 101 | case "created", "publish", "activate": 102 | // not used in Go (yet) 103 | } 104 | } 105 | return p, nil 106 | } 107 | 108 | func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) { 109 | p := new(ecdsa.PrivateKey) 110 | p.D = new(big.Int) 111 | // TODO: validate that the required flags are present 112 | for k, v := range m { 113 | switch k { 114 | case "privatekey": 115 | v1, err := fromBase64([]byte(v)) 116 | if err != nil { 117 | return nil, err 118 | } 119 | p.D.SetBytes(v1) 120 | case "created", "publish", "activate": 121 | /* not used in Go (yet) */ 122 | } 123 | } 124 | return p, nil 125 | } 126 | 127 | func readPrivateKeyED25519(m map[string]string) (ed25519.PrivateKey, error) { 128 | var p ed25519.PrivateKey 129 | // TODO: validate that the required flags are present 130 | for k, v := range m { 131 | switch k { 132 | case "privatekey": 133 | p1, err := fromBase64([]byte(v)) 134 | if err != nil { 135 | return nil, err 136 | } 137 | if len(p1) != ed25519.SeedSize { 138 | return nil, ErrPrivKey 139 | } 140 | p = ed25519.NewKeyFromSeed(p1) 141 | case "created", "publish", "activate": 142 | /* not used in Go (yet) */ 143 | } 144 | } 145 | return p, nil 146 | } 147 | 148 | // parseKey reads a private key from r. It returns a map[string]string, 149 | // with the key-value pairs, or an error when the file is not correct. 150 | func parseKey(r io.Reader, file string) (map[string]string, error) { 151 | m := make(map[string]string) 152 | var k string 153 | 154 | c := newKLexer(r) 155 | 156 | for l, ok := c.Next(); ok; l, ok = c.Next() { 157 | // It should alternate 158 | switch l.value { 159 | case zKey: 160 | k = l.token 161 | case zValue: 162 | if k == "" { 163 | return nil, &ParseError{file: file, err: "no private key seen", lex: l} 164 | } 165 | 166 | m[strings.ToLower(k)] = l.token 167 | k = "" 168 | } 169 | } 170 | 171 | // Surface any read errors from r. 172 | if err := c.Err(); err != nil { 173 | return nil, &ParseError{file: file, err: err.Error()} 174 | } 175 | 176 | return m, nil 177 | } 178 | 179 | type klexer struct { 180 | br io.ByteReader 181 | 182 | readErr error 183 | 184 | line int 185 | column int 186 | 187 | key bool 188 | 189 | eol bool // end-of-line 190 | } 191 | 192 | func newKLexer(r io.Reader) *klexer { 193 | br, ok := r.(io.ByteReader) 194 | if !ok { 195 | br = bufio.NewReaderSize(r, 1024) 196 | } 197 | 198 | return &klexer{ 199 | br: br, 200 | 201 | line: 1, 202 | 203 | key: true, 204 | } 205 | } 206 | 207 | func (kl *klexer) Err() error { 208 | if kl.readErr == io.EOF { 209 | return nil 210 | } 211 | 212 | return kl.readErr 213 | } 214 | 215 | // readByte returns the next byte from the input 216 | func (kl *klexer) readByte() (byte, bool) { 217 | if kl.readErr != nil { 218 | return 0, false 219 | } 220 | 221 | c, err := kl.br.ReadByte() 222 | if err != nil { 223 | kl.readErr = err 224 | return 0, false 225 | } 226 | 227 | // delay the newline handling until the next token is delivered, 228 | // fixes off-by-one errors when reporting a parse error. 229 | if kl.eol { 230 | kl.line++ 231 | kl.column = 0 232 | kl.eol = false 233 | } 234 | 235 | if c == '\n' { 236 | kl.eol = true 237 | } else { 238 | kl.column++ 239 | } 240 | 241 | return c, true 242 | } 243 | 244 | func (kl *klexer) Next() (lex, bool) { 245 | var ( 246 | l lex 247 | 248 | str strings.Builder 249 | 250 | commt bool 251 | ) 252 | 253 | for x, ok := kl.readByte(); ok; x, ok = kl.readByte() { 254 | l.line, l.column = kl.line, kl.column 255 | 256 | switch x { 257 | case ':': 258 | if commt || !kl.key { 259 | break 260 | } 261 | 262 | kl.key = false 263 | 264 | // Next token is a space, eat it 265 | kl.readByte() 266 | 267 | l.value = zKey 268 | l.token = str.String() 269 | return l, true 270 | case ';': 271 | commt = true 272 | case '\n': 273 | if commt { 274 | // Reset a comment 275 | commt = false 276 | } 277 | 278 | if kl.key && str.Len() == 0 { 279 | // ignore empty lines 280 | break 281 | } 282 | 283 | kl.key = true 284 | 285 | l.value = zValue 286 | l.token = str.String() 287 | return l, true 288 | default: 289 | if commt { 290 | break 291 | } 292 | 293 | str.WriteByte(x) 294 | } 295 | } 296 | 297 | if kl.readErr != nil && kl.readErr != io.EOF { 298 | // Don't return any tokens after a read error occurs. 299 | return lex{value: zEOF}, false 300 | } 301 | 302 | if str.Len() > 0 { 303 | // Send remainder 304 | l.value = zValue 305 | l.token = str.String() 306 | return l, true 307 | } 308 | 309 | return lex{value: zEOF}, false 310 | } 311 | -------------------------------------------------------------------------------- /dnssec_privkey.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto" 5 | "crypto/ecdsa" 6 | "crypto/ed25519" 7 | "crypto/rsa" 8 | "math/big" 9 | "strconv" 10 | ) 11 | 12 | const format = "Private-key-format: v1.3\n" 13 | 14 | var bigIntOne = big.NewInt(1) 15 | 16 | // PrivateKeyString converts a PrivateKey to a string. This string has the same 17 | // format as the private-key-file of BIND9 (Private-key-format: v1.3). 18 | // It needs some info from the key (the algorithm), so its a method of the DNSKEY. 19 | // It supports *rsa.PrivateKey, *ecdsa.PrivateKey and ed25519.PrivateKey. 20 | func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string { 21 | algorithm := strconv.Itoa(int(r.Algorithm)) 22 | algorithm += " (" + AlgorithmToString[r.Algorithm] + ")" 23 | 24 | switch p := p.(type) { 25 | case *rsa.PrivateKey: 26 | modulus := toBase64(p.PublicKey.N.Bytes()) 27 | e := big.NewInt(int64(p.PublicKey.E)) 28 | publicExponent := toBase64(e.Bytes()) 29 | privateExponent := toBase64(p.D.Bytes()) 30 | prime1 := toBase64(p.Primes[0].Bytes()) 31 | prime2 := toBase64(p.Primes[1].Bytes()) 32 | // Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm 33 | // and from: http://code.google.com/p/go/issues/detail?id=987 34 | p1 := new(big.Int).Sub(p.Primes[0], bigIntOne) 35 | q1 := new(big.Int).Sub(p.Primes[1], bigIntOne) 36 | exp1 := new(big.Int).Mod(p.D, p1) 37 | exp2 := new(big.Int).Mod(p.D, q1) 38 | coeff := new(big.Int).ModInverse(p.Primes[1], p.Primes[0]) 39 | 40 | exponent1 := toBase64(exp1.Bytes()) 41 | exponent2 := toBase64(exp2.Bytes()) 42 | coefficient := toBase64(coeff.Bytes()) 43 | 44 | return format + 45 | "Algorithm: " + algorithm + "\n" + 46 | "Modulus: " + modulus + "\n" + 47 | "PublicExponent: " + publicExponent + "\n" + 48 | "PrivateExponent: " + privateExponent + "\n" + 49 | "Prime1: " + prime1 + "\n" + 50 | "Prime2: " + prime2 + "\n" + 51 | "Exponent1: " + exponent1 + "\n" + 52 | "Exponent2: " + exponent2 + "\n" + 53 | "Coefficient: " + coefficient + "\n" 54 | 55 | case *ecdsa.PrivateKey: 56 | var intlen int 57 | switch r.Algorithm { 58 | case ECDSAP256SHA256: 59 | intlen = 32 60 | case ECDSAP384SHA384: 61 | intlen = 48 62 | } 63 | private := toBase64(intToBytes(p.D, intlen)) 64 | return format + 65 | "Algorithm: " + algorithm + "\n" + 66 | "PrivateKey: " + private + "\n" 67 | 68 | case ed25519.PrivateKey: 69 | private := toBase64(p.Seed()) 70 | return format + 71 | "Algorithm: " + algorithm + "\n" + 72 | "PrivateKey: " + private + "\n" 73 | 74 | default: 75 | return "" 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /dnsutil/util.go: -------------------------------------------------------------------------------- 1 | // Package dnsutil contains higher-level methods useful with the dns 2 | // package. While package dns implements the DNS protocols itself, 3 | // these functions are related but not directly required for protocol 4 | // processing. They are often useful in preparing input/output of the 5 | // functions in package dns. 6 | package dnsutil 7 | 8 | import ( 9 | "strings" 10 | 11 | "github.com/miekg/dns" 12 | ) 13 | 14 | // AddOrigin adds origin to s if s is not already a FQDN. 15 | // Note that the result may not be a FQDN. If origin does not end 16 | // with a ".", the result won't either. 17 | // This implements the zonefile convention (specified in RFC 1035, 18 | // Section "5.1. Format") that "@" represents the 19 | // apex (bare) domain. i.e. AddOrigin("@", "foo.com.") returns "foo.com.". 20 | func AddOrigin(s, origin string) string { 21 | // ("foo.", "origin.") -> "foo." (already a FQDN) 22 | // ("foo", "origin.") -> "foo.origin." 23 | // ("foo", "origin") -> "foo.origin" 24 | // ("foo", ".") -> "foo." (Same as dns.Fqdn()) 25 | // ("foo.", ".") -> "foo." (Same as dns.Fqdn()) 26 | // ("@", "origin.") -> "origin." (@ represents the apex (bare) domain) 27 | // ("", "origin.") -> "origin." (not obvious) 28 | // ("foo", "") -> "foo" (not obvious) 29 | 30 | if dns.IsFqdn(s) { 31 | return s // s is already a FQDN, no need to mess with it. 32 | } 33 | if origin == "" { 34 | return s // Nothing to append. 35 | } 36 | if s == "@" || s == "" { 37 | return origin // Expand apex. 38 | } 39 | if origin == "." { 40 | return dns.Fqdn(s) 41 | } 42 | 43 | return s + "." + origin // The simple case. 44 | } 45 | 46 | // TrimDomainName trims origin from s if s is a subdomain. 47 | // This function will never return "", but returns "@" instead (@ represents the apex domain). 48 | func TrimDomainName(s, origin string) string { 49 | // An apex (bare) domain is always returned as "@". 50 | // If the return value ends in a ".", the domain was not the suffix. 51 | // origin can end in "." or not. Either way the results should be the same. 52 | 53 | if s == "" { 54 | return "@" 55 | } 56 | // Someone is using TrimDomainName(s, ".") to remove a dot if it exists. 57 | if origin == "." { 58 | return strings.TrimSuffix(s, origin) 59 | } 60 | 61 | original := s 62 | s = dns.Fqdn(s) 63 | origin = dns.Fqdn(origin) 64 | 65 | if !dns.IsSubDomain(origin, s) { 66 | return original 67 | } 68 | 69 | slabels := dns.Split(s) 70 | olabels := dns.Split(origin) 71 | m := dns.CompareDomainName(s, origin) 72 | if len(olabels) == m { 73 | if len(olabels) == len(slabels) { 74 | return "@" // origin == s 75 | } 76 | if (s[0] == '.') && (len(slabels) == (len(olabels) + 1)) { 77 | return "@" // TrimDomainName(".foo.", "foo.") 78 | } 79 | } 80 | 81 | // Return the first (len-m) labels: 82 | return s[:slabels[len(slabels)-m]-1] 83 | } 84 | -------------------------------------------------------------------------------- /dnsutil/util_test.go: -------------------------------------------------------------------------------- 1 | package dnsutil 2 | 3 | import "testing" 4 | 5 | func TestAddOrigin(t *testing.T) { 6 | var tests = []struct{ e1, e2, expected string }{ 7 | {"@", "example.com", "example.com"}, 8 | {"foo", "example.com", "foo.example.com"}, 9 | {"foo.", "example.com", "foo."}, 10 | {"@", "example.com.", "example.com."}, 11 | {"foo", "example.com.", "foo.example.com."}, 12 | {"foo.", "example.com.", "foo."}, 13 | {"example.com", ".", "example.com."}, 14 | {"example.com.", ".", "example.com."}, 15 | // Oddball tests: 16 | // In general origin should not be "" or "." but at least 17 | // these tests verify we don't crash and will keep results 18 | // from changing unexpectedly. 19 | {"*.", "", "*."}, 20 | {"@", "", "@"}, 21 | {"foobar", "", "foobar"}, 22 | {"foobar.", "", "foobar."}, 23 | {"*.", ".", "*."}, 24 | {"@", ".", "."}, 25 | {"foobar", ".", "foobar."}, 26 | {"foobar.", ".", "foobar."}, 27 | } 28 | for _, test := range tests { 29 | actual := AddOrigin(test.e1, test.e2) 30 | if test.expected != actual { 31 | t.Errorf("AddOrigin(%#v, %#v) expected %#v, got %#v\n", test.e1, test.e2, test.expected, actual) 32 | } 33 | } 34 | } 35 | 36 | func TestTrimDomainName(t *testing.T) { 37 | // Basic tests. 38 | // Try trimming "example.com" and "example.com." from typical use cases. 39 | testsEx := []struct{ experiment, expected string }{ 40 | {"foo.example.com", "foo"}, 41 | {"foo.example.com.", "foo"}, 42 | {".foo.example.com", ".foo"}, 43 | {".foo.example.com.", ".foo"}, 44 | {"*.example.com", "*"}, 45 | {"example.com", "@"}, 46 | {"example.com.", "@"}, 47 | {"com.", "com."}, 48 | {"foo.", "foo."}, 49 | {"serverfault.com.", "serverfault.com."}, 50 | {"serverfault.com", "serverfault.com"}, 51 | {".foo.ronco.com", ".foo.ronco.com"}, 52 | {".foo.ronco.com.", ".foo.ronco.com."}, 53 | } 54 | for _, dom := range []string{"example.com", "example.com."} { 55 | for i, test := range testsEx { 56 | actual := TrimDomainName(test.experiment, dom) 57 | if test.expected != actual { 58 | t.Errorf("%d TrimDomainName(%#v, %#v): expected %v, got %v\n", i, test.experiment, dom, test.expected, actual) 59 | } 60 | } 61 | } 62 | 63 | // Paranoid tests. 64 | // These test shouldn't be needed but I was weary of off-by-one errors. 65 | // In theory, these can't happen because there are no single-letter TLDs, 66 | // but it is good to exercise the code this way. 67 | tests := []struct{ experiment, expected string }{ 68 | {"", "@"}, 69 | {".", "."}, 70 | {"a.b.c.d.e.f.", "a.b.c.d.e"}, 71 | {"b.c.d.e.f.", "b.c.d.e"}, 72 | {"c.d.e.f.", "c.d.e"}, 73 | {"d.e.f.", "d.e"}, 74 | {"e.f.", "e"}, 75 | {"f.", "@"}, 76 | {".a.b.c.d.e.f.", ".a.b.c.d.e"}, 77 | {".b.c.d.e.f.", ".b.c.d.e"}, 78 | {".c.d.e.f.", ".c.d.e"}, 79 | {".d.e.f.", ".d.e"}, 80 | {".e.f.", ".e"}, 81 | {".f.", "@"}, 82 | {"a.b.c.d.e.f", "a.b.c.d.e"}, 83 | {"a.b.c.d.e.", "a.b.c.d.e."}, 84 | {"a.b.c.d.e", "a.b.c.d.e"}, 85 | {"a.b.c.d.", "a.b.c.d."}, 86 | {"a.b.c.d", "a.b.c.d"}, 87 | {"a.b.c.", "a.b.c."}, 88 | {"a.b.c", "a.b.c"}, 89 | {"a.b.", "a.b."}, 90 | {"a.b", "a.b"}, 91 | {"a.", "a."}, 92 | {"a", "a"}, 93 | {".a.b.c.d.e.f", ".a.b.c.d.e"}, 94 | {".a.b.c.d.e.", ".a.b.c.d.e."}, 95 | {".a.b.c.d.e", ".a.b.c.d.e"}, 96 | {".a.b.c.d.", ".a.b.c.d."}, 97 | {".a.b.c.d", ".a.b.c.d"}, 98 | {".a.b.c.", ".a.b.c."}, 99 | {".a.b.c", ".a.b.c"}, 100 | {".a.b.", ".a.b."}, 101 | {".a.b", ".a.b"}, 102 | {".a.", ".a."}, 103 | {".a", ".a"}, 104 | } 105 | for _, dom := range []string{"f", "f."} { 106 | for i, test := range tests { 107 | actual := TrimDomainName(test.experiment, dom) 108 | if test.expected != actual { 109 | t.Errorf("%d TrimDomainName(%#v, %#v): expected %v, got %v\n", i, test.experiment, dom, test.expected, actual) 110 | } 111 | } 112 | } 113 | 114 | // Test cases for bugs found in the wild. 115 | // These test cases provide both origin, s, and the expected result. 116 | // If you find a bug in the while, this is probably the easiest place 117 | // to add it as a test case. 118 | var testsWild = []struct{ e1, e2, expected string }{ 119 | {"mathoverflow.net.", ".", "mathoverflow.net"}, 120 | {"mathoverflow.net", ".", "mathoverflow.net"}, 121 | {"", ".", "@"}, 122 | {"@", ".", "@"}, 123 | } 124 | for i, test := range testsWild { 125 | actual := TrimDomainName(test.e1, test.e2) 126 | if test.expected != actual { 127 | t.Errorf("%d TrimDomainName(%#v, %#v): expected %v, got %v\n", i, test.e1, test.e2, test.expected, actual) 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /duplicate.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | //go:generate go run duplicate_generate.go 4 | 5 | // IsDuplicate checks of r1 and r2 are duplicates of each other, excluding the TTL. 6 | // So this means the header data is equal *and* the RDATA is the same. Returns true 7 | // if so, otherwise false. It's a protocol violation to have identical RRs in a message. 8 | func IsDuplicate(r1, r2 RR) bool { 9 | // Check whether the record header is identical. 10 | if !r1.Header().isDuplicate(r2.Header()) { 11 | return false 12 | } 13 | 14 | // Check whether the RDATA is identical. 15 | return r1.isDuplicate(r2) 16 | } 17 | 18 | func (r1 *RR_Header) isDuplicate(_r2 RR) bool { 19 | r2, ok := _r2.(*RR_Header) 20 | if !ok { 21 | return false 22 | } 23 | if r1.Class != r2.Class { 24 | return false 25 | } 26 | if r1.Rrtype != r2.Rrtype { 27 | return false 28 | } 29 | if !isDuplicateName(r1.Name, r2.Name) { 30 | return false 31 | } 32 | // ignore TTL 33 | return true 34 | } 35 | 36 | // isDuplicateName checks if the domain names s1 and s2 are equal. 37 | func isDuplicateName(s1, s2 string) bool { return equal(s1, s2) } 38 | -------------------------------------------------------------------------------- /duplicate_generate.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | // types_generate.go is meant to run with go generate. It will use 5 | // go/{importer,types} to track down all the RR struct types. Then for each type 6 | // it will generate conversion tables (TypeToRR and TypeToString) and banal 7 | // methods (len, Header, copy) based on the struct tags. The generated source is 8 | // written to ztypes.go, and is meant to be checked into git. 9 | package main 10 | 11 | import ( 12 | "bytes" 13 | "fmt" 14 | "go/format" 15 | "go/types" 16 | "log" 17 | "os" 18 | 19 | "golang.org/x/tools/go/packages" 20 | ) 21 | 22 | var packageHdr = ` 23 | // Code generated by "go run duplicate_generate.go"; DO NOT EDIT. 24 | 25 | package dns 26 | 27 | ` 28 | 29 | func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 30 | st, ok := t.Underlying().(*types.Struct) 31 | if !ok { 32 | return nil, false 33 | } 34 | if st.NumFields() == 0 { 35 | return nil, false 36 | } 37 | if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 38 | return st, false 39 | } 40 | if st.Field(0).Anonymous() { 41 | st, _ := getTypeStruct(st.Field(0).Type(), scope) 42 | return st, true 43 | } 44 | return nil, false 45 | } 46 | 47 | // loadModule retrieves package description for a given module. 48 | func loadModule(name string) (*types.Package, error) { 49 | conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 50 | pkgs, err := packages.Load(&conf, name) 51 | if err != nil { 52 | return nil, err 53 | } 54 | return pkgs[0].Types, nil 55 | } 56 | 57 | func main() { 58 | // Import and type-check the package 59 | pkg, err := loadModule("github.com/miekg/dns") 60 | fatalIfErr(err) 61 | scope := pkg.Scope() 62 | 63 | // Collect actual types (*X) 64 | var namedTypes []string 65 | for _, name := range scope.Names() { 66 | o := scope.Lookup(name) 67 | if o == nil || !o.Exported() { 68 | continue 69 | } 70 | 71 | if st, _ := getTypeStruct(o.Type(), scope); st == nil { 72 | continue 73 | } 74 | 75 | if name == "PrivateRR" || name == "OPT" { 76 | continue 77 | } 78 | 79 | namedTypes = append(namedTypes, o.Name()) 80 | } 81 | 82 | b := &bytes.Buffer{} 83 | b.WriteString(packageHdr) 84 | 85 | // Generate the duplicate check for each type. 86 | fmt.Fprint(b, "// isDuplicate() functions\n\n") 87 | for _, name := range namedTypes { 88 | 89 | o := scope.Lookup(name) 90 | st, _ := getTypeStruct(o.Type(), scope) 91 | fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name) 92 | fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name) 93 | fmt.Fprint(b, "if !ok { return false }\n") 94 | fmt.Fprint(b, "_ = r2\n") 95 | for i := 1; i < st.NumFields(); i++ { 96 | field := st.Field(i).Name() 97 | o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) } 98 | o3 := func(s string) { fmt.Fprintf(b, s+"\n", field, field, field) } 99 | 100 | // For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are 101 | // *indirectly* defined as a slice in the net package). 102 | if _, ok := st.Field(i).Type().(*types.Slice); ok { 103 | o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}") 104 | 105 | if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` { 106 | o3(`for i := 0; i < len(r1.%s); i++ { 107 | if !isDuplicateName(r1.%s[i], r2.%s[i]) { 108 | return false 109 | } 110 | }`) 111 | 112 | continue 113 | } 114 | 115 | if st.Tag(i) == `dns:"apl"` { 116 | o3(`for i := 0; i < len(r1.%s); i++ { 117 | if !r1.%s[i].equals(&r2.%s[i]) { 118 | return false 119 | } 120 | }`) 121 | 122 | continue 123 | } 124 | 125 | if st.Tag(i) == `dns:"pairs"` { 126 | o2(`if !areSVCBPairArraysEqual(r1.%s, r2.%s) { 127 | return false 128 | }`) 129 | 130 | continue 131 | } 132 | 133 | o3(`for i := 0; i < len(r1.%s); i++ { 134 | if r1.%s[i] != r2.%s[i] { 135 | return false 136 | } 137 | }`) 138 | 139 | continue 140 | } 141 | 142 | switch st.Tag(i) { 143 | case `dns:"-"`: 144 | // ignored 145 | case `dns:"a"`, `dns:"aaaa"`: 146 | o2("if !r1.%s.Equal(r2.%s) {\nreturn false\n}") 147 | case `dns:"cdomain-name"`, `dns:"domain-name"`: 148 | o2("if !isDuplicateName(r1.%s, r2.%s) {\nreturn false\n}") 149 | case `dns:"ipsechost"`, `dns:"amtrelayhost"`: 150 | o2(`switch r1.GatewayType { 151 | case IPSECGatewayIPv4, IPSECGatewayIPv6: 152 | if !r1.GatewayAddr.Equal(r2.GatewayAddr) { 153 | return false 154 | } 155 | case IPSECGatewayHost: 156 | if !isDuplicateName(r1.%s, r2.%s) { 157 | return false 158 | } 159 | } 160 | `) 161 | default: 162 | o2("if r1.%s != r2.%s {\nreturn false\n}") 163 | } 164 | } 165 | fmt.Fprint(b, "return true\n}\n\n") 166 | } 167 | 168 | // gofmt 169 | res, err := format.Source(b.Bytes()) 170 | if err != nil { 171 | b.WriteTo(os.Stderr) 172 | log.Fatal(err) 173 | } 174 | 175 | // write result 176 | f, err := os.Create("zduplicate.go") 177 | fatalIfErr(err) 178 | defer f.Close() 179 | f.Write(res) 180 | } 181 | 182 | func fatalIfErr(err error) { 183 | if err != nil { 184 | log.Fatal(err) 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /duplicate_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "testing" 4 | 5 | func TestDuplicateA(t *testing.T) { 6 | a1, _ := NewRR("www.example.org. 2700 IN A 127.0.0.1") 7 | a2, _ := NewRR("www.example.org. IN A 127.0.0.1") 8 | if !IsDuplicate(a1, a2) { 9 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 10 | } 11 | 12 | a2, _ = NewRR("www.example.org. IN A 127.0.0.2") 13 | if IsDuplicate(a1, a2) { 14 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 15 | } 16 | } 17 | 18 | func TestDuplicateTXT(t *testing.T) { 19 | a1, _ := NewRR("www.example.org. IN TXT \"aa\"") 20 | a2, _ := NewRR("www.example.org. IN TXT \"aa\"") 21 | 22 | if !IsDuplicate(a1, a2) { 23 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 24 | } 25 | 26 | a2, _ = NewRR("www.example.org. IN TXT \"aa\" \"bb\"") 27 | if IsDuplicate(a1, a2) { 28 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 29 | } 30 | 31 | a1, _ = NewRR("www.example.org. IN TXT \"aa\" \"bc\"") 32 | if IsDuplicate(a1, a2) { 33 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 34 | } 35 | } 36 | 37 | func TestDuplicateSVCB(t *testing.T) { 38 | a1, _ := NewRR(`example.com. 3600 IN SVCB 1 . ipv6hint=1::3:3:3:3 key65300=\254\032\030\000\ \043,\;`) 39 | a2, _ := NewRR(`example.com. 3600 IN SVCB 1 . ipv6hint=1:0::3:3:3:3 key65300="\254\ \030\000 +\,;"`) 40 | 41 | if !IsDuplicate(a1, a2) { 42 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 43 | } 44 | 45 | a2, _ = NewRR(`example.com. 3600 IN SVCB 1 . ipv6hint=1::3:3:3:3 key65300="\255\ \030\000 +\,;"`) 46 | 47 | if IsDuplicate(a1, a2) { 48 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 49 | } 50 | 51 | a1, _ = NewRR(`example.com. 3600 IN SVCB 1 . ipv6hint=1::3:3:3:3`) 52 | 53 | if IsDuplicate(a1, a2) { 54 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 55 | } 56 | 57 | a2, _ = NewRR(`example.com. 3600 IN SVCB 1 . ipv4hint=1.1.1.1`) 58 | 59 | if IsDuplicate(a1, a2) { 60 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 61 | } 62 | 63 | a1, _ = NewRR(`example.com. 3600 IN SVCB 1 . ipv4hint=1.1.1.1,1.0.2.1`) 64 | 65 | if IsDuplicate(a1, a2) { 66 | t.Errorf("expected %s/%s not to be duplicates, but got true", a1.String(), a2.String()) 67 | } 68 | } 69 | 70 | func TestDuplicateOwner(t *testing.T) { 71 | a1, _ := NewRR("www.example.org. IN A 127.0.0.1") 72 | a2, _ := NewRR("www.example.org. IN A 127.0.0.1") 73 | if !IsDuplicate(a1, a2) { 74 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 75 | } 76 | 77 | a2, _ = NewRR("WWw.exaMPle.org. IN A 127.0.0.2") 78 | if IsDuplicate(a1, a2) { 79 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 80 | } 81 | } 82 | 83 | func TestDuplicateDomain(t *testing.T) { 84 | a1, _ := NewRR("www.example.org. IN CNAME example.org.") 85 | a2, _ := NewRR("www.example.org. IN CNAME example.org.") 86 | if !IsDuplicate(a1, a2) { 87 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 88 | } 89 | 90 | a2, _ = NewRR("www.example.org. IN CNAME exAMPLe.oRG.") 91 | if !IsDuplicate(a1, a2) { 92 | t.Errorf("expected %s/%s to be duplicates, but got false", a1.String(), a2.String()) 93 | } 94 | } 95 | 96 | func TestDuplicateWrongRrtype(t *testing.T) { 97 | // Test that IsDuplicate won't panic for a record that's lying about 98 | // it's Rrtype. 99 | 100 | r1 := &A{Hdr: RR_Header{Rrtype: TypeA}} 101 | r2 := &AAAA{Hdr: RR_Header{Rrtype: TypeA}} 102 | if IsDuplicate(r1, r2) { 103 | t.Errorf("expected %s/%s not to be duplicates, but got true", r1.String(), r2.String()) 104 | } 105 | 106 | r3 := &AAAA{Hdr: RR_Header{Rrtype: TypeA}} 107 | r4 := &A{Hdr: RR_Header{Rrtype: TypeA}} 108 | if IsDuplicate(r3, r4) { 109 | t.Errorf("expected %s/%s not to be duplicates, but got true", r3.String(), r4.String()) 110 | } 111 | 112 | r5 := &AAAA{Hdr: RR_Header{Rrtype: TypeA}} 113 | r6 := &AAAA{Hdr: RR_Header{Rrtype: TypeA}} 114 | if !IsDuplicate(r5, r6) { 115 | t.Errorf("expected %s/%s to be duplicates, but got false", r5.String(), r6.String()) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /dyn_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // Find better solution 4 | -------------------------------------------------------------------------------- /edns_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "bytes" 5 | "net" 6 | "testing" 7 | ) 8 | 9 | func TestOPTTtl(t *testing.T) { 10 | e := &OPT{} 11 | e.Hdr.Name = "." 12 | e.Hdr.Rrtype = TypeOPT 13 | 14 | // verify the default setting of DO=0 15 | if e.Do() { 16 | t.Errorf("DO bit should be zero") 17 | } 18 | 19 | // There are 6 possible invocations of SetDo(): 20 | // 21 | // 1. Starting with DO=0, using SetDo() 22 | // 2. Starting with DO=0, using SetDo(true) 23 | // 3. Starting with DO=0, using SetDo(false) 24 | // 4. Starting with DO=1, using SetDo() 25 | // 5. Starting with DO=1, using SetDo(true) 26 | // 6. Starting with DO=1, using SetDo(false) 27 | 28 | // verify that invoking SetDo() sets DO=1 (TEST #1) 29 | e.SetDo() 30 | if !e.Do() { 31 | t.Errorf("DO bit should be non-zero") 32 | } 33 | // verify that using SetDo(true) works when DO=1 (TEST #5) 34 | e.SetDo(true) 35 | if !e.Do() { 36 | t.Errorf("DO bit should still be non-zero") 37 | } 38 | // verify that we can use SetDo(false) to set DO=0 (TEST #6) 39 | e.SetDo(false) 40 | if e.Do() { 41 | t.Errorf("DO bit should be zero") 42 | } 43 | // verify that if we call SetDo(false) when DO=0 that it is unchanged (TEST #3) 44 | e.SetDo(false) 45 | if e.Do() { 46 | t.Errorf("DO bit should still be zero") 47 | } 48 | // verify that using SetDo(true) works for DO=0 (TEST #2) 49 | e.SetDo(true) 50 | if !e.Do() { 51 | t.Errorf("DO bit should be non-zero") 52 | } 53 | // verify that using SetDo() works for DO=1 (TEST #4) 54 | e.SetDo() 55 | if !e.Do() { 56 | t.Errorf("DO bit should be non-zero") 57 | } 58 | 59 | // CO (Compact ANswers OK) flag tests follow the same pattern as DO tests 60 | // verify that invoking SetCo() sets CO=1 61 | e.SetCo() 62 | if !e.Co() { 63 | t.Errorf("CO bit should be non-zero") 64 | } 65 | 66 | // verify that using SetCo(true) works when CO=1 67 | e.SetCo(true) 68 | if !e.Co() { 69 | t.Errorf("CO bit should still be non-zero") 70 | } 71 | // verify that we can use SetCo(false) to set CO=0 72 | e.SetCo(false) 73 | if e.Co() { 74 | t.Errorf("CO bit should be zero") 75 | } 76 | // verify that if we call SetCo(false) when CO=0 that it is unchanged 77 | e.SetCo(false) 78 | if e.Co() { 79 | t.Errorf("CO bit should still be zero") 80 | } 81 | // verify that using SetCo(true) works for CO=0 82 | e.SetCo(true) 83 | if !e.Co() { 84 | t.Errorf("CO bit should be non-zero") 85 | } 86 | // verify that using SetCo() works for CO=1 87 | e.SetCo() 88 | if !e.Co() { 89 | t.Errorf("CO bit should be non-zero") 90 | } 91 | 92 | if e.Version() != 0 { 93 | t.Errorf("version should be non-zero") 94 | } 95 | 96 | e.SetVersion(42) 97 | if e.Version() != 42 { 98 | t.Errorf("set 42, expected %d, got %d", 42, e.Version()) 99 | } 100 | 101 | e.SetExtendedRcode(42) 102 | // ExtendedRcode has the last 4 bits set to 0. 103 | if e.ExtendedRcode() != 42&0xFFFFFFF0 { 104 | t.Errorf("set 42, expected %d, got %d", 42&0xFFFFFFF0, e.ExtendedRcode()) 105 | } 106 | 107 | // This will reset the 8 upper bits of the extended rcode 108 | e.SetExtendedRcode(RcodeNotAuth) 109 | if e.ExtendedRcode() != 0 { 110 | t.Errorf("Setting a non-extended rcode is expected to set extended rcode to 0, got: %d", e.ExtendedRcode()) 111 | } 112 | } 113 | 114 | func TestEDNS0_SUBNETUnpack(t *testing.T) { 115 | for _, ip := range []net.IP{ 116 | net.IPv4(0xde, 0xad, 0xbe, 0xef), 117 | net.ParseIP("192.0.2.1"), 118 | net.ParseIP("2001:db8::68"), 119 | } { 120 | var s1 EDNS0_SUBNET 121 | s1.Address = ip 122 | 123 | if ip.To4() == nil { 124 | s1.Family = 2 125 | s1.SourceNetmask = net.IPv6len * 8 126 | } else { 127 | s1.Family = 1 128 | s1.SourceNetmask = net.IPv4len * 8 129 | } 130 | 131 | b, err := s1.pack() 132 | if err != nil { 133 | t.Fatalf("failed to pack: %v", err) 134 | } 135 | 136 | var s2 EDNS0_SUBNET 137 | if err := s2.unpack(b); err != nil { 138 | t.Fatalf("failed to unpack: %v", err) 139 | } 140 | 141 | if !ip.Equal(s2.Address) { 142 | t.Errorf("address different after unpacking; expected %s, got %s", ip, s2.Address) 143 | } 144 | } 145 | } 146 | 147 | func TestEDNS0_UL(t *testing.T) { 148 | cases := []struct { 149 | l uint32 150 | kl uint32 151 | }{ 152 | {0x01234567, 0}, 153 | {0x76543210, 0xFEDCBA98}, 154 | } 155 | for _, c := range cases { 156 | expect := EDNS0_UL{EDNS0UL, c.l, c.kl} 157 | b, err := expect.pack() 158 | if err != nil { 159 | t.Fatalf("failed to pack: %v", err) 160 | } 161 | actual := EDNS0_UL{EDNS0UL, ^uint32(0), ^uint32(0)} 162 | if err := actual.unpack(b); err != nil { 163 | t.Fatalf("failed to unpack: %v", err) 164 | } 165 | if expect != actual { 166 | t.Errorf("unpacked option is different; expected %v, got %v", expect, actual) 167 | } 168 | } 169 | } 170 | 171 | func TestZ(t *testing.T) { 172 | e := &OPT{} 173 | e.Hdr.Name = "." 174 | e.Hdr.Rrtype = TypeOPT 175 | e.SetVersion(8) 176 | e.SetDo() 177 | e.SetCo() 178 | if e.Z() != 0 { 179 | t.Errorf("expected Z of 0, got %d", e.Z()) 180 | } 181 | e.SetZ(5) 182 | if e.Z() != 5 { 183 | t.Errorf("expected Z of 5, got %d", e.Z()) 184 | } 185 | e.SetZ(0xFFFF) 186 | if e.Z() != 0x3FFF { 187 | t.Errorf("expected Z of 0x3FFFF, got %d", e.Z()) 188 | } 189 | if e.Version() != 8 { 190 | t.Errorf("expected version to still be 8, got %d", e.Version()) 191 | } 192 | if !e.Do() { 193 | t.Error("expected DO to be set") 194 | } 195 | } 196 | 197 | func TestEDNS0_ESU(t *testing.T) { 198 | p := []byte{ 199 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 200 | 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x29, 0x04, 201 | 0xC4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 202 | 0x04, 0x00, 0x24, 0x73, 0x69, 0x70, 0x3A, 0x2B, 203 | 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 204 | 0x39, 0x40, 0x74, 0x65, 0x73, 0x74, 0x2E, 0x63, 205 | 0x6F, 0x6D, 0x3B, 0x75, 0x73, 0x65, 0x72, 0x3D, 206 | 0x63, 0x67, 0x72, 0x61, 0x74, 0x65, 0x73, 207 | } 208 | 209 | m := new(Msg) 210 | if err := m.Unpack(p); err != nil { 211 | t.Fatalf("failed to unpack: %v", err) 212 | } 213 | opt := m.IsEdns0() 214 | if opt == nil { 215 | t.Fatalf("expected edns0 option") 216 | } 217 | if len(opt.Option) != 1 { 218 | t.Fatalf("expected only one option: %v", opt.Option) 219 | } 220 | edns0 := opt.Option[0] 221 | esu, ok := edns0.(*EDNS0_ESU) 222 | if !ok { 223 | t.Fatalf("expected option of type EDNS0_ESU, got %t", edns0) 224 | } 225 | expect := "sip:+123456789@test.com;user=cgrates" 226 | if esu.Uri != expect { 227 | t.Errorf("unpacked option is different; expected %v, got %v", expect, esu.Uri) 228 | } 229 | } 230 | 231 | func TestEDNS0_TCP_KEEPALIVE_unpack(t *testing.T) { 232 | cases := []struct { 233 | name string 234 | b []byte 235 | expected uint16 236 | expectedErr bool 237 | }{ 238 | { 239 | name: "empty", 240 | b: []byte{}, 241 | expected: 0, 242 | }, 243 | { 244 | name: "timeout 1", 245 | b: []byte{0, 1}, 246 | expected: 1, 247 | }, 248 | { 249 | name: "invalid", 250 | b: []byte{0, 1, 3}, 251 | expectedErr: true, 252 | }, 253 | } 254 | 255 | for _, tc := range cases { 256 | t.Run(tc.name, func(t *testing.T) { 257 | e := &EDNS0_TCP_KEEPALIVE{} 258 | err := e.unpack(tc.b) 259 | if err != nil && !tc.expectedErr { 260 | t.Error("failed to unpack, expected no error") 261 | } 262 | if err == nil && tc.expectedErr { 263 | t.Error("unpacked, but expected an error") 264 | } 265 | if e.Timeout != tc.expected { 266 | t.Errorf("invalid timeout, actual: %d, expected: %d", e.Timeout, tc.expected) 267 | } 268 | }) 269 | } 270 | } 271 | 272 | func TestEDNS0_TCP_KEEPALIVE_pack(t *testing.T) { 273 | cases := []struct { 274 | name string 275 | edns *EDNS0_TCP_KEEPALIVE 276 | expected []byte 277 | }{ 278 | { 279 | name: "empty", 280 | edns: &EDNS0_TCP_KEEPALIVE{ 281 | Code: EDNS0TCPKEEPALIVE, 282 | Timeout: 0, 283 | }, 284 | expected: nil, 285 | }, 286 | { 287 | name: "timeout 1", 288 | edns: &EDNS0_TCP_KEEPALIVE{ 289 | Code: EDNS0TCPKEEPALIVE, 290 | Timeout: 1, 291 | }, 292 | expected: []byte{0, 1}, 293 | }, 294 | } 295 | 296 | for _, tc := range cases { 297 | t.Run(tc.name, func(t *testing.T) { 298 | b, err := tc.edns.pack() 299 | if err != nil { 300 | t.Error("expected no error") 301 | } 302 | 303 | if tc.expected == nil && b != nil { 304 | t.Errorf("invalid result, expected nil") 305 | } 306 | 307 | res := bytes.Compare(b, tc.expected) 308 | if res != 0 { 309 | t.Errorf("invalid result, expected: %v, actual: %v", tc.expected, b) 310 | } 311 | }) 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package dns_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net" 8 | 9 | "github.com/miekg/dns" 10 | ) 11 | 12 | // Retrieve the MX records for miek.nl. 13 | func ExampleMX() { 14 | config, _ := dns.ClientConfigFromFile("/etc/resolv.conf") 15 | c := new(dns.Client) 16 | m := new(dns.Msg) 17 | m.SetQuestion("miek.nl.", dns.TypeMX) 18 | m.RecursionDesired = true 19 | r, _, err := c.Exchange(m, net.JoinHostPort(config.Servers[0], config.Port)) 20 | if err != nil { 21 | return 22 | } 23 | if r.Rcode != dns.RcodeSuccess { 24 | return 25 | } 26 | for _, a := range r.Answer { 27 | if mx, ok := a.(*dns.MX); ok { 28 | fmt.Printf("%s\n", mx.String()) 29 | } 30 | } 31 | } 32 | 33 | // Retrieve the DNSKEY records of a zone and convert them 34 | // to DS records for SHA1, SHA256 and SHA384. 35 | func ExampleDS() { 36 | config, _ := dns.ClientConfigFromFile("/etc/resolv.conf") 37 | c := new(dns.Client) 38 | m := new(dns.Msg) 39 | zone := "miek.nl" 40 | m.SetQuestion(dns.Fqdn(zone), dns.TypeDNSKEY) 41 | m.SetEdns0(4096, true) 42 | r, _, err := c.Exchange(m, net.JoinHostPort(config.Servers[0], config.Port)) 43 | if err != nil { 44 | return 45 | } 46 | if r.Rcode != dns.RcodeSuccess { 47 | return 48 | } 49 | for _, k := range r.Answer { 50 | if key, ok := k.(*dns.DNSKEY); ok { 51 | for _, alg := range []uint8{dns.SHA1, dns.SHA256, dns.SHA384} { 52 | fmt.Printf("%s; %d\n", key.ToDS(alg).String(), key.Flags) 53 | } 54 | } 55 | } 56 | } 57 | 58 | const TypeAPAIR = 0x0F99 59 | 60 | type APAIR struct { 61 | addr [2]net.IP 62 | } 63 | 64 | func NewAPAIR() dns.PrivateRdata { return new(APAIR) } 65 | 66 | func (rd *APAIR) String() string { return rd.addr[0].String() + " " + rd.addr[1].String() } 67 | 68 | func (rd *APAIR) Parse(txt []string) error { 69 | if len(txt) != 2 { 70 | return errors.New("two addresses required for APAIR") 71 | } 72 | for i, s := range txt { 73 | ip := net.ParseIP(s) 74 | if ip == nil { 75 | return errors.New("invalid IP in APAIR text representation") 76 | } 77 | rd.addr[i] = ip 78 | } 79 | return nil 80 | } 81 | 82 | func (rd *APAIR) Pack(buf []byte) (int, error) { 83 | b := append([]byte(rd.addr[0]), []byte(rd.addr[1])...) 84 | n := copy(buf, b) 85 | if n != len(b) { 86 | return n, dns.ErrBuf 87 | } 88 | return n, nil 89 | } 90 | 91 | func (rd *APAIR) Unpack(buf []byte) (int, error) { 92 | ln := net.IPv4len * 2 93 | if len(buf) != ln { 94 | return 0, errors.New("invalid length of APAIR rdata") 95 | } 96 | cp := make([]byte, ln) 97 | copy(cp, buf) // clone bytes to use them in IPs 98 | 99 | rd.addr[0] = net.IP(cp[:3]) 100 | rd.addr[1] = net.IP(cp[4:]) 101 | 102 | return len(buf), nil 103 | } 104 | 105 | func (rd *APAIR) Copy(dest dns.PrivateRdata) error { 106 | cp := make([]byte, rd.Len()) 107 | _, err := rd.Pack(cp) 108 | if err != nil { 109 | return err 110 | } 111 | 112 | d := dest.(*APAIR) 113 | d.addr[0] = net.IP(cp[:3]) 114 | d.addr[1] = net.IP(cp[4:]) 115 | return nil 116 | } 117 | 118 | func (rd *APAIR) Len() int { 119 | return net.IPv4len * 2 120 | } 121 | 122 | func ExamplePrivateHandle() { 123 | dns.PrivateHandle("APAIR", TypeAPAIR, NewAPAIR) 124 | defer dns.PrivateHandleRemove(TypeAPAIR) 125 | var oldId = dns.Id 126 | dns.Id = func() uint16 { return 3 } 127 | defer func() { dns.Id = oldId }() 128 | 129 | rr, err := dns.NewRR("miek.nl. APAIR (1.2.3.4 1.2.3.5)") 130 | if err != nil { 131 | log.Fatal("could not parse APAIR record: ", err) 132 | } 133 | fmt.Println(rr) // see first line of Output below 134 | 135 | m := new(dns.Msg) 136 | m.SetQuestion("miek.nl.", TypeAPAIR) 137 | m.Answer = append(m.Answer, rr) 138 | 139 | fmt.Println(m) 140 | // Output: miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5 141 | // ;; opcode: QUERY, status: NOERROR, id: 3 142 | // ;; flags: rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0 143 | // 144 | // ;; QUESTION SECTION: 145 | // ;miek.nl. IN APAIR 146 | // 147 | // ;; ANSWER SECTION: 148 | // miek.nl. 3600 IN APAIR 1.2.3.4 1.2.3.5 149 | } 150 | -------------------------------------------------------------------------------- /format.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "net" 5 | "reflect" 6 | "strconv" 7 | ) 8 | 9 | // NumField returns the number of rdata fields r has. 10 | func NumField(r RR) int { 11 | return reflect.ValueOf(r).Elem().NumField() - 1 // Remove RR_Header 12 | } 13 | 14 | // Field returns the rdata field i as a string. Fields are indexed starting from 1. 15 | // RR types that holds slice data, for instance the NSEC type bitmap will return a single 16 | // string where the types are concatenated using a space. 17 | // Accessing non existing fields will cause a panic. 18 | func Field(r RR, i int) string { 19 | if i == 0 { 20 | return "" 21 | } 22 | d := reflect.ValueOf(r).Elem().Field(i) 23 | switch d.Kind() { 24 | case reflect.String: 25 | return d.String() 26 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 27 | return strconv.FormatInt(d.Int(), 10) 28 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 29 | return strconv.FormatUint(d.Uint(), 10) 30 | case reflect.Slice: 31 | switch reflect.ValueOf(r).Elem().Type().Field(i).Tag { 32 | case `dns:"a"`: 33 | // TODO(miek): Hmm store this as 16 bytes 34 | if d.Len() < net.IPv4len { 35 | return "" 36 | } 37 | if d.Len() < net.IPv6len { 38 | return net.IPv4(byte(d.Index(0).Uint()), 39 | byte(d.Index(1).Uint()), 40 | byte(d.Index(2).Uint()), 41 | byte(d.Index(3).Uint())).String() 42 | } 43 | return net.IPv4(byte(d.Index(12).Uint()), 44 | byte(d.Index(13).Uint()), 45 | byte(d.Index(14).Uint()), 46 | byte(d.Index(15).Uint())).String() 47 | case `dns:"aaaa"`: 48 | if d.Len() < net.IPv6len { 49 | return "" 50 | } 51 | return net.IP{ 52 | byte(d.Index(0).Uint()), 53 | byte(d.Index(1).Uint()), 54 | byte(d.Index(2).Uint()), 55 | byte(d.Index(3).Uint()), 56 | byte(d.Index(4).Uint()), 57 | byte(d.Index(5).Uint()), 58 | byte(d.Index(6).Uint()), 59 | byte(d.Index(7).Uint()), 60 | byte(d.Index(8).Uint()), 61 | byte(d.Index(9).Uint()), 62 | byte(d.Index(10).Uint()), 63 | byte(d.Index(11).Uint()), 64 | byte(d.Index(12).Uint()), 65 | byte(d.Index(13).Uint()), 66 | byte(d.Index(14).Uint()), 67 | byte(d.Index(15).Uint()), 68 | }.String() 69 | case `dns:"nsec"`: 70 | if d.Len() == 0 { 71 | return "" 72 | } 73 | s := Type(d.Index(0).Uint()).String() 74 | for i := 1; i < d.Len(); i++ { 75 | s += " " + Type(d.Index(i).Uint()).String() 76 | } 77 | return s 78 | default: 79 | // if it does not have a tag its a string slice 80 | fallthrough 81 | case `dns:"txt"`: 82 | if d.Len() == 0 { 83 | return "" 84 | } 85 | s := d.Index(0).String() 86 | for i := 1; i < d.Len(); i++ { 87 | s += " " + d.Index(i).String() 88 | } 89 | return s 90 | } 91 | } 92 | return "" 93 | } 94 | -------------------------------------------------------------------------------- /format_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestFieldEmptyAOrAAAAData(t *testing.T) { 8 | res := Field(new(A), 1) 9 | if res != "" { 10 | t.Errorf("expected empty string but got %v", res) 11 | } 12 | res = Field(new(AAAA), 1) 13 | if res != "" { 14 | t.Errorf("expected empty string but got %v", res) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /fuzz.go: -------------------------------------------------------------------------------- 1 | //go:build fuzz 2 | // +build fuzz 3 | 4 | package dns 5 | 6 | import "strings" 7 | 8 | func Fuzz(data []byte) int { 9 | msg := new(Msg) 10 | 11 | if err := msg.Unpack(data); err != nil { 12 | return 0 13 | } 14 | if _, err := msg.Pack(); err != nil { 15 | return 0 16 | } 17 | 18 | return 1 19 | } 20 | 21 | func FuzzNewRR(data []byte) int { 22 | str := string(data) 23 | // Do not fuzz lines that include the $INCLUDE keyword and hint the fuzzer 24 | // at avoiding them. 25 | // See GH#1025 for context. 26 | if strings.Contains(strings.ToUpper(str), "$INCLUDE") { 27 | return -1 28 | } 29 | if _, err := NewRR(str); err != nil { 30 | return 0 31 | } 32 | return 1 33 | } 34 | -------------------------------------------------------------------------------- /fuzz_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | ) 7 | 8 | // TestPackDataOpt tests generated using fuzz.go and with a message pack 9 | // containing the following bytes: 10 | // "0000\x00\x00000000\x00\x00/00000" + 11 | // "0\x00\v\x00#\b00000000\x00\x00)000" + 12 | // "000\x00\x1c00\x00\x0000\x00\x01000\x00\x00\x00\b" + 13 | // "\x00\v\x00\x02\x0000000000" 14 | // That bytes sequence created the overflow error. 15 | func TestPackDataOpt(t *testing.T) { 16 | type args struct { 17 | option []EDNS0 18 | msg []byte 19 | off int 20 | } 21 | tests := []struct { 22 | name string 23 | args args 24 | want int 25 | wantErr bool 26 | wantErrMsg string 27 | }{ 28 | { 29 | name: "overflow", 30 | args: args{ 31 | option: []EDNS0{ 32 | &EDNS0_LOCAL{Code: 0x3030, Data: []uint8{}}, 33 | &EDNS0_LOCAL{Code: 0x3030, Data: []uint8{0x30}}, 34 | &EDNS0_LOCAL{Code: 0x3030, Data: []uint8{}}, 35 | &EDNS0_SUBNET{ 36 | Code: 0x0, Family: 0x2, 37 | SourceNetmask: 0x0, SourceScope: 0x30, 38 | Address: net.IP{0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}}, 39 | }, 40 | msg: []byte{ 41 | 0x30, 0x30, 0x30, 0x30, 0x00, 0x00, 0x00, 0x2, 42 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2f, 0x30, 43 | 0x30, 0x30, 0x30, 0x30, 0x30, 0x00, 0x0b, 0x00, 44 | 0x23, 0x08, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 45 | 0x30, 0x30, 0x00, 0x00, 0x29, 0x30, 0x30, 0x30, 46 | 0x30, 0x30, 0x30, 0x00, 0x00, 0x30, 0x30, 0x00, 47 | 0x00, 0x30, 0x30, 0x00, 0x01, 0x30, 0x00, 0x00, 48 | 0x00, 49 | }, 50 | off: 54, 51 | }, 52 | wantErr: true, 53 | wantErrMsg: "dns: overflow packing opt", 54 | want: 57, 55 | }, 56 | } 57 | for _, tt := range tests { 58 | t.Run(tt.name, func(t *testing.T) { 59 | got, err := packDataOpt(tt.args.option, tt.args.msg, tt.args.off) 60 | if (err != nil) != tt.wantErr { 61 | t.Errorf("packDataOpt() error = %v, wantErr %v", err, tt.wantErr) 62 | return 63 | } 64 | if err != nil && tt.wantErrMsg != err.Error() { 65 | t.Errorf("packDataOpt() error msg = %v, wantErrMsg %v", err.Error(), tt.wantErrMsg) 66 | return 67 | } 68 | if got != tt.want { 69 | t.Errorf("packDataOpt() = %v, want %v", got, tt.want) 70 | } 71 | }) 72 | } 73 | } 74 | 75 | // TestCrashNSEC tests generated using fuzz.go and with a message pack 76 | // containing the following bytes: 77 | // "0000\x00\x00000000\x00\x00/00000" + 78 | // "0\x00\v\x00#\b00000\x00\x00\x00\x00\x00\x1a000" + 79 | // "000\x00\x00\x00\x00\x1a000000\x00\x00\x00\x00\x1a0" + 80 | // "00000\x00\v00\a0000000\x00" 81 | // That byte sequence, when Unpack() and subsequent Pack() created a 82 | // panic: runtime error: slice bounds out of range 83 | // which was attributed to the fact that NSEC RR length computation was different (and smaller) 84 | // then when within packDataNsec. 85 | func TestCrashNSEC(t *testing.T) { 86 | compression := make(map[string]struct{}) 87 | nsec := &NSEC{ 88 | Hdr: RR_Header{ 89 | Name: ".", 90 | Rrtype: 0x2f, 91 | Class: 0x3030, 92 | Ttl: 0x30303030, 93 | Rdlength: 0xb, 94 | }, 95 | NextDomain: ".", 96 | TypeBitMap: []uint16{ 97 | 0x2302, 0x2303, 0x230a, 0x230b, 98 | 0x2312, 0x2313, 0x231a, 0x231b, 99 | 0x2322, 0x2323, 100 | }, 101 | } 102 | expectedLength := 19 103 | l := nsec.len(0, compression) 104 | if l != expectedLength { 105 | t.Fatalf("expected length of %d, got %d", expectedLength, l) 106 | } 107 | } 108 | 109 | // TestCrashNSEC3 tests generated using fuzz.go and with a message pack 110 | // containing the following bytes: 111 | // "0000\x00\x00000000\x00\x00200000" + 112 | // "0\x00\v0000\x00\x00#\x0300\x00\x00\x00\x1a000" + 113 | // "000\x00\v00\x0200\x00\x03000\x00" 114 | // That byte sequence, when Unpack() and subsequent Pack() created a 115 | // panic: runtime error: slice bounds out of range 116 | // which was attributed to the fact that NSEC3 RR length computation was 117 | // different (and smaller) then within NSEC3.pack (which relies on 118 | // packDataNsec). 119 | func TestCrashNSEC3(t *testing.T) { 120 | compression := make(map[string]struct{}) 121 | nsec3 := &NSEC3{ 122 | Hdr: RR_Header{ 123 | Name: ".", 124 | Rrtype: 0x32, 125 | Class: 0x3030, 126 | Ttl: 0x30303030, 127 | Rdlength: 0xb, 128 | }, 129 | Hash: 0x30, 130 | Flags: 0x30, 131 | Iterations: 0x3030, 132 | SaltLength: 0x0, 133 | Salt: "", 134 | HashLength: 0x0, 135 | NextDomain: ".", 136 | TypeBitMap: []uint16{ 137 | 0x2302, 0x2303, 0x230a, 0x230b, 138 | }, 139 | } 140 | expectedLength := 24 141 | l := nsec3.len(0, compression) 142 | if l != expectedLength { 143 | t.Fatalf("expected length of %d, got %d", expectedLength, l) 144 | } 145 | } 146 | 147 | // TestNewRRCommentLengthCrasherString test inputs to NewRR that generated crashes. 148 | func TestNewRRCommentLengthCrasherString(t *testing.T) { 149 | tests := []struct { 150 | name string 151 | in string 152 | err string 153 | }{ 154 | 155 | { 156 | "HINFO1", " HINFO ;;;;;;;;;;;;;" + 157 | ";;;;;;;;\x00\x19;;;;;;;;;;" + 158 | ";\u007f;;;;;;;;;;;;;;;;;;" + 159 | ";;}mP_Qq_3sJ_1_84X_5" + 160 | "45iW_3K4p8J8_v9_LT3_" + 161 | "6_0l_3D4VT3xq6N_3K__" + 162 | "_U_xX2m;;;;;;(;;;;;;" + 163 | ";;;;;;;;;;;;;;;\x1d;;;;" + 164 | ";;;;;;-0x804dBDe8ba " + 165 | "\t \t\tr HINFO \" \t\t\tve" + 166 | "k1xH11e__P6_dk1_51bo" + 167 | "g8gJK1V_O_v84_Bw4_1_" + 168 | "72jQ3_0J3V_S5iYn4h5X" + 169 | "R_2n___51J nN_ \t\tm " + 170 | "aa_XO4_5\t \t\t \t\tg6b" + 171 | "p_KI_1_YWc_K8c2b___A" + 172 | "e_Y1m__4Y_R_avy6t08x" + 173 | "b5Cp9_7uS_yLa\t\t\t d " + 174 | "EKe1Q83vS___ a \t\t " + 175 | "\tmP_Qq_3sJ_1_84X_545" + 176 | "iW_3K4p8J8_v9_LT3_6_" + 177 | "0l_3D4VT3xq6N_3K___U" + 178 | "_xX2\"\" \t \t_fL Ogl5" + 179 | "_09i_9__3O7C__QMAG2U" + 180 | "35IO8RRU6aJ9_6_57_6_" + 181 | "b05BMoX5I__4833_____" + 182 | "yfD_2_OPs__sqzM_pqQi" + 183 | "_\t\t \tN__GuY4_Trath_0" + 184 | "yy___cAK_a__0J0q5 L_" + 185 | "p63Fzdva_Lb_29V7_R__" + 186 | "Go_H2_8m_4__FJM5B_Y5" + 187 | "Slw_ghp_55l_X2_Pnt6Y" + 188 | "_Wd_hM7jRZ_\t\t \tm \t" + 189 | " \t\ta md rK \x00 7_\"sr " + 190 | "- sg o -0x804dBDe8b" + 191 | "a \t \t\tN_W6J3PBS_W__C" + 192 | "yJu__k6F_jY0INI_LC27" + 193 | "7x14b_1b___Y8f_K_3y_" + 194 | "0055yaP_LKu_72g_T_32" + 195 | "iBk1Zm_o 9i1P44_S0_" + 196 | "_4AXUpo2__H55tL_g78_" + 197 | "8V_8l0yg6bp_KI_1_YWc" + 198 | "_K8c2b \t \tmaa_XO4_5" + 199 | "rg6bp_KI_1_YWc_K8c2b" + 200 | " _C20w i_4 \t\t u_k d" + 201 | " rKsg09099 \"\"2335779" + 202 | "05047986112651e025 \t" + 203 | " \t\tN_W6J3PBS_W__CyJu" + 204 | "__k6F_jY0INI_LC277x1" + 205 | "4b_1b___Y8f_K_3y_005" + 206 | "5yaP_LKu_72g_T_32iBk" + 207 | "1Zm_o 9i1P44_S0__4A" + 208 | "XUpo2__H55tL_g78_8V_" + 209 | "8l0y_9K9_C__6af__wj_" + 210 | "UbSYy_ge29S_s_Qe259q" + 211 | "_kGod \t\t\t\t :0xb1AF1F" + 212 | "b71D2ACeaB3FEce2ssg " + 213 | "o dr-0x804dBDe8ba \t " + 214 | "\t\t$ Y5 _BzOc6S_Lk0K" + 215 | "y43j1TzV__9367tbX56_" + 216 | "6B3__q6_v8_4_0_t_2q_" + 217 | "nJ2gV3j9_tkOrx_H__a}" + 218 | "mT 0g6bp_KI_1_YWc_K8" + 219 | "c2b\t_ a\t \t54KM8f9_63" + 220 | "zJ2Q_c1_C_Zf4ICF4m0q" + 221 | "_RVm_3Zh4vr7yI_H2 a" + 222 | " m 0yq__TiqA_FQBv_SS" + 223 | "_Hm_8T8__M8F2_53TTo_" + 224 | "k_o2__u_W6Vr__524q9l" + 225 | "9CQsC_kOU___g_94 \"" + 226 | " ~a_j_16_6iUSu_96V1W" + 227 | "5r01j____gn157__8_LO" + 228 | "0y_08Jr6OR__WF8__JK_" + 229 | "N_wx_k_CGB_SjJ9R74i_" + 230 | "7_1t_6 m NULLNULLNUL" + 231 | "L \t \t\t\t drK\t\x00 7_\"\" 5" + 232 | "_5_y732S43__D_8U9FX2" + 233 | "27_k\t\tg6bp_KI_1_YWc_" + 234 | "K8c2b_J_wx8yw1CMw27j" + 235 | "___f_a8uw_ Er9gB_L2 " + 236 | "\t\t \t\t\tm aa_XO4_5 Y_" + 237 | " I_T7762_zlMi_n8_FjH" + 238 | "vy62p__M4S_8__r092af" + 239 | "P_T_vhp6__SA_jVF13c5" + 240 | "2__8J48K__S4YcjoY91X" + 241 | "_iNf06 am aa_XO4_5\t" + 242 | " d _ am_SYY4G__2h4QL" + 243 | "iUIDd \t\t \tXXp__KFjR" + 244 | "V__JU3o\"\" d \t_Iks_ " + 245 | "aa_XO4_5= 0 { 25 | if i+1 == len(token) { 26 | return zp.setParseError("bad step in $GENERATE range", l) 27 | } 28 | 29 | s, err := strconv.ParseInt(token[i+1:], 10, 64) 30 | if err != nil || s <= 0 { 31 | return zp.setParseError("bad step in $GENERATE range", l) 32 | } 33 | 34 | step = s 35 | token = token[:i] 36 | } 37 | 38 | startStr, endStr, ok := strings.Cut(token, "-") 39 | if !ok { 40 | return zp.setParseError("bad start-stop in $GENERATE range", l) 41 | } 42 | 43 | start, err := strconv.ParseInt(startStr, 10, 64) 44 | if err != nil { 45 | return zp.setParseError("bad start in $GENERATE range", l) 46 | } 47 | 48 | end, err := strconv.ParseInt(endStr, 10, 64) 49 | if err != nil { 50 | return zp.setParseError("bad stop in $GENERATE range", l) 51 | } 52 | if end < 0 || start < 0 || end < start || (end-start)/step > 65535 { 53 | return zp.setParseError("bad range in $GENERATE range", l) 54 | } 55 | 56 | // _BLANK 57 | l, ok = zp.c.Next() 58 | if !ok || l.value != zBlank { 59 | return zp.setParseError("garbage after $GENERATE range", l) 60 | } 61 | 62 | // Create a complete new string, which we then parse again. 63 | var s string 64 | for l, ok := zp.c.Next(); ok; l, ok = zp.c.Next() { 65 | if l.err { 66 | return zp.setParseError("bad data in $GENERATE directive", l) 67 | } 68 | if l.value == zNewline { 69 | break 70 | } 71 | 72 | s += l.token 73 | } 74 | 75 | r := &generateReader{ 76 | s: s, 77 | 78 | cur: start, 79 | start: start, 80 | end: end, 81 | step: step, 82 | 83 | file: zp.file, 84 | lex: &l, 85 | } 86 | zp.sub = NewZoneParser(r, zp.origin, zp.file) 87 | zp.sub.includeDepth, zp.sub.includeAllowed = zp.includeDepth, zp.includeAllowed 88 | zp.sub.generateDisallowed = true 89 | zp.sub.SetDefaultTTL(defaultTtl) 90 | return zp.subNext() 91 | } 92 | 93 | type generateReader struct { 94 | s string 95 | si int 96 | 97 | cur int64 98 | start int64 99 | end int64 100 | step int64 101 | 102 | mod bytes.Buffer 103 | 104 | escape bool 105 | 106 | eof bool 107 | 108 | file string 109 | lex *lex 110 | } 111 | 112 | func (r *generateReader) parseError(msg string, end int) *ParseError { 113 | r.eof = true // Make errors sticky. 114 | 115 | l := *r.lex 116 | l.token = r.s[r.si-1 : end] 117 | l.column += r.si // l.column starts one zBLANK before r.s 118 | 119 | return &ParseError{file: r.file, err: msg, lex: l} 120 | } 121 | 122 | func (r *generateReader) Read(p []byte) (int, error) { 123 | // NewZLexer, through NewZoneParser, should use ReadByte and 124 | // not end up here. 125 | 126 | panic("not implemented") 127 | } 128 | 129 | func (r *generateReader) ReadByte() (byte, error) { 130 | if r.eof { 131 | return 0, io.EOF 132 | } 133 | if r.mod.Len() > 0 { 134 | return r.mod.ReadByte() 135 | } 136 | 137 | if r.si >= len(r.s) { 138 | r.si = 0 139 | r.cur += r.step 140 | 141 | r.eof = r.cur > r.end || r.cur < 0 142 | return '\n', nil 143 | } 144 | 145 | si := r.si 146 | r.si++ 147 | 148 | switch r.s[si] { 149 | case '\\': 150 | if r.escape { 151 | r.escape = false 152 | return '\\', nil 153 | } 154 | 155 | r.escape = true 156 | return r.ReadByte() 157 | case '$': 158 | if r.escape { 159 | r.escape = false 160 | return '$', nil 161 | } 162 | 163 | mod := "%d" 164 | 165 | if si >= len(r.s)-1 { 166 | // End of the string 167 | fmt.Fprintf(&r.mod, mod, r.cur) 168 | return r.mod.ReadByte() 169 | } 170 | 171 | if r.s[si+1] == '$' { 172 | r.si++ 173 | return '$', nil 174 | } 175 | 176 | var offset int64 177 | 178 | // Search for { and } 179 | if r.s[si+1] == '{' { 180 | // Modifier block 181 | sep := strings.Index(r.s[si+2:], "}") 182 | if sep < 0 { 183 | return 0, r.parseError("bad modifier in $GENERATE", len(r.s)) 184 | } 185 | 186 | var errMsg string 187 | mod, offset, errMsg = modToPrintf(r.s[si+2 : si+2+sep]) 188 | if errMsg != "" { 189 | return 0, r.parseError(errMsg, si+3+sep) 190 | } 191 | if r.start+offset < 0 || r.end+offset > 1<<31-1 { 192 | return 0, r.parseError("bad offset in $GENERATE", si+3+sep) 193 | } 194 | 195 | r.si += 2 + sep // Jump to it 196 | } 197 | 198 | fmt.Fprintf(&r.mod, mod, r.cur+offset) 199 | return r.mod.ReadByte() 200 | default: 201 | if r.escape { // Pretty useless here 202 | r.escape = false 203 | return r.ReadByte() 204 | } 205 | 206 | return r.s[si], nil 207 | } 208 | } 209 | 210 | // Convert a $GENERATE modifier 0,0,d to something Printf can deal with. 211 | func modToPrintf(s string) (string, int64, string) { 212 | // Modifier is { offset [ ,width [ ,base ] ] } - provide default 213 | // values for optional width and type, if necessary. 214 | offStr, s, ok0 := strings.Cut(s, ",") 215 | widthStr, s, ok1 := strings.Cut(s, ",") 216 | base, _, ok2 := strings.Cut(s, ",") 217 | if !ok0 { 218 | widthStr = "0" 219 | } 220 | if !ok1 { 221 | base = "d" 222 | } 223 | if ok2 { 224 | return "", 0, "bad modifier in $GENERATE" 225 | } 226 | 227 | switch base { 228 | case "o", "d", "x", "X": 229 | default: 230 | return "", 0, "bad base in $GENERATE" 231 | } 232 | 233 | offset, err := strconv.ParseInt(offStr, 10, 64) 234 | if err != nil { 235 | return "", 0, "bad offset in $GENERATE" 236 | } 237 | 238 | width, err := strconv.ParseUint(widthStr, 10, 8) 239 | if err != nil { 240 | return "", 0, "bad width in $GENERATE" 241 | } 242 | 243 | if width == 0 { 244 | return "%" + base, offset, "" 245 | } 246 | 247 | return "%0" + widthStr + base, offset, "" 248 | } 249 | -------------------------------------------------------------------------------- /generate_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestGenerateRangeGuard(t *testing.T) { 12 | tmpdir := t.TempDir() 13 | 14 | for i := 0; i <= 1; i++ { 15 | path := filepath.Join(tmpdir, fmt.Sprintf("%04d.conf", i)) 16 | data := []byte(fmt.Sprintf("dhcp-%04d A 10.0.0.%d", i, i)) 17 | 18 | if err := os.WriteFile(path, data, 0o644); err != nil { 19 | t.Fatalf("could not create tmpfile for test: %v", err) 20 | } 21 | } 22 | 23 | tests := [...]struct { 24 | zone string 25 | fail bool 26 | }{ 27 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 28 | $GENERATE 0-1 dhcp-${0,4,d} A 10.0.0.$ 29 | `, false}, 30 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 31 | $GENERATE 0-1 dhcp-${0,0,x} A 10.0.0.$ 32 | `, false}, 33 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 34 | $GENERATE 128-129 dhcp-${-128,4,d} A 10.0.0.$ 35 | `, false}, 36 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 37 | $GENERATE 128-129 dhcp-${-129,4,d} A 10.0.0.$ 38 | `, true}, 39 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 40 | $GENERATE 0-2 dhcp-${2147483647,4,d} A 10.0.0.$ 41 | `, true}, 42 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 43 | $GENERATE 0-1 dhcp-${2147483646,4,d} A 10.0.0.$ 44 | `, false}, 45 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 46 | $GENERATE 0-1/step dhcp-${0,4,d} A 10.0.0.$ 47 | `, true}, 48 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 49 | $GENERATE 0-1/ dhcp-${0,4,d} A 10.0.0.$ 50 | `, true}, 51 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 52 | $GENERATE 0-10/2 dhcp-${0,4,d} A 10.0.0.$ 53 | `, false}, 54 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 55 | $GENERATE 0-1/0 dhcp-${0,4,d} A 10.0.0.$ 56 | `, true}, 57 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 58 | $GENERATE 0-1 $$INCLUDE ` + tmpdir + string(filepath.Separator) + `${0,4,d}.conf 59 | `, false}, 60 | {`@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 61 | $GENERATE 0-1 dhcp-${0,4,d} A 10.0.0.$ 62 | $GENERATE 0-2 dhcp-${0,4,d} A 10.1.0.$ 63 | `, false}, 64 | } 65 | 66 | for i := range tests { 67 | z := NewZoneParser(strings.NewReader(tests[i].zone), "test.", "test") 68 | z.SetIncludeAllowed(true) 69 | 70 | for _, ok := z.Next(); ok; _, ok = z.Next() { 71 | } 72 | 73 | err := z.Err() 74 | if err != nil && !tests[i].fail { 75 | t.Errorf("expected \n\n%s\nto be parsed, but got %v", tests[i].zone, err) 76 | } else if err == nil && tests[i].fail { 77 | t.Errorf("expected \n\n%s\nto fail, but got no error", tests[i].zone) 78 | } 79 | } 80 | } 81 | 82 | func TestGenerateIncludeDepth(t *testing.T) { 83 | tmpfile, err := os.CreateTemp("", "dns") 84 | if err != nil { 85 | t.Fatalf("could not create tmpfile for test: %v", err) 86 | } 87 | defer os.Remove(tmpfile.Name()) 88 | 89 | zone := `@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 90 | $GENERATE 0-1 $$INCLUDE ` + tmpfile.Name() + ` 91 | ` 92 | if _, err := tmpfile.WriteString(zone); err != nil { 93 | t.Fatalf("could not write to tmpfile for test: %v", err) 94 | } 95 | if err := tmpfile.Close(); err != nil { 96 | t.Fatalf("could not close tmpfile for test: %v", err) 97 | } 98 | 99 | zp := NewZoneParser(strings.NewReader(zone), ".", tmpfile.Name()) 100 | zp.SetIncludeAllowed(true) 101 | 102 | for _, ok := zp.Next(); ok; _, ok = zp.Next() { 103 | } 104 | 105 | const expected = "too deeply nested $INCLUDE" 106 | if err := zp.Err(); err == nil || !strings.Contains(err.Error(), expected) { 107 | t.Errorf("expected error to include %q, got %v", expected, err) 108 | } 109 | } 110 | 111 | func TestGenerateIncludeDisallowed(t *testing.T) { 112 | const zone = `@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 113 | $GENERATE 0-1 $$INCLUDE test.conf 114 | ` 115 | zp := NewZoneParser(strings.NewReader(zone), ".", "") 116 | 117 | for _, ok := zp.Next(); ok; _, ok = zp.Next() { 118 | } 119 | 120 | const expected = "$INCLUDE directive not allowed" 121 | if err := zp.Err(); err == nil || !strings.Contains(err.Error(), expected) { 122 | t.Errorf("expected error to include %q, got %v", expected, err) 123 | } 124 | } 125 | 126 | func TestGenerateSurfacesErrors(t *testing.T) { 127 | const zone = `@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 128 | $GENERATE 0-1 dhcp-${0,4,dd} A 10.0.0.$ 129 | ` 130 | zp := NewZoneParser(strings.NewReader(zone), ".", "test") 131 | 132 | for _, ok := zp.Next(); ok; _, ok = zp.Next() { 133 | } 134 | 135 | const expected = `test: dns: bad base in $GENERATE: "${0,4,dd}" at line: 2:20` 136 | if err := zp.Err(); err == nil || err.Error() != expected { 137 | t.Errorf("expected specific error, wanted %q, got %v", expected, err) 138 | } 139 | } 140 | 141 | func TestGenerateSurfacesLexerErrors(t *testing.T) { 142 | const zone = `@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 143 | $GENERATE 0-1 dhcp-${0,4,d} A 10.0.0.$ ) 144 | ` 145 | zp := NewZoneParser(strings.NewReader(zone), ".", "test") 146 | 147 | for _, ok := zp.Next(); ok; _, ok = zp.Next() { 148 | } 149 | 150 | const expected = `test: dns: bad data in $GENERATE directive: "extra closing brace" at line: 2:40` 151 | if err := zp.Err(); err == nil || err.Error() != expected { 152 | t.Errorf("expected specific error, wanted %q, got %v", expected, err) 153 | } 154 | } 155 | 156 | func TestGenerateModToPrintf(t *testing.T) { 157 | tests := []struct { 158 | mod string 159 | wantFmt string 160 | wantOffset int64 161 | wantErr bool 162 | }{ 163 | {"0,0,d", "%d", 0, false}, 164 | {"0,0", "%d", 0, false}, 165 | {"0", "%d", 0, false}, 166 | {"3,2,d", "%02d", 3, false}, 167 | {"3,2", "%02d", 3, false}, 168 | {"3", "%d", 3, false}, 169 | {"0,0,o", "%o", 0, false}, 170 | {"0,0,x", "%x", 0, false}, 171 | {"0,0,X", "%X", 0, false}, 172 | {"0,0,z", "", 0, true}, 173 | {"0,0,0,d", "", 0, true}, 174 | {"-100,0,d", "%d", -100, false}, 175 | } 176 | for _, test := range tests { 177 | gotFmt, gotOffset, errMsg := modToPrintf(test.mod) 178 | switch { 179 | case errMsg != "" && !test.wantErr: 180 | t.Errorf("modToPrintf(%q) - expected empty-error, but got %v", test.mod, errMsg) 181 | case errMsg == "" && test.wantErr: 182 | t.Errorf("modToPrintf(%q) - expected error, but got empty-error", test.mod) 183 | case gotFmt != test.wantFmt: 184 | t.Errorf("modToPrintf(%q) - expected format %q, but got %q", test.mod, test.wantFmt, gotFmt) 185 | case gotOffset != test.wantOffset: 186 | t.Errorf("modToPrintf(%q) - expected offset %d, but got %d", test.mod, test.wantOffset, gotOffset) 187 | } 188 | } 189 | } 190 | 191 | func BenchmarkGenerate(b *testing.B) { 192 | const zone = `@ IN SOA ns.test. hostmaster.test. ( 1 8h 2h 7d 1d ) 193 | $GENERATE 32-158 dhcp-${-32,4,d} A 10.0.0.$ 194 | ` 195 | 196 | for n := 0; n < b.N; n++ { 197 | zp := NewZoneParser(strings.NewReader(zone), ".", "") 198 | 199 | for _, ok := zp.Next(); ok; _, ok = zp.Next() { 200 | } 201 | 202 | if err := zp.Err(); err != nil { 203 | b.Fatal(err) 204 | } 205 | } 206 | } 207 | 208 | func TestCrasherString(t *testing.T) { 209 | tests := []struct { 210 | in string 211 | err string 212 | }{ 213 | {"$GENERATE 0-300103\"$$GENERATE 2-2", "bad range in $GENERATE"}, 214 | {"$GENERATE 0-5414137360", "bad range in $GENERATE"}, 215 | {"$GENERATE 11522-3668518066406258", "bad range in $GENERATE"}, 216 | {"$GENERATE 0-200\"(;00000000000000\n$$GENERATE 0-0", "dns: garbage after $GENERATE range: \"\\\"\" at line: 1:16"}, 217 | {"$GENERATE 6-2048 $$GENERATE 6-036160 $$$$ORIGIN \\$", `dns: nested $GENERATE directive not allowed: "6-036160" at line: 1:19`}, 218 | } 219 | for _, tc := range tests { 220 | t.Run(tc.in, func(t *testing.T) { 221 | _, err := NewRR(tc.in) 222 | if err == nil { 223 | t.Errorf("Expecting error for crasher line %s", tc.in) 224 | } 225 | if !strings.Contains(err.Error(), tc.err) { 226 | t.Errorf("Expecting error %s, got %s", tc.err, err.Error()) 227 | } 228 | }) 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/miekg/dns 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | golang.org/x/net v0.39.0 9 | golang.org/x/sync v0.13.0 10 | golang.org/x/sys v0.32.0 11 | golang.org/x/tools v0.32.0 12 | ) 13 | 14 | require golang.org/x/mod v0.24.0 // indirect 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 2 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 3 | golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= 4 | golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 5 | golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= 6 | golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= 7 | golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= 8 | golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= 9 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 10 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 11 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 12 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 13 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= 14 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 15 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 16 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 17 | golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= 18 | golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 19 | golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= 20 | golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= 21 | golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU= 22 | golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s= 23 | -------------------------------------------------------------------------------- /hash.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "bytes" 5 | "crypto" 6 | "hash" 7 | ) 8 | 9 | // identityHash will not hash, it only buffers the data written into it and returns it as-is. 10 | type identityHash struct { 11 | b *bytes.Buffer 12 | } 13 | 14 | // Implement the hash.Hash interface. 15 | 16 | func (i identityHash) Write(b []byte) (int, error) { return i.b.Write(b) } 17 | func (i identityHash) Size() int { return i.b.Len() } 18 | func (i identityHash) BlockSize() int { return 1024 } 19 | func (i identityHash) Reset() { i.b.Reset() } 20 | func (i identityHash) Sum(b []byte) []byte { return append(b, i.b.Bytes()...) } 21 | 22 | func hashFromAlgorithm(alg uint8) (hash.Hash, crypto.Hash, error) { 23 | hashnumber, ok := AlgorithmToHash[alg] 24 | if !ok { 25 | return nil, 0, ErrAlg 26 | } 27 | if hashnumber == 0 { 28 | return identityHash{b: &bytes.Buffer{}}, hashnumber, nil 29 | } 30 | return hashnumber.New(), hashnumber, nil 31 | } 32 | -------------------------------------------------------------------------------- /issue_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // Tests that solve that an specific issue. 4 | 5 | import ( 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestNSEC3MissingSalt(t *testing.T) { 11 | rr := testRR("ji6neoaepv8b5o6k4ev33abha8ht9fgc.example. NSEC3 1 1 12 aabbccdd K8UDEMVP1J2F7EG6JEBPS17VP3N8I58H") 12 | m := new(Msg) 13 | m.Answer = []RR{rr} 14 | mb, err := m.Pack() 15 | if err != nil { 16 | t.Fatalf("expected to pack message. err: %s", err) 17 | } 18 | if err := m.Unpack(mb); err != nil { 19 | t.Fatalf("expected to unpack message. missing salt? err: %s", err) 20 | } 21 | in := rr.(*NSEC3).Salt 22 | out := m.Answer[0].(*NSEC3).Salt 23 | if in != out { 24 | t.Fatalf("expected salts to match. packed: `%s`. returned: `%s`", in, out) 25 | } 26 | } 27 | 28 | func TestNSEC3MixedNextDomain(t *testing.T) { 29 | rr := testRR("ji6neoaepv8b5o6k4ev33abha8ht9fgc.example. NSEC3 1 1 12 - k8udemvp1j2f7eg6jebps17vp3n8i58h") 30 | m := new(Msg) 31 | m.Answer = []RR{rr} 32 | mb, err := m.Pack() 33 | if err != nil { 34 | t.Fatalf("expected to pack message. err: %s", err) 35 | } 36 | if err := m.Unpack(mb); err != nil { 37 | t.Fatalf("expected to unpack message. err: %s", err) 38 | } 39 | in := strings.ToUpper(rr.(*NSEC3).NextDomain) 40 | out := m.Answer[0].(*NSEC3).NextDomain 41 | if in != out { 42 | t.Fatalf("expected round trip to produce NextDomain `%s`, instead `%s`", in, out) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /labels.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // Holds a bunch of helper functions for dealing with labels. 4 | 5 | // SplitDomainName splits a name string into it's labels. 6 | // www.miek.nl. returns []string{"www", "miek", "nl"} 7 | // .www.miek.nl. returns []string{"", "www", "miek", "nl"}, 8 | // The root label (.) returns nil. Note that using 9 | // strings.Split(s) will work in most cases, but does not handle 10 | // escaped dots (\.) for instance. 11 | // s must be a syntactically valid domain name, see IsDomainName. 12 | func SplitDomainName(s string) (labels []string) { 13 | if s == "" { 14 | return nil 15 | } 16 | fqdnEnd := 0 // offset of the final '.' or the length of the name 17 | idx := Split(s) 18 | begin := 0 19 | if IsFqdn(s) { 20 | fqdnEnd = len(s) - 1 21 | } else { 22 | fqdnEnd = len(s) 23 | } 24 | 25 | switch len(idx) { 26 | case 0: 27 | return nil 28 | case 1: 29 | // no-op 30 | default: 31 | for _, end := range idx[1:] { 32 | labels = append(labels, s[begin:end-1]) 33 | begin = end 34 | } 35 | } 36 | 37 | return append(labels, s[begin:fqdnEnd]) 38 | } 39 | 40 | // CompareDomainName compares the names s1 and s2 and 41 | // returns how many labels they have in common starting from the *right*. 42 | // The comparison stops at the first inequality. The names are downcased 43 | // before the comparison. 44 | // 45 | // www.miek.nl. and miek.nl. have two labels in common: miek and nl 46 | // www.miek.nl. and www.bla.nl. have one label in common: nl 47 | // 48 | // s1 and s2 must be syntactically valid domain names. 49 | func CompareDomainName(s1, s2 string) (n int) { 50 | // the first check: root label 51 | if s1 == "." || s2 == "." { 52 | return 0 53 | } 54 | 55 | l1 := Split(s1) 56 | l2 := Split(s2) 57 | 58 | j1 := len(l1) - 1 // end 59 | i1 := len(l1) - 2 // start 60 | j2 := len(l2) - 1 61 | i2 := len(l2) - 2 62 | // the second check can be done here: last/only label 63 | // before we fall through into the for-loop below 64 | if equal(s1[l1[j1]:], s2[l2[j2]:]) { 65 | n++ 66 | } else { 67 | return 68 | } 69 | for { 70 | if i1 < 0 || i2 < 0 { 71 | break 72 | } 73 | if equal(s1[l1[i1]:l1[j1]], s2[l2[i2]:l2[j2]]) { 74 | n++ 75 | } else { 76 | break 77 | } 78 | j1-- 79 | i1-- 80 | j2-- 81 | i2-- 82 | } 83 | return 84 | } 85 | 86 | // CountLabel counts the number of labels in the string s. 87 | // s must be a syntactically valid domain name. 88 | func CountLabel(s string) (labels int) { 89 | if s == "." { 90 | return 91 | } 92 | off := 0 93 | end := false 94 | for { 95 | off, end = NextLabel(s, off) 96 | labels++ 97 | if end { 98 | return 99 | } 100 | } 101 | } 102 | 103 | // Split splits a name s into its label indexes. 104 | // www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}. 105 | // The root name (.) returns nil. Also see SplitDomainName. 106 | // s must be a syntactically valid domain name. 107 | func Split(s string) []int { 108 | if s == "." { 109 | return nil 110 | } 111 | idx := make([]int, 1, 3) 112 | off := 0 113 | end := false 114 | 115 | for { 116 | off, end = NextLabel(s, off) 117 | if end { 118 | return idx 119 | } 120 | idx = append(idx, off) 121 | } 122 | } 123 | 124 | // NextLabel returns the index of the start of the next label in the 125 | // string s starting at offset. A negative offset will cause a panic. 126 | // The bool end is true when the end of the string has been reached. 127 | // Also see PrevLabel. 128 | func NextLabel(s string, offset int) (i int, end bool) { 129 | if s == "" { 130 | return 0, true 131 | } 132 | for i = offset; i < len(s)-1; i++ { 133 | if s[i] != '.' { 134 | continue 135 | } 136 | j := i - 1 137 | for j >= 0 && s[j] == '\\' { 138 | j-- 139 | } 140 | 141 | if (j-i)%2 == 0 { 142 | continue 143 | } 144 | 145 | return i + 1, false 146 | } 147 | return i + 1, true 148 | } 149 | 150 | // PrevLabel returns the index of the label when starting from the right and 151 | // jumping n labels to the left. 152 | // The bool start is true when the start of the string has been overshot. 153 | // Also see NextLabel. 154 | func PrevLabel(s string, n int) (i int, start bool) { 155 | if s == "" { 156 | return 0, true 157 | } 158 | if n == 0 { 159 | return len(s), false 160 | } 161 | 162 | l := len(s) - 1 163 | if s[l] == '.' { 164 | l-- 165 | } 166 | 167 | for ; l >= 0 && n > 0; l-- { 168 | if s[l] != '.' { 169 | continue 170 | } 171 | j := l - 1 172 | for j >= 0 && s[j] == '\\' { 173 | j-- 174 | } 175 | 176 | if (j-l)%2 == 0 { 177 | continue 178 | } 179 | 180 | n-- 181 | if n == 0 { 182 | return l + 1, false 183 | } 184 | } 185 | 186 | return 0, n > 1 187 | } 188 | 189 | // equal compares a and b while ignoring case. It returns true when equal otherwise false. 190 | func equal(a, b string) bool { 191 | // might be lifted into API function. 192 | la := len(a) 193 | lb := len(b) 194 | if la != lb { 195 | return false 196 | } 197 | 198 | for i := la - 1; i >= 0; i-- { 199 | ai := a[i] 200 | bi := b[i] 201 | if ai >= 'A' && ai <= 'Z' { 202 | ai |= 'a' - 'A' 203 | } 204 | if bi >= 'A' && bi <= 'Z' { 205 | bi |= 'a' - 'A' 206 | } 207 | if ai != bi { 208 | return false 209 | } 210 | } 211 | return true 212 | } 213 | -------------------------------------------------------------------------------- /leak_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "runtime" 7 | "sort" 8 | "strings" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // copied from net/http/main_test.go 14 | 15 | func interestingGoroutines() (gs []string) { 16 | buf := make([]byte, 2<<20) 17 | buf = buf[:runtime.Stack(buf, true)] 18 | for _, g := range strings.Split(string(buf), "\n\n") { 19 | sl := strings.SplitN(g, "\n", 2) 20 | if len(sl) != 2 { 21 | continue 22 | } 23 | stack := strings.TrimSpace(sl[1]) 24 | if stack == "" || 25 | strings.Contains(stack, "testing.(*M).before.func1") || 26 | strings.Contains(stack, "os/signal.signal_recv") || 27 | strings.Contains(stack, "created by net.startServer") || 28 | strings.Contains(stack, "created by testing.RunTests") || 29 | strings.Contains(stack, "closeWriteAndWait") || 30 | strings.Contains(stack, "testing.Main(") || 31 | strings.Contains(stack, "testing.(*T).Run(") || 32 | // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) 33 | strings.Contains(stack, "runtime.goexit") || 34 | strings.Contains(stack, "created by runtime.gc") || 35 | strings.Contains(stack, "dns.interestingGoroutines") || 36 | strings.Contains(stack, "runtime.MHeap_Scavenger") { 37 | continue 38 | } 39 | gs = append(gs, stack) 40 | } 41 | sort.Strings(gs) 42 | return 43 | } 44 | 45 | func goroutineLeaked() error { 46 | if testing.Short() { 47 | // Don't worry about goroutine leaks in -short mode or in 48 | // benchmark mode. Too distracting when there are false positives. 49 | return nil 50 | } 51 | 52 | var stackCount map[string]int 53 | for i := 0; i < 5; i++ { 54 | n := 0 55 | stackCount = make(map[string]int) 56 | gs := interestingGoroutines() 57 | for _, g := range gs { 58 | stackCount[g]++ 59 | n++ 60 | } 61 | if n == 0 { 62 | return nil 63 | } 64 | // Wait for goroutines to schedule and die off: 65 | time.Sleep(100 * time.Millisecond) 66 | } 67 | for stack, count := range stackCount { 68 | fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack) 69 | } 70 | return fmt.Errorf("too many goroutines running after dns test(s)") 71 | } 72 | -------------------------------------------------------------------------------- /listen_no_socket_options.go: -------------------------------------------------------------------------------- 1 | //go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd 2 | // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd 3 | 4 | package dns 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | ) 10 | 11 | const ( 12 | supportsReusePort = false 13 | supportsReuseAddr = false 14 | ) 15 | 16 | func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) { 17 | if reuseport || reuseaddr { 18 | // TODO(tmthrgd): return an error? 19 | } 20 | 21 | return net.Listen(network, addr) 22 | } 23 | 24 | func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) { 25 | if reuseport || reuseaddr { 26 | // TODO(tmthrgd): return an error? 27 | } 28 | 29 | return net.ListenPacket(network, addr) 30 | } 31 | 32 | // this is just for test compatibility 33 | func checkReuseport(fd uintptr) (bool, error) { 34 | return false, fmt.Errorf("not supported") 35 | } 36 | 37 | // this is just for test compatibility 38 | func checkReuseaddr(fd uintptr) (bool, error) { 39 | return false, fmt.Errorf("not supported") 40 | } 41 | -------------------------------------------------------------------------------- /listen_socket_options.go: -------------------------------------------------------------------------------- 1 | //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd 2 | // +build aix darwin dragonfly freebsd linux netbsd openbsd 3 | 4 | package dns 5 | 6 | import ( 7 | "context" 8 | "net" 9 | "syscall" 10 | 11 | "golang.org/x/sys/unix" 12 | ) 13 | 14 | const supportsReusePort = true 15 | 16 | func reuseportControl(network, address string, c syscall.RawConn) error { 17 | var opErr error 18 | err := c.Control(func(fd uintptr) { 19 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) 20 | }) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | return opErr 26 | } 27 | 28 | const supportsReuseAddr = true 29 | 30 | func reuseaddrControl(network, address string, c syscall.RawConn) error { 31 | var opErr error 32 | err := c.Control(func(fd uintptr) { 33 | opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) 34 | }) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | return opErr 40 | } 41 | 42 | func reuseaddrandportControl(network, address string, c syscall.RawConn) error { 43 | err := reuseaddrControl(network, address, c) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | return reuseportControl(network, address, c) 49 | } 50 | 51 | // this is just for test compatibility 52 | func checkReuseport(fd uintptr) (bool, error) { 53 | v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT) 54 | if err != nil { 55 | return false, err 56 | } 57 | 58 | return v == 1, nil 59 | } 60 | 61 | // this is just for test compatibility 62 | func checkReuseaddr(fd uintptr) (bool, error) { 63 | v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR) 64 | if err != nil { 65 | return false, err 66 | } 67 | 68 | return v == 1, nil 69 | } 70 | 71 | func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) { 72 | var lc net.ListenConfig 73 | switch { 74 | case reuseaddr && reuseport: 75 | lc.Control = reuseaddrandportControl 76 | case reuseport: 77 | lc.Control = reuseportControl 78 | case reuseaddr: 79 | lc.Control = reuseaddrControl 80 | } 81 | 82 | return lc.Listen(context.Background(), network, addr) 83 | } 84 | 85 | func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) { 86 | var lc net.ListenConfig 87 | switch { 88 | case reuseaddr && reuseport: 89 | lc.Control = reuseaddrandportControl 90 | case reuseport: 91 | lc.Control = reuseportControl 92 | case reuseaddr: 93 | lc.Control = reuseaddrControl 94 | } 95 | 96 | return lc.ListenPacket(context.Background(), network, addr) 97 | } 98 | -------------------------------------------------------------------------------- /msg_truncate.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // Truncate ensures the reply message will fit into the requested buffer 4 | // size by removing records that exceed the requested size. 5 | // 6 | // It will first check if the reply fits without compression and then with 7 | // compression. If it won't fit with compression, Truncate then walks the 8 | // record adding as many records as possible without exceeding the 9 | // requested buffer size. 10 | // 11 | // If the message fits within the requested size without compression, 12 | // Truncate will set the message's Compress attribute to false. It is 13 | // the caller's responsibility to set it back to true if they wish to 14 | // compress the payload regardless of size. 15 | // 16 | // The TC bit will be set if any records were excluded from the message. 17 | // If the TC bit is already set on the message it will be retained. 18 | // TC indicates that the client should retry over TCP. 19 | // 20 | // According to RFC 2181, the TC bit should only be set if not all of the 21 | // "required" RRs can be included in the response. Unfortunately, we have 22 | // no way of knowing which RRs are required so we set the TC bit if any RR 23 | // had to be omitted from the response. 24 | // 25 | // The appropriate buffer size can be retrieved from the requests OPT 26 | // record, if present, and is transport specific otherwise. dns.MinMsgSize 27 | // should be used for UDP requests without an OPT record, and 28 | // dns.MaxMsgSize for TCP requests without an OPT record. 29 | func (dns *Msg) Truncate(size int) { 30 | if dns.IsTsig() != nil { 31 | // To simplify this implementation, we don't perform 32 | // truncation on responses with a TSIG record. 33 | return 34 | } 35 | 36 | // RFC 6891 mandates that the payload size in an OPT record 37 | // less than 512 (MinMsgSize) bytes must be treated as equal to 512 bytes. 38 | // 39 | // For ease of use, we impose that restriction here. 40 | if size < MinMsgSize { 41 | size = MinMsgSize 42 | } 43 | 44 | l := msgLenWithCompressionMap(dns, nil) // uncompressed length 45 | if l <= size { 46 | // Don't waste effort compressing this message. 47 | dns.Compress = false 48 | return 49 | } 50 | 51 | dns.Compress = true 52 | 53 | edns0 := dns.popEdns0() 54 | if edns0 != nil { 55 | // Account for the OPT record that gets added at the end, 56 | // by subtracting that length from our budget. 57 | // 58 | // The EDNS(0) OPT record must have the root domain and 59 | // it's length is thus unaffected by compression. 60 | size -= Len(edns0) 61 | } 62 | 63 | compression := make(map[string]struct{}) 64 | 65 | l = headerSize 66 | for _, r := range dns.Question { 67 | l += r.len(l, compression) 68 | } 69 | 70 | var numAnswer int 71 | if l < size { 72 | l, numAnswer = truncateLoop(dns.Answer, size, l, compression) 73 | } 74 | 75 | var numNS int 76 | if l < size { 77 | l, numNS = truncateLoop(dns.Ns, size, l, compression) 78 | } 79 | 80 | var numExtra int 81 | if l < size { 82 | _, numExtra = truncateLoop(dns.Extra, size, l, compression) 83 | } 84 | 85 | // See the function documentation for when we set this. 86 | dns.Truncated = dns.Truncated || len(dns.Answer) > numAnswer || 87 | len(dns.Ns) > numNS || len(dns.Extra) > numExtra 88 | 89 | dns.Answer = dns.Answer[:numAnswer] 90 | dns.Ns = dns.Ns[:numNS] 91 | dns.Extra = dns.Extra[:numExtra] 92 | 93 | if edns0 != nil { 94 | // Add the OPT record back onto the additional section. 95 | dns.Extra = append(dns.Extra, edns0) 96 | } 97 | } 98 | 99 | func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) { 100 | for i, r := range rrs { 101 | if r == nil { 102 | continue 103 | } 104 | 105 | l += r.len(l, compression) 106 | if l > size { 107 | // Return size, rather than l prior to this record, 108 | // to prevent any further records being added. 109 | return size, i 110 | } 111 | if l == size { 112 | return l, i + 1 113 | } 114 | } 115 | 116 | return l, len(rrs) 117 | } 118 | -------------------------------------------------------------------------------- /msg_truncate_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestRequestTruncateAnswer(t *testing.T) { 9 | m := new(Msg) 10 | m.SetQuestion("large.example.com.", TypeSRV) 11 | 12 | reply := new(Msg) 13 | reply.SetReply(m) 14 | for i := 1; i < 200; i++ { 15 | reply.Answer = append(reply.Answer, testRR( 16 | fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) 17 | } 18 | 19 | reply.Truncate(MinMsgSize) 20 | if want, got := MinMsgSize, reply.Len(); want < got { 21 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 22 | } 23 | if !reply.Truncated { 24 | t.Errorf("truncated bit should be set") 25 | } 26 | } 27 | 28 | func TestRequestTruncateExtra(t *testing.T) { 29 | m := new(Msg) 30 | m.SetQuestion("large.example.com.", TypeSRV) 31 | 32 | reply := new(Msg) 33 | reply.SetReply(m) 34 | for i := 1; i < 200; i++ { 35 | reply.Extra = append(reply.Extra, testRR( 36 | fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) 37 | } 38 | 39 | reply.Truncate(MinMsgSize) 40 | if want, got := MinMsgSize, reply.Len(); want < got { 41 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 42 | } 43 | if !reply.Truncated { 44 | t.Errorf("truncated bit should be set") 45 | } 46 | } 47 | 48 | func TestRequestTruncateExtraEdns0(t *testing.T) { 49 | const size = 4096 50 | 51 | m := new(Msg) 52 | m.SetQuestion("large.example.com.", TypeSRV) 53 | m.SetEdns0(size, true) 54 | 55 | reply := new(Msg) 56 | reply.SetReply(m) 57 | for i := 1; i < 200; i++ { 58 | reply.Extra = append(reply.Extra, testRR( 59 | fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) 60 | } 61 | reply.SetEdns0(size, true) 62 | 63 | reply.Truncate(size) 64 | if want, got := size, reply.Len(); want < got { 65 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 66 | } 67 | if !reply.Truncated { 68 | t.Errorf("truncated bit should be set") 69 | } 70 | opt := reply.Extra[len(reply.Extra)-1] 71 | if opt.Header().Rrtype != TypeOPT { 72 | t.Errorf("expected last RR to be OPT") 73 | } 74 | } 75 | 76 | func TestRequestTruncateExtraRegression(t *testing.T) { 77 | const size = 2048 78 | 79 | m := new(Msg) 80 | m.SetQuestion("large.example.com.", TypeSRV) 81 | m.SetEdns0(size, true) 82 | 83 | reply := new(Msg) 84 | reply.SetReply(m) 85 | for i := 1; i < 33; i++ { 86 | reply.Answer = append(reply.Answer, testRR( 87 | fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) 88 | } 89 | for i := 1; i < 33; i++ { 90 | reply.Extra = append(reply.Extra, testRR( 91 | fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i))) 92 | } 93 | reply.SetEdns0(size, true) 94 | 95 | reply.Truncate(size) 96 | if want, got := size, reply.Len(); want < got { 97 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 98 | } 99 | if !reply.Truncated { 100 | t.Errorf("truncated bit should be set") 101 | } 102 | opt := reply.Extra[len(reply.Extra)-1] 103 | if opt.Header().Rrtype != TypeOPT { 104 | t.Errorf("expected last RR to be OPT") 105 | } 106 | } 107 | 108 | func TestTruncation(t *testing.T) { 109 | reply := new(Msg) 110 | 111 | for i := 0; i < 61; i++ { 112 | reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("http.service.tcp.srv.k8s.example.org. 5 IN SRV 0 0 80 10-144-230-%d.default.pod.k8s.example.org.", i))) 113 | } 114 | 115 | for i := 0; i < 5; i++ { 116 | reply.Extra = append(reply.Extra, testRR(fmt.Sprintf("ip-10-10-52-5%d.subdomain.example.org. 5 IN A 10.10.52.5%d", i, i))) 117 | } 118 | 119 | for i := 0; i < 5; i++ { 120 | reply.Ns = append(reply.Ns, testRR(fmt.Sprintf("srv.subdomain.example.org. 5 IN NS ip-10-10-33-6%d.subdomain.example.org.", i))) 121 | } 122 | 123 | for bufsize := 1024; bufsize <= 4096; bufsize += 12 { 124 | m := new(Msg) 125 | m.SetQuestion("http.service.tcp.srv.k8s.example.org.", TypeSRV) 126 | m.SetEdns0(uint16(bufsize), true) 127 | 128 | copy := reply.Copy() 129 | copy.SetReply(m) 130 | 131 | copy.Truncate(bufsize) 132 | if want, got := bufsize, copy.Len(); want < got { 133 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 134 | } 135 | } 136 | } 137 | 138 | func TestRequestTruncateAnswerExact(t *testing.T) { 139 | const size = 867 // Bit fiddly, but this hits the rl == size break clause in Truncate, 52 RRs should remain. 140 | 141 | m := new(Msg) 142 | m.SetQuestion("large.example.com.", TypeSRV) 143 | m.SetEdns0(size, false) 144 | 145 | reply := new(Msg) 146 | reply.SetReply(m) 147 | for i := 1; i < 200; i++ { 148 | reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i))) 149 | } 150 | 151 | reply.Truncate(size) 152 | if want, got := size, reply.Len(); want < got { 153 | t.Errorf("message length should be below %d bytes, got %d bytes", want, got) 154 | } 155 | if expected := 52; len(reply.Answer) != expected { 156 | t.Errorf("wrong number of answers; expected %d, got %d", expected, len(reply.Answer)) 157 | } 158 | } 159 | 160 | func BenchmarkMsgTruncate(b *testing.B) { 161 | const size = 2048 162 | 163 | m := new(Msg) 164 | m.SetQuestion("example.com.", TypeA) 165 | m.SetEdns0(size, true) 166 | 167 | reply := new(Msg) 168 | reply.SetReply(m) 169 | for i := 1; i < 33; i++ { 170 | reply.Answer = append(reply.Answer, testRR( 171 | fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) 172 | } 173 | for i := 1; i < 33; i++ { 174 | reply.Extra = append(reply.Extra, testRR( 175 | fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i))) 176 | } 177 | 178 | b.ResetTimer() 179 | 180 | for i := 0; i < b.N; i++ { 181 | b.StopTimer() 182 | copy := reply.Copy() 183 | b.StartTimer() 184 | 185 | copy.Truncate(size) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /nsecx.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/sha1" 5 | "encoding/hex" 6 | "strings" 7 | ) 8 | 9 | // HashName hashes a string (label) according to RFC 5155. It returns the hashed string in uppercase. 10 | func HashName(label string, ha uint8, iter uint16, salt string) string { 11 | if ha != SHA1 { 12 | return "" 13 | } 14 | 15 | wireSalt := make([]byte, hex.DecodedLen(len(salt))) 16 | n, err := packStringHex(salt, wireSalt, 0) 17 | if err != nil { 18 | return "" 19 | } 20 | wireSalt = wireSalt[:n] 21 | 22 | name := make([]byte, 255) 23 | off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false) 24 | if err != nil { 25 | return "" 26 | } 27 | name = name[:off] 28 | 29 | s := sha1.New() 30 | // k = 0 31 | s.Write(name) 32 | s.Write(wireSalt) 33 | nsec3 := s.Sum(nil) 34 | 35 | // k > 0 36 | for k := uint16(0); k < iter; k++ { 37 | s.Reset() 38 | s.Write(nsec3) 39 | s.Write(wireSalt) 40 | nsec3 = s.Sum(nsec3[:0]) 41 | } 42 | 43 | return toBase32(nsec3) 44 | } 45 | 46 | // Cover returns true if a name is covered by the NSEC3 record. 47 | func (rr *NSEC3) Cover(name string) bool { 48 | nameHash := HashName(name, rr.Hash, rr.Iterations, rr.Salt) 49 | owner := strings.ToUpper(rr.Hdr.Name) 50 | labelIndices := Split(owner) 51 | if len(labelIndices) < 2 { 52 | return false 53 | } 54 | ownerHash := owner[:labelIndices[1]-1] 55 | ownerZone := owner[labelIndices[1]:] 56 | if !IsSubDomain(ownerZone, strings.ToUpper(name)) { // name is outside owner zone 57 | return false 58 | } 59 | 60 | nextHash := rr.NextDomain 61 | 62 | // if empty interval found, try cover wildcard hashes so nameHash shouldn't match with ownerHash 63 | if ownerHash == nextHash && nameHash != ownerHash { // empty interval 64 | return true 65 | } 66 | if ownerHash > nextHash { // end of zone 67 | if nameHash > ownerHash { // covered since there is nothing after ownerHash 68 | return true 69 | } 70 | return nameHash < nextHash // if nameHash is before beginning of zone it is covered 71 | } 72 | if nameHash < ownerHash { // nameHash is before ownerHash, not covered 73 | return false 74 | } 75 | return nameHash < nextHash // if nameHash is before nextHash is it covered (between ownerHash and nextHash) 76 | } 77 | 78 | // Match returns true if a name matches the NSEC3 record 79 | func (rr *NSEC3) Match(name string) bool { 80 | nameHash := HashName(name, rr.Hash, rr.Iterations, rr.Salt) 81 | owner := strings.ToUpper(rr.Hdr.Name) 82 | labelIndices := Split(owner) 83 | if len(labelIndices) < 2 { 84 | return false 85 | } 86 | ownerHash := owner[:labelIndices[1]-1] 87 | ownerZone := owner[labelIndices[1]:] 88 | if !IsSubDomain(ownerZone, strings.ToUpper(name)) { // name is outside owner zone 89 | return false 90 | } 91 | if ownerHash == nameHash { 92 | return true 93 | } 94 | return false 95 | } 96 | -------------------------------------------------------------------------------- /nsecx_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | ) 7 | 8 | func TestPackNsec3(t *testing.T) { 9 | nsec3 := HashName("dnsex.nl.", SHA1, 0, "DEAD") 10 | if nsec3 != "ROCCJAE8BJJU7HN6T7NG3TNM8ACRS87J" { 11 | t.Error(nsec3) 12 | } 13 | 14 | nsec3 = HashName("a.b.c.example.org.", SHA1, 2, "DEAD") 15 | if nsec3 != "6LQ07OAHBTOOEU2R9ANI2AT70K5O0RCG" { 16 | t.Error(nsec3) 17 | } 18 | } 19 | 20 | func TestNsec3(t *testing.T) { 21 | nsec3 := testRR("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 SK4F38CQ0ATIEI8MH3RGD0P5I4II6QAN NS SOA TXT RRSIG DNSKEY NSEC3PARAM") 22 | if !nsec3.(*NSEC3).Match("nl.") { // name hash = sk4e8fj94u78smusb40o1n0oltbblu2r 23 | t.Fatal("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. should match sk4e8fj94u78smusb40o1n0oltbblu2r.nl.") 24 | } 25 | if !nsec3.(*NSEC3).Match("NL.") { // name hash = sk4e8fj94u78smusb40o1n0oltbblu2r 26 | t.Fatal("sk4e8fj94u78smusb40o1n0oltbblu2r.NL. should match sk4e8fj94u78smusb40o1n0oltbblu2r.nl.") 27 | } 28 | if nsec3.(*NSEC3).Match("com.") { // 29 | t.Fatal("com. is not in the zone nl.") 30 | } 31 | if nsec3.(*NSEC3).Match("test.nl.") { // name hash = gd0ptr5bnfpimpu2d3v6gd4n0bai7s0q 32 | t.Fatal("gd0ptr5bnfpimpu2d3v6gd4n0bai7s0q.nl. should not match sk4e8fj94u78smusb40o1n0oltbblu2r.nl.") 33 | } 34 | nsec3 = testRR("nl. IN NSEC3 1 1 5 F10E9F7EA83FC8F3 SK4F38CQ0ATIEI8MH3RGD0P5I4II6QAN NS SOA TXT RRSIG DNSKEY NSEC3PARAM") 35 | if nsec3.(*NSEC3).Match("nl.") { 36 | t.Fatal("sk4e8fj94u78smusb40o1n0oltbblu2r.nl. should not match a record without a owner hash") 37 | } 38 | 39 | for _, tc := range []struct { 40 | rr *NSEC3 41 | name string 42 | covers bool 43 | }{ 44 | // positive tests 45 | { // name hash between owner hash and next hash 46 | rr: &NSEC3{ 47 | Hdr: RR_Header{Name: "2N1TB3VAIRUOBL6RKDVII42N9TFMIALP.com."}, 48 | Hash: 1, 49 | Flags: 1, 50 | Iterations: 5, 51 | Salt: "F10E9F7EA83FC8F3", 52 | NextDomain: "PT3RON8N7PM3A0OE989IB84OOSADP7O8", 53 | }, 54 | name: "bsd.com.", 55 | covers: true, 56 | }, 57 | { // end of zone, name hash is after owner hash 58 | rr: &NSEC3{ 59 | Hdr: RR_Header{Name: "3v62ulr0nre83v0rja2vjgtlif9v6rab.com."}, 60 | Hash: 1, 61 | Flags: 1, 62 | Iterations: 5, 63 | Salt: "F10E9F7EA83FC8F3", 64 | NextDomain: "2N1TB3VAIRUOBL6RKDVII42N9TFMIALP", 65 | }, 66 | name: "csd.com.", 67 | covers: true, 68 | }, 69 | { // end of zone, name hash is before beginning of zone 70 | rr: &NSEC3{ 71 | Hdr: RR_Header{Name: "PT3RON8N7PM3A0OE989IB84OOSADP7O8.com."}, 72 | Hash: 1, 73 | Flags: 1, 74 | Iterations: 5, 75 | Salt: "F10E9F7EA83FC8F3", 76 | NextDomain: "3V62ULR0NRE83V0RJA2VJGTLIF9V6RAB", 77 | }, 78 | name: "asd.com.", 79 | covers: true, 80 | }, 81 | // negative tests 82 | { // too short owner name 83 | rr: &NSEC3{ 84 | Hdr: RR_Header{Name: "nl."}, 85 | Hash: 1, 86 | Flags: 1, 87 | Iterations: 5, 88 | Salt: "F10E9F7EA83FC8F3", 89 | NextDomain: "39P99DCGG0MDLARTCRMCF6OFLLUL7PR6", 90 | }, 91 | name: "asd.com.", 92 | covers: false, 93 | }, 94 | { // outside of zone 95 | rr: &NSEC3{ 96 | Hdr: RR_Header{Name: "39p91242oslggest5e6a7cci4iaeqvnk.nl."}, 97 | Hash: 1, 98 | Flags: 1, 99 | Iterations: 5, 100 | Salt: "F10E9F7EA83FC8F3", 101 | NextDomain: "39P99DCGG0MDLARTCRMCF6OFLLUL7PR6", 102 | }, 103 | name: "asd.com.", 104 | covers: false, 105 | }, 106 | { // empty interval 107 | rr: &NSEC3{ 108 | Hdr: RR_Header{Name: "2n1tb3vairuobl6rkdvii42n9tfmialp.com."}, 109 | Hash: 1, 110 | Flags: 1, 111 | Iterations: 5, 112 | Salt: "F10E9F7EA83FC8F3", 113 | NextDomain: "2N1TB3VAIRUOBL6RKDVII42N9TFMIALP", 114 | }, 115 | name: "asd.com.", 116 | covers: false, 117 | }, 118 | { // empty interval wildcard 119 | rr: &NSEC3{ 120 | Hdr: RR_Header{Name: "2n1tb3vairuobl6rkdvii42n9tfmialp.com."}, 121 | Hash: 1, 122 | Flags: 1, 123 | Iterations: 5, 124 | Salt: "F10E9F7EA83FC8F3", 125 | NextDomain: "2N1TB3VAIRUOBL6RKDVII42N9TFMIALP", 126 | }, 127 | name: "*.asd.com.", 128 | covers: true, 129 | }, 130 | { // name hash is before owner hash, not covered 131 | rr: &NSEC3{ 132 | Hdr: RR_Header{Name: "3V62ULR0NRE83V0RJA2VJGTLIF9V6RAB.com."}, 133 | Hash: 1, 134 | Flags: 1, 135 | Iterations: 5, 136 | Salt: "F10E9F7EA83FC8F3", 137 | NextDomain: "PT3RON8N7PM3A0OE989IB84OOSADP7O8", 138 | }, 139 | name: "asd.com.", 140 | covers: false, 141 | }, 142 | } { 143 | covers := tc.rr.Cover(tc.name) 144 | if tc.covers != covers { 145 | t.Fatalf("cover failed for %s: expected %t, got %t [record: %s]", tc.name, tc.covers, covers, tc.rr) 146 | } 147 | } 148 | } 149 | 150 | func TestNsec3EmptySalt(t *testing.T) { 151 | rr, _ := NewRR("CK0POJMG874LJREF7EFN8430QVIT8BSM.com. 86400 IN NSEC3 1 1 0 - CK0Q1GIN43N1ARRC9OSM6QPQR81H5M9A NS SOA RRSIG DNSKEY NSEC3PARAM") 152 | 153 | if !rr.(*NSEC3).Match("com.") { 154 | t.Fatalf("expected record to match com. label") 155 | } 156 | } 157 | 158 | func BenchmarkHashName(b *testing.B) { 159 | for _, iter := range []uint16{ 160 | 150, 2500, 5000, 10000, ^uint16(0), 161 | } { 162 | b.Run(strconv.Itoa(int(iter)), func(b *testing.B) { 163 | for n := 0; n < b.N; n++ { 164 | if HashName("some.example.org.", SHA1, iter, "deadbeef") == "" { 165 | b.Fatalf("HashName failed") 166 | } 167 | } 168 | }) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /privaterr.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "strings" 4 | 5 | // PrivateRdata is an interface used for implementing "Private Use" RR types, see 6 | // RFC 6895. This allows one to experiment with new RR types, without requesting an 7 | // official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove. 8 | type PrivateRdata interface { 9 | // String returns the text presentation of the Rdata of the Private RR. 10 | String() string 11 | // Parse parses the Rdata of the private RR. 12 | Parse([]string) error 13 | // Pack is used when packing a private RR into a buffer. 14 | Pack([]byte) (int, error) 15 | // Unpack is used when unpacking a private RR from a buffer. 16 | Unpack([]byte) (int, error) 17 | // Copy copies the Rdata into the PrivateRdata argument. 18 | Copy(PrivateRdata) error 19 | // Len returns the length in octets of the Rdata. 20 | Len() int 21 | } 22 | 23 | // PrivateRR represents an RR that uses a PrivateRdata user-defined type. 24 | // It mocks normal RRs and implements dns.RR interface. 25 | type PrivateRR struct { 26 | Hdr RR_Header 27 | Data PrivateRdata 28 | 29 | generator func() PrivateRdata // for copy 30 | } 31 | 32 | // Header return the RR header of r. 33 | func (r *PrivateRR) Header() *RR_Header { return &r.Hdr } 34 | 35 | func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() } 36 | 37 | // Private len and copy parts to satisfy RR interface. 38 | func (r *PrivateRR) len(off int, compression map[string]struct{}) int { 39 | l := r.Hdr.len(off, compression) 40 | l += r.Data.Len() 41 | return l 42 | } 43 | 44 | func (r *PrivateRR) copy() RR { 45 | // make new RR like this: 46 | rr := &PrivateRR{r.Hdr, r.generator(), r.generator} 47 | 48 | if err := r.Data.Copy(rr.Data); err != nil { 49 | panic("dns: got value that could not be used to copy Private rdata: " + err.Error()) 50 | } 51 | 52 | return rr 53 | } 54 | 55 | func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { 56 | n, err := r.Data.Pack(msg[off:]) 57 | if err != nil { 58 | return len(msg), err 59 | } 60 | off += n 61 | return off, nil 62 | } 63 | 64 | func (r *PrivateRR) unpack(msg []byte, off int) (int, error) { 65 | off1, err := r.Data.Unpack(msg[off:]) 66 | off += off1 67 | return off, err 68 | } 69 | 70 | func (r *PrivateRR) parse(c *zlexer, origin string) *ParseError { 71 | var l lex 72 | text := make([]string, 0, 2) // could be 0..N elements, median is probably 1 73 | Fetch: 74 | for { 75 | // TODO(miek): we could also be returning _QUOTE, this might or might not 76 | // be an issue (basically parsing TXT becomes hard) 77 | switch l, _ = c.Next(); l.value { 78 | case zNewline, zEOF: 79 | break Fetch 80 | case zString: 81 | text = append(text, l.token) 82 | } 83 | } 84 | 85 | err := r.Data.Parse(text) 86 | if err != nil { 87 | return &ParseError{wrappedErr: err, lex: l} 88 | } 89 | 90 | return nil 91 | } 92 | 93 | func (r *PrivateRR) isDuplicate(r2 RR) bool { return false } 94 | 95 | // PrivateHandle registers a private resource record type. It requires 96 | // string and numeric representation of private RR type and generator function as argument. 97 | func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) { 98 | rtypestr = strings.ToUpper(rtypestr) 99 | 100 | TypeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator(), generator} } 101 | TypeToString[rtype] = rtypestr 102 | StringToType[rtypestr] = rtype 103 | } 104 | 105 | // PrivateHandleRemove removes definitions required to support private RR type. 106 | func PrivateHandleRemove(rtype uint16) { 107 | rtypestr, ok := TypeToString[rtype] 108 | if ok { 109 | delete(TypeToRR, rtype) 110 | delete(TypeToString, rtype) 111 | delete(StringToType, rtypestr) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /privaterr_test.go: -------------------------------------------------------------------------------- 1 | package dns_test 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/miekg/dns" 8 | ) 9 | 10 | const TypeISBN uint16 = 0xFF00 11 | 12 | // A crazy new RR type :) 13 | type ISBN struct { 14 | x string // rdata with 10 or 13 numbers, dashes or spaces allowed 15 | } 16 | 17 | func NewISBN() dns.PrivateRdata { return &ISBN{""} } 18 | 19 | func (rd *ISBN) Len() int { return len([]byte(rd.x)) } 20 | func (rd *ISBN) String() string { return rd.x } 21 | 22 | func (rd *ISBN) Parse(txt []string) error { 23 | rd.x = strings.TrimSpace(strings.Join(txt, " ")) 24 | return nil 25 | } 26 | 27 | func (rd *ISBN) Pack(buf []byte) (int, error) { 28 | b := []byte(rd.x) 29 | n := copy(buf, b) 30 | if n != len(b) { 31 | return n, dns.ErrBuf 32 | } 33 | return n, nil 34 | } 35 | 36 | func (rd *ISBN) Unpack(buf []byte) (int, error) { 37 | rd.x = string(buf) 38 | return len(buf), nil 39 | } 40 | 41 | func (rd *ISBN) Copy(dest dns.PrivateRdata) error { 42 | isbn, ok := dest.(*ISBN) 43 | if !ok { 44 | return dns.ErrRdata 45 | } 46 | isbn.x = rd.x 47 | return nil 48 | } 49 | 50 | var testrecord = strings.Join([]string{"example.org.", "3600", "IN", "ISBN", "12-3 456789-0-123"}, "\t") 51 | 52 | func TestPrivateText(t *testing.T) { 53 | dns.PrivateHandle("ISBN", TypeISBN, NewISBN) 54 | defer dns.PrivateHandleRemove(TypeISBN) 55 | 56 | rr, err := dns.NewRR(testrecord) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | if rr.String() != testrecord { 61 | t.Errorf("record string representation did not match original %#v != %#v", rr.String(), testrecord) 62 | } 63 | } 64 | 65 | func TestPrivateByteSlice(t *testing.T) { 66 | dns.PrivateHandle("ISBN", TypeISBN, NewISBN) 67 | defer dns.PrivateHandleRemove(TypeISBN) 68 | 69 | rr, err := dns.NewRR(testrecord) 70 | if err != nil { 71 | t.Fatal(err) 72 | } 73 | 74 | buf := make([]byte, 100) 75 | off, err := dns.PackRR(rr, buf, 0, nil, false) 76 | if err != nil { 77 | t.Errorf("got error packing ISBN: %v", err) 78 | } 79 | 80 | custrr := rr.(*dns.PrivateRR) 81 | if ln := custrr.Data.Len() + len(custrr.Header().Name) + 11; ln != off { 82 | t.Errorf("offset is not matching to length of Private RR: %d!=%d", off, ln) 83 | } 84 | 85 | rr1, off1, err := dns.UnpackRR(buf[:off], 0) 86 | if err != nil { 87 | t.Errorf("got error unpacking ISBN: %v", err) 88 | return 89 | } 90 | 91 | if off1 != off { 92 | t.Errorf("offset after unpacking differs: %d != %d", off1, off) 93 | } 94 | 95 | if rr1.String() != testrecord { 96 | t.Errorf("record string representation did not match original %#v != %#v", rr1.String(), testrecord) 97 | } 98 | } 99 | 100 | const TypeVERSION uint16 = 0xFF01 101 | 102 | type VERSION struct { 103 | x string 104 | } 105 | 106 | func NewVersion() dns.PrivateRdata { return &VERSION{""} } 107 | 108 | func (rd *VERSION) String() string { return rd.x } 109 | func (rd *VERSION) Parse(txt []string) error { 110 | rd.x = strings.TrimSpace(strings.Join(txt, " ")) 111 | return nil 112 | } 113 | 114 | func (rd *VERSION) Pack(buf []byte) (int, error) { 115 | b := []byte(rd.x) 116 | n := copy(buf, b) 117 | if n != len(b) { 118 | return n, dns.ErrBuf 119 | } 120 | return n, nil 121 | } 122 | 123 | func (rd *VERSION) Unpack(buf []byte) (int, error) { 124 | rd.x = string(buf) 125 | return len(buf), nil 126 | } 127 | 128 | func (rd *VERSION) Copy(dest dns.PrivateRdata) error { 129 | isbn, ok := dest.(*VERSION) 130 | if !ok { 131 | return dns.ErrRdata 132 | } 133 | isbn.x = rd.x 134 | return nil 135 | } 136 | 137 | func (rd *VERSION) Len() int { 138 | return len([]byte(rd.x)) 139 | } 140 | 141 | var smallzone = `$ORIGIN example.org. 142 | @ 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. ( 143 | 2014091518 7200 3600 1209600 3600 144 | ) 145 | A 1.2.3.4 146 | ok ISBN 1231-92110-12 147 | go VERSION ( 148 | 1.3.1 ; comment 149 | ) 150 | www ISBN 1231-92110-16 151 | * CNAME @ 152 | ` 153 | 154 | func TestPrivateZoneParser(t *testing.T) { 155 | dns.PrivateHandle("ISBN", TypeISBN, NewISBN) 156 | dns.PrivateHandle("VERSION", TypeVERSION, NewVersion) 157 | defer dns.PrivateHandleRemove(TypeISBN) 158 | defer dns.PrivateHandleRemove(TypeVERSION) 159 | 160 | r := strings.NewReader(smallzone) 161 | z := dns.NewZoneParser(r, ".", "") 162 | 163 | for _, ok := z.Next(); ok; _, ok = z.Next() { 164 | } 165 | if err := z.Err(); err != nil { 166 | t.Fatal(err) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /reverse.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // StringToType is the reverse of TypeToString, needed for string parsing. 4 | var StringToType = reverseInt16(TypeToString) 5 | 6 | // StringToClass is the reverse of ClassToString, needed for string parsing. 7 | var StringToClass = reverseInt16(ClassToString) 8 | 9 | // StringToOpcode is a map of opcodes to strings. 10 | var StringToOpcode = reverseInt(OpcodeToString) 11 | 12 | // StringToRcode is a map of rcodes to strings. 13 | var StringToRcode = reverseInt(RcodeToString) 14 | 15 | func init() { 16 | // Preserve previous NOTIMP typo, see github.com/miekg/dns/issues/733. 17 | StringToRcode["NOTIMPL"] = RcodeNotImplemented 18 | } 19 | 20 | // StringToAlgorithm is the reverse of AlgorithmToString. 21 | var StringToAlgorithm = reverseInt8(AlgorithmToString) 22 | 23 | // StringToHash is a map of names to hash IDs. 24 | var StringToHash = reverseInt8(HashToString) 25 | 26 | // StringToCertType is the reverse of CertTypeToString. 27 | var StringToCertType = reverseInt16(CertTypeToString) 28 | 29 | // StringToStatefulType is the reverse of StatefulTypeToString. 30 | var StringToStatefulType = reverseInt16(StatefulTypeToString) 31 | 32 | // Reverse a map 33 | func reverseInt8(m map[uint8]string) map[string]uint8 { 34 | n := make(map[string]uint8, len(m)) 35 | for u, s := range m { 36 | n[s] = u 37 | } 38 | return n 39 | } 40 | 41 | func reverseInt16(m map[uint16]string) map[string]uint16 { 42 | n := make(map[string]uint16, len(m)) 43 | for u, s := range m { 44 | n[s] = u 45 | } 46 | return n 47 | } 48 | 49 | func reverseInt(m map[int]string) map[string]int { 50 | n := make(map[string]int, len(m)) 51 | for u, s := range m { 52 | n[s] = u 53 | } 54 | return n 55 | } 56 | -------------------------------------------------------------------------------- /rr_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // testRR is a helper that wraps a call to NewRR and panics if the error is non-nil. 4 | func testRR(s string) RR { 5 | r, err := NewRR(s) 6 | if err != nil { 7 | panic(err) 8 | } 9 | 10 | return r 11 | } 12 | -------------------------------------------------------------------------------- /sanitize.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // Dedup removes identical RRs from rrs. It preserves the original ordering. 4 | // The lowest TTL of any duplicates is used in the remaining one. Dedup modifies 5 | // rrs. 6 | // m is used to store the RRs temporary. If it is nil a new map will be allocated. 7 | func Dedup(rrs []RR, m map[string]RR) []RR { 8 | 9 | if m == nil { 10 | m = make(map[string]RR) 11 | } 12 | // Save the keys, so we don't have to call normalizedString twice. 13 | keys := make([]*string, 0, len(rrs)) 14 | 15 | for _, r := range rrs { 16 | key := normalizedString(r) 17 | keys = append(keys, &key) 18 | if mr, ok := m[key]; ok { 19 | // Shortest TTL wins. 20 | rh, mrh := r.Header(), mr.Header() 21 | if mrh.Ttl > rh.Ttl { 22 | mrh.Ttl = rh.Ttl 23 | } 24 | continue 25 | } 26 | 27 | m[key] = r 28 | } 29 | // If the length of the result map equals the amount of RRs we got, 30 | // it means they were all different. We can then just return the original rrset. 31 | if len(m) == len(rrs) { 32 | return rrs 33 | } 34 | 35 | j := 0 36 | for i, r := range rrs { 37 | // If keys[i] lives in the map, we should copy and remove it. 38 | if _, ok := m[*keys[i]]; ok { 39 | delete(m, *keys[i]) 40 | rrs[j] = r 41 | j++ 42 | } 43 | 44 | if len(m) == 0 { 45 | break 46 | } 47 | } 48 | 49 | return rrs[:j] 50 | } 51 | 52 | // normalizedString returns a normalized string from r. The TTL 53 | // is removed and the domain name is lowercased. We go from this: 54 | // DomainNameTTLCLASSTYPERDATA to: 55 | // lowercasenameCLASSTYPE... 56 | func normalizedString(r RR) string { 57 | // A string Go DNS makes has: domainnameTTL... 58 | b := []byte(r.String()) 59 | 60 | // find the first non-escaped tab, then another, so we capture where the TTL lives. 61 | esc := false 62 | ttlStart, ttlEnd := 0, 0 63 | for i := 0; i < len(b) && ttlEnd == 0; i++ { 64 | switch { 65 | case b[i] == '\\': 66 | esc = !esc 67 | case b[i] == '\t' && !esc: 68 | if ttlStart == 0 { 69 | ttlStart = i 70 | continue 71 | } 72 | if ttlEnd == 0 { 73 | ttlEnd = i 74 | } 75 | case b[i] >= 'A' && b[i] <= 'Z' && !esc: 76 | b[i] += 32 77 | default: 78 | esc = false 79 | } 80 | } 81 | 82 | // remove TTL. 83 | copy(b[ttlStart:], b[ttlEnd:]) 84 | cut := ttlEnd - ttlStart 85 | return string(b[:len(b)-cut]) 86 | } 87 | -------------------------------------------------------------------------------- /sanitize_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "testing" 4 | 5 | func TestDedup(t *testing.T) { 6 | testcases := map[[3]RR][]string{ 7 | [...]RR{ 8 | testRR("mIek.nl. IN A 127.0.0.1"), 9 | testRR("mieK.nl. IN A 127.0.0.1"), 10 | testRR("miek.Nl. IN A 127.0.0.1"), 11 | }: {"mIek.nl.\t3600\tIN\tA\t127.0.0.1"}, 12 | [...]RR{ 13 | testRR("miEk.nl. 2000 IN A 127.0.0.1"), 14 | testRR("mieK.Nl. 1000 IN A 127.0.0.1"), 15 | testRR("Miek.nL. 500 IN A 127.0.0.1"), 16 | }: {"miEk.nl.\t500\tIN\tA\t127.0.0.1"}, 17 | [...]RR{ 18 | testRR("miek.nl. IN A 127.0.0.1"), 19 | testRR("miek.nl. CH A 127.0.0.1"), 20 | testRR("miek.nl. IN A 127.0.0.1"), 21 | }: {"miek.nl.\t3600\tIN\tA\t127.0.0.1", 22 | "miek.nl.\t3600\tCH\tA\t127.0.0.1", 23 | }, 24 | [...]RR{ 25 | testRR("miek.nl. CH A 127.0.0.1"), 26 | testRR("miek.nl. IN A 127.0.0.1"), 27 | testRR("miek.de. IN A 127.0.0.1"), 28 | }: {"miek.nl.\t3600\tCH\tA\t127.0.0.1", 29 | "miek.nl.\t3600\tIN\tA\t127.0.0.1", 30 | "miek.de.\t3600\tIN\tA\t127.0.0.1", 31 | }, 32 | [...]RR{ 33 | testRR("miek.de. IN A 127.0.0.1"), 34 | testRR("miek.nl. 200 IN A 127.0.0.1"), 35 | testRR("miek.nl. 300 IN A 127.0.0.1"), 36 | }: {"miek.de.\t3600\tIN\tA\t127.0.0.1", 37 | "miek.nl.\t200\tIN\tA\t127.0.0.1", 38 | }, 39 | } 40 | 41 | for rr, expected := range testcases { 42 | out := Dedup([]RR{rr[0], rr[1], rr[2]}, nil) 43 | for i, o := range out { 44 | if o.String() != expected[i] { 45 | t.Fatalf("expected %v, got %v", expected[i], o.String()) 46 | } 47 | } 48 | } 49 | } 50 | 51 | func BenchmarkDedup(b *testing.B) { 52 | rrs := []RR{ 53 | testRR("miEk.nl. 2000 IN A 127.0.0.1"), 54 | testRR("mieK.Nl. 1000 IN A 127.0.0.1"), 55 | testRR("Miek.nL. 500 IN A 127.0.0.1"), 56 | } 57 | m := make(map[string]RR) 58 | for i := 0; i < b.N; i++ { 59 | Dedup(rrs, m) 60 | } 61 | } 62 | 63 | func TestNormalizedString(t *testing.T) { 64 | tests := map[RR]string{ 65 | testRR("mIEk.Nl. 3600 IN A 127.0.0.1"): "miek.nl.\tIN\tA\t127.0.0.1", 66 | testRR("m\\ iek.nL. 3600 IN A 127.0.0.1"): "m\\ iek.nl.\tIN\tA\t127.0.0.1", 67 | testRR("m\\\tIeK.nl. 3600 in A 127.0.0.1"): "m\\009iek.nl.\tIN\tA\t127.0.0.1", 68 | } 69 | for tc, expected := range tests { 70 | n := normalizedString(tc) 71 | if n != expected { 72 | t.Errorf("expected %s, got %s", expected, n) 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /serve_mux.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // ServeMux is an DNS request multiplexer. It matches the zone name of 8 | // each incoming request against a list of registered patterns add calls 9 | // the handler for the pattern that most closely matches the zone name. 10 | // 11 | // ServeMux is DNSSEC aware, meaning that queries for the DS record are 12 | // redirected to the parent zone (if that is also registered), otherwise 13 | // the child gets the query. 14 | // 15 | // ServeMux is also safe for concurrent access from multiple goroutines. 16 | // 17 | // The zero ServeMux is empty and ready for use. 18 | type ServeMux struct { 19 | z map[string]Handler 20 | m sync.RWMutex 21 | } 22 | 23 | // NewServeMux allocates and returns a new ServeMux. 24 | func NewServeMux() *ServeMux { 25 | return new(ServeMux) 26 | } 27 | 28 | // DefaultServeMux is the default ServeMux used by Serve. 29 | var DefaultServeMux = NewServeMux() 30 | 31 | func (mux *ServeMux) match(q string, t uint16) Handler { 32 | mux.m.RLock() 33 | defer mux.m.RUnlock() 34 | if mux.z == nil { 35 | return nil 36 | } 37 | 38 | q = CanonicalName(q) 39 | 40 | var handler Handler 41 | for off, end := 0, false; !end; off, end = NextLabel(q, off) { 42 | if h, ok := mux.z[q[off:]]; ok { 43 | if t != TypeDS { 44 | return h 45 | } 46 | // Continue for DS to see if we have a parent too, if so delegate to the parent 47 | handler = h 48 | } 49 | } 50 | 51 | // Wildcard match, if we have found nothing try the root zone as a last resort. 52 | if h, ok := mux.z["."]; ok { 53 | return h 54 | } 55 | 56 | return handler 57 | } 58 | 59 | // Handle adds a handler to the ServeMux for pattern. 60 | func (mux *ServeMux) Handle(pattern string, handler Handler) { 61 | if pattern == "" { 62 | panic("dns: invalid pattern " + pattern) 63 | } 64 | mux.m.Lock() 65 | if mux.z == nil { 66 | mux.z = make(map[string]Handler) 67 | } 68 | mux.z[CanonicalName(pattern)] = handler 69 | mux.m.Unlock() 70 | } 71 | 72 | // HandleFunc adds a handler function to the ServeMux for pattern. 73 | func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { 74 | mux.Handle(pattern, HandlerFunc(handler)) 75 | } 76 | 77 | // HandleRemove deregisters the handler specific for pattern from the ServeMux. 78 | func (mux *ServeMux) HandleRemove(pattern string) { 79 | if pattern == "" { 80 | panic("dns: invalid pattern " + pattern) 81 | } 82 | mux.m.Lock() 83 | delete(mux.z, CanonicalName(pattern)) 84 | mux.m.Unlock() 85 | } 86 | 87 | // ServeDNS dispatches the request to the handler whose pattern most 88 | // closely matches the request message. 89 | // 90 | // ServeDNS is DNSSEC aware, meaning that queries for the DS record 91 | // are redirected to the parent zone (if that is also registered), 92 | // otherwise the child gets the query. 93 | // 94 | // If no handler is found, or there is no question, a standard REFUSED 95 | // message is returned 96 | func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) { 97 | var h Handler 98 | if len(req.Question) >= 1 { // allow more than one question 99 | h = mux.match(req.Question[0].Name, req.Question[0].Qtype) 100 | } 101 | 102 | if h != nil { 103 | h.ServeDNS(w, req) 104 | } else { 105 | handleRefused(w, req) 106 | } 107 | } 108 | 109 | // Handle registers the handler with the given pattern 110 | // in the DefaultServeMux. The documentation for 111 | // ServeMux explains how patterns are matched. 112 | func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } 113 | 114 | // HandleRemove deregisters the handle with the given pattern 115 | // in the DefaultServeMux. 116 | func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) } 117 | 118 | // HandleFunc registers the handler function with the given pattern 119 | // in the DefaultServeMux. 120 | func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { 121 | DefaultServeMux.HandleFunc(pattern, handler) 122 | } 123 | -------------------------------------------------------------------------------- /serve_mux_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "testing" 4 | 5 | func TestDotAsCatchAllWildcard(t *testing.T) { 6 | mux := NewServeMux() 7 | mux.Handle(".", HandlerFunc(HelloServer)) 8 | mux.Handle("example.com.", HandlerFunc(AnotherHelloServer)) 9 | 10 | handler := mux.match("www.miek.nl.", TypeTXT) 11 | if handler == nil { 12 | t.Error("wildcard match failed") 13 | } 14 | 15 | handler = mux.match("www.example.com.", TypeTXT) 16 | if handler == nil { 17 | t.Error("example.com match failed") 18 | } 19 | 20 | handler = mux.match("a.www.example.com.", TypeTXT) 21 | if handler == nil { 22 | t.Error("a.www.example.com match failed") 23 | } 24 | 25 | handler = mux.match("boe.", TypeTXT) 26 | if handler == nil { 27 | t.Error("boe. match failed") 28 | } 29 | } 30 | 31 | func TestCaseFolding(t *testing.T) { 32 | mux := NewServeMux() 33 | mux.Handle("_udp.example.com.", HandlerFunc(HelloServer)) 34 | 35 | handler := mux.match("_dns._udp.example.com.", TypeSRV) 36 | if handler == nil { 37 | t.Error("case sensitive characters folded") 38 | } 39 | 40 | handler = mux.match("_DNS._UDP.EXAMPLE.COM.", TypeSRV) 41 | if handler == nil { 42 | t.Error("case insensitive characters not folded") 43 | } 44 | } 45 | 46 | func TestRootServer(t *testing.T) { 47 | mux := NewServeMux() 48 | mux.Handle(".", HandlerFunc(HelloServer)) 49 | 50 | handler := mux.match(".", TypeNS) 51 | if handler == nil { 52 | t.Error("root match failed") 53 | } 54 | } 55 | 56 | func BenchmarkMuxMatch(b *testing.B) { 57 | mux := NewServeMux() 58 | mux.Handle("_udp.example.com.", HandlerFunc(HelloServer)) 59 | 60 | bench := func(q string) func(*testing.B) { 61 | return func(b *testing.B) { 62 | for n := 0; n < b.N; n++ { 63 | handler := mux.match(q, TypeSRV) 64 | if handler == nil { 65 | b.Fatal("couldn't find match") 66 | } 67 | } 68 | } 69 | } 70 | b.Run("lowercase", bench("_dns._udp.example.com.")) 71 | b.Run("uppercase", bench("_DNS._UDP.EXAMPLE.COM.")) 72 | } 73 | -------------------------------------------------------------------------------- /sig0.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto" 5 | "crypto/ecdsa" 6 | "crypto/ed25519" 7 | "crypto/rsa" 8 | "encoding/binary" 9 | "math/big" 10 | "time" 11 | ) 12 | 13 | // Sign signs a dns.Msg. It fills the signature with the appropriate data. 14 | // The SIG record should have the SignerName, KeyTag, Algorithm, Inception 15 | // and Expiration set. 16 | func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { 17 | if k == nil { 18 | return nil, ErrPrivKey 19 | } 20 | if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 { 21 | return nil, ErrKey 22 | } 23 | 24 | rr.Hdr = RR_Header{Name: ".", Rrtype: TypeSIG, Class: ClassANY, Ttl: 0} 25 | rr.OrigTtl, rr.TypeCovered, rr.Labels = 0, 0, 0 26 | 27 | buf := make([]byte, m.Len()+Len(rr)) 28 | mbuf, err := m.PackBuffer(buf) 29 | if err != nil { 30 | return nil, err 31 | } 32 | if &buf[0] != &mbuf[0] { 33 | return nil, ErrBuf 34 | } 35 | off, err := PackRR(rr, buf, len(mbuf), nil, false) 36 | if err != nil { 37 | return nil, err 38 | } 39 | buf = buf[:off:cap(buf)] 40 | 41 | h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | // Write SIG rdata 47 | h.Write(buf[len(mbuf)+1+2+2+4+2:]) 48 | // Write message 49 | h.Write(buf[:len(mbuf)]) 50 | 51 | signature, err := sign(k, h.Sum(nil), cryptohash, rr.Algorithm) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | rr.Signature = toBase64(signature) 57 | 58 | buf = append(buf, signature...) 59 | if len(buf) > int(^uint16(0)) { 60 | return nil, ErrBuf 61 | } 62 | // Adjust sig data length 63 | rdoff := len(mbuf) + 1 + 2 + 2 + 4 64 | rdlen := binary.BigEndian.Uint16(buf[rdoff:]) 65 | rdlen += uint16(len(signature)) 66 | binary.BigEndian.PutUint16(buf[rdoff:], rdlen) 67 | // Adjust additional count 68 | adc := binary.BigEndian.Uint16(buf[10:]) 69 | adc++ 70 | binary.BigEndian.PutUint16(buf[10:], adc) 71 | return buf, nil 72 | } 73 | 74 | // Verify validates the message buf using the key k. 75 | // It's assumed that buf is a valid message from which rr was unpacked. 76 | func (rr *SIG) Verify(k *KEY, buf []byte) error { 77 | if k == nil { 78 | return ErrKey 79 | } 80 | if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 { 81 | return ErrKey 82 | } 83 | 84 | h, cryptohash, err := hashFromAlgorithm(rr.Algorithm) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | buflen := len(buf) 90 | qdc := binary.BigEndian.Uint16(buf[4:]) 91 | anc := binary.BigEndian.Uint16(buf[6:]) 92 | auc := binary.BigEndian.Uint16(buf[8:]) 93 | adc := binary.BigEndian.Uint16(buf[10:]) 94 | offset := headerSize 95 | for i := uint16(0); i < qdc && offset < buflen; i++ { 96 | _, offset, err = UnpackDomainName(buf, offset) 97 | if err != nil { 98 | return err 99 | } 100 | // Skip past Type and Class 101 | offset += 2 + 2 102 | } 103 | for i := uint16(1); i < anc+auc+adc && offset < buflen; i++ { 104 | _, offset, err = UnpackDomainName(buf, offset) 105 | if err != nil { 106 | return err 107 | } 108 | // Skip past Type, Class and TTL 109 | offset += 2 + 2 + 4 110 | if offset+1 >= buflen { 111 | continue 112 | } 113 | rdlen := binary.BigEndian.Uint16(buf[offset:]) 114 | offset += 2 115 | offset += int(rdlen) 116 | } 117 | if offset >= buflen { 118 | return &Error{err: "overflowing unpacking signed message"} 119 | } 120 | 121 | // offset should be just prior to SIG 122 | bodyend := offset 123 | // owner name SHOULD be root 124 | _, offset, err = UnpackDomainName(buf, offset) 125 | if err != nil { 126 | return err 127 | } 128 | // Skip Type, Class, TTL, RDLen 129 | offset += 2 + 2 + 4 + 2 130 | sigstart := offset 131 | // Skip Type Covered, Algorithm, Labels, Original TTL 132 | offset += 2 + 1 + 1 + 4 133 | if offset+4+4 >= buflen { 134 | return &Error{err: "overflow unpacking signed message"} 135 | } 136 | expire := binary.BigEndian.Uint32(buf[offset:]) 137 | offset += 4 138 | incept := binary.BigEndian.Uint32(buf[offset:]) 139 | offset += 4 140 | now := uint32(time.Now().Unix()) 141 | if now < incept || now > expire { 142 | return ErrTime 143 | } 144 | // Skip key tag 145 | offset += 2 146 | var signername string 147 | signername, offset, err = UnpackDomainName(buf, offset) 148 | if err != nil { 149 | return err 150 | } 151 | // If key has come from the DNS name compression might 152 | // have mangled the case of the name 153 | if !equal(signername, k.Header().Name) { 154 | return &Error{err: "signer name doesn't match key name"} 155 | } 156 | sigend := offset 157 | h.Write(buf[sigstart:sigend]) 158 | h.Write(buf[:10]) 159 | h.Write([]byte{ 160 | byte((adc - 1) << 8), 161 | byte(adc - 1), 162 | }) 163 | h.Write(buf[12:bodyend]) 164 | 165 | hashed := h.Sum(nil) 166 | sig := buf[sigend:] 167 | switch k.Algorithm { 168 | case RSASHA1, RSASHA256, RSASHA512: 169 | pk := k.publicKeyRSA() 170 | if pk != nil { 171 | return rsa.VerifyPKCS1v15(pk, cryptohash, hashed, sig) 172 | } 173 | case ECDSAP256SHA256, ECDSAP384SHA384: 174 | pk := k.publicKeyECDSA() 175 | r := new(big.Int).SetBytes(sig[:len(sig)/2]) 176 | s := new(big.Int).SetBytes(sig[len(sig)/2:]) 177 | if pk != nil { 178 | if ecdsa.Verify(pk, hashed, r, s) { 179 | return nil 180 | } 181 | return ErrSig 182 | } 183 | case ED25519: 184 | pk := k.publicKeyED25519() 185 | if pk != nil { 186 | if ed25519.Verify(pk, hashed, sig) { 187 | return nil 188 | } 189 | return ErrSig 190 | } 191 | } 192 | return ErrKeyAlg 193 | } 194 | -------------------------------------------------------------------------------- /sig0_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestSIG0(t *testing.T) { 10 | if testing.Short() { 11 | t.Skip("skipping test in short mode.") 12 | } 13 | m := new(Msg) 14 | m.SetQuestion("example.org.", TypeSOA) 15 | for _, alg := range []uint8{ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256, RSASHA512, ED25519} { 16 | algstr := AlgorithmToString[alg] 17 | keyrr := new(KEY) 18 | keyrr.Hdr.Name = algstr + "." 19 | keyrr.Hdr.Rrtype = TypeKEY 20 | keyrr.Hdr.Class = ClassINET 21 | keyrr.Algorithm = alg 22 | keysize := 512 23 | switch alg { 24 | case ECDSAP256SHA256, ED25519: 25 | keysize = 256 26 | case ECDSAP384SHA384: 27 | keysize = 384 28 | case RSASHA512: 29 | keysize = 1024 30 | } 31 | pk, err := keyrr.Generate(keysize) 32 | if err != nil { 33 | t.Errorf("failed to generate key for %q: %v", algstr, err) 34 | continue 35 | } 36 | now := uint32(time.Now().Unix()) 37 | sigrr := new(SIG) 38 | sigrr.Hdr.Name = "." 39 | sigrr.Hdr.Rrtype = TypeSIG 40 | sigrr.Hdr.Class = ClassANY 41 | sigrr.Algorithm = alg 42 | sigrr.Expiration = now + 300 43 | sigrr.Inception = now - 300 44 | sigrr.KeyTag = keyrr.KeyTag() 45 | sigrr.SignerName = keyrr.Hdr.Name 46 | mb, err := sigrr.Sign(pk.(crypto.Signer), m) 47 | if err != nil { 48 | t.Errorf("failed to sign message using %q: %v", algstr, err) 49 | continue 50 | } 51 | m := new(Msg) 52 | if err := m.Unpack(mb); err != nil { 53 | t.Errorf("failed to unpack message signed using %q: %v", algstr, err) 54 | continue 55 | } 56 | if len(m.Extra) != 1 { 57 | t.Errorf("missing SIG for message signed using %q", algstr) 58 | continue 59 | } 60 | var sigrrwire *SIG 61 | switch rr := m.Extra[0].(type) { 62 | case *SIG: 63 | sigrrwire = rr 64 | default: 65 | t.Errorf("expected SIG RR, instead: %v", rr) 66 | continue 67 | } 68 | for _, rr := range []*SIG{sigrr, sigrrwire} { 69 | id := "sigrr" 70 | if rr == sigrrwire { 71 | id = "sigrrwire" 72 | } 73 | if err := rr.Verify(keyrr, mb); err != nil { 74 | t.Errorf("failed to verify %q signed SIG(%s): %v", algstr, id, err) 75 | continue 76 | } 77 | } 78 | mb[13]++ 79 | if err := sigrr.Verify(keyrr, mb); err == nil { 80 | t.Errorf("verify succeeded on an altered message using %q", algstr) 81 | continue 82 | } 83 | sigrr.Expiration = 2 84 | sigrr.Inception = 1 85 | mb, _ = sigrr.Sign(pk.(crypto.Signer), m) 86 | if err := sigrr.Verify(keyrr, mb); err == nil { 87 | t.Errorf("verify succeeded on an expired message using %q", algstr) 88 | continue 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /smimea.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/sha256" 5 | "crypto/x509" 6 | "encoding/hex" 7 | ) 8 | 9 | // Sign creates a SMIMEA record from an SSL certificate. 10 | func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) { 11 | r.Hdr.Rrtype = TypeSMIMEA 12 | r.Usage = uint8(usage) 13 | r.Selector = uint8(selector) 14 | r.MatchingType = uint8(matchingType) 15 | 16 | r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) 17 | return err 18 | } 19 | 20 | // Verify verifies a SMIMEA record against an SSL certificate. If it is OK 21 | // a nil error is returned. 22 | func (r *SMIMEA) Verify(cert *x509.Certificate) error { 23 | c, err := CertificateToDANE(r.Selector, r.MatchingType, cert) 24 | if err != nil { 25 | return err // Not also ErrSig? 26 | } 27 | if r.Certificate == c { 28 | return nil 29 | } 30 | return ErrSig // ErrSig, really? 31 | } 32 | 33 | // SMIMEAName returns the ownername of a SMIMEA resource record as per the 34 | // format specified in RFC 'draft-ietf-dane-smime-12' Section 2 and 3 35 | func SMIMEAName(email, domain string) (string, error) { 36 | hasher := sha256.New() 37 | hasher.Write([]byte(email)) 38 | 39 | // RFC Section 3: "The local-part is hashed using the SHA2-256 40 | // algorithm with the hash truncated to 28 octets and 41 | // represented in its hexadecimal representation to become the 42 | // left-most label in the prepared domain name" 43 | return hex.EncodeToString(hasher.Sum(nil)[:28]) + "." + "_smimecert." + domain, nil 44 | } 45 | -------------------------------------------------------------------------------- /svcb_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | // This tests everything valid about SVCB but parsing. 8 | // Parsing tests belong to parse_test.go. 9 | func TestSVCB(t *testing.T) { 10 | svcbs := []struct { 11 | key string 12 | data string 13 | }{ 14 | {`mandatory`, `alpn,key65000`}, 15 | {`alpn`, `h2,h2c`}, 16 | {`port`, `499`}, 17 | {`ipv4hint`, `3.4.3.2,1.1.1.1`}, 18 | {`no-default-alpn`, ``}, 19 | {`ipv6hint`, `1::4:4:4:4,1::3:3:3:3`}, 20 | {`ech`, `YUdWc2JHOD0=`}, 21 | {`dohpath`, `/dns-query{?dns}`}, 22 | {`key65000`, `4\ 3`}, 23 | {`key65001`, `\"\ `}, 24 | {`key65002`, ``}, 25 | {`key65003`, `=\"\"`}, 26 | {`key65004`, `\254\ \ \030\000`}, 27 | {`ohttp`, ``}, 28 | } 29 | 30 | for _, o := range svcbs { 31 | keyCode := svcbStringToKey(o.key) 32 | kv := makeSVCBKeyValue(keyCode) 33 | if kv == nil { 34 | t.Error("failed to parse svc key: ", o.key) 35 | continue 36 | } 37 | if kv.Key() != keyCode { 38 | t.Error("key constant is not in sync: ", keyCode) 39 | continue 40 | } 41 | err := kv.parse(o.data) 42 | if err != nil { 43 | t.Error("failed to parse svc pair: ", o.key) 44 | continue 45 | } 46 | b, err := kv.pack() 47 | if err != nil { 48 | t.Error("failed to pack value of svc pair: ", o.key, err) 49 | continue 50 | } 51 | if len(b) != int(kv.len()) { 52 | t.Errorf("expected packed svc value %s to be of length %d but got %d", o.key, int(kv.len()), len(b)) 53 | } 54 | err = kv.unpack(b) 55 | if err != nil { 56 | t.Error("failed to unpack value of svc pair: ", o.key, err) 57 | continue 58 | } 59 | if str := kv.String(); str != o.data { 60 | t.Errorf("`%s' should be equal to\n`%s', but is `%s'", o.key, o.data, str) 61 | } 62 | } 63 | } 64 | 65 | func TestDecodeBadSVCB(t *testing.T) { 66 | svcbs := []struct { 67 | key SVCBKey 68 | data []byte 69 | }{ 70 | { 71 | key: SVCB_ALPN, 72 | data: []byte{3, 0, 0}, // There aren't three octets after 3 73 | }, 74 | { 75 | key: SVCB_NO_DEFAULT_ALPN, 76 | data: []byte{0}, 77 | }, 78 | { 79 | key: SVCB_PORT, 80 | data: []byte{}, 81 | }, 82 | { 83 | key: SVCB_IPV4HINT, 84 | data: []byte{0, 0, 0}, 85 | }, 86 | { 87 | key: SVCB_IPV6HINT, 88 | data: []byte{0, 0, 0}, 89 | }, 90 | { 91 | key: SVCB_OHTTP, 92 | data: []byte{0}, 93 | }, 94 | } 95 | for _, o := range svcbs { 96 | err := makeSVCBKeyValue(SVCBKey(o.key)).unpack(o.data) 97 | if err == nil { 98 | t.Error("accepted invalid svc value with key ", SVCBKey(o.key).String()) 99 | } 100 | } 101 | } 102 | 103 | func TestPresentationSVCBAlpn(t *testing.T) { 104 | tests := map[string]string{ 105 | "h2": "h2", 106 | "http": "http", 107 | "\xfa": `\250`, 108 | "some\"other,chars": `some\"other\\\044chars`, 109 | } 110 | for input, want := range tests { 111 | e := new(SVCBAlpn) 112 | e.Alpn = []string{input} 113 | if e.String() != want { 114 | t.Errorf("improper conversion with String(), wanted %v got %v", want, e.String()) 115 | } 116 | } 117 | } 118 | 119 | func TestSVCBAlpn(t *testing.T) { 120 | tests := map[string][]string{ 121 | `. 1 IN SVCB 10 one.test. alpn=h2`: {"h2"}, 122 | `. 2 IN SVCB 20 two.test. alpn=h2,h3-19`: {"h2", "h3-19"}, 123 | `. 3 IN SVCB 30 three.test. alpn="f\\\\oo\\,bar,h2"`: {`f\oo,bar`, "h2"}, 124 | `. 4 IN SVCB 40 four.test. alpn="part1,part2,part3\\,part4\\\\"`: {"part1", "part2", `part3,part4\`}, 125 | `. 5 IN SVCB 50 five.test. alpn=part1\,\p\a\r\t2\044part3\092,part4\092\\`: {"part1", "part2", `part3,part4\`}, 126 | } 127 | for s, v := range tests { 128 | rr, err := NewRR(s) 129 | if err != nil { 130 | t.Error("failed to parse RR: ", err) 131 | continue 132 | } 133 | alpn := rr.(*SVCB).Value[0].(*SVCBAlpn).Alpn 134 | if len(v) != len(alpn) { 135 | t.Fatalf("parsing alpn failed, wanted %v got %v", v, alpn) 136 | } 137 | for i := range v { 138 | if v[i] != alpn[i] { 139 | t.Fatalf("parsing alpn failed, wanted %v got %v", v, alpn) 140 | } 141 | } 142 | } 143 | } 144 | 145 | func TestCompareSVCB(t *testing.T) { 146 | val1 := []SVCBKeyValue{ 147 | &SVCBPort{ 148 | Port: 117, 149 | }, 150 | &SVCBAlpn{ 151 | Alpn: []string{"h2", "h3"}, 152 | }, 153 | } 154 | val2 := []SVCBKeyValue{ 155 | &SVCBAlpn{ 156 | Alpn: []string{"h2", "h3"}, 157 | }, 158 | &SVCBPort{ 159 | Port: 117, 160 | }, 161 | } 162 | if !areSVCBPairArraysEqual(val1, val2) { 163 | t.Error("svcb pairs were compared without sorting") 164 | } 165 | if val1[0].Key() != SVCB_PORT || val2[0].Key() != SVCB_ALPN { 166 | t.Error("original svcb pairs were reordered during comparison") 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /tlsa.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/x509" 5 | "net" 6 | "strconv" 7 | ) 8 | 9 | // Sign creates a TLSA record from an SSL certificate. 10 | func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (err error) { 11 | r.Hdr.Rrtype = TypeTLSA 12 | r.Usage = uint8(usage) 13 | r.Selector = uint8(selector) 14 | r.MatchingType = uint8(matchingType) 15 | 16 | r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) 17 | return err 18 | } 19 | 20 | // Verify verifies a TLSA record against an SSL certificate. If it is OK 21 | // a nil error is returned. 22 | func (r *TLSA) Verify(cert *x509.Certificate) error { 23 | c, err := CertificateToDANE(r.Selector, r.MatchingType, cert) 24 | if err != nil { 25 | return err // Not also ErrSig? 26 | } 27 | if r.Certificate == c { 28 | return nil 29 | } 30 | return ErrSig // ErrSig, really? 31 | } 32 | 33 | // TLSAName returns the ownername of a TLSA resource record as per the 34 | // rules specified in RFC 6698, Section 3. 35 | func TLSAName(name, service, network string) (string, error) { 36 | if !IsFqdn(name) { 37 | return "", ErrFqdn 38 | } 39 | p, err := net.LookupPort(network, service) 40 | if err != nil { 41 | return "", err 42 | } 43 | return "_" + strconv.Itoa(p) + "._" + network + "." + name, nil 44 | } 45 | -------------------------------------------------------------------------------- /tmpdir_darwin_test.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | 3 | package dns 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | // tempDir creates a temporary directory for tests and returns a file path as 13 | // a result of concatenation of said temporary directory path and provided filename. 14 | // The reason for this is to work around some limitations in socket file name 15 | // lengths on darwin. 16 | // 17 | // Ref: 18 | // - https://github.com/golang/go/blob/go1.20.2/src/syscall/ztypes_darwin_arm64.go#L178 19 | // - https://github.com/golang/go/blob/go1.20.2/src/syscall/ztypes_linux_arm64.go#L175 20 | func tempFile(t *testing.T, filename string) string { 21 | t.Helper() 22 | 23 | dir, err := os.MkdirTemp("", strings.ReplaceAll(t.Name(), string(filepath.Separator), "-")) 24 | if err != nil { 25 | t.Fatalf("failed to create temp dir: %v", err) 26 | } 27 | 28 | return filepath.Join(dir, filename) 29 | } 30 | -------------------------------------------------------------------------------- /tmpdir_test.go: -------------------------------------------------------------------------------- 1 | //go:build !darwin 2 | 3 | package dns 4 | 5 | import ( 6 | "path/filepath" 7 | "testing" 8 | ) 9 | 10 | // tempDir creates a temporary directory for tests and returns a file path as 11 | // a result of concatenation of said temporary directory path and provided filename. 12 | func tempFile(t *testing.T, filename string) string { 13 | t.Helper() 14 | 15 | return filepath.Join(t.TempDir(), filename) 16 | } 17 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | // We include our tool dependencies for `go generate` here to ensure they're 5 | // properly tracked by the go tool. See the Go Wiki for the rationale behind this: 6 | // https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module. 7 | 8 | package dns 9 | 10 | import _ "golang.org/x/tools/go/packages" 11 | -------------------------------------------------------------------------------- /types_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCmToM(t *testing.T) { 8 | s := cmToM((0 << 4) + 0) 9 | if s != "0.00" { 10 | t.Error("0, 0") 11 | } 12 | 13 | s = cmToM((1 << 4) + 0) 14 | if s != "0.01" { 15 | t.Error("1, 0") 16 | } 17 | 18 | s = cmToM((3 << 4) + 1) 19 | if s != "0.30" { 20 | t.Error("3, 1") 21 | } 22 | 23 | s = cmToM((4 << 4) + 2) 24 | if s != "4" { 25 | t.Error("4, 2") 26 | } 27 | 28 | s = cmToM((5 << 4) + 3) 29 | if s != "50" { 30 | t.Error("5, 3") 31 | } 32 | 33 | s = cmToM((7 << 4) + 5) 34 | if s != "7000" { 35 | t.Error("7, 5") 36 | } 37 | 38 | s = cmToM((9 << 4) + 9) 39 | if s != "90000000" { 40 | t.Error("9, 9") 41 | } 42 | } 43 | 44 | func TestSplitN(t *testing.T) { 45 | xs := splitN("abc", 5) 46 | if len(xs) != 1 && xs[0] != "abc" { 47 | t.Errorf("failure to split abc") 48 | } 49 | 50 | s := "" 51 | for i := 0; i < 255; i++ { 52 | s += "a" 53 | } 54 | 55 | xs = splitN(s, 255) 56 | if len(xs) != 1 && xs[0] != s { 57 | t.Errorf("failure to split 255 char long string") 58 | } 59 | 60 | s += "b" 61 | xs = splitN(s, 255) 62 | if len(xs) != 2 || xs[1] != "b" { 63 | t.Errorf("failure to split 256 char long string: %d", len(xs)) 64 | } 65 | 66 | // Make s longer 67 | for i := 0; i < 255; i++ { 68 | s += "a" 69 | } 70 | xs = splitN(s, 255) 71 | if len(xs) != 3 || xs[2] != "a" { 72 | t.Errorf("failure to split 510 char long string: %d", len(xs)) 73 | } 74 | } 75 | 76 | func TestSprintName(t *testing.T) { 77 | tests := map[string]string{ 78 | // Non-numeric escaping of special printable characters. 79 | " '@;()\"\\..example": `\ \'\@\;\(\)\"\..example`, 80 | "\\032\\039\\064\\059\\040\\041\\034\\046\\092.example": `\ \'\@\;\(\)\"\.\\.example`, 81 | 82 | // Numeric escaping of nonprintable characters. 83 | "\x00\x07\x09\x0a\x1f.\x7f\x80\xad\xef\xff": `\000\007\009\010\031.\127\128\173\239\255`, 84 | "\\000\\007\\009\\010\\031.\\127\\128\\173\\239\\255": `\000\007\009\010\031.\127\128\173\239\255`, 85 | 86 | // No escaping of other printable characters, at least after a prior escape. 87 | ";[a-zA-Z0-9_]+/*.~": `\;[a-zA-Z0-9_]+/*.~`, 88 | ";\\091\\097\\045\\122\\065\\045\\090\\048\\045\\057\\095\\093\\043\\047\\042.\\126": `\;[a-zA-Z0-9_]+/*.~`, 89 | // "\\091\\097\\045\\122\\065\\045\\090\\048\\045\\057\\095\\093\\043\\047\\042.\\126": `[a-zA-Z0-9_]+/*.~`, 90 | 91 | // Incomplete "dangling" escapes are dropped regardless of prior escaping. 92 | "a\\": `a`, 93 | ";\\": `\;`, 94 | 95 | // Escaped dots stay escaped regardless of prior escaping. 96 | "a\\.\\046.\\.\\046": `a\.\..\.\.`, 97 | "a\\046\\..\\046\\.": `a\.\..\.\.`, 98 | } 99 | for input, want := range tests { 100 | got := sprintName(input) 101 | if got != want { 102 | t.Errorf("input %q: expected %q, got %q", input, want, got) 103 | } 104 | } 105 | } 106 | 107 | func TestSprintTxtOctet(t *testing.T) { 108 | got := sprintTxtOctet("abc\\.def\007\"\127@\255\x05\xef\\") 109 | 110 | if want := "\"abc\\.def\\007\\\"W@\\173\\005\\239\""; got != want { 111 | t.Errorf("expected %q, got %q", want, got) 112 | } 113 | } 114 | 115 | func TestSprintTxt(t *testing.T) { 116 | got := sprintTxt([]string{ 117 | "abc\\.def\007\"\127@\255\x05\xef\\", 118 | "example.com", 119 | }) 120 | 121 | if want := "\"abc.def\\007\\\"W@\\173\\005\\239\" \"example.com\""; got != want { 122 | t.Errorf("expected %q, got %q", want, got) 123 | } 124 | } 125 | 126 | func TestRPStringer(t *testing.T) { 127 | rp := &RP{ 128 | Hdr: RR_Header{ 129 | Name: "test.example.com.", 130 | Rrtype: TypeRP, 131 | Class: ClassINET, 132 | Ttl: 600, 133 | }, 134 | Mbox: "\x05first.example.com.", 135 | Txt: "second.\x07example.com.", 136 | } 137 | 138 | const expected = "test.example.com.\t600\tIN\tRP\t\\005first.example.com. second.\\007example.com." 139 | if rp.String() != expected { 140 | t.Errorf("expected %v, got %v", expected, rp) 141 | } 142 | 143 | _, err := NewRR(rp.String()) 144 | if err != nil { 145 | t.Fatalf("error parsing %q: %v", rp, err) 146 | } 147 | } 148 | 149 | func BenchmarkSprintName(b *testing.B) { 150 | for n := 0; n < b.N; n++ { 151 | got := sprintName("abc\\.def\007\"\127@\255\x05\xef\\") 152 | 153 | if want := "abc\\.def\\007\\\"W\\@\\173\\005\\239"; got != want { 154 | b.Fatalf("expected %q, got %q", want, got) 155 | } 156 | } 157 | } 158 | 159 | func BenchmarkSprintName_NoEscape(b *testing.B) { 160 | for n := 0; n < b.N; n++ { 161 | got := sprintName("large.example.com") 162 | 163 | if want := "large.example.com"; got != want { 164 | b.Fatalf("expected %q, got %q", want, got) 165 | } 166 | } 167 | } 168 | 169 | func BenchmarkSprintTxtOctet(b *testing.B) { 170 | for n := 0; n < b.N; n++ { 171 | got := sprintTxtOctet("abc\\.def\007\"\127@\255\x05\xef\\") 172 | 173 | if want := "\"abc\\.def\\007\\\"W@\\173\\005\\239\""; got != want { 174 | b.Fatalf("expected %q, got %q", want, got) 175 | } 176 | } 177 | } 178 | 179 | func BenchmarkSprintTxt(b *testing.B) { 180 | txt := []string{ 181 | "abc\\.def\007\"\127@\255\x05\xef\\", 182 | "example.com", 183 | } 184 | 185 | b.ResetTimer() 186 | for n := 0; n < b.N; n++ { 187 | got := sprintTxt(txt) 188 | 189 | if want := "\"abc.def\\007\\\"W@\\173\\005\\239\" \"example.com\""; got != want { 190 | b.Fatalf("expected %q, got %q", got, want) 191 | } 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /udp.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !darwin 2 | // +build !windows,!darwin 3 | 4 | package dns 5 | 6 | import ( 7 | "net" 8 | 9 | "golang.org/x/net/ipv4" 10 | "golang.org/x/net/ipv6" 11 | ) 12 | 13 | // This is the required size of the OOB buffer to pass to ReadMsgUDP. 14 | var udpOOBSize = func() int { 15 | // We can't know whether we'll get an IPv4 control message or an 16 | // IPv6 control message ahead of time. To get around this, we size 17 | // the buffer equal to the largest of the two. 18 | 19 | oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface) 20 | oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface) 21 | 22 | if len(oob4) > len(oob6) { 23 | return len(oob4) 24 | } 25 | 26 | return len(oob6) 27 | }() 28 | 29 | // SessionUDP holds the remote address and the associated 30 | // out-of-band data. 31 | type SessionUDP struct { 32 | raddr *net.UDPAddr 33 | context []byte 34 | } 35 | 36 | // RemoteAddr returns the remote network address. 37 | func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } 38 | 39 | // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a 40 | // net.UDPAddr. 41 | func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { 42 | oob := make([]byte, udpOOBSize) 43 | n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) 44 | if err != nil { 45 | return n, nil, err 46 | } 47 | return n, &SessionUDP{raddr, oob[:oobn]}, err 48 | } 49 | 50 | // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. 51 | func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { 52 | oob := correctSource(session.context) 53 | n, _, err := conn.WriteMsgUDP(b, oob, session.raddr) 54 | return n, err 55 | } 56 | 57 | func setUDPSocketOptions(conn *net.UDPConn) error { 58 | // Try setting the flags for both families and ignore the errors unless they 59 | // both error. 60 | err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) 61 | err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) 62 | if err6 != nil && err4 != nil { 63 | return err4 64 | } 65 | return nil 66 | } 67 | 68 | // parseDstFromOOB takes oob data and returns the destination IP. 69 | func parseDstFromOOB(oob []byte) net.IP { 70 | // Start with IPv6 and then fallback to IPv4 71 | // TODO(fastest963): Figure out a way to prefer one or the other. Looking at 72 | // the lvl of the header for a 0 or 41 isn't cross-platform. 73 | cm6 := new(ipv6.ControlMessage) 74 | if cm6.Parse(oob) == nil && cm6.Dst != nil { 75 | return cm6.Dst 76 | } 77 | cm4 := new(ipv4.ControlMessage) 78 | if cm4.Parse(oob) == nil && cm4.Dst != nil { 79 | return cm4.Dst 80 | } 81 | return nil 82 | } 83 | 84 | // correctSource takes oob data and returns new oob data with the Src equal to the Dst 85 | func correctSource(oob []byte) []byte { 86 | dst := parseDstFromOOB(oob) 87 | if dst == nil { 88 | return nil 89 | } 90 | // If the dst is definitely an IPv6, then use ipv6's ControlMessage to 91 | // respond otherwise use ipv4's because ipv6's marshal ignores ipv4 92 | // addresses. 93 | if dst.To4() == nil { 94 | cm := new(ipv6.ControlMessage) 95 | cm.Src = dst 96 | oob = cm.Marshal() 97 | } else { 98 | cm := new(ipv4.ControlMessage) 99 | cm.Src = dst 100 | oob = cm.Marshal() 101 | } 102 | return oob 103 | } 104 | -------------------------------------------------------------------------------- /udp_no_control.go: -------------------------------------------------------------------------------- 1 | //go:build windows || darwin 2 | // +build windows darwin 3 | 4 | // TODO(tmthrgd): Remove this Windows-specific code if go.dev/issue/7175 and 5 | // go.dev/issue/7174 are ever fixed. 6 | 7 | // NOTICE(stek29): darwin supports PKTINFO in sendmsg, but it unbinds sockets, see https://github.com/miekg/dns/issues/724 8 | 9 | package dns 10 | 11 | import "net" 12 | 13 | // SessionUDP holds the remote address 14 | type SessionUDP struct { 15 | raddr *net.UDPAddr 16 | } 17 | 18 | // RemoteAddr returns the remote network address. 19 | func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } 20 | 21 | // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a 22 | // net.UDPAddr. 23 | func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { 24 | n, raddr, err := conn.ReadFrom(b) 25 | if err != nil { 26 | return n, nil, err 27 | } 28 | return n, &SessionUDP{raddr.(*net.UDPAddr)}, err 29 | } 30 | 31 | // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. 32 | func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { 33 | return conn.WriteTo(b, session.raddr) 34 | } 35 | 36 | func setUDPSocketOptions(*net.UDPConn) error { return nil } 37 | func parseDstFromOOB([]byte, net.IP) net.IP { return nil } 38 | -------------------------------------------------------------------------------- /udp_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux && !appengine 2 | // +build linux,!appengine 3 | 4 | package dns 5 | 6 | import ( 7 | "bytes" 8 | "net" 9 | "runtime" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "golang.org/x/net/ipv4" 15 | "golang.org/x/net/ipv6" 16 | ) 17 | 18 | func TestSetUDPSocketOptions(t *testing.T) { 19 | // returns an error if we cannot resolve that address 20 | testFamily := func(n, addr string) error { 21 | a, err := net.ResolveUDPAddr(n, addr) 22 | if err != nil { 23 | return err 24 | } 25 | c, err := net.ListenUDP(n, a) 26 | if err != nil { 27 | return err 28 | } 29 | if err := setUDPSocketOptions(c); err != nil { 30 | t.Fatalf("failed to set socket options: %v", err) 31 | } 32 | ch := make(chan *SessionUDP) 33 | go func() { 34 | // Set some deadline so this goroutine doesn't hang forever 35 | c.SetReadDeadline(time.Now().Add(time.Minute)) 36 | b := make([]byte, 1) 37 | _, sess, err := ReadFromSessionUDP(c, b) 38 | if err != nil { 39 | t.Errorf("failed to read from conn: %v", err) 40 | // fallthrough to chan send below 41 | } 42 | ch <- sess 43 | }() 44 | 45 | c2, err := net.Dial("udp", c.LocalAddr().String()) 46 | if err != nil { 47 | t.Fatalf("failed to dial udp: %v", err) 48 | } 49 | if _, err := c2.Write([]byte{1}); err != nil { 50 | t.Fatalf("failed to write to conn: %v", err) 51 | } 52 | sess := <-ch 53 | if sess == nil { 54 | // t.Error was already called in the goroutine above. 55 | t.FailNow() 56 | } 57 | if len(sess.context) == 0 { 58 | t.Fatalf("empty session context: %v", sess) 59 | } 60 | ip := parseDstFromOOB(sess.context) 61 | if ip == nil { 62 | t.Fatalf("failed to parse dst: %v", sess) 63 | } 64 | if !strings.Contains(c.LocalAddr().String(), ip.String()) { 65 | t.Fatalf("dst was different than listen addr: %v != %v", ip.String(), c.LocalAddr().String()) 66 | } 67 | return nil 68 | } 69 | 70 | // we require that ipv4 be supported 71 | if err := testFamily("udp4", "127.0.0.1:0"); err != nil { 72 | t.Fatalf("failed to test socket options on IPv4: %v", err) 73 | } 74 | // IPv6 might not be supported so these will just log 75 | if err := testFamily("udp6", "[::1]:0"); err != nil { 76 | t.Logf("failed to test socket options on IPv6-only: %v", err) 77 | } 78 | if err := testFamily("udp", "[::1]:0"); err != nil { 79 | t.Logf("failed to test socket options on IPv6/IPv4: %v", err) 80 | } 81 | } 82 | 83 | func TestParseDstFromOOB(t *testing.T) { 84 | if runtime.GOARCH != "amd64" { 85 | // The cmsghdr struct differs in the width (32/64-bit) of 86 | // lengths and the struct padding between architectures. 87 | // The data below was only written with amd64 in mind, and 88 | // thus the test must be skipped on other architectures. 89 | t.Skip("skipping test on unsupported architecture") 90 | } 91 | 92 | // dst is :ffff:100.100.100.100 93 | oob := []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 100, 100, 100, 100, 2, 0, 0, 0} 94 | dst := parseDstFromOOB(oob) 95 | dst4 := dst.To4() 96 | if dst4 == nil { 97 | t.Errorf("failed to parse IPv4 in IPv6: %v", dst) 98 | } else if dst4.String() != "100.100.100.100" { 99 | t.Errorf("unexpected IPv4: %v", dst4) 100 | } 101 | 102 | // dst is 2001:db8::1 103 | oob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0} 104 | dst = parseDstFromOOB(oob) 105 | dst6 := dst.To16() 106 | if dst6 == nil { 107 | t.Errorf("failed to parse IPv6: %v", dst) 108 | } else if dst6.String() != "2001:db8::1" { 109 | t.Errorf("unexpected IPv6: %v", dst4) 110 | } 111 | 112 | // dst is 100.100.100.100 but was received on 10.10.10.10 113 | oob = []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0, 10, 10, 10, 10, 100, 100, 100, 100, 0, 0, 0, 0} 114 | dst = parseDstFromOOB(oob) 115 | dst4 = dst.To4() 116 | if dst4 == nil { 117 | t.Errorf("failed to parse IPv4: %v", dst) 118 | } else if dst4.String() != "100.100.100.100" { 119 | t.Errorf("unexpected IPv4: %v", dst4) 120 | } 121 | } 122 | 123 | func TestCorrectSource(t *testing.T) { 124 | if runtime.GOARCH != "amd64" { 125 | // See comment above in TestParseDstFromOOB. 126 | t.Skip("skipping test on unsupported architecture") 127 | } 128 | 129 | // dst is :ffff:100.100.100.100 which should be counted as IPv4 130 | oob := []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 100, 100, 100, 100, 2, 0, 0, 0} 131 | soob := correctSource(oob) 132 | cm4 := new(ipv4.ControlMessage) 133 | cm4.Src = net.ParseIP("100.100.100.100") 134 | if !bytes.Equal(soob, cm4.Marshal()) { 135 | t.Errorf("unexpected oob for ipv4 address: %v", soob) 136 | } 137 | 138 | // dst is 2001:db8::1 139 | oob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0} 140 | soob = correctSource(oob) 141 | cm6 := new(ipv6.ControlMessage) 142 | cm6.Src = net.ParseIP("2001:db8::1") 143 | if !bytes.Equal(soob, cm6.Marshal()) { 144 | t.Errorf("unexpected oob for IPv6 address: %v", soob) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | // NameUsed sets the RRs in the prereq section to 4 | // "Name is in use" RRs. RFC 2136 section 2.4.4. 5 | // See [ANY] on how to make RRs without rdata. 6 | func (u *Msg) NameUsed(rr []RR) { 7 | if u.Answer == nil { 8 | u.Answer = make([]RR, 0, len(rr)) 9 | } 10 | for _, r := range rr { 11 | u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}}) 12 | } 13 | } 14 | 15 | // NameNotUsed sets the RRs in the prereq section to 16 | // "Name is in not use" RRs. RFC 2136 section 2.4.5. 17 | func (u *Msg) NameNotUsed(rr []RR) { 18 | if u.Answer == nil { 19 | u.Answer = make([]RR, 0, len(rr)) 20 | } 21 | for _, r := range rr { 22 | u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassNONE}}) 23 | } 24 | } 25 | 26 | // Used sets the RRs in the prereq section to 27 | // "RRset exists (value dependent -- with rdata)" RRs. RFC 2136 section 2.4.2. 28 | func (u *Msg) Used(rr []RR) { 29 | if len(u.Question) == 0 { 30 | panic("dns: empty question section") 31 | } 32 | if u.Answer == nil { 33 | u.Answer = make([]RR, 0, len(rr)) 34 | } 35 | for _, r := range rr { 36 | hdr := r.Header() 37 | hdr.Class = u.Question[0].Qclass 38 | hdr.Ttl = 0 39 | u.Answer = append(u.Answer, r) 40 | } 41 | } 42 | 43 | // RRsetUsed sets the RRs in the prereq section to 44 | // "RRset exists (value independent -- no rdata)" RRs. RFC 2136 section 2.4.1. 45 | // See [ANY] on how to make RRs without rdata. 46 | func (u *Msg) RRsetUsed(rr []RR) { 47 | if u.Answer == nil { 48 | u.Answer = make([]RR, 0, len(rr)) 49 | } 50 | for _, r := range rr { 51 | h := r.Header() 52 | u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassANY}}) 53 | } 54 | } 55 | 56 | // RRsetNotUsed sets the RRs in the prereq section to 57 | // "RRset does not exist" RRs. RFC 2136 section 2.4.3. 58 | // See [ANY] on how to make RRs without rdata. 59 | func (u *Msg) RRsetNotUsed(rr []RR) { 60 | if u.Answer == nil { 61 | u.Answer = make([]RR, 0, len(rr)) 62 | } 63 | for _, r := range rr { 64 | h := r.Header() 65 | u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassNONE}}) 66 | } 67 | } 68 | 69 | // Insert creates a dynamic update packet that adds an complete RRset, see RFC 2136 section 2.5.1. 70 | // See [ANY] on how to make RRs without rdata. 71 | func (u *Msg) Insert(rr []RR) { 72 | if len(u.Question) == 0 { 73 | panic("dns: empty question section") 74 | } 75 | if u.Ns == nil { 76 | u.Ns = make([]RR, 0, len(rr)) 77 | } 78 | for _, r := range rr { 79 | r.Header().Class = u.Question[0].Qclass 80 | u.Ns = append(u.Ns, r) 81 | } 82 | } 83 | 84 | // RemoveRRset creates a dynamic update packet that deletes an RRset, see RFC 2136 section 2.5.2. 85 | // See [ANY] on how to make RRs without rdata. 86 | func (u *Msg) RemoveRRset(rr []RR) { 87 | if u.Ns == nil { 88 | u.Ns = make([]RR, 0, len(rr)) 89 | } 90 | for _, r := range rr { 91 | h := r.Header() 92 | u.Ns = append(u.Ns, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassANY}}) 93 | } 94 | } 95 | 96 | // RemoveName creates a dynamic update packet that deletes all RRsets of a name, see RFC 2136 section 2.5.3 97 | // See [ANY] on how to make RRs without rdata. 98 | func (u *Msg) RemoveName(rr []RR) { 99 | if u.Ns == nil { 100 | u.Ns = make([]RR, 0, len(rr)) 101 | } 102 | for _, r := range rr { 103 | u.Ns = append(u.Ns, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: TypeANY, Class: ClassANY}}) 104 | } 105 | } 106 | 107 | // Remove creates a dynamic update packet deletes RR from a RRSset, see RFC 2136 section 2.5.4 108 | // See [ANY] on how to make RRs without rdata. 109 | func (u *Msg) Remove(rr []RR) { 110 | if u.Ns == nil { 111 | u.Ns = make([]RR, 0, len(rr)) 112 | } 113 | for _, r := range rr { 114 | h := r.Header() 115 | h.Class = ClassNONE 116 | h.Ttl = 0 117 | u.Ns = append(u.Ns, r) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /update_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestDynamicUpdateParsing(t *testing.T) { 9 | const prefix = "example.com. IN " 10 | 11 | for typ, name := range TypeToString { 12 | switch typ { 13 | case TypeNone, TypeReserved: 14 | continue 15 | case TypeANY: 16 | // ANY is ambiguous here and ends up parsed as a CLASS. 17 | // 18 | // TODO(tmthrgd): Using TYPE255 here doesn't seem to work and also 19 | // seems to fail for some other record types. Investigate. 20 | continue 21 | } 22 | 23 | s := prefix + name 24 | if _, err := NewRR(s); err != nil { 25 | t.Errorf("failure to parse: %s: %v", s, err) 26 | } 27 | 28 | s += " \\# 0" 29 | if _, err := NewRR(s); err != nil { 30 | t.Errorf("failure to parse: %s: %v", s, err) 31 | } 32 | } 33 | } 34 | 35 | func TestDynamicUpdateUnpack(t *testing.T) { 36 | // From https://github.com/miekg/dns/issues/150#issuecomment-62296803 37 | // It should be an update message for the zone "example.", 38 | // deleting the A RRset "example." and then adding an A record at "example.". 39 | // class ANY, TYPE A 40 | buf := []byte{171, 68, 40, 0, 0, 1, 0, 0, 0, 2, 0, 0, 7, 101, 120, 97, 109, 112, 108, 101, 0, 0, 6, 0, 1, 192, 12, 0, 1, 0, 255, 0, 0, 0, 0, 0, 0, 192, 12, 0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 127, 0, 0, 1} 41 | msg := new(Msg) 42 | err := msg.Unpack(buf) 43 | if err != nil { 44 | t.Errorf("failed to unpack: %v\n%s", err, msg.String()) 45 | } 46 | } 47 | 48 | func TestDynamicUpdateZeroRdataUnpack(t *testing.T) { 49 | m := new(Msg) 50 | rr := &RR_Header{Name: ".", Rrtype: 0, Class: 1, Ttl: ^uint32(0), Rdlength: 0} 51 | m.Answer = []RR{rr, rr, rr, rr, rr} 52 | m.Ns = m.Answer 53 | for n, s := range TypeToString { 54 | rr.Rrtype = n 55 | bytes, err := m.Pack() 56 | if err != nil { 57 | t.Errorf("failed to pack %s: %v", s, err) 58 | continue 59 | } 60 | if err := new(Msg).Unpack(bytes); err != nil { 61 | t.Errorf("failed to unpack %s: %v", s, err) 62 | } 63 | } 64 | } 65 | 66 | func TestRemoveRRset(t *testing.T) { 67 | // Should add a zero data RR in Class ANY with a TTL of 0 68 | // for each set mentioned in the RRs provided to it. 69 | rr := testRR(". 100 IN A 127.0.0.1") 70 | m := new(Msg) 71 | m.Ns = []RR{&RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY, Ttl: 0, Rdlength: 0}} 72 | expectstr := m.String() 73 | expect, err := m.Pack() 74 | if err != nil { 75 | t.Fatalf("error packing expected msg: %v", err) 76 | } 77 | 78 | m.Ns = nil 79 | m.RemoveRRset([]RR{rr}) 80 | actual, err := m.Pack() 81 | if err != nil { 82 | t.Fatalf("error packing actual msg: %v", err) 83 | } 84 | if !bytes.Equal(actual, expect) { 85 | tmp := new(Msg) 86 | if err := tmp.Unpack(actual); err != nil { 87 | t.Fatalf("error unpacking actual msg: %v\nexpected: %v\ngot: %v\n", err, expect, actual) 88 | } 89 | t.Errorf("expected msg:\n%s", expectstr) 90 | t.Errorf("actual msg:\n%v", tmp) 91 | } 92 | } 93 | 94 | func TestPreReqAndRemovals(t *testing.T) { 95 | // Build a list of multiple prereqs and then some removes followed by an insert. 96 | // We should be able to add multiple prereqs and updates. 97 | m := new(Msg) 98 | m.SetUpdate("example.org.") 99 | m.Id = 1234 100 | 101 | // Use a full set of RRs each time, so we are sure the rdata is stripped. 102 | rrName1 := testRR("name_used. 3600 IN A 127.0.0.1") 103 | rrName2 := testRR("name_not_used. 3600 IN A 127.0.0.1") 104 | rrRemove1 := testRR("remove1. 3600 IN A 127.0.0.1") 105 | rrRemove2 := testRR("remove2. 3600 IN A 127.0.0.1") 106 | rrRemove3 := testRR("remove3. 3600 IN A 127.0.0.1") 107 | rrInsert := testRR("insert. 3600 IN A 127.0.0.1") 108 | rrRrset1 := testRR("rrset_used1. 3600 IN A 127.0.0.1") 109 | rrRrset2 := testRR("rrset_used2. 3600 IN A 127.0.0.1") 110 | rrRrset3 := testRR("rrset_not_used. 3600 IN A 127.0.0.1") 111 | 112 | // Handle the prereqs. 113 | m.NameUsed([]RR{rrName1}) 114 | m.NameNotUsed([]RR{rrName2}) 115 | m.RRsetUsed([]RR{rrRrset1}) 116 | m.Used([]RR{rrRrset2}) 117 | m.RRsetNotUsed([]RR{rrRrset3}) 118 | 119 | // and now the updates. 120 | m.RemoveName([]RR{rrRemove1}) 121 | m.RemoveRRset([]RR{rrRemove2}) 122 | m.Remove([]RR{rrRemove3}) 123 | m.Insert([]RR{rrInsert}) 124 | 125 | // This test function isn't a Example function because we print these RR with tabs at the 126 | // end and the Example function trim these, thus they never match. 127 | // TODO(miek): don't print these tabs and make this into an Example function. 128 | expect := `;; opcode: UPDATE, status: NOERROR, id: 1234 129 | ;; flags:; ZONE: 1, PREREQ: 5, UPDATE: 4, ADDITIONAL: 0 130 | 131 | ;; ZONE SECTION: 132 | ;example.org. IN SOA 133 | 134 | ;; PREREQUISITE SECTION: 135 | name_used. 0 CLASS255 ANY 136 | name_not_used. 0 NONE ANY 137 | rrset_used1. 0 CLASS255 A 138 | rrset_used2. 0 IN A 127.0.0.1 139 | rrset_not_used. 0 NONE A 140 | 141 | ;; UPDATE SECTION: 142 | remove1. 0 CLASS255 ANY 143 | remove2. 0 CLASS255 A 144 | remove3. 0 NONE A 127.0.0.1 145 | insert. 3600 IN A 127.0.0.1 146 | ` 147 | 148 | if m.String() != expect { 149 | t.Errorf("expected msg:\n%s", expect) 150 | t.Errorf("actual msg:\n%v", m.String()) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "fmt" 4 | 5 | // Version is current version of this library. 6 | var Version = v{1, 1, 66} 7 | 8 | // v holds the version of this library. 9 | type v struct { 10 | Major, Minor, Patch int 11 | } 12 | 13 | func (v v) String() string { 14 | return fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch) 15 | } 16 | -------------------------------------------------------------------------------- /version_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import "testing" 4 | 5 | func TestVersion(t *testing.T) { 6 | v := v{1, 0, 0} 7 | if x := v.String(); x != "1.0.0" { 8 | t.Fatalf("Failed to convert version %v, got: %s", v, x) 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /xfr.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | // Envelope is used when doing a zone transfer with a remote server. 10 | type Envelope struct { 11 | RR []RR // The set of RRs in the answer section of the xfr reply message. 12 | Error error // If something went wrong, this contains the error. 13 | } 14 | 15 | // A Transfer defines parameters that are used during a zone transfer. 16 | type Transfer struct { 17 | *Conn 18 | DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds 19 | ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds 20 | WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds 21 | TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. 22 | TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) 23 | tsigTimersOnly bool 24 | TLS *tls.Config // TLS config. If Xfr over TLS will be attempted 25 | } 26 | 27 | func (t *Transfer) tsigProvider() TsigProvider { 28 | if t.TsigProvider != nil { 29 | return t.TsigProvider 30 | } 31 | if t.TsigSecret != nil { 32 | return tsigSecretProvider(t.TsigSecret) 33 | } 34 | return nil 35 | } 36 | 37 | // TODO: Think we need to away to stop the transfer 38 | 39 | // In performs an incoming transfer with the server in a. 40 | // If you would like to set the source IP, or some other attribute 41 | // of a Dialer for a Transfer, you can do so by specifying the attributes 42 | // in the Transfer.Conn: 43 | // 44 | // d := net.Dialer{LocalAddr: transfer_source} 45 | // con, err := d.Dial("tcp", master) 46 | // dnscon := &dns.Conn{Conn:con} 47 | // transfer = &dns.Transfer{Conn: dnscon} 48 | // channel, err := transfer.In(message, master) 49 | func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { 50 | switch q.Question[0].Qtype { 51 | case TypeAXFR, TypeIXFR: 52 | default: 53 | return nil, &Error{"unsupported question type"} 54 | } 55 | 56 | timeout := dnsTimeout 57 | if t.DialTimeout != 0 { 58 | timeout = t.DialTimeout 59 | } 60 | 61 | if t.Conn == nil { 62 | if t.TLS != nil { 63 | t.Conn, err = DialTimeoutWithTLS("tcp-tls", a, t.TLS, timeout) 64 | } else { 65 | t.Conn, err = DialTimeout("tcp", a, timeout) 66 | } 67 | if err != nil { 68 | return nil, err 69 | } 70 | } 71 | 72 | if err := t.WriteMsg(q); err != nil { 73 | return nil, err 74 | } 75 | 76 | env = make(chan *Envelope) 77 | switch q.Question[0].Qtype { 78 | case TypeAXFR: 79 | go t.inAxfr(q, env) 80 | case TypeIXFR: 81 | go t.inIxfr(q, env) 82 | } 83 | 84 | return env, nil 85 | } 86 | 87 | func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) { 88 | first := true 89 | defer func() { 90 | // First close the connection, then the channel. This allows functions blocked on 91 | // the channel to assume that the connection is closed and no further operations are 92 | // pending when they resume. 93 | t.Close() 94 | close(c) 95 | }() 96 | timeout := dnsTimeout 97 | if t.ReadTimeout != 0 { 98 | timeout = t.ReadTimeout 99 | } 100 | for { 101 | t.Conn.SetReadDeadline(time.Now().Add(timeout)) 102 | in, err := t.ReadMsg() 103 | if err != nil { 104 | c <- &Envelope{nil, err} 105 | return 106 | } 107 | if q.Id != in.Id { 108 | c <- &Envelope{in.Answer, ErrId} 109 | return 110 | } 111 | if first { 112 | if in.Rcode != RcodeSuccess { 113 | c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}} 114 | return 115 | } 116 | if !isSOAFirst(in) { 117 | c <- &Envelope{in.Answer, ErrSoa} 118 | return 119 | } 120 | first = !first 121 | // only one answer that is SOA, receive more 122 | if len(in.Answer) == 1 { 123 | t.tsigTimersOnly = true 124 | c <- &Envelope{in.Answer, nil} 125 | continue 126 | } 127 | } 128 | 129 | if !first { 130 | t.tsigTimersOnly = true // Subsequent envelopes use this. 131 | if isSOALast(in) { 132 | c <- &Envelope{in.Answer, nil} 133 | return 134 | } 135 | c <- &Envelope{in.Answer, nil} 136 | } 137 | } 138 | } 139 | 140 | func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) { 141 | var serial uint32 // The first serial seen is the current server serial 142 | axfr := true 143 | n := 0 144 | qser := q.Ns[0].(*SOA).Serial 145 | defer func() { 146 | // First close the connection, then the channel. This allows functions blocked on 147 | // the channel to assume that the connection is closed and no further operations are 148 | // pending when they resume. 149 | t.Close() 150 | close(c) 151 | }() 152 | timeout := dnsTimeout 153 | if t.ReadTimeout != 0 { 154 | timeout = t.ReadTimeout 155 | } 156 | for { 157 | t.SetReadDeadline(time.Now().Add(timeout)) 158 | in, err := t.ReadMsg() 159 | if err != nil { 160 | c <- &Envelope{nil, err} 161 | return 162 | } 163 | if q.Id != in.Id { 164 | c <- &Envelope{in.Answer, ErrId} 165 | return 166 | } 167 | if in.Rcode != RcodeSuccess { 168 | c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}} 169 | return 170 | } 171 | if n == 0 { 172 | // Check if the returned answer is ok 173 | if !isSOAFirst(in) { 174 | c <- &Envelope{in.Answer, ErrSoa} 175 | return 176 | } 177 | // This serial is important 178 | serial = in.Answer[0].(*SOA).Serial 179 | // Check if there are no changes in zone 180 | if qser >= serial { 181 | c <- &Envelope{in.Answer, nil} 182 | return 183 | } 184 | } 185 | // Now we need to check each message for SOA records, to see what we need to do 186 | t.tsigTimersOnly = true 187 | for _, rr := range in.Answer { 188 | if v, ok := rr.(*SOA); ok { 189 | if v.Serial == serial { 190 | n++ 191 | // quit if it's a full axfr or the servers' SOA is repeated the third time 192 | if axfr && n == 2 || n == 3 { 193 | c <- &Envelope{in.Answer, nil} 194 | return 195 | } 196 | } else if axfr { 197 | // it's an ixfr 198 | axfr = false 199 | } 200 | } 201 | } 202 | c <- &Envelope{in.Answer, nil} 203 | } 204 | } 205 | 206 | // Out performs an outgoing transfer with the client connecting in w. 207 | // Basic use pattern: 208 | // 209 | // ch := make(chan *dns.Envelope) 210 | // tr := new(dns.Transfer) 211 | // var wg sync.WaitGroup 212 | // wg.Add(1) 213 | // go func() { 214 | // tr.Out(w, r, ch) 215 | // wg.Done() 216 | // }() 217 | // ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}} 218 | // close(ch) 219 | // wg.Wait() // wait until everything is written out 220 | // w.Close() // close connection 221 | // 222 | // The server is responsible for sending the correct sequence of RRs through the channel ch. 223 | func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { 224 | for x := range ch { 225 | r := new(Msg) 226 | // Compress? 227 | r.SetReply(q) 228 | r.Authoritative = true 229 | // assume it fits TODO(miek): fix 230 | r.Answer = append(r.Answer, x.RR...) 231 | if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil { 232 | r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix()) 233 | } 234 | if err := w.WriteMsg(r); err != nil { 235 | return err 236 | } 237 | w.TsigTimersOnly(true) 238 | } 239 | return nil 240 | } 241 | 242 | // ReadMsg reads a message from the transfer connection t. 243 | func (t *Transfer) ReadMsg() (*Msg, error) { 244 | m := new(Msg) 245 | p := make([]byte, MaxMsgSize) 246 | n, err := t.Read(p) 247 | if err != nil && n == 0 { 248 | return nil, err 249 | } 250 | p = p[:n] 251 | if err := m.Unpack(p); err != nil { 252 | return nil, err 253 | } 254 | 255 | if tp := t.tsigProvider(); tp != nil { 256 | // Need to work on the original message p, as that was used to calculate the tsig. 257 | err = TsigVerifyWithProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly) 258 | if ts := m.IsTsig(); ts != nil { 259 | t.tsigRequestMAC = ts.MAC 260 | } 261 | } 262 | return m, err 263 | } 264 | 265 | // WriteMsg writes a message through the transfer connection t. 266 | func (t *Transfer) WriteMsg(m *Msg) (err error) { 267 | var out []byte 268 | if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil { 269 | out, t.tsigRequestMAC, err = TsigGenerateWithProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly) 270 | } else { 271 | out, err = m.Pack() 272 | } 273 | if err != nil { 274 | return err 275 | } 276 | _, err = t.Write(out) 277 | return err 278 | } 279 | 280 | func isSOAFirst(in *Msg) bool { 281 | return len(in.Answer) > 0 && 282 | in.Answer[0].Header().Rrtype == TypeSOA 283 | } 284 | 285 | func isSOALast(in *Msg) bool { 286 | return len(in.Answer) > 0 && 287 | in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA 288 | } 289 | 290 | const errXFR = "bad xfr rcode: %d" 291 | -------------------------------------------------------------------------------- /xfr_test.go: -------------------------------------------------------------------------------- 1 | package dns 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | var ( 11 | tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} 12 | xfrSoa = testRR(`miek.nl. 0 IN SOA linode.atoom.net. miek.miek.nl. 2009032802 21600 7200 604800 3600`) 13 | xfrA = testRR(`x.miek.nl. 1792 IN A 10.0.0.1`) 14 | xfrMX = testRR(`miek.nl. 1800 IN MX 1 x.miek.nl.`) 15 | xfrTestData = []RR{xfrSoa, xfrA, xfrMX, xfrSoa} 16 | ) 17 | 18 | func InvalidXfrServer(w ResponseWriter, req *Msg) { 19 | ch := make(chan *Envelope) 20 | tr := new(Transfer) 21 | 22 | go tr.Out(w, req, ch) 23 | ch <- &Envelope{RR: []RR{}} 24 | close(ch) 25 | w.Hijack() 26 | } 27 | 28 | func SingleEnvelopeXfrServer(w ResponseWriter, req *Msg) { 29 | ch := make(chan *Envelope) 30 | tr := new(Transfer) 31 | 32 | go tr.Out(w, req, ch) 33 | ch <- &Envelope{RR: xfrTestData} 34 | close(ch) 35 | w.Hijack() 36 | } 37 | 38 | func MultipleEnvelopeXfrServer(w ResponseWriter, req *Msg) { 39 | ch := make(chan *Envelope) 40 | tr := new(Transfer) 41 | 42 | go tr.Out(w, req, ch) 43 | 44 | for _, rr := range xfrTestData { 45 | ch <- &Envelope{RR: []RR{rr}} 46 | } 47 | close(ch) 48 | w.Hijack() 49 | } 50 | 51 | func TestInvalidXfr(t *testing.T) { 52 | HandleFunc("miek.nl.", InvalidXfrServer) 53 | defer HandleRemove("miek.nl.") 54 | 55 | s, addrstr, _, err := RunLocalTCPServer(":0") 56 | if err != nil { 57 | t.Fatalf("unable to run test server: %s", err) 58 | } 59 | defer s.Shutdown() 60 | 61 | tr := new(Transfer) 62 | m := new(Msg) 63 | m.SetAxfr("miek.nl.") 64 | 65 | c, err := tr.In(m, addrstr) 66 | if err != nil { 67 | t.Fatal("failed to zone transfer in", err) 68 | } 69 | 70 | for msg := range c { 71 | if msg.Error == nil { 72 | t.Fatal("failed to catch 'no SOA' error") 73 | } 74 | } 75 | } 76 | 77 | func TestSingleEnvelopeXfr(t *testing.T) { 78 | HandleFunc("miek.nl.", SingleEnvelopeXfrServer) 79 | defer HandleRemove("miek.nl.") 80 | 81 | s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { 82 | srv.TsigSecret = tsigSecret 83 | }) 84 | if err != nil { 85 | t.Fatalf("unable to run test server: %s", err) 86 | } 87 | defer s.Shutdown() 88 | 89 | axfrTestingSuite(t, addrstr) 90 | } 91 | 92 | func TestSingleEnvelopeXfrTLS(t *testing.T) { 93 | HandleFunc("miek.nl.", SingleEnvelopeXfrServer) 94 | defer HandleRemove("miek.nl.") 95 | 96 | cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 97 | if err != nil { 98 | t.Fatalf("unable to build certificate: %v", err) 99 | } 100 | 101 | tlsConfig := tls.Config{ 102 | Certificates: []tls.Certificate{cert}, 103 | } 104 | s, addrstr, _, err := RunLocalTLSServer(":0", &tlsConfig) 105 | if err != nil { 106 | t.Fatalf("unable to run test server: %s", err) 107 | } 108 | defer s.Shutdown() 109 | 110 | axfrTestingSuiteTLS(t, addrstr) 111 | } 112 | 113 | func TestMultiEnvelopeXfr(t *testing.T) { 114 | HandleFunc("miek.nl.", MultipleEnvelopeXfrServer) 115 | defer HandleRemove("miek.nl.") 116 | 117 | s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { 118 | srv.TsigSecret = tsigSecret 119 | }) 120 | if err != nil { 121 | t.Fatalf("unable to run test server: %s", err) 122 | } 123 | defer s.Shutdown() 124 | 125 | axfrTestingSuite(t, addrstr) 126 | } 127 | 128 | func axfrTestingSuite(t *testing.T, addrstr string) { 129 | tr := new(Transfer) 130 | m := new(Msg) 131 | m.SetAxfr("miek.nl.") 132 | 133 | c, err := tr.In(m, addrstr) 134 | if err != nil { 135 | t.Fatal("failed to zone transfer in", err) 136 | } 137 | 138 | var records []RR 139 | for msg := range c { 140 | if msg.Error != nil { 141 | t.Fatal(msg.Error) 142 | } 143 | records = append(records, msg.RR...) 144 | } 145 | 146 | if len(records) != len(xfrTestData) { 147 | t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) 148 | } 149 | 150 | for i, rr := range records { 151 | if !IsDuplicate(rr, xfrTestData[i]) { 152 | t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) 153 | } 154 | } 155 | } 156 | 157 | func axfrTestingSuiteTLS(t *testing.T, addrstr string) { 158 | tr := new(Transfer) 159 | m := new(Msg) 160 | m.SetAxfr("miek.nl.") 161 | 162 | tr.TLS = &tls.Config{ 163 | InsecureSkipVerify: true, 164 | } 165 | c, err := tr.In(m, addrstr) 166 | if err != nil { 167 | t.Fatal("failed to zone transfer in", err) 168 | } 169 | 170 | var records []RR 171 | for msg := range c { 172 | if msg.Error != nil { 173 | t.Fatal(msg.Error) 174 | } 175 | records = append(records, msg.RR...) 176 | } 177 | 178 | if len(records) != len(xfrTestData) { 179 | t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) 180 | } 181 | 182 | for i, rr := range records { 183 | if !IsDuplicate(rr, xfrTestData[i]) { 184 | t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) 185 | } 186 | } 187 | } 188 | 189 | func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) { 190 | tr := new(Transfer) 191 | m := new(Msg) 192 | var err error 193 | tr.Conn, err = Dial("tcp", addrstr) 194 | if err != nil { 195 | t.Fatal("failed to dial", err) 196 | } 197 | tr.TsigProvider = provider 198 | m.SetAxfr("miek.nl.") 199 | m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix()) 200 | 201 | c, err := tr.In(m, addrstr) 202 | if err != nil { 203 | t.Fatal("failed to zone transfer in", err) 204 | } 205 | 206 | var records []RR 207 | for msg := range c { 208 | if msg.Error != nil { 209 | t.Fatal(msg.Error) 210 | } 211 | records = append(records, msg.RR...) 212 | } 213 | 214 | if len(records) != len(xfrTestData) { 215 | t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) 216 | } 217 | 218 | for i, rr := range records { 219 | if !IsDuplicate(rr, xfrTestData[i]) { 220 | t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData) 221 | } 222 | } 223 | } 224 | 225 | func axfrTestingSuiteWithMsgNotSigned(t *testing.T, addrstr string, provider TsigProvider) { 226 | tr := new(Transfer) 227 | m := new(Msg) 228 | var err error 229 | tr.Conn, err = Dial("tcp", addrstr) 230 | if err != nil { 231 | t.Fatal("failed to dial", err) 232 | } 233 | tr.TsigProvider = provider 234 | m.SetAxfr("miek.nl.") 235 | 236 | c, err := tr.In(m, addrstr) 237 | if err != nil { 238 | t.Fatal("failed to zone transfer in", err) 239 | } 240 | 241 | for msg := range c { 242 | if !errors.Is(msg.Error, ErrNoSig) { 243 | t.Fatal("expecting ErrNoSig error") 244 | } 245 | } 246 | } 247 | 248 | func TestCustomTsigProvider(t *testing.T) { 249 | HandleFunc("miek.nl.", SingleEnvelopeXfrServer) 250 | defer HandleRemove("miek.nl.") 251 | 252 | s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { 253 | srv.TsigProvider = tsigSecretProvider(tsigSecret) 254 | }) 255 | if err != nil { 256 | t.Fatalf("unable to run test server: %s", err) 257 | } 258 | defer s.Shutdown() 259 | 260 | axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret)) 261 | } 262 | 263 | func TestTSIGNotSigned(t *testing.T) { 264 | HandleFunc("miek.nl.", SingleEnvelopeXfrServer) 265 | defer HandleRemove("miek.nl.") 266 | 267 | s, addrstr, _, err := RunLocalTCPServer(":0") 268 | if err != nil { 269 | t.Fatalf("unable to run test server: %s", err) 270 | } 271 | defer s.Shutdown() 272 | 273 | axfrTestingSuiteWithMsgNotSigned(t, addrstr, tsigSecretProvider(tsigSecret)) 274 | } 275 | --------------------------------------------------------------------------------