├── .travis.yml ├── LICENSE.txt ├── README.md ├── cache.go ├── config-example.json ├── config.go ├── daemon.go ├── daemon_linux.go ├── daemon_windows.go ├── dial.go ├── main.go ├── myip.go ├── test_script ├── config.json └── test.sh └── upstream.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go: 4 | - 1.7.1 5 | addons: 6 | apt: 7 | packages: 8 | - dnsutils 9 | script: 10 | - bash -x ./test_script/test.sh 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2016] [ayanamist] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Local DNS服务器。[![Build Status](https://travis-ci.org/ayanamist/gdns-go.png?branch=master)](https://travis-ci.org/ayanamist/gdns-go) 2 | 3 | 1. 服务端是[Google HTTPS DNS](https://developers.google.com/speed/public-dns/docs/dns-over-https) 4 | 5 | 2. 通过ShadowSocks解决访问问题。 6 | 7 | 3. 通过传递探测到的公网IP作为edns0 subnet的参数来解决CDN解析出美国IP而不是中国IP的问题。 8 | 9 | 4. 公网IP探测使用了[淘宝的API](http://ip.taobao.com/instructions.php)。 10 | 11 | 5. 自带域名分流功能,但设计目标仅针对公司内网域名服务,不需要把常用国内网站加入,由于第3点,不会受到SS服务器IP的影响(SS服务器在美国依然能解析出中国IP) 12 | 13 | 6. 通过HTTP 2.0解决传统DNS over TCP缓慢的问题。 14 | 15 | ---- 16 | 17 | 已知问题: 18 | 19 | 1. 连接了AnyConnect VPN后,AnyConnect Windows Client会阻断访问本地53端口的DNS,导致无法使用。 -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/cloudflare/golibs/lrucache" 8 | "github.com/miekg/dns" 9 | ) 10 | 11 | type DNSCache struct { 12 | cache *lrucache.LRUCache 13 | } 14 | 15 | func NewDNSCache(size uint32) *DNSCache { 16 | return &DNSCache{ 17 | cache: lrucache.NewLRUCache(uint(size)), 18 | } 19 | } 20 | 21 | func questionKey(q dns.Question) string { 22 | return fmt.Sprintf("%s%d%d", q.Name, q.Qclass, q.Qtype) 23 | } 24 | 25 | func (d *DNSCache) Put(q dns.Question, m *dns.Msg) { 26 | if d.cache.Capacity() == 0 { 27 | return 28 | } 29 | var minTTL uint32 = 0xffffffff 30 | for _, rr := range m.Answer { 31 | ttl := rr.Header().Ttl 32 | if minTTL > ttl { 33 | minTTL = ttl 34 | } 35 | } 36 | 37 | d.cache.Set(questionKey(q), m, time.Now().Add(time.Duration(minTTL)*time.Second)) 38 | } 39 | 40 | func (d *DNSCache) Get(q dns.Question) *dns.Msg { 41 | v, _ := d.cache.GetNotStale(questionKey(q)) 42 | if v != nil { 43 | return v.(*dns.Msg).Copy() 44 | } else { 45 | return nil 46 | } 47 | } 48 | 49 | func (d *DNSCache) Purge() { 50 | d.cache.Clear() 51 | } 52 | -------------------------------------------------------------------------------- /config-example.json: -------------------------------------------------------------------------------- 1 | { 2 | "listen": "127.0.0.1:53", 3 | "proxy": "ss://method:pass@server:port", 4 | "mapping": { 5 | "taobao.com": "223.5.5.5" 6 | } 7 | } -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "os" 7 | ) 8 | 9 | type Config struct { 10 | Listen string `json:"listen"` 11 | Proxy string `json:"proxy"` 12 | MyIP string `json:"myip"` 13 | Mapping map[string]string `json:"mapping"` 14 | CacheSize *uint32 `json:"cache_size"` 15 | QueryTimeoutSec uint32 `json:"query_timeout_sec"` 16 | } 17 | 18 | func GetConfigFromFile(path string) (*Config, error) { 19 | f, err := os.Open(path) 20 | if err != nil { 21 | return nil, err 22 | } 23 | jsonBytes, err := ioutil.ReadAll(f) 24 | if err != nil { 25 | return nil, err 26 | } 27 | var config Config 28 | if err := json.Unmarshal(jsonBytes, &config); err != nil { 29 | return nil, err 30 | } 31 | return &config, nil 32 | } 33 | -------------------------------------------------------------------------------- /daemon.go: -------------------------------------------------------------------------------- 1 | // +build !linux,!windows 2 | 3 | package main 4 | 5 | import "syscall" 6 | 7 | var ( 8 | daemonSysprocAttr *syscall.SysProcAttr = nil 9 | ) 10 | -------------------------------------------------------------------------------- /daemon_linux.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "syscall" 4 | 5 | var ( 6 | daemonSysprocAttr = &syscall.SysProcAttr{ 7 | Setpgid: true, 8 | } 9 | ) 10 | -------------------------------------------------------------------------------- /daemon_windows.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "syscall" 4 | 5 | var ( 6 | daemonSysprocAttr = &syscall.SysProcAttr{ 7 | HideWindow: true, 8 | } 9 | ) 10 | -------------------------------------------------------------------------------- /dial.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "net/url" 8 | "time" 9 | 10 | "golang.org/x/net/proxy" 11 | ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" 12 | ) 13 | 14 | func NewDialFromURL(u *url.URL) (func(network, addr string) (net.Conn, error), error) { 15 | switch u.Scheme { 16 | case "ss": 17 | return newSSDial(u) 18 | case "socks5": 19 | dialer, err := proxy.FromURL(u, proxy.Direct) 20 | return dialer.Dial, err 21 | default: 22 | return nil, fmt.Errorf("unsupported scheme: %s", u.Scheme) 23 | } 24 | } 25 | 26 | func newSSDial(u *url.URL) (func(network, addr string) (net.Conn, error), error) { 27 | password, ok := u.User.Password() 28 | if !ok { 29 | return nil, errors.New("no password") 30 | } 31 | if _, err := ss.NewCipher(u.User.Username(), password); err != nil { 32 | return nil, err 33 | } 34 | return func(network, addr string) (net.Conn, error) { 35 | rawAddr, err := ss.RawAddr(addr) 36 | if err != nil { 37 | return nil, err 38 | } 39 | conn, err := net.DialTimeout("tcp", u.Host, 5*time.Second) 40 | if err != nil { 41 | return nil, err 42 | } 43 | cipher, _ := ss.NewCipher(u.User.Username(), password) 44 | c := ss.NewConn(conn, cipher) 45 | if _, err = c.Write(rawAddr); err != nil { 46 | c.Close() 47 | return nil, err 48 | } 49 | return c, nil 50 | }, nil 51 | } 52 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "flag" 7 | "log" 8 | "net" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "os/exec" 13 | "strings" 14 | "time" 15 | 16 | "github.com/miekg/dns" 17 | "golang.org/x/net/http2" 18 | ) 19 | 20 | const ( 21 | AliDNS = "223.5.5.5:53" 22 | ) 23 | 24 | var ( 25 | confFile = flag.String("conf", "config.json", "Specify config json path") 26 | daemon = flag.Bool("d", false, "Run as daemon") 27 | 28 | myIP *MyIP 29 | dnsCache *DNSCache 30 | possibleLoopDomains = []string{GoogleDnsHttpsDomain} 31 | dnsQueryTimeoutSec time.Duration 32 | fallbackUpstream *TcpUdpUpstream 33 | ) 34 | 35 | type MyHandler struct { 36 | upstreamMap map[string][]Upstream 37 | cache *DNSCache 38 | } 39 | 40 | func appendEdns0Subnet(m *dns.Msg, addr net.IP) { 41 | newOpt := true 42 | var o *dns.OPT 43 | for _, v := range m.Extra { 44 | if v.Header().Rrtype == dns.TypeOPT { 45 | o = v.(*dns.OPT) 46 | newOpt = false 47 | break 48 | } 49 | } 50 | if o == nil { 51 | o = new(dns.OPT) 52 | o.Hdr.Name = "." 53 | o.Hdr.Rrtype = dns.TypeOPT 54 | } 55 | e := new(dns.EDNS0_SUBNET) 56 | e.Code = dns.EDNS0SUBNET 57 | e.SourceScope = 0 58 | e.Address = addr 59 | if e.Address.To4() == nil { 60 | e.Family = 2 // IP6 61 | e.SourceNetmask = net.IPv6len * 8 62 | } else { 63 | e.Family = 1 // IP4 64 | e.SourceNetmask = net.IPv4len * 8 65 | } 66 | o.Option = append(o.Option, e) 67 | if newOpt { 68 | m.Extra = append(m.Extra, o) 69 | } 70 | } 71 | 72 | func (h *MyHandler) determineRoute(domain string) (u []Upstream) { 73 | for domain != "" && domain[len(domain)-1] == '.' { 74 | domain = domain[:len(domain)-1] 75 | } 76 | avoidLoop := false 77 | var ok bool 78 | for domain != "" { 79 | for _, d := range possibleLoopDomains { 80 | if domain == d { 81 | avoidLoop = true 82 | break 83 | } 84 | } 85 | u, ok = h.upstreamMap[domain] 86 | if ok { 87 | break 88 | } 89 | idx := strings.IndexByte(domain, '.') 90 | if idx < 0 { 91 | break 92 | } 93 | domain = domain[idx+1:] 94 | } 95 | if len(u) == 0 { 96 | u = h.upstreamMap[""] 97 | } 98 | if avoidLoop { 99 | ups := []Upstream{} 100 | for _, s := range u { 101 | if _, ok = s.(*GoogleHttpsUpstream); !ok { 102 | ups = append(ups, s) 103 | } 104 | } 105 | if len(ups) > 0 { 106 | u = ups 107 | } else { 108 | u = []Upstream{fallbackUpstream} 109 | } 110 | } 111 | return 112 | } 113 | 114 | func (h *MyHandler) ServeDNS(w dns.ResponseWriter, reqMsg *dns.Msg) { 115 | var err error 116 | addr := myIP.GetIP() 117 | if addr != nil && !addr.IsLoopback() { 118 | appendEdns0Subnet(reqMsg, addr) 119 | } 120 | 121 | type chanResp struct { 122 | m *dns.Msg 123 | err error 124 | } 125 | var respMsg *dns.Msg 126 | allQuestions := reqMsg.Question 127 | for qi, q := range allQuestions { 128 | typ, ok := dns.TypeToString[q.Qtype] 129 | if !ok { 130 | typ = "UnknownType" 131 | } 132 | 133 | respMsg = h.cache.Get(q) 134 | if respMsg == nil { 135 | up := h.determineRoute(q.Name) 136 | 137 | for i, u := range up { 138 | m := reqMsg.Copy() 139 | m.Question = allQuestions[qi : qi+1] 140 | 141 | log.Printf("%s#%d %d/%d query %v, type=%s => %s(%d)", w.RemoteAddr(), m.Id, qi+1, len(allQuestions), q.Name, typ, u.Name(), i) 142 | ch := make(chan chanResp) 143 | go func(i int, u Upstream) { 144 | start := time.Now() 145 | respMsg, err := u.Exchange(m) 146 | log.Printf("%s#%d %d/%d %s(%d) rtt=%dms, err=%v", w.RemoteAddr(), m.Id, qi+1, len(allQuestions), u.Name(), i, time.Since(start)/1e6, err) 147 | ch <- chanResp{respMsg, err} 148 | close(ch) 149 | }(i, u) 150 | select { 151 | case resp := <-ch: 152 | respMsg, err = resp.m, resp.err 153 | case <-time.After(dnsQueryTimeoutSec): 154 | go func() { 155 | <-ch 156 | }() 157 | respMsg, err = nil, errors.New("single timeout") 158 | } 159 | if err == nil { 160 | break 161 | } 162 | } 163 | 164 | if respMsg != nil { 165 | h.cache.Put(q, respMsg) 166 | } 167 | } else { 168 | respMsg.Id = reqMsg.Id 169 | log.Printf("%s#%d %d/%d query %v, type=%s => cache", w.RemoteAddr(), respMsg.Id, qi+1, len(allQuestions), q.Name, typ) 170 | } 171 | 172 | if respMsg != nil { 173 | if err := w.WriteMsg(respMsg); err != nil { 174 | log.Printf("WriteMsg: %v", err) 175 | } 176 | } 177 | } 178 | } 179 | 180 | func init() { 181 | log.SetOutput(os.Stdout) 182 | } 183 | 184 | func main() { 185 | flag.Parse() 186 | 187 | if *daemon { 188 | newArgs := make([]string, len(os.Args)-1) 189 | for i, j := 0, 0; i < len(os.Args); i++ { 190 | if os.Args[i] != "-d" { 191 | newArgs[j] = os.Args[i] 192 | j++ 193 | } 194 | } 195 | cmd := exec.Command(newArgs[0], newArgs[1:]...) 196 | cmd.Dir, _ = os.Getwd() 197 | cmd.Stdout = os.Stdout 198 | cmd.Stderr = os.Stderr 199 | cmd.Env = os.Environ() 200 | cmd.SysProcAttr = daemonSysprocAttr 201 | if err := cmd.Start(); err != nil { 202 | log.Fatalf("run as daemon error: %v", err) 203 | } 204 | return 205 | } 206 | 207 | config, err := GetConfigFromFile(*confFile) 208 | if err != nil { 209 | log.Fatalln(err) 210 | } 211 | 212 | dnsQueryTimeoutSec = time.Duration(config.QueryTimeoutSec) * time.Second 213 | if dnsQueryTimeoutSec == 0 { 214 | dnsQueryTimeoutSec = 5 * time.Second 215 | } 216 | 217 | fallbackUpstream = &TcpUdpUpstream{ 218 | NameServer: AliDNS, 219 | Network: "udp", 220 | Dial: (&net.Dialer{ 221 | Timeout: dnsQueryTimeoutSec, 222 | }).Dial, 223 | } 224 | 225 | myIP = new(MyIP) 226 | if config.MyIP == "" { 227 | myIP.Client = &http.Client{ 228 | Transport: &http.Transport{ 229 | Dial: (&net.Dialer{ 230 | Timeout: 3 * time.Second, 231 | }).Dial, 232 | ResponseHeaderTimeout: 30 * time.Second, 233 | IdleConnTimeout: 30 * time.Second, 234 | }, 235 | Timeout: 30 * time.Second, 236 | } 237 | myIP.SetIP(net.IP{127, 0, 0, 1}) 238 | myIP.StartTaobaoIPLoop(func(oldIP, newIP net.IP) { 239 | dnsCache.Purge() 240 | }) 241 | } else { 242 | myIP.SetIP(net.ParseIP(config.MyIP)) 243 | } 244 | 245 | dial := (&net.Dialer{ 246 | Timeout: 5 * time.Second, 247 | }).Dial 248 | if config.Proxy != "" { 249 | u, err := url.Parse(config.Proxy) 250 | if err != nil { 251 | log.Fatalf("invalid proxy url %s: %v", config.Proxy, err) 252 | } 253 | if dial, err = NewDialFromURL(u); err != nil { 254 | log.Fatalln(err) 255 | } 256 | domain := strings.SplitN(u.Host, ":", 2)[0] 257 | if net.ParseIP(domain) == nil { 258 | possibleLoopDomains = append(possibleLoopDomains, domain) 259 | } 260 | } 261 | defaultGoogleUpstream := &GoogleHttpsUpstream{ 262 | Client: &http.Client{ 263 | Transport: &http2.Transport{ 264 | DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 265 | conn, err := dial(network, addr) 266 | if err != nil { 267 | return nil, err 268 | } 269 | return tls.Client(conn, cfg), nil 270 | }, 271 | }, 272 | Timeout: 2 * time.Second, 273 | }, 274 | } 275 | 276 | upstreamMap := make(map[string][]Upstream) 277 | for k, v := range config.Mapping { 278 | upstreams := []Upstream{} 279 | for _, v := range strings.Split(v, ",") { 280 | var upstream Upstream 281 | if v == "default" { 282 | upstream = defaultGoogleUpstream 283 | } else { 284 | if _, _, err := net.SplitHostPort(v); err != nil { 285 | if strings.Contains(err.Error(), "missing port in address") { 286 | v += ":53" 287 | } else { 288 | log.Fatalf("dns server %s invalid: %v", v, err) 289 | } 290 | } 291 | upstream = &TcpUdpUpstream{ 292 | NameServer: v, 293 | Network: "udp", 294 | Dial: (&net.Dialer{ 295 | Timeout: dnsQueryTimeoutSec, 296 | }).Dial, 297 | } 298 | } 299 | upstreams = append(upstreams, upstream) 300 | } 301 | if len(upstreams) > 0 { 302 | upstreamMap[k] = upstreams 303 | } 304 | } 305 | if _, ok := upstreamMap[""]; !ok { 306 | upstreamMap[""] = []Upstream{defaultGoogleUpstream} 307 | } 308 | 309 | listenAddr := "127.0.0.1:53" 310 | if config.Listen != "" { 311 | listenAddr = config.Listen 312 | } 313 | 314 | var cacheSize uint32 = 1000 315 | if config.CacheSize != nil { 316 | cacheSize = *config.CacheSize 317 | } 318 | dnsCache = NewDNSCache(cacheSize) 319 | server := &dns.Server{ 320 | Addr: listenAddr, 321 | Net: "udp", 322 | Handler: &MyHandler{ 323 | upstreamMap: upstreamMap, 324 | cache: dnsCache, 325 | }, 326 | TsigSecret: nil, 327 | } 328 | 329 | log.Printf("try to listen on %s", listenAddr) 330 | if err := server.ListenAndServe(); err != nil { 331 | log.Fatalf("Failed to setup the server: %v", err) 332 | } 333 | } 334 | -------------------------------------------------------------------------------- /myip.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "math/rand" 9 | "net" 10 | "net/http" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | const TaobaoIpURL = "http://ip.taobao.com/service/getIpInfo.php?ip=myip" 16 | 17 | type MyIP struct { 18 | Client *http.Client 19 | sync.RWMutex 20 | ip net.IP 21 | } 22 | 23 | func (m *MyIP) refreshFromTaobaoIP() error { 24 | req, err := http.NewRequest(http.MethodGet, TaobaoIpURL, nil) 25 | if err != nil { 26 | return err 27 | } 28 | client := m.Client 29 | if client == nil { 30 | client = http.DefaultClient 31 | } 32 | resp, err := client.Do(req) 33 | if err != nil { 34 | return err 35 | } 36 | defer resp.Body.Close() 37 | if resp.StatusCode != http.StatusOK { 38 | return fmt.Errorf("unexpected status: %v", resp.Status) 39 | } 40 | respBytes, err := ioutil.ReadAll(resp.Body) 41 | if err != nil { 42 | return err 43 | } 44 | tbRes := struct { 45 | Code int `json:"code"` 46 | Data struct { 47 | IP string `json:"ip"` 48 | } `json:"data"` 49 | }{} 50 | if err := json.Unmarshal(respBytes, &tbRes); err != nil || tbRes.Data.IP == "" { 51 | return fmt.Errorf("unexpected result, error=%v: %s", err, string(respBytes)) 52 | } 53 | ip := net.ParseIP(tbRes.Data.IP) 54 | if ip == nil { 55 | return fmt.Errorf("unexpected ip: %s", tbRes.Data.IP) 56 | } 57 | m.SetIP(ip) 58 | return nil 59 | } 60 | 61 | func (m *MyIP) GetIP() net.IP { 62 | m.RLock() 63 | defer m.RUnlock() 64 | return m.ip 65 | } 66 | 67 | func (m *MyIP) SetIP(ip net.IP) { 68 | m.Lock() 69 | m.ip = ip 70 | m.Unlock() 71 | } 72 | 73 | func (m *MyIP) StartTaobaoIPLoop(cb func(oldIP, newIP net.IP)) { 74 | go func() { 75 | oldIP := m.GetIP() 76 | for { 77 | if err := m.refreshFromTaobaoIP(); err != nil { 78 | log.Printf("refresh myip failed: %v", err) 79 | } else { 80 | newIP := m.GetIP() 81 | if !oldIP.Equal(newIP) { 82 | log.Printf("myip changed from %s to %s", oldIP, newIP) 83 | if cb != nil { 84 | go cb(oldIP, newIP) 85 | } 86 | oldIP = newIP 87 | } 88 | } 89 | time.Sleep(time.Duration(60+rand.Intn(60)) * time.Second) 90 | } 91 | }() 92 | } 93 | -------------------------------------------------------------------------------- /test_script/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "listen": "127.0.0.1:5353", 3 | "proxy": "ss://aes-128-cfb:password@127.0.0.1:8388", 4 | "mapping": { 5 | "dns.google.com": "8.8.8.8" 6 | } 7 | } -------------------------------------------------------------------------------- /test_script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | go get github.com/shadowsocks/shadowsocks-go/cmd/shadowsocks-server 4 | $GOPATH/bin/shadowsocks-server -p 8388 -k password -m aes-128-cfb -t 60 & 5 | SS_PID=$! 6 | 7 | $GOPATH/bin/gdns-go -conf $(cd $(dirname ${BASH_SOURCE[0]}); pwd)/config.json | tee stdout & 8 | PID=$! 9 | sleep 1 10 | dig -p 5353 @127.0.0.1 www.google.com &&\ 11 | grep -qF ' => https://dns.google.com/resolve' stdout &&\ 12 | dig -p 5353 @127.0.0.1 dns.google.com &&\ 13 | grep -qF ' => udp://8.8.8.8:53' stdout 14 | CODE=$? 15 | kill -9 $SS_PID 16 | kill $PID 17 | sleep 1 18 | if kill -0 $PID 2>/dev/null; then 19 | echo "$PID is still alive" 20 | kill -9 $PID 21 | exit 1 22 | fi 23 | exit $CODE -------------------------------------------------------------------------------- /upstream.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "strconv" 11 | "strings" 12 | "sync/atomic" 13 | "time" 14 | 15 | "github.com/miekg/dns" 16 | ) 17 | 18 | type Upstream interface { 19 | Name() string 20 | Exchange(m *dns.Msg) (r *dns.Msg, err error) 21 | } 22 | 23 | type TcpUdpUpstream struct { 24 | NameServer string 25 | Network string 26 | Dial func(network, addr string) (net.Conn, error) 27 | trId uint32 28 | } 29 | 30 | func (t *TcpUdpUpstream) Name() string { 31 | return t.Network + "://" + t.NameServer 32 | } 33 | 34 | func (t *TcpUdpUpstream) Exchange(m *dns.Msg) (r *dns.Msg, err error) { 35 | co := new(dns.Conn) 36 | if co.Conn, err = t.Dial(t.Network, t.NameServer); err != nil { 37 | return nil, fmt.Errorf("Dial: %v", err) 38 | } 39 | defer co.Close() 40 | oldId := m.Id 41 | m.Id = uint16(atomic.AddUint32(&t.trId, 1)) 42 | defer func() { 43 | m.Id = oldId 44 | }() 45 | co.SetWriteDeadline(time.Now().Add(dnsQueryTimeoutSec)) 46 | if err = co.WriteMsg(m); err != nil { 47 | return nil, fmt.Errorf("WriteMsg: %v", err) 48 | } 49 | co.SetReadDeadline(time.Now().Add(dnsQueryTimeoutSec)) 50 | r, err = co.ReadMsg() 51 | if err != nil { 52 | err = fmt.Errorf("ReadMsg: %v", err) 53 | } 54 | if r != nil { 55 | r.Id = oldId 56 | } 57 | return r, err 58 | } 59 | 60 | const ( 61 | GoogleDnsHttpsDomain = "dns.google.com" 62 | GoogleDnsHttpsUrl = "https://" + GoogleDnsHttpsDomain + "/resolve" 63 | ) 64 | 65 | type GoogleHttpsUpstream struct { 66 | Client *http.Client 67 | } 68 | 69 | func (g *GoogleHttpsUpstream) Name() string { 70 | return GoogleDnsHttpsUrl 71 | } 72 | 73 | func extractEdns0Subnet(m *dns.Msg) *dns.EDNS0_SUBNET { 74 | for _, rr := range m.Extra { 75 | if rrOpt, ok := rr.(*dns.OPT); ok { 76 | for _, opt := range rrOpt.Option { 77 | if e, ok := opt.(*dns.EDNS0_SUBNET); ok { 78 | return e 79 | } 80 | } 81 | } 82 | } 83 | return nil 84 | } 85 | 86 | type GoogleDnsHttpsQuestion struct { 87 | Name string `json:"name"` 88 | Type uint16 `json:"type"` 89 | } 90 | 91 | type GoogleDnsHttpsAnswer struct { 92 | Name string `json:"name"` 93 | Type uint16 `json:"type"` 94 | TTL uint32 95 | Data string `json:"data"` 96 | } 97 | 98 | type GoogleDnsHttpsResponse struct { 99 | Status int 100 | TC bool 101 | RD bool 102 | RA bool 103 | AD bool 104 | CD bool 105 | Question []GoogleDnsHttpsQuestion 106 | Answer []GoogleDnsHttpsAnswer 107 | Authority []GoogleDnsHttpsAnswer 108 | Additional []struct { 109 | } 110 | Comment string 111 | } 112 | 113 | func extractRRHdr(a GoogleDnsHttpsAnswer) dns.RR_Header { 114 | return dns.RR_Header{ 115 | Name: a.Name, 116 | Rrtype: a.Type, 117 | Ttl: a.TTL, 118 | Class: dns.ClassINET, 119 | } 120 | } 121 | 122 | func (g *GoogleHttpsUpstream) Exchange(m *dns.Msg) (r *dns.Msg, err error) { 123 | params := url.Values{ 124 | "name": {m.Question[0].Name}, 125 | "type": {strconv.FormatUint(uint64(m.Question[0].Qtype), 10)}, 126 | } 127 | edns0Subnet := extractEdns0Subnet(m) 128 | if edns0Subnet != nil && edns0Subnet.Address != nil { 129 | params.Set("edns_client_subnet", edns0Subnet.Address.String()+"/"+strconv.Itoa(int(edns0Subnet.SourceNetmask))) 130 | } 131 | reqUrl := GoogleDnsHttpsUrl + "?" + params.Encode() 132 | req, err := http.NewRequest(http.MethodGet, reqUrl, nil) 133 | if err != nil { 134 | return nil, err 135 | } 136 | resp, err := g.Client.Do(req) 137 | if err != nil { 138 | return nil, err 139 | } 140 | defer resp.Body.Close() 141 | if resp.StatusCode != http.StatusOK { 142 | return nil, fmt.Errorf("status=%s", resp.Status) 143 | } 144 | respBytes, err := ioutil.ReadAll(resp.Body) 145 | if err != nil { 146 | return nil, err 147 | } 148 | var msgResp GoogleDnsHttpsResponse 149 | if err := json.Unmarshal(respBytes, &msgResp); err != nil { 150 | return nil, err 151 | } 152 | r = new(dns.Msg) 153 | r.Id = m.Id 154 | r.MsgHdr.Response = true 155 | r.MsgHdr.Opcode = dns.OpcodeQuery 156 | r.MsgHdr.Rcode = msgResp.Status 157 | r.MsgHdr.Truncated = msgResp.TC 158 | r.MsgHdr.RecursionDesired = msgResp.RD 159 | r.MsgHdr.RecursionAvailable = msgResp.RA 160 | r.MsgHdr.CheckingDisabled = msgResp.CD 161 | for _, q := range msgResp.Question { 162 | r.Question = append(r.Question, dns.Question{q.Name, q.Type, dns.ClassINET}) 163 | } 164 | for _, a := range msgResp.Answer { 165 | hdr := extractRRHdr(a) 166 | var rr dns.RR 167 | switch a.Type { 168 | case dns.TypeA: 169 | rr = &dns.A{ 170 | Hdr: hdr, 171 | A: net.ParseIP(a.Data), 172 | } 173 | case dns.TypeNS: 174 | rr = &dns.NS{ 175 | Hdr: hdr, 176 | Ns: a.Data, 177 | } 178 | case dns.TypeMD: 179 | rr = &dns.MD{ 180 | Hdr: hdr, 181 | Md: a.Data, 182 | } 183 | case dns.TypeMF: 184 | rr = &dns.MF{ 185 | Hdr: hdr, 186 | Mf: a.Data, 187 | } 188 | case dns.TypeCNAME: 189 | rr = &dns.CNAME{ 190 | Hdr: hdr, 191 | Target: a.Data, 192 | } 193 | case dns.TypeSOA: 194 | case dns.TypeMB: 195 | rr = &dns.MB{ 196 | Hdr: hdr, 197 | Mb: a.Data, 198 | } 199 | case dns.TypeMG: 200 | rr = &dns.MG{ 201 | Hdr: hdr, 202 | Mg: a.Data, 203 | } 204 | case dns.TypeMR: 205 | rr = &dns.MR{ 206 | Hdr: hdr, 207 | Mr: a.Data, 208 | } 209 | case dns.TypeNULL: 210 | case dns.TypePTR: 211 | rr = &dns.PTR{ 212 | Hdr: hdr, 213 | Ptr: a.Data, 214 | } 215 | case dns.TypeHINFO: 216 | case dns.TypeMINFO: 217 | case dns.TypeMX: 218 | mx := &dns.MX{ 219 | Hdr: hdr, 220 | } 221 | parts := strings.Split(a.Data, " ") 222 | if len(parts) < 2 { 223 | continue 224 | } 225 | var n uint64 226 | n, _ = strconv.ParseUint(parts[0], 10, 16) 227 | mx.Preference = uint16(n) 228 | mx.Mx = parts[1] 229 | rr = mx 230 | case dns.TypeTXT: 231 | rr = &dns.TXT{ 232 | Hdr: hdr, 233 | Txt: strings.Split(a.Data, " "), 234 | } 235 | case dns.TypeRP: 236 | rp := &dns.RP{ 237 | Hdr: hdr, 238 | } 239 | parts := strings.Split(a.Data, " ") 240 | if len(parts) < 2 { 241 | continue 242 | } 243 | rp.Mbox, rp.Txt = parts[0], parts[1] 244 | rr = rp 245 | case dns.TypeAAAA: 246 | rr = &dns.AAAA{ 247 | Hdr: hdr, 248 | AAAA: net.ParseIP(a.Data), 249 | } 250 | case dns.TypeSRV: 251 | srv := &dns.SRV{ 252 | Hdr: hdr, 253 | } 254 | parts := strings.Split(a.Data, " ") 255 | if len(parts) < 4 { 256 | continue 257 | } 258 | var n uint64 259 | n, _ = strconv.ParseUint(parts[0], 10, 16) 260 | srv.Priority = uint16(n) 261 | n, _ = strconv.ParseUint(parts[1], 10, 16) 262 | srv.Weight = uint16(n) 263 | n, _ = strconv.ParseUint(parts[2], 10, 16) 264 | srv.Port = uint16(n) 265 | srv.Target = parts[3] 266 | rr = srv 267 | case dns.TypeSPF: 268 | rr = &dns.SPF{ 269 | Hdr: hdr, 270 | Txt: strings.Split(a.Data, " "), 271 | } 272 | case dns.TypeDS: 273 | ds := &dns.DS{ 274 | Hdr: hdr, 275 | } 276 | parts := strings.Split(a.Data, " ") 277 | if len(parts) < 4 { 278 | continue 279 | } 280 | var n uint64 281 | n, _ = strconv.ParseUint(parts[0], 10, 16) 282 | ds.KeyTag = uint16(n) 283 | n, _ = strconv.ParseUint(parts[1], 10, 8) 284 | ds.Algorithm = uint8(n) 285 | n, _ = strconv.ParseUint(parts[2], 10, 8) 286 | ds.DigestType = uint8(n) 287 | ds.Digest = parts[3] 288 | rr = ds 289 | case dns.TypeSSHFP: 290 | sshfp := &dns.SSHFP{ 291 | Hdr: hdr, 292 | } 293 | parts := strings.Split(a.Data, " ") 294 | if len(parts) < 3 { 295 | continue 296 | } 297 | var n uint64 298 | n, _ = strconv.ParseUint(parts[0], 10, 8) 299 | sshfp.Algorithm = uint8(n) 300 | n, _ = strconv.ParseUint(parts[1], 10, 8) 301 | sshfp.Type = uint8(n) 302 | sshfp.FingerPrint = parts[2] 303 | rr = sshfp 304 | case dns.TypeRRSIG: 305 | rrsig := &dns.RRSIG{ 306 | Hdr: hdr, 307 | } 308 | parts := strings.Split(a.Data, " ") 309 | if len(parts) < 9 { 310 | continue 311 | } 312 | var ok bool 313 | if rrsig.TypeCovered, ok = dns.StringToType[strings.ToUpper(parts[0])]; !ok { 314 | continue 315 | } 316 | var n uint64 317 | n, _ = strconv.ParseUint(parts[1], 10, 8) 318 | rrsig.Algorithm = uint8(n) 319 | n, _ = strconv.ParseUint(parts[2], 10, 8) 320 | rrsig.Labels = uint8(n) 321 | n, _ = strconv.ParseUint(parts[3], 10, 32) 322 | rrsig.OrigTtl = uint32(n) 323 | n, _ = strconv.ParseUint(parts[4], 10, 32) 324 | rrsig.Expiration = uint32(n) 325 | n, _ = strconv.ParseUint(parts[5], 10, 32) 326 | rrsig.Inception = uint32(n) 327 | n, _ = strconv.ParseUint(parts[6], 10, 16) 328 | rrsig.KeyTag = uint16(n) 329 | rrsig.SignerName = parts[7] 330 | rrsig.Signature = parts[8] 331 | rr = rrsig 332 | case dns.TypeNSEC: 333 | nsec := &dns.NSEC{ 334 | Hdr: hdr, 335 | } 336 | parts := strings.Split(a.Data, " ") 337 | nsec.NextDomain = parts[0] 338 | for _, d := range parts[1:] { 339 | if typeBit, ok := dns.StringToType[strings.ToUpper(d)]; ok { 340 | nsec.TypeBitMap = append(nsec.TypeBitMap, typeBit) 341 | } 342 | } 343 | rr = nsec 344 | case dns.TypeDNSKEY: 345 | dnskey := &dns.DNSKEY{ 346 | Hdr: hdr, 347 | } 348 | parts := strings.Split(a.Data, " ") 349 | if len(parts) < 4 { 350 | continue 351 | } 352 | var n uint64 353 | n, _ = strconv.ParseUint(parts[0], 10, 16) 354 | dnskey.Flags = uint16(n) 355 | n, _ = strconv.ParseUint(parts[1], 10, 8) 356 | dnskey.Protocol = uint8(n) 357 | n, _ = strconv.ParseUint(parts[2], 10, 8) 358 | dnskey.Algorithm = uint8(n) 359 | dnskey.PublicKey = parts[3] 360 | rr = dnskey 361 | case dns.TypeNSEC3: 362 | nsec3 := &dns.NSEC3{ 363 | Hdr: hdr, 364 | } 365 | parts := strings.Split(a.Data, " ") 366 | if len(parts) < 7 { 367 | continue 368 | } 369 | var n uint64 370 | n, _ = strconv.ParseUint(parts[0], 10, 8) 371 | nsec3.Hash = uint8(n) 372 | n, _ = strconv.ParseUint(parts[1], 10, 8) 373 | nsec3.Flags = uint8(n) 374 | n, _ = strconv.ParseUint(parts[2], 10, 16) 375 | nsec3.Iterations = uint16(n) 376 | n, _ = strconv.ParseUint(parts[3], 10, 8) 377 | nsec3.SaltLength = uint8(n) 378 | nsec3.Salt = parts[4] 379 | n, _ = strconv.ParseUint(parts[5], 10, 8) 380 | nsec3.HashLength = uint8(n) 381 | nsec3.NextDomain = parts[6] 382 | for _, d := range parts[7:] { 383 | if t, ok := dns.StringToType[strings.ToUpper(d)]; ok { 384 | nsec3.TypeBitMap = append(nsec3.TypeBitMap, t) 385 | } 386 | } 387 | rr = nsec3 388 | case dns.TypeNSEC3PARAM: 389 | nsec3param := &dns.NSEC3PARAM{ 390 | Hdr: hdr, 391 | } 392 | parts := strings.Split(a.Data, " ") 393 | if len(parts) < 5 { 394 | continue 395 | } 396 | var n uint64 397 | n, _ = strconv.ParseUint(parts[0], 10, 8) 398 | nsec3param.Hash = uint8(n) 399 | n, _ = strconv.ParseUint(parts[1], 10, 8) 400 | nsec3param.Flags = uint8(n) 401 | n, _ = strconv.ParseUint(parts[2], 10, 16) 402 | nsec3param.Iterations = uint16(n) 403 | n, _ = strconv.ParseUint(parts[3], 10, 8) 404 | nsec3param.SaltLength = uint8(n) 405 | nsec3param.Salt = parts[4] 406 | rr = nsec3param 407 | } 408 | if rr != nil { 409 | r.Answer = append(r.Answer, rr) 410 | } 411 | } 412 | for _, a := range msgResp.Authority { 413 | hdr := extractRRHdr(a) 414 | var rr dns.RR 415 | switch a.Type { 416 | case dns.TypeSOA: 417 | soa := &dns.SOA{ 418 | Hdr: hdr, 419 | } 420 | parts := strings.Split(a.Data, " ") 421 | if len(parts) < 7 { 422 | continue 423 | } 424 | soa.Ns = parts[0] 425 | soa.Mbox = parts[1] 426 | var n uint64 427 | n, _ = strconv.ParseUint(parts[2], 10, 32) 428 | soa.Serial = uint32(n) 429 | n, _ = strconv.ParseUint(parts[3], 10, 32) 430 | soa.Refresh = uint32(n) 431 | n, _ = strconv.ParseUint(parts[4], 10, 32) 432 | soa.Retry = uint32(n) 433 | n, _ = strconv.ParseUint(parts[5], 10, 32) 434 | soa.Expire = uint32(n) 435 | n, _ = strconv.ParseUint(parts[6], 10, 32) 436 | soa.Minttl = uint32(n) 437 | rr = soa 438 | } 439 | r.Ns = append(r.Ns, rr) 440 | } 441 | err = nil 442 | return 443 | } 444 | --------------------------------------------------------------------------------