├── .ci ├── Dockerfile ├── Jenkinsfile └── resources │ └── static │ ├── footer.template.html │ └── header.template.html ├── .gitignore ├── Makefile ├── README.md ├── cmd └── dotproxy │ └── main.go ├── config.example.yaml ├── go.mod ├── init └── dotproxy.service ├── internal ├── data │ ├── doc.go │ ├── mru.go │ └── priority.go ├── log │ ├── console.go │ ├── doc.go │ ├── level.go │ └── logger.go ├── meta │ ├── config.go │ ├── doc.go │ └── version.go ├── metrics │ ├── doc.go │ └── hook.go ├── network │ ├── client.go │ ├── conn.go │ ├── doc.go │ ├── init.go │ ├── persistent.go │ ├── server.go │ └── sharding.go └── protocol │ ├── dns_proxy.go │ └── doc.go └── tools.go /.ci/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM docker.internal.kevinlin.info/infra/ci-base:0.3.1 2 | 3 | ENV APINDEX_VERSION e8ed53a76dfd2dfaf2aa444f666b4513d3108653 4 | 5 | # Release dependencies 6 | ADD https://storage.kevinlin.info/deploy/external/apindex/$APINDEX_VERSION/release.tar.gz apindex.tar.gz 7 | RUN sudo tar xvf apindex.tar.gz 8 | RUN sudo mv bin/* /usr/local/bin/ 9 | RUN sudo mv share/* /usr/local/share/ 10 | COPY resources/static/header.template.html /usr/local/share/apindex/header.template.html 11 | COPY resources/static/footer.template.html /usr/local/share/apindex/footer.template.html 12 | -------------------------------------------------------------------------------- /.ci/Jenkinsfile: -------------------------------------------------------------------------------- 1 | pipeline { 2 | agent { 3 | dockerfile { 4 | dir '.ci' 5 | label 'docker-executor' 6 | } 7 | } 8 | 9 | options { 10 | withAWS( 11 | endpointUrl: 'https://storage.kevinlin.info', 12 | credentials: 'storage-internal', 13 | ) 14 | } 15 | 16 | stages { 17 | stage('Install') { 18 | steps { 19 | sh 'go mod download -x' 20 | } 21 | } 22 | stage('Generate') { 23 | steps { 24 | sh 'make generate' 25 | } 26 | } 27 | stage('Lint') { 28 | steps { 29 | sh 'make lint' 30 | } 31 | } 32 | stage('Build') { 33 | environment { 34 | CGO_ENABLED = 0 35 | VERSION_SHA = "${GIT_COMMIT}" 36 | } 37 | parallel { 38 | stage('linux/386') { 39 | environment { 40 | GOOS = 'linux' 41 | GOARCH = '386' 42 | } 43 | steps { 44 | sh 'make' 45 | } 46 | } 47 | stage('linux/amd64') { 48 | environment { 49 | GOOS = 'linux' 50 | GOARCH = 'amd64' 51 | } 52 | steps { 53 | sh 'make' 54 | } 55 | } 56 | stage('linux/arm') { 57 | environment { 58 | GOOS = 'linux' 59 | GOARCH = 'arm' 60 | } 61 | steps { 62 | sh 'make' 63 | } 64 | } 65 | stage('linux/arm64') { 66 | environment { 67 | GOOS = 'linux' 68 | GOARCH = 'arm64' 69 | } 70 | steps { 71 | sh 'make' 72 | } 73 | } 74 | } 75 | } 76 | stage('Release') { 77 | environment { 78 | RELEASE_WORKDIR = sh( 79 | script: 'mktemp -d', 80 | returnStdout: true, 81 | ).trim() 82 | } 83 | steps { 84 | // Binary 85 | sh 'tar -cvzf release.tar.gz bin/ init/' 86 | s3Upload( 87 | bucket: 'internal', 88 | path: "deploy/${env.JOB_NAME}/${env.GIT_COMMIT}/", 89 | file: 'release.tar.gz', 90 | ) 91 | 92 | // Static site 93 | script { 94 | deleteDir() 95 | } 96 | git( 97 | url: env.GIT_URL - '.git', 98 | branch: 'static', 99 | ) 100 | // Download release archive 101 | s3Download( 102 | bucket: 'internal', 103 | path: "deploy/${env.JOB_NAME}/${env.GIT_COMMIT}/release.tar.gz", 104 | file: 'release.tar.gz', 105 | ) 106 | sh "tar -C ${RELEASE_WORKDIR} -xvf release.tar.gz" 107 | sh 'rm release.tar.gz' 108 | // Create release directory 109 | sh "mkdir -pv releases/${GIT_COMMIT}/" 110 | sh "ln -sTfv ${GIT_COMMIT} releases/latest" 111 | sh "mv -v ${RELEASE_WORKDIR}/bin/* releases/${GIT_COMMIT}/" 112 | // Generate page index 113 | sh 'apindex . .git,CNAME,release' 114 | // Create release 115 | sh "tar -cvzf release.tar.gz index.html releases/index.html releases/${GIT_COMMIT}/ releases/latest/" 116 | s3Upload( 117 | bucket: 'internal', 118 | path: "deploy/${env.JOB_NAME}-static/${env.GIT_COMMIT}/", 119 | file: 'release.tar.gz', 120 | ) 121 | } 122 | } 123 | stage('Deploy') { 124 | steps { 125 | build( 126 | job: 'task--static-deploy', 127 | parameters: [ 128 | string(name: 'RELEASE_ARTIFACT', value: "${env.JOB_NAME}-static"), 129 | string(name: 'RELEASE_VERSION', value: "${env.GIT_COMMIT}"), 130 | string(name: 'DOMAIN', value: 'dotproxy.static.kevinlin.info'), 131 | string(name: 'GIT_REMOTE_INTERNAL', value: "${env.GIT_URL}"), 132 | string(name: 'GIT_REMOTE_GITHUB', value: 'git@github.com:LINKIWI/dotproxy-static.git'), 133 | booleanParam(name: 'CLEAN_DEPLOY', value: false), 134 | ], 135 | wait: true, 136 | ) 137 | } 138 | } 139 | stage('Publish') { 140 | environment { 141 | PACKAGE_VERSION = sh( 142 | script: 'git show ' + 143 | '--no-patch ' + 144 | '--no-notes ' + 145 | "--date=format:\"%Y.%m.%d-%H.%M.%S-\$(echo ${GIT_COMMIT} | cut -c -8)\" " + 146 | "--pretty=format:'%cd' ${GIT_COMMIT}", 147 | returnStdout: true, 148 | ).trim() 149 | } 150 | parallel { 151 | stage('linux/amd64') { 152 | steps { 153 | build( 154 | job: 'task--package', 155 | parameters: [ 156 | string(name: 'RELEASE_ARTIFACT', value: "${env.JOB_NAME}"), 157 | string(name: 'RELEASE_VERSION', value: "${env.GIT_COMMIT}"), 158 | string(name: 'PACKAGE_NAME', value: "${env.JOB_NAME}"), 159 | string(name: 'PACKAGE_VERSION', value: "${env.PACKAGE_VERSION}"), 160 | string(name: 'PACKAGE_DESCRIPTION', value: 'High performance DNS-over-TLS proxy'), 161 | string(name: 'PACKAGE_ARCHITECTURE', value: 'amd64'), 162 | string(name: 'BINARY_SPEC', value: 'dotproxy:bin/dotproxy-linux-amd64'), 163 | string(name: 'SYSTEMD_SERVICES', value: 'init/dotproxy.service'), 164 | ], 165 | wait: true, 166 | ) 167 | } 168 | } 169 | stage('linux/arm') { 170 | steps { 171 | build( 172 | job: 'task--package', 173 | parameters: [ 174 | string(name: 'RELEASE_ARTIFACT', value: "${env.JOB_NAME}"), 175 | string(name: 'RELEASE_VERSION', value: "${env.GIT_COMMIT}"), 176 | string(name: 'PACKAGE_NAME', value: "${env.JOB_NAME}"), 177 | string(name: 'PACKAGE_VERSION', value: "${env.PACKAGE_VERSION}"), 178 | string(name: 'PACKAGE_DESCRIPTION', value: 'High performance DNS-over-TLS proxy'), 179 | string(name: 'PACKAGE_ARCHITECTURE', value: 'armhf'), 180 | string(name: 'BINARY_SPEC', value: 'dotproxy:bin/dotproxy-linux-arm'), 181 | string(name: 'SYSTEMD_SERVICES', value: 'init/dotproxy.service'), 182 | ], 183 | wait: true, 184 | ) 185 | } 186 | } 187 | } 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /.ci/resources/static/footer.template.html: -------------------------------------------------------------------------------- 1 | dotproxy 2 | -------------------------------------------------------------------------------- /.ci/resources/static/header.template.html: -------------------------------------------------------------------------------- 1 | 13 | 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated code 2 | internal/log/level_string.go 3 | internal/network/loadbalancingpolicy_string.go 4 | internal/network/transport_string.go 5 | 6 | # Application configuration 7 | config.yaml 8 | 9 | # Build artifacts 10 | bin/ 11 | 12 | # Tooling artifacts 13 | go.sum 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Name of the binary executable 2 | DOTPROXY = dotproxy 3 | 4 | # Output binary directory 5 | BIN_DIR = bin 6 | 7 | # OS and architecture to use for the build 8 | GOOS ?= $(shell go env GOOS) 9 | GOARCH ?= $(shell go env GOARCH) 10 | 11 | # Generated source code 12 | GENERATED_SOURCE = internal/log/level.go \ 13 | internal/network/server.go \ 14 | internal/network/sharding.go 15 | GENERATED_ARTIFACTS = internal/log/level_string.go \ 16 | internal/network/loadbalancingpolicy_string.go \ 17 | internal/network/transport_string.go 18 | 19 | binary: $(DOTPROXY) 20 | 21 | generate: $(GENERATED_ARTIFACTS) 22 | 23 | $(DOTPROXY): $(GENERATED_ARTIFACTS) 24 | go build \ 25 | -ldflags "-w -s -X dotproxy/internal/meta.VersionSHA=$(VERSION_SHA)" \ 26 | -o $(BIN_DIR)/$(DOTPROXY)-$(GOOS)-$(GOARCH) \ 27 | cmd/$(DOTPROXY)/main.go 28 | 29 | $(GENERATED_ARTIFACTS): $(GENERATED_SOURCE) 30 | go generate -v ./... 31 | 32 | lint: 33 | ! gofmt -s -d . | grep "^" 34 | go run golang.org/x/lint/golint --set_exit_status ./... 35 | go vet ./... 36 | 37 | clean: 38 | rm -f $(BIN_DIR)/* 39 | rm -f $(GENERATED_ARTIFACTS) 40 | 41 | .PHONY: lint clean 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dotproxy 2 | 3 | **dotproxy** is a high-performance and fault-tolerant DNS-over-TLS proxy. It listens on both TCP and UDP transports and proxies DNS traffic transparently to configurable TLS-enabled upstream server(s). 4 | 5 | dotproxy is intended to sit at the edge of a private network, encrypting traffic over an untrusted channel to and from external, public DNS servers like [Cloudflare DNS](https://developers.cloudflare.com/1.1.1.1/dns-over-tls/) or [Google DNS](https://developers.google.com/speed/public-dns/docs/dns-over-tls). As a plaintext-to-TLS proxy, dotproxy can be *transparently* inserted into existing network infrastructure without requiring DNS reconfiguration on existing clients. 6 | 7 | ## Features 8 | 9 | * Intelligent client-side connection persistence and pooling to minimize TCP and TLS latency overhead 10 | * Rudimentary load balancing policy among multiple upstream servers 11 | * Rich metrics reporting via `statsd`: connection establishment/teardown events, network I/O events, upstream latency, and RTT latency 12 | * Supports both TCP and UDP ingress (with automatic spec-compliant data reshaping to support UDP ingress to TCP/TLS egress, and vice versa) 13 | 14 | dotproxy is stateless and generally not protocol-aware. This sacrifies some features (like upstream response caching behavior or domain-aware load balancing/sharding) in favor of slightly reduced proxy latency overhead (by not parsing request and response packets). 15 | 16 | ## Performance 17 | 18 | dotproxy maintains a pool of persistent, long-lived TCP connections to upstream server(s). This helps amortize the cost of establishing TCP connections and performing TLS handshakes with the server, thus providing the client near-UDP levels of performance. Additionally, most network behavior parameters are exposed in application configuration, allowing for the proxy to be performance-tuned specifically for the deployment's environment. 19 | 20 | Networks characterized by high request volume (in terms of QPS) will generally benefit from a larger upstream connection pool. On the other hand, networks characterized by low request volume will generally benefit from a smaller upstream connection pool; too large of a connection pool will decrease average performance due to excessive connection churn from server-side TCP timeouts. Cloudflare's DNS servers, for example, close client TCP connections after a 10 second period of inactivity. 21 | 22 | Most use cases will benefit from a large number of maximum concurrent ingress UDP connections. Generally speaking, this value should be set to a responsible estimate of highest number of concurrent UDP clients. 23 | 24 | ## Usage 25 | 26 | Download a precompiled binary for the target platform/architecture at the [releases index](https://dotproxy.static.kevinlin.info/releases/latest). Currently, binaries are built for most flavors of Linux. 27 | 28 | Alternatively, to compile the project manually with a recent version of the Go toolchain: 29 | 30 | ```bash 31 | $ make 32 | $ ./bin/dotproxy-$OS-$ARCH --help 33 | ``` 34 | 35 | The versioned `systemd` unit file can serve as an example for how to daemonize the process. 36 | 37 | ## Configuration 38 | 39 | ### Configuration file 40 | 41 | dotproxy must be passed a YAML configuration file path with the `--config` flag. The versioned `config.example.yaml` in the repository root can serve as an example of a valid configuration file. 42 | 43 | The following table documents each field and its expected value: 44 | 45 | |Key|Required|Description| 46 | |-|-|-| 47 | |`metrics.statsd.addr`|No|Address of the statsd server for metrics reporting| 48 | |`metrics.statsd.sample_rate`|No|statsd sample rate, if enabled| 49 | |`listener.tcp.addr`|Yes|Address to bind to for the TCP listener| 50 | |`listener.tcp.read_timeout`|No|Time duration string for a client TCP read timeout| 51 | |`listener.tcp.write_timeout`|No|Time duration string for a client TCP write timeout| 52 | |`listener.udp.addr`|Yes|Address to bind to for the UDP listener| 53 | |`listener.udp.read_timeout`|No|Time duration string for a client UDP read timeout; should generally be omitted or set to 0| 54 | |`listener.udp.write_timeout`|No|Time duration string for a client UDP write timeout| 55 | |`upstream.load_balacing_policy`|No|One of the `LoadBalancingPolicy` constants to control how requests are sharded among all specified upstream servers| 56 | |`upstream.max_connection_retries`|No|Maximum number of times to retry an upstream I/O operation, per request| 57 | |`upstream.servers[].addr`|Yes|The address of the upstream TLS-enabled DNS server| 58 | |`upstream.servers[].server_name`|Yes|The TLS server hostname (used for server identity verification)| 59 | |`upstream.servers[].connection_pool_size`|No|Size of the connection pool to maintain for this server; environments with high traffic and/or request concurrency will generally benefit from a larger connection pool| 60 | |`upstream.servers[].connect_timeout`|No|Time duration string for an upstream TCP connection establishment timeout| 61 | |`upstream.servers[].handshake_timeout`|No|Time duration string for an upstream TLS handshake timeout| 62 | |`upstream.servers[].read_timeout`|No|Time duration string for an upstream TCP read timeout| 63 | |`upstream.servers[].write_timeout`|No|Time duration string for an upstream TCP write timeout| 64 | |`upstream.servers[].stale_timeout`|No|Time duration string describing the interval of time between consecutive open connection uses after which it should be considered stale and reestablished| 65 | 66 | ### Load balancing policies 67 | 68 | When there exists more than one upstream DNS server in configuration, the `upstream.load_balancing_policy` field controls how dotproxy shards requests among the servers. The policies below are mostly stateless and protocol-agnostic. 69 | 70 | |Policy|Description| 71 | |-|-| 72 | |`RoundRobin`|Select servers in [round-robin](https://en.wikipedia.org/wiki/Round-robin_scheduling), circular order. Simple, fair, but not fault tolerant.| 73 | |`Random`|Select a server at random. Simple, fair, async-safe, but not fault tolerant.| 74 | |`HistoricalConnections`|Select the server that has, up until the time of request, provided the fewest number of connections. Ideal if it is important that all servers share an equal amount of load, without regard to fault tolerance.| 75 | |`Availability`|Randomly select an available server. A server is considered *available* if it is successful in providing a connection. Servers that fail to provide a connection are pulled out of the availability pool for exponentially increasing durations of time, preventing them from providing connections until their unavailability period has expired. Ideal for greatest fault tolerance while maintaining roughly equal load distribution and minimizing downstream latency impact, at the cost of running potentially expensive logic every time a connection is requested.| 76 | |`Failover`|Prioritize a single primary server and failover to secondary server(s) only when the primary fails. Ideal if one server should serve all traffic, but there is a need for fault tolerance.| 77 | -------------------------------------------------------------------------------- /cmd/dotproxy/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | 8 | "dotproxy/internal/log" 9 | "dotproxy/internal/meta" 10 | "dotproxy/internal/metrics" 11 | "dotproxy/internal/network" 12 | "dotproxy/internal/protocol" 13 | 14 | "github.com/getsentry/raven-go" 15 | ) 16 | 17 | func main() { 18 | configPath := flag.String( 19 | "config", 20 | os.Getenv("DOTPROXY_CONFIG"), 21 | "path to the configuration file on disk", 22 | ) 23 | version := flag.Bool( 24 | "version", 25 | false, 26 | "print the compiled dotproxy version SHA", 27 | ) 28 | verbosity := flag.String( 29 | "verbosity", 30 | "error", 31 | "desired logging verbosity: one of error, warn, info, debug", 32 | ) 33 | flag.Parse() 34 | 35 | // Report the compiled version and exit 36 | if *version { 37 | fmt.Printf("dotproxy/%s\n", meta.VersionSHA) 38 | return 39 | } 40 | 41 | // Logging configuration; default to log.Error verbosity 42 | level, _ := log.ParseLevel(*verbosity) 43 | logger := log.NewConsoleLogger(level) 44 | logger.Debug("main: initialized logger: level=%v", level) 45 | 46 | // Parse application configuration 47 | logger.Debug("main: reading and parsing config: path=%s", *configPath) 48 | config, err := meta.ParseConfig(*configPath) 49 | if err != nil { 50 | panic(err) 51 | } 52 | 53 | // Configure error reporting 54 | if config.Application != nil && config.Application.SentryDSN != "" { 55 | raven.SetDSN(config.Application.SentryDSN) 56 | raven.SetRelease(meta.VersionSHA) 57 | } 58 | 59 | // Configure metrics reporting 60 | clientCxLifecycleHook := metrics.NewNoopConnectionLifecycleHook() 61 | upstreamCxLifecycleHook := metrics.NewNoopConnectionLifecycleHook() 62 | clientCxIOHook := metrics.NewNoopConnectionIOHook() 63 | upstreamCxIOHook := metrics.NewNoopConnectionIOHook() 64 | proxyHook := metrics.NewNoopProxyHook() 65 | 66 | if config.Metrics != nil && config.Metrics.Statsd != nil { 67 | logger.Info( 68 | "main: configuring statsd metrics reporting: addr=%s sample_rate=%f", 69 | config.Metrics.Statsd.Address, 70 | config.Metrics.Statsd.SampleRate, 71 | ) 72 | 73 | if clientCxLifecycleHook, err = metrics.NewAsyncStatsdConnectionLifecycleHook( 74 | "client", 75 | config.Metrics.Statsd.Address, 76 | config.Metrics.Statsd.SampleRate, 77 | meta.VersionSHA, 78 | ); err != nil { 79 | panic(err) 80 | } 81 | 82 | if upstreamCxLifecycleHook, err = metrics.NewAsyncStatsdConnectionLifecycleHook( 83 | "upstream", 84 | config.Metrics.Statsd.Address, 85 | config.Metrics.Statsd.SampleRate, 86 | meta.VersionSHA, 87 | ); err != nil { 88 | panic(err) 89 | } 90 | 91 | if clientCxIOHook, err = metrics.NewAsyncStatsdConnectionIOHook( 92 | "client", 93 | config.Metrics.Statsd.Address, 94 | config.Metrics.Statsd.SampleRate, 95 | meta.VersionSHA, 96 | ); err != nil { 97 | panic(err) 98 | } 99 | 100 | if upstreamCxIOHook, err = metrics.NewAsyncStatsdConnectionIOHook( 101 | "upstream", 102 | config.Metrics.Statsd.Address, 103 | config.Metrics.Statsd.SampleRate, 104 | meta.VersionSHA, 105 | ); err != nil { 106 | panic(err) 107 | } 108 | 109 | if proxyHook, err = metrics.NewAsyncStatsdProxyHook( 110 | config.Metrics.Statsd.Address, 111 | config.Metrics.Statsd.SampleRate, 112 | meta.VersionSHA, 113 | ); err != nil { 114 | panic(err) 115 | } 116 | } else { 117 | logger.Warn("main: no metrics output engine specified; disabling metrics") 118 | } 119 | 120 | // Configure upstreams 121 | var servers []network.Client 122 | for _, server := range config.Upstream.Servers { 123 | opts := network.TLSClientOpts{ 124 | ConnectTimeout: server.ConnectTimeout, 125 | HandshakeTimeout: server.HandshakeTimeout, 126 | ReadTimeout: server.ReadTimeout, 127 | WriteTimeout: server.WriteTimeout, 128 | PoolOpts: network.PersistentConnPoolOpts{ 129 | Capacity: server.ConnectionPoolSize, 130 | StaleTimeout: server.StaleTimeout, 131 | }, 132 | } 133 | 134 | logger.Info( 135 | "main: starting TLS client for upstream server: addr=%s name=%s conns=%d", 136 | server.Address, 137 | server.ServerName, 138 | opts.PoolOpts.Capacity, 139 | ) 140 | 141 | client, err := network.NewTLSClient( 142 | server.Address, 143 | server.ServerName, 144 | upstreamCxLifecycleHook, 145 | opts, 146 | ) 147 | 148 | if err != nil { 149 | panic(err) 150 | } 151 | 152 | servers = append(servers, client) 153 | } 154 | 155 | // Create sharded client for all upstreams 156 | lbPolicy, ok := network.ParseLoadBalancingPolicy(config.Upstream.LoadBalancingPolicy) 157 | if !ok { 158 | logger.Warn( 159 | "main: unknown load balancing policy; use default: supplied=%s default=%s", 160 | config.Upstream.LoadBalancingPolicy, 161 | lbPolicy, 162 | ) 163 | } 164 | 165 | logger.Debug("main: using load balancing policy for request sharding: policy=%s", lbPolicy) 166 | client, _ := network.NewShardedClient(servers, lbPolicy) 167 | 168 | // Configure server listeners 169 | h := &protocol.DNSProxyHandler{ 170 | Upstream: client, 171 | ClientCxIOHook: clientCxIOHook, 172 | UpstreamCxIOHook: upstreamCxIOHook, 173 | ProxyHook: proxyHook, 174 | Logger: logger, 175 | Opts: protocol.DNSProxyOpts{ 176 | MaxUpstreamRetries: config.Upstream.MaxConnectionRetries, 177 | }, 178 | } 179 | 180 | if config.Listener.UDP != nil { 181 | logger.Info( 182 | "main: configuring UDP server listener: addr=%s max_concurrent_conns=%d", 183 | config.Listener.UDP.Address, 184 | config.Listener.UDP.MaxConcurrentConnections, 185 | ) 186 | 187 | opts := network.UDPServerOpts{ 188 | MaxConcurrentConnections: config.Listener.UDP.MaxConcurrentConnections, 189 | ReadTimeout: config.Listener.UDP.ReadTimeout, 190 | WriteTimeout: config.Listener.UDP.WriteTimeout, 191 | } 192 | 193 | udpServer := network.NewUDPServer(config.Listener.UDP.Address, opts) 194 | 195 | go func() { 196 | if err := udpServer.ListenAndServe(h); err != nil { 197 | panic(err) 198 | } 199 | }() 200 | } 201 | 202 | if config.Listener.TCP != nil { 203 | logger.Info( 204 | "main: configuring TCP server listener: addr=%s", 205 | config.Listener.TCP.Address, 206 | ) 207 | 208 | opts := network.TCPServerOpts{ 209 | ReadTimeout: config.Listener.TCP.ReadTimeout, 210 | WriteTimeout: config.Listener.TCP.WriteTimeout, 211 | } 212 | 213 | tcpServer := network.NewTCPServer( 214 | config.Listener.TCP.Address, 215 | clientCxLifecycleHook, 216 | opts, 217 | ) 218 | 219 | go func() { 220 | if err := tcpServer.ListenAndServe(h); err != nil { 221 | panic(err) 222 | } 223 | }() 224 | } 225 | 226 | // Serve indefinitely 227 | logger.Info("main: serving indefinitely") 228 | <-make(chan bool) 229 | } 230 | -------------------------------------------------------------------------------- /config.example.yaml: -------------------------------------------------------------------------------- 1 | metrics: 2 | statsd: 3 | addr: udp://127.0.0.1:8125 4 | sample_rate: 1.0 5 | listener: 6 | tcp: 7 | addr: 127.0.0.1:53 8 | read_timeout: 5s 9 | write_timeout: 5s 10 | udp: 11 | addr: 127.0.0.1:53 12 | max_concurrent_connections: 64 13 | write_timeout: 5s 14 | upstream: 15 | load_balancing_policy: RoundRobin 16 | max_connection_retries: 10 17 | servers: 18 | - addr: 1.1.1.1:853 19 | server_name: cloudflare-dns.com 20 | connection_pool_size: 8 21 | connect_timeout: 100ms 22 | handshake_timeout: 250ms 23 | read_timeout: 5s 24 | write_timeout: 5s 25 | stale_timeout: 10s 26 | - addr: 1.0.0.1:853 27 | server_name: cloudflare-dns.com 28 | connection_pool_size: 8 29 | connect_timeout: 100ms 30 | handshake_timeout: 250ms 31 | read_timeout: 5s 32 | write_timeout: 5s 33 | stale_timeout: 10s 34 | - addr: 8.8.8.8:853 35 | server_name: dns.google 36 | connection_pool_size: 8 37 | connect_timeout: 100ms 38 | handshake_timeout: 250ms 39 | read_timeout: 5s 40 | write_timeout: 5s 41 | stale_timeout: 10s 42 | - addr: 8.8.4.4:853 43 | server_name: dns.google 44 | connection_pool_size: 8 45 | connect_timeout: 100ms 46 | handshake_timeout: 250ms 47 | read_timeout: 5s 48 | write_timeout: 5s 49 | stale_timeout: 10s 50 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module dotproxy 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect 7 | github.com/getsentry/raven-go v0.2.0 8 | github.com/pkg/errors v0.9.1 // indirect 9 | golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 10 | golang.org/x/tools v0.1.0 11 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b 12 | lib.kevinlin.info/aperture v0.0.0-20210116070205-5bba968871c5 13 | ) 14 | -------------------------------------------------------------------------------- /init/dotproxy.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=High performance DNS-over-TLS proxy 3 | After=network.target 4 | 5 | [Service] 6 | Type=simple 7 | Restart=always 8 | RestartSec=30 9 | User=root 10 | SyslogIdentifier=dotproxy 11 | Environment=DOTPROXY_CONFIG=/etc/dotproxy/config.yaml 12 | EnvironmentFile=-/etc/default/dotproxy 13 | ExecStart=/usr/bin/dotproxy --verbosity info 14 | 15 | [Install] 16 | WantedBy=multi-user.target 17 | -------------------------------------------------------------------------------- /internal/data/doc.go: -------------------------------------------------------------------------------- 1 | // Package data contains general-purpose data structures for in-memory storage of ephemeral objects. 2 | package data 3 | -------------------------------------------------------------------------------- /internal/data/mru.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "container/heap" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // MRUQueue is an abstraction on top of a priority queue that assigns priorities based on 10 | // timestamps, for most-recently-used retrieval semantics. 11 | type MRUQueue struct { 12 | store *PriorityQueue 13 | capacity int 14 | mutex sync.Mutex 15 | } 16 | 17 | // NewMRUQueue creates a new MRU queue with the specified capacity. 18 | // The capacity may be any non-positive integer to disable the capacity limit. 19 | func NewMRUQueue(capacity int) *MRUQueue { 20 | var store PriorityQueue 21 | 22 | if capacity > 0 { 23 | store = make(PriorityQueue, 0, capacity) 24 | } else { 25 | store = make(PriorityQueue, 0) 26 | } 27 | 28 | heap.Init(&store) 29 | 30 | return &MRUQueue{store: &store, capacity: capacity} 31 | } 32 | 33 | // Push inserts a new value into the queue. It is tagged with a priority equal to the timestamp at 34 | // which the item is inserted. It is considered an error to add an item beyond the queue's 35 | // provisioned capacity. 36 | func (m *MRUQueue) Push(value interface{}) bool { 37 | m.mutex.Lock() 38 | defer m.mutex.Unlock() 39 | 40 | // Refuse to add beyond capacity 41 | if m.capacity > 0 && m.store.Len() == m.capacity { 42 | return false 43 | } 44 | 45 | heap.Push(m.store, &Item{ 46 | value: value, 47 | priority: int(time.Now().Unix()), 48 | }) 49 | 50 | return true 51 | } 52 | 53 | // Pop removes the most recently used item from the queue. It returns the item itself, the timestamp 54 | // at which it was last used, and a boolean indicating whether the pop was successful. 55 | func (m *MRUQueue) Pop() (interface{}, time.Time, bool) { 56 | m.mutex.Lock() 57 | defer m.mutex.Unlock() 58 | 59 | if m.store.Len() == 0 { 60 | return nil, time.Unix(0, 0), false 61 | } 62 | 63 | item := heap.Pop(m.store).(*Item) 64 | return item.value, time.Unix(int64(item.priority), 0), true 65 | } 66 | 67 | // Size reads the current sizes of the queue. 68 | func (m *MRUQueue) Size() int { 69 | m.mutex.Lock() 70 | defer m.mutex.Unlock() 71 | 72 | return m.store.Len() 73 | } 74 | 75 | // Empty returns whether the queue holds no items. 76 | func (m *MRUQueue) Empty() bool { 77 | m.mutex.Lock() 78 | defer m.mutex.Unlock() 79 | 80 | return m.store.Len() == 0 81 | } 82 | -------------------------------------------------------------------------------- /internal/data/priority.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | "container/heap" 5 | ) 6 | 7 | // Item describes an entry in the priority queue. 8 | type Item struct { 9 | value interface{} 10 | priority int 11 | index int 12 | } 13 | 14 | // PriorityQueue implements heap.Interface and holds Items. 15 | // This implementation is adapted from the container/heap documentation: 16 | // https://golang.org/pkg/container/heap/ 17 | type PriorityQueue []*Item 18 | 19 | // Len returns the current size of the queue. 20 | func (pq PriorityQueue) Len() int { 21 | return len(pq) 22 | } 23 | 24 | // Less instructs heap.Interface how to sort items within the heap. 25 | // A priority queue is a max heap, so this particular application considers a higher priority as 26 | // "less." This allows us to pop the highest-priority item. 27 | func (pq PriorityQueue) Less(i, j int) bool { 28 | return pq[i].priority > pq[j].priority 29 | } 30 | 31 | // Swap swaps the ith and jth items in the backing data structure. 32 | func (pq PriorityQueue) Swap(i, j int) { 33 | pq[i], pq[j] = pq[j], pq[i] 34 | pq[i].index = i 35 | pq[j].index = j 36 | } 37 | 38 | // Push adds a new item to the backing data structure. 39 | func (pq *PriorityQueue) Push(x interface{}) { 40 | n := len(*pq) 41 | item := x.(*Item) 42 | item.index = n 43 | *pq = append(*pq, item) 44 | } 45 | 46 | // Pop removes the last item from the backing data structure. 47 | func (pq *PriorityQueue) Pop() interface{} { 48 | old := *pq 49 | n := len(old) 50 | item := old[n-1] 51 | item.index = -1 52 | *pq = old[0 : n-1] 53 | 54 | return item 55 | } 56 | 57 | // update modifies the priority and value of an Item in the queue. 58 | func (pq *PriorityQueue) update(item *Item, value string, priority int) { 59 | item.value = value 60 | item.priority = priority 61 | heap.Fix(pq, item.index) 62 | } 63 | -------------------------------------------------------------------------------- /internal/log/console.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // ConsoleLogger is a simple, leveled, standard output logging engine. 9 | type ConsoleLogger struct { 10 | level Level 11 | } 12 | 13 | // NewConsoleLogger creates a logger limited to the specified level. Only log messages that are less 14 | // verbose than the specified level are logged. 15 | func NewConsoleLogger(level Level) Logger { 16 | return &ConsoleLogger{level} 17 | } 18 | 19 | // Debug logs a debug message, if permitted by the current level. 20 | func (l *ConsoleLogger) Debug(format string, v ...interface{}) { 21 | l.log(Debug, format, v...) 22 | } 23 | 24 | // Info logs an informational message, if permitted by the current level. 25 | func (l *ConsoleLogger) Info(format string, v ...interface{}) { 26 | l.log(Info, format, v...) 27 | } 28 | 29 | // Warn logs a warning message, if permitted by the current level. 30 | func (l *ConsoleLogger) Warn(format string, v ...interface{}) { 31 | l.log(Warn, format, v...) 32 | } 33 | 34 | // Error logs an error message, if permitted by the current level. 35 | func (l *ConsoleLogger) Error(format string, v ...interface{}) { 36 | l.log(Error, format, v...) 37 | } 38 | 39 | // Level reads the current logging level. 40 | func (l *ConsoleLogger) Level() Level { 41 | return l.level 42 | } 43 | 44 | // log logs a message to standard output with a timestamp and level indicator, if permitted by the 45 | // current level. 46 | func (l *ConsoleLogger) log(level Level, format string, v ...interface{}) { 47 | if l.level.Enables(level) { 48 | fmt.Printf( 49 | "%s %s\t%s\n", 50 | time.Now().Format("2006-01-02 15:04:05"), 51 | level, 52 | fmt.Sprintf(format, v...), 53 | ) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /internal/log/doc.go: -------------------------------------------------------------------------------- 1 | // Package log contains abstractions for application logging at various verbosity levels. 2 | package log 3 | -------------------------------------------------------------------------------- /internal/log/level.go: -------------------------------------------------------------------------------- 1 | //go:generate go run golang.org/x/tools/cmd/stringer -type=Level -linecomment=true 2 | 3 | package log 4 | 5 | import ( 6 | "strings" 7 | ) 8 | 9 | // Level parametrizes supported log verbosity levels. 10 | type Level int 11 | 12 | const ( 13 | // Debug messages trace application-level behaviors. 14 | Debug Level = iota // DEBUG 15 | // Info messages convey general events. 16 | Info // INFO 17 | // Warn messages describe non-erroring divergences from the ideal code path. 18 | Warn // WARN 19 | // Error messages indicate behavior that is not intended and should be corrected. 20 | Error // ERROR 21 | ) 22 | 23 | // ParseLevel looks up a Level constant by its stringified (case-insensitive) representation. 24 | func ParseLevel(level string) (Level, bool) { 25 | knownLevels := []Level{Debug, Info, Warn, Error} 26 | 27 | for _, knownLevel := range knownLevels { 28 | if strings.ToLower(level) == strings.ToLower(knownLevel.String()) { 29 | return knownLevel, true 30 | } 31 | } 32 | 33 | return Error, false 34 | } 35 | 36 | // Enables indicates whether the current log level enables logging at another level. 37 | // 38 | // For example, 39 | // Debug enables Debug, Info, Warn, and Error 40 | // Info enables Warn and Error, but not Debug 41 | // Error enables Error, but not Debug, Info, or Warn 42 | func (l Level) Enables(other Level) bool { 43 | return l <= other 44 | } 45 | -------------------------------------------------------------------------------- /internal/log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | // Logger defines a common interface shared by logging engines. 4 | type Logger interface { 5 | // Debug logs a debug message. 6 | Debug(format string, v ...interface{}) 7 | 8 | // Info logs an informational message. 9 | Info(format string, v ...interface{}) 10 | 11 | // Warn logs a warning message. 12 | Warn(format string, v ...interface{}) 13 | 14 | // Error logs an error message. 15 | Error(format string, v ...interface{}) 16 | 17 | // Level returns the currently configured logging level. 18 | Level() Level 19 | } 20 | -------------------------------------------------------------------------------- /internal/meta/config.go: -------------------------------------------------------------------------------- 1 | package meta 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "time" 7 | 8 | "gopkg.in/yaml.v3" 9 | 10 | "dotproxy/internal/network" 11 | ) 12 | 13 | // ApplicationConfig is a top-level block for application-level meta configuration. 14 | type ApplicationConfig struct { 15 | SentryDSN string `yaml:"sentry_dsn"` 16 | } 17 | 18 | // MetricsConfig is a top-level block for metrics configuration. 19 | type MetricsConfig struct { 20 | Statsd *struct { 21 | Address string `yaml:"addr"` 22 | SampleRate float64 `yaml:"sample_rate"` 23 | } `yaml:"statsd"` 24 | } 25 | 26 | // ListenerConfig is a top-level block for server listener configuration. 27 | type ListenerConfig struct { 28 | TCP *struct { 29 | Address string `yaml:"addr"` 30 | ReadTimeout time.Duration `yaml:"read_timeout"` 31 | WriteTimeout time.Duration `yaml:"write_timeout"` 32 | } `yaml:"tcp"` 33 | UDP *struct { 34 | Address string `yaml:"addr"` 35 | MaxConcurrentConnections int `yaml:"max_concurrent_connections"` 36 | ReadTimeout time.Duration `yaml:"read_timeout"` 37 | WriteTimeout time.Duration `yaml:"write_timeout"` 38 | } `yaml:"udp"` 39 | } 40 | 41 | // UpstreamServer describes parameters for a single upstream server. 42 | type UpstreamServer struct { 43 | Address string `yaml:"addr"` 44 | ServerName string `yaml:"server_name"` 45 | ConnectionPoolSize int `yaml:"connection_pool_size"` 46 | ConnectTimeout time.Duration `yaml:"connect_timeout"` 47 | HandshakeTimeout time.Duration `yaml:"handshake_timeout"` 48 | ReadTimeout time.Duration `yaml:"read_timeout"` 49 | WriteTimeout time.Duration `yaml:"write_timeout"` 50 | StaleTimeout time.Duration `yaml:"stale_timeout"` 51 | } 52 | 53 | // UpstreamConfig is a top-level block for upstream configuration. 54 | type UpstreamConfig struct { 55 | LoadBalancingPolicy string `yaml:"load_balancing_policy"` 56 | MaxConnectionRetries int `yaml:"max_connection_retries"` 57 | Servers []UpstreamServer `yaml:"servers"` 58 | } 59 | 60 | // Config describes all application configuration options. 61 | type Config struct { 62 | Application *ApplicationConfig `yaml:"application"` 63 | Metrics *MetricsConfig `yaml:"metrics"` 64 | Listener *ListenerConfig `yaml:"listener"` 65 | Upstream *UpstreamConfig `yaml:"upstream"` 66 | } 67 | 68 | // ParseConfig parses a Config struct instance from a file specified as a path on disk. 69 | func ParseConfig(path string) (*Config, error) { 70 | data, err := ioutil.ReadFile(path) 71 | if err != nil { 72 | return nil, fmt.Errorf("config: error reading config: err=%v", err) 73 | } 74 | 75 | var cfg *Config 76 | if err := yaml.Unmarshal(data, &cfg); err != nil { 77 | return nil, fmt.Errorf("config: error parsing config: err=%v", err) 78 | } 79 | 80 | if err := cfg.validate(); err != nil { 81 | return nil, err 82 | } 83 | 84 | return cfg, nil 85 | } 86 | 87 | // validate the contents of the configuration. Returns an error if validation failed; nil otherwise. 88 | func (c *Config) validate() error { 89 | /* Metrics */ 90 | 91 | // Users can omit the metrics block entirely to disable metrics reporting. 92 | if c.Metrics != nil && c.Metrics.Statsd != nil { 93 | if c.Metrics.Statsd.Address == "" { 94 | return fmt.Errorf("config: missing metrics statsd address") 95 | } 96 | 97 | if c.Metrics.Statsd.SampleRate < 0 || c.Metrics.Statsd.SampleRate > 1 { 98 | return fmt.Errorf("config: statsd sample rate must be in range [0.0, 1.0]") 99 | } 100 | } 101 | 102 | /* Listener */ 103 | 104 | if c.Listener == nil { 105 | return fmt.Errorf("config: missing top-level listener config key") 106 | } 107 | 108 | if c.Listener.TCP == nil && c.Listener.UDP == nil { 109 | return fmt.Errorf("config: at least one TCP or UDP listener must be specified") 110 | } 111 | 112 | if c.Listener.TCP != nil && c.Listener.TCP.Address == "" { 113 | return fmt.Errorf("config: missing TCP server listening address") 114 | } 115 | 116 | if c.Listener.UDP != nil && c.Listener.UDP.Address == "" { 117 | return fmt.Errorf("config: missing UDP server listening address") 118 | } 119 | 120 | /* Upstream */ 121 | 122 | if c.Upstream == nil { 123 | return fmt.Errorf("config: missing top-level upstream config key") 124 | } 125 | 126 | // Validate the load balancing policy, only if provided (empty signifies default). 127 | if c.Upstream.LoadBalancingPolicy != "" { 128 | if _, ok := network.ParseLoadBalancingPolicy(c.Upstream.LoadBalancingPolicy); !ok { 129 | return fmt.Errorf( 130 | "config: unknown load balancing policy: policy=%s", 131 | c.Upstream.LoadBalancingPolicy, 132 | ) 133 | } 134 | } 135 | 136 | if len(c.Upstream.Servers) == 0 { 137 | return fmt.Errorf("config: no upstream servers specified") 138 | } 139 | 140 | for idx, server := range c.Upstream.Servers { 141 | if server.Address == "" { 142 | return fmt.Errorf("config: missing server address: idx=%d", idx) 143 | } 144 | 145 | if server.ServerName == "" { 146 | return fmt.Errorf("config: missing server TLS hostname: idx=%d", idx) 147 | } 148 | } 149 | 150 | return nil 151 | } 152 | -------------------------------------------------------------------------------- /internal/meta/doc.go: -------------------------------------------------------------------------------- 1 | // Package meta contains logic related to the application itself. 2 | package meta 3 | -------------------------------------------------------------------------------- /internal/meta/version.go: -------------------------------------------------------------------------------- 1 | package meta 2 | 3 | // VersionSHA is a build-time injected variable describing the Git commit SHA at which dotproxy was 4 | // built. It is used as a general purpose, global version identifier. 5 | var VersionSHA string 6 | -------------------------------------------------------------------------------- /internal/metrics/doc.go: -------------------------------------------------------------------------------- 1 | // Package metrics contains abstractions for emission of metrics generated throughout the lifetime 2 | // of the application. Currently, the only supported metrics output engine is statsd. 3 | // 4 | // The nature of the application is such that metrics are generated at various points in time 5 | // throughout a single request lifecycle. Thus, the metrics emissions in this package are structured 6 | // around the notion of hooks: a hook interface defines methods that are invoked by the server's 7 | // main logic routines while serving a client request. Thus, they "hook" into lifecycle points in 8 | // logic. Implementations of hook interfaces actually output the metrics to a backend engine; this 9 | // responsibility is decoupled from the semantics of "hooking" into business logic. 10 | package metrics 11 | -------------------------------------------------------------------------------- /internal/metrics/hook.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "sync/atomic" 7 | "time" 8 | 9 | "lib.kevinlin.info/aperture" 10 | ) 11 | 12 | // ConnectionLifecycleHook is a metrics hook interface for reporting events that occur during a TCP 13 | // connection lifecycle. Note that it is not pertinent to UDP transports, since it is an inherently 14 | // connectionless protocol. 15 | type ConnectionLifecycleHook interface { 16 | // EmitConnectionOpen reports the event that a connection was successfully opened. 17 | EmitConnectionOpen(latency time.Duration, addr net.Addr) 18 | 19 | // EmitConnectionClose reports the event that a connection was closed. 20 | EmitConnectionClose(addr net.Addr) 21 | 22 | // EmitConnectionError reports occurrence of an error establishing a connection. 23 | EmitConnectionError() 24 | } 25 | 26 | // ConnectionIOHook is a metrics hook interface for reporting events related to I/O with an 27 | // established TCP or UDP connection. 28 | type ConnectionIOHook interface { 29 | // EmitRead reports a successful connection read. 30 | EmitRead(latency time.Duration, addr net.Addr) 31 | 32 | // EmitReadError reports the event that a connection read failed. 33 | EmitReadError(addr net.Addr) 34 | 35 | // EmitWrite reports a successful connection write. 36 | EmitWrite(latency time.Duration, addr net.Addr) 37 | 38 | // EmitWriteError reports the event that a connection write failed. 39 | EmitWriteError(addr net.Addr) 40 | 41 | // EmitRetry reports the event that an I/O operation was retried due to failure. 42 | EmitRetry(addr net.Addr) 43 | } 44 | 45 | // ProxyHook is a metrics hook interface for reporting events and latencies related to end-to-end 46 | // proxying of a client request with an upstream server. 47 | type ProxyHook interface { 48 | // EmitRequestSize reports the size of the proxied request on the wire. 49 | EmitRequestSize(bytes int64, client net.Addr) 50 | 51 | // EmitResponseSize reports the size of the proxied response on the wire. 52 | EmitResponseSize(bytes int64, upstream net.Addr) 53 | 54 | // EmitRTT reports the total, end-to-end latency associated with serving a single request 55 | // from a client. This includes the time to establish/teardown all connections, transact 56 | // with the upstream, and proxy the response to/from the client. 57 | EmitRTT(latency time.Duration, client net.Addr, upstream net.Addr) 58 | 59 | // EmitUpstreamLatency reports the latency associated with transacting with the upstream 60 | // to serve a single request. 61 | EmitUpstreamLatency(latency time.Duration, client net.Addr, upstream net.Addr) 62 | 63 | // EmitProcess reports the occurrence of a processed proxy request. 64 | EmitProcess(client net.Addr, upstream net.Addr) 65 | 66 | // EmitError reports the occurrence of a critical error in the proxy lifecycle that causes 67 | // the request to not be correctly served. 68 | EmitError() 69 | } 70 | 71 | // AsyncStatsdConnectionLifecycleHook is an implementation of ConnectionLifecycleHook that outputs 72 | // metrics asynchronously to statsd. 73 | type AsyncStatsdConnectionLifecycleHook struct { 74 | client aperture.Statsd 75 | source string 76 | } 77 | 78 | // AsyncStatsdConnectionIOHook is an implementation of ConnectionIOHook that outputs metrics 79 | // asynchronously to statsd. 80 | type AsyncStatsdConnectionIOHook struct { 81 | client aperture.Statsd 82 | source string 83 | } 84 | 85 | // AsyncStatsdProxyHook is an implementation of ProxyHook that outputs metrics asynchronously to 86 | // statsd. 87 | type AsyncStatsdProxyHook struct { 88 | client aperture.Statsd 89 | sequenceID int64 90 | } 91 | 92 | // NoopConnectionLifecycleHook implements the ConnectionLifecycleHook interface but noops on all 93 | // emissions. 94 | type NoopConnectionLifecycleHook struct{} 95 | 96 | // NoopConnectionIOHook implements the ConnectionIOHook interface but noops on all emissions. 97 | type NoopConnectionIOHook struct{} 98 | 99 | // NoopProxyHook implements the ProxyHook interface but noops on all emissions. 100 | type NoopProxyHook struct{} 101 | 102 | // NewAsyncStatsdConnectionLifecycleHook creates a new client with the specified source, statsd 103 | // address, and statsd sample rate. The source denotes the entity with whom the server is opening 104 | // and closing TCP connections. 105 | func NewAsyncStatsdConnectionLifecycleHook(source string, addr string, sampleRate float64, version string) (ConnectionLifecycleHook, error) { 106 | client, err := statsdClientFactory(addr, sampleRate, version) 107 | if err != nil { 108 | return nil, err 109 | } 110 | 111 | return &AsyncStatsdConnectionLifecycleHook{ 112 | client: client, 113 | source: source, 114 | }, nil 115 | } 116 | 117 | // EmitConnectionOpen statsd implementation 118 | func (h *AsyncStatsdConnectionLifecycleHook) EmitConnectionOpen(latency time.Duration, addr net.Addr) { 119 | go func() { 120 | tags := map[string]interface{}{ 121 | "addr": ipFromAddr(addr), 122 | "transport": transportFromAddr(addr), 123 | } 124 | 125 | h.client.Count(fmt.Sprintf("event.%s.cx_open", h.source), 1, tags) 126 | 127 | if latency > 0 { 128 | h.client.Timing(fmt.Sprintf("latency.%s.cx_open", h.source), latency, tags) 129 | } 130 | }() 131 | } 132 | 133 | // EmitConnectionClose statsd implementation 134 | func (h *AsyncStatsdConnectionLifecycleHook) EmitConnectionClose(addr net.Addr) { 135 | go h.client.Count(fmt.Sprintf("event.%s.cx_close", h.source), 1, map[string]interface{}{ 136 | "addr": ipFromAddr(addr), 137 | "transport": transportFromAddr(addr), 138 | }) 139 | } 140 | 141 | // EmitConnectionError statsd implementation 142 | func (h *AsyncStatsdConnectionLifecycleHook) EmitConnectionError() { 143 | go h.client.Count(fmt.Sprintf("event.%s.cx_error", h.source), 1, nil) 144 | } 145 | 146 | // NewNoopConnectionLifecycleHook creates a noop implementation of ConnectionLifecycleHook. 147 | func NewNoopConnectionLifecycleHook() ConnectionLifecycleHook { 148 | return &NoopConnectionLifecycleHook{} 149 | } 150 | 151 | // EmitConnectionOpen noops. 152 | func (h *NoopConnectionLifecycleHook) EmitConnectionOpen(latency time.Duration, addr net.Addr) {} 153 | 154 | // EmitConnectionClose noops. 155 | func (h *NoopConnectionLifecycleHook) EmitConnectionClose(addr net.Addr) {} 156 | 157 | // EmitConnectionError noops. 158 | func (h *NoopConnectionLifecycleHook) EmitConnectionError() {} 159 | 160 | // NewAsyncStatsdConnectionIOHook creates a new client with the specified source, statsd address, 161 | // and statsd sample rate. The source denotes the entity with whom the server is performing I/O. 162 | func NewAsyncStatsdConnectionIOHook(source string, addr string, sampleRate float64, version string) (ConnectionIOHook, error) { 163 | client, err := statsdClientFactory(addr, sampleRate, version) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | return &AsyncStatsdConnectionIOHook{ 169 | client: client, 170 | source: source, 171 | }, nil 172 | } 173 | 174 | // EmitRead statsd implementation. 175 | func (h *AsyncStatsdConnectionIOHook) EmitRead(latency time.Duration, addr net.Addr) { 176 | go func() { 177 | tags := map[string]interface{}{ 178 | "addr": ipFromAddr(addr), 179 | "transport": transportFromAddr(addr), 180 | } 181 | 182 | h.client.Count(fmt.Sprintf("event.%s.cx_read", h.source), 1, tags) 183 | h.client.Timing(fmt.Sprintf("latency.%s.cx_read", h.source), latency, tags) 184 | }() 185 | } 186 | 187 | // EmitReadError statsd implementation. 188 | func (h *AsyncStatsdConnectionIOHook) EmitReadError(addr net.Addr) { 189 | go h.client.Count(fmt.Sprintf("event.%s.cx_read_error", h.source), 1, map[string]interface{}{ 190 | "addr": ipFromAddr(addr), 191 | "transport": transportFromAddr(addr), 192 | }) 193 | } 194 | 195 | // EmitWrite statsd implementation. 196 | func (h *AsyncStatsdConnectionIOHook) EmitWrite(latency time.Duration, addr net.Addr) { 197 | go func() { 198 | tags := map[string]interface{}{ 199 | "addr": ipFromAddr(addr), 200 | "transport": transportFromAddr(addr), 201 | } 202 | 203 | h.client.Count(fmt.Sprintf("event.%s.cx_write", h.source), 1, tags) 204 | h.client.Timing(fmt.Sprintf("latency.%s.cx_write", h.source), latency, tags) 205 | }() 206 | } 207 | 208 | // EmitWriteError statsd implementation. 209 | func (h *AsyncStatsdConnectionIOHook) EmitWriteError(addr net.Addr) { 210 | go h.client.Count(fmt.Sprintf("event.%s.cx_write_error", h.source), 1, map[string]interface{}{ 211 | "addr": ipFromAddr(addr), 212 | "transport": transportFromAddr(addr), 213 | }) 214 | } 215 | 216 | // EmitRetry statsd implementation. 217 | func (h *AsyncStatsdConnectionIOHook) EmitRetry(addr net.Addr) { 218 | go h.client.Count(fmt.Sprintf("event.%s.cx_io_retry", h.source), 1, map[string]interface{}{ 219 | "addr": ipFromAddr(addr), 220 | "transport": transportFromAddr(addr), 221 | }) 222 | } 223 | 224 | // NewNoopConnectionIOHook creates a noop implementation of ConnectionIOHook. 225 | func NewNoopConnectionIOHook() ConnectionIOHook { 226 | return &NoopConnectionIOHook{} 227 | } 228 | 229 | // EmitRead noops. 230 | func (h *NoopConnectionIOHook) EmitRead(latency time.Duration, addr net.Addr) {} 231 | 232 | // EmitReadError noops. 233 | func (h *NoopConnectionIOHook) EmitReadError(addr net.Addr) {} 234 | 235 | // EmitWrite noops. 236 | func (h *NoopConnectionIOHook) EmitWrite(latency time.Duration, addr net.Addr) {} 237 | 238 | // EmitWriteError noops. 239 | func (h *NoopConnectionIOHook) EmitWriteError(addr net.Addr) {} 240 | 241 | // EmitRetry noops. 242 | func (h *NoopConnectionIOHook) EmitRetry(addr net.Addr) {} 243 | 244 | // NewAsyncStatsdProxyHook creates a new client with the specified statsd address and sample rate. 245 | func NewAsyncStatsdProxyHook(addr string, sampleRate float64, version string) (ProxyHook, error) { 246 | client, err := statsdClientFactory(addr, sampleRate, version) 247 | if err != nil { 248 | return nil, err 249 | } 250 | 251 | return &AsyncStatsdProxyHook{client: client}, nil 252 | } 253 | 254 | // EmitRequestSize statsd implementation 255 | func (h *AsyncStatsdProxyHook) EmitRequestSize(bytes int64, client net.Addr) { 256 | go h.client.Size("size.proxy.request", bytes, map[string]interface{}{ 257 | "addr": ipFromAddr(client), 258 | }) 259 | } 260 | 261 | // EmitResponseSize statsd implementation 262 | func (h *AsyncStatsdProxyHook) EmitResponseSize(bytes int64, upstream net.Addr) { 263 | go h.client.Size("size.proxy.response", bytes, map[string]interface{}{ 264 | "addr": ipFromAddr(upstream), 265 | }) 266 | } 267 | 268 | // EmitRTT statsd implementation 269 | func (h *AsyncStatsdProxyHook) EmitRTT(latency time.Duration, client net.Addr, upstream net.Addr) { 270 | go h.client.Timing("latency.proxy.tx_rtt", latency, map[string]interface{}{ 271 | "client": ipFromAddr(client), 272 | "upstream": ipFromAddr(upstream), 273 | "transport": transportFromAddr(client), 274 | }) 275 | } 276 | 277 | // EmitUpstreamLatency statsd implementation 278 | func (h *AsyncStatsdProxyHook) EmitUpstreamLatency(latency time.Duration, client net.Addr, upstream net.Addr) { 279 | go h.client.Timing("latency.proxy.tx_upstream", latency, map[string]interface{}{ 280 | "client": ipFromAddr(client), 281 | "upstream": ipFromAddr(upstream), 282 | }) 283 | } 284 | 285 | // EmitProcess statsd implementation 286 | func (h *AsyncStatsdProxyHook) EmitProcess(client net.Addr, upstream net.Addr) { 287 | go func() { 288 | tags := map[string]interface{}{ 289 | "client": ipFromAddr(client), 290 | "upstream": ipFromAddr(upstream), 291 | } 292 | 293 | h.client.Count("event.proxy.process", 1, tags) 294 | h.client.Gauge( 295 | "gauge.proxy.sequence_id", 296 | float64(atomic.LoadInt64(&h.sequenceID)), 297 | tags, 298 | ) 299 | 300 | atomic.AddInt64(&h.sequenceID, 1) 301 | }() 302 | } 303 | 304 | // EmitError statsd implementation 305 | func (h *AsyncStatsdProxyHook) EmitError() { 306 | go h.client.Count("event.proxy.error", 1, nil) 307 | } 308 | 309 | // NewNoopProxyHook creates a noop implementation of ProxyHook. 310 | func NewNoopProxyHook() ProxyHook { 311 | return &NoopProxyHook{} 312 | } 313 | 314 | // EmitRequestSize noops. 315 | func (h *NoopProxyHook) EmitRequestSize(bytes int64, client net.Addr) {} 316 | 317 | // EmitResponseSize noops. 318 | func (h *NoopProxyHook) EmitResponseSize(bytes int64, upstream net.Addr) {} 319 | 320 | // EmitRTT noops. 321 | func (h *NoopProxyHook) EmitRTT(latency time.Duration, client net.Addr, upstream net.Addr) {} 322 | 323 | // EmitUpstreamLatency noops. 324 | func (h *NoopProxyHook) EmitUpstreamLatency(latency time.Duration, client net.Addr, upstream net.Addr) { 325 | } 326 | 327 | // EmitProcess noops. 328 | func (h *NoopProxyHook) EmitProcess(client net.Addr, upstream net.Addr) {} 329 | 330 | // EmitError noops. 331 | func (h *NoopProxyHook) EmitError() {} 332 | 333 | // statsdClientFactory creates a configured statsd client with reasonable defaults for the given 334 | // statsd server address and sample rate. 335 | func statsdClientFactory(addr string, sampleRate float64, version string) (*aperture.Client, error) { 336 | return aperture.NewClient(&aperture.Config{ 337 | Address: addr, 338 | Prefix: "dotproxy", 339 | SampleRate: sampleRate, 340 | TransportProbeInterval: 10 * time.Second, 341 | DefaultTags: map[string]interface{}{ 342 | "version": version, 343 | }, 344 | }) 345 | } 346 | 347 | // ipFromAddr returns the IP address from a full net.Addr, or null if unavailable. 348 | func ipFromAddr(addr net.Addr) string { 349 | switch networkAddr := addr.(type) { 350 | case *net.UDPAddr: 351 | return networkAddr.IP.String() 352 | case *net.TCPAddr: 353 | return networkAddr.IP.String() 354 | default: 355 | return "null" 356 | } 357 | } 358 | 359 | // transportFromAddr returns the transport protocol (as a string) behind a net.Addr, or null if 360 | // unavailable. 361 | func transportFromAddr(addr net.Addr) string { 362 | switch addr.(type) { 363 | case *net.UDPAddr: 364 | return "udp" 365 | case *net.TCPAddr: 366 | return "tcp" 367 | default: 368 | return "null" 369 | } 370 | } 371 | -------------------------------------------------------------------------------- /internal/network/client.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "syscall" 9 | "time" 10 | 11 | "dotproxy/internal/metrics" 12 | ) 13 | 14 | // Client defines the interface for a TCP network client. 15 | type Client interface { 16 | // Conn retrieves a single persistent connection. 17 | Conn() (*PersistentConn, error) 18 | 19 | // Stats returns historical client stats. 20 | Stats() Stats 21 | } 22 | 23 | // Stats formalizes stats tracked per-client. 24 | type Stats struct { 25 | // SuccessfulConnections is the number of connections that the client has successfully 26 | // provided. 27 | SuccessfulConnections int 28 | // FailedConnections is the number of times that the client has failed to provide a 29 | // connection. 30 | FailedConnections int 31 | } 32 | 33 | // TLSClient describes a TLS_secured TCP client that recycles connections in a pool. 34 | type TLSClient struct { 35 | addr string 36 | cxHook metrics.ConnectionLifecycleHook 37 | pool *PersistentConnPool 38 | stats Stats 39 | statsMutex sync.RWMutex 40 | } 41 | 42 | // TLSClientOpts formalizes TLS client configuration options. 43 | type TLSClientOpts struct { 44 | // PoolOpts are connection pool-specific options. 45 | PoolOpts PersistentConnPoolOpts 46 | // ConnectTimeout is the timeout associated with establishing a connection with the remote 47 | // server. 48 | ConnectTimeout time.Duration 49 | // HandshakeTimeout is the timeout associated with performing a TLS handshake with the 50 | // remote server, after a connection has been successfully established. 51 | HandshakeTimeout time.Duration 52 | // ReadTimeout is the timeout associated with each read from a remote connection. 53 | ReadTimeout time.Duration 54 | // WriteTimeout is the timeout associated with each write to a remote connection. 55 | WriteTimeout time.Duration 56 | } 57 | 58 | const ( 59 | // tcpFastOpenConnect is the TCP socket option constant (defined in the kernel) 60 | // controlling whether outgoing connections should use TCP Fast Open to reduce the number of 61 | // round trips, and thus overall latency, when re-establishing a TCP connection to a server. 62 | // It is not yet present in the syscall standard library for platform-agnostic usage. 63 | // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/tcp.h?h=v4.20#n120 64 | tcpFastOpenConnect = 30 65 | ) 66 | 67 | // NewTLSClient creates a TLSClient pool, connected to a specified remote address. 68 | // This procedure will establish the initial connections, perform TLS handshakes, and validate the 69 | // server identity. 70 | func NewTLSClient(addr string, serverName string, cxHook metrics.ConnectionLifecycleHook, opts TLSClientOpts) (*TLSClient, error) { 71 | // Use a custom dialer that sets the TCP Fast Open socket option and a connection timeout. 72 | dialer := &net.Dialer{ 73 | Timeout: opts.ConnectTimeout, 74 | Control: func(network string, addr string, rc syscall.RawConn) error { 75 | return rc.Control(func(fd uintptr) { 76 | syscall.SetsockoptInt( 77 | int(fd), 78 | syscall.IPPROTO_TCP, 79 | tcpFastOpenConnect, 80 | 1, 81 | ) 82 | }) 83 | }, 84 | } 85 | 86 | conf := &tls.Config{ 87 | ServerName: serverName, 88 | ClientSessionCache: tls.NewLRUClientSessionCache(opts.PoolOpts.Capacity), 89 | } 90 | 91 | // The TLS dialer wraps the custom TCP dialer with a TLS encryption layer and R/W timeouts. 92 | tlsDialer := func() (net.Conn, error) { 93 | conn, err := dialer.Dial("tcp", addr) 94 | if err != nil { 95 | return nil, fmt.Errorf("client: error establishing connection: err=%v", err) 96 | } 97 | 98 | // Implicitly set a TLS handshake timeout by enforcing a R/W deadline on the 99 | // underlying connection. 100 | if opts.HandshakeTimeout > 0 { 101 | conn.SetDeadline(time.Now().Add(opts.HandshakeTimeout)) 102 | } 103 | 104 | tlsConn := tls.Client(conn, conf) 105 | if err := tlsConn.Handshake(); err != nil { 106 | go conn.Close() 107 | return nil, fmt.Errorf("client: TLS handshake failed: err=%v", err) 108 | } 109 | 110 | return NewTCPConn(tlsConn, opts.ReadTimeout, opts.WriteTimeout), nil 111 | } 112 | 113 | pool := NewPersistentConnPool(tlsDialer, cxHook, opts.PoolOpts) 114 | 115 | return &TLSClient{ 116 | addr: addr, 117 | pool: pool, 118 | stats: Stats{}, 119 | }, nil 120 | } 121 | 122 | // Conn retrieves a single persistent connection from the pool. 123 | func (c *TLSClient) Conn() (*PersistentConn, error) { 124 | conn, err := c.pool.Conn() 125 | 126 | defer func() { 127 | go func() { 128 | c.statsMutex.Lock() 129 | defer c.statsMutex.Unlock() 130 | 131 | if err != nil { 132 | c.stats.FailedConnections++ 133 | } else { 134 | c.stats.SuccessfulConnections++ 135 | } 136 | }() 137 | }() 138 | 139 | return conn, err 140 | } 141 | 142 | // Stats returns current client stats. 143 | func (c *TLSClient) Stats() Stats { 144 | c.statsMutex.RLock() 145 | defer c.statsMutex.RUnlock() 146 | 147 | return c.stats 148 | } 149 | 150 | // String returns a string representation of the client. 151 | func (c *TLSClient) String() string { 152 | return fmt.Sprintf("TLSClient{addr: %s, connections: %d}", c.addr, c.pool.Size()) 153 | } 154 | -------------------------------------------------------------------------------- /internal/network/conn.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "time" 7 | ) 8 | 9 | // UDPConn is an abstraction over a UDP net.PacketConn to give it net.Conn-like semantics. It 10 | // statefully tracks connection state changes across reads and writes, assuming that a write follows 11 | // an initial read. 12 | type UDPConn struct { 13 | conn net.PacketConn 14 | readTimeout time.Duration 15 | writeTimeout time.Duration 16 | remote net.Addr 17 | } 18 | 19 | // TCPConn is an abstraction over a net.Conn that provides dynamic read and write timeouts. 20 | type TCPConn struct { 21 | readTimeout time.Duration 22 | writeTimeout time.Duration 23 | 24 | net.Conn 25 | } 26 | 27 | // NewUDPConn creates a UDPConn from a backing net.PacketConn. 28 | func NewUDPConn(conn net.PacketConn, readTimeout time.Duration, writeTimeout time.Duration) *UDPConn { 29 | return &UDPConn{ 30 | conn: conn, 31 | readTimeout: readTimeout, 32 | writeTimeout: writeTimeout, 33 | } 34 | } 35 | 36 | // Read performs a read from the remote client. The remote address is statefully tracked as a struct 37 | // member. 38 | func (c *UDPConn) Read(buf []byte) (n int, err error) { 39 | if c.remote != nil { 40 | return 0, fmt.Errorf("conn: already associated with a transaction") 41 | } 42 | 43 | if c.readTimeout > 0 { 44 | if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { 45 | return 0, err 46 | } 47 | } 48 | 49 | n, c.remote, err = c.conn.ReadFrom(buf) 50 | 51 | return 52 | } 53 | 54 | // Write writes to the same client from which data was read. It is an error to write to a connection 55 | // without a prior read from a remote client. 56 | func (c *UDPConn) Write(buf []byte) (n int, err error) { 57 | if c.remote == nil { 58 | return 0, fmt.Errorf("conn: no remote associated with this connection") 59 | } 60 | 61 | if c.writeTimeout > 0 { 62 | if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { 63 | return 0, err 64 | } 65 | } 66 | 67 | return c.conn.WriteTo(buf, c.remote) 68 | } 69 | 70 | // Close closes the underlying connection. 71 | func (c *UDPConn) Close() error { 72 | return c.conn.Close() 73 | } 74 | 75 | // LocalAddr obtains the connection's local address. 76 | func (c *UDPConn) LocalAddr() net.Addr { 77 | return c.conn.LocalAddr() 78 | } 79 | 80 | // RemoteAddr obtains the connection's remote address. 81 | func (c *UDPConn) RemoteAddr() net.Addr { 82 | return c.remote 83 | } 84 | 85 | // SetDeadline sets both the read and write deadline. 86 | func (c *UDPConn) SetDeadline(t time.Time) error { 87 | return c.conn.SetDeadline(t) 88 | } 89 | 90 | // SetReadDeadline sets the read deadline. 91 | func (c *UDPConn) SetReadDeadline(t time.Time) error { 92 | return c.conn.SetReadDeadline(t) 93 | } 94 | 95 | // SetWriteDeadline sets the write deadline. 96 | func (c *UDPConn) SetWriteDeadline(t time.Time) error { 97 | return c.conn.SetWriteDeadline(t) 98 | } 99 | 100 | // NewTCPConn creates a TCPConn from a backing net.Conn. 101 | func NewTCPConn(conn net.Conn, readTimeout time.Duration, writeTimeout time.Duration) *TCPConn { 102 | return &TCPConn{ 103 | Conn: conn, 104 | readTimeout: readTimeout, 105 | writeTimeout: writeTimeout, 106 | } 107 | } 108 | 109 | // Read sets a read deadline followed by reading from the backing connection. 110 | func (c *TCPConn) Read(buf []byte) (n int, err error) { 111 | if c.readTimeout > 0 { 112 | if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { 113 | return 0, err 114 | } 115 | } 116 | 117 | return c.Conn.Read(buf) 118 | } 119 | 120 | // Write sets a write deadline followed by reading from the backing connection. 121 | func (c *TCPConn) Write(buf []byte) (n int, err error) { 122 | if c.writeTimeout > 0 { 123 | if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { 124 | return 0, err 125 | } 126 | } 127 | 128 | return c.Conn.Write(buf) 129 | } 130 | -------------------------------------------------------------------------------- /internal/network/doc.go: -------------------------------------------------------------------------------- 1 | // Package network contains abstractions for communicating with other machines over the network. It 2 | // attempts to abstract away the API differences between TCP and UDP connections, and minimizes the 3 | // exposed interaction surfaces for TLS connections in an effort to simplify client usage. 4 | package network 5 | -------------------------------------------------------------------------------- /internal/network/init.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | func init() { 9 | // Seed the random number generator once 10 | rand.Seed(time.Now().UnixNano()) 11 | } 12 | -------------------------------------------------------------------------------- /internal/network/persistent.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "time" 7 | 8 | "lib.kevinlin.info/aperture/lib" 9 | 10 | "dotproxy/internal/data" 11 | "dotproxy/internal/metrics" 12 | ) 13 | 14 | // PersistentConnPool is a pool of persistent, long-lived connections. Connections are returned to 15 | // the pool instead of closed for later reuse. 16 | type PersistentConnPool struct { 17 | dialer func() (net.Conn, error) 18 | cxHook metrics.ConnectionLifecycleHook 19 | staleTimeout time.Duration 20 | conns *data.MRUQueue 21 | } 22 | 23 | // PersistentConnPoolOpts formalizes configuration options for a persistent connection pool. 24 | type PersistentConnPoolOpts struct { 25 | // Capacity is the maximum number of cached connections that may be held open in the pool. 26 | // Depending on client and server behaviors, the actual number of connections open at any 27 | // time may be less than or greater than this capacity. For example, there may be more 28 | // connections to serve a high number of concurrent clients, and there may be fewer 29 | // connections if many of them have been destroyed due to timeout or error. 30 | Capacity int 31 | // StaleTimeout is the duration after which a cached connection should be considered stale, 32 | // and thus reconnected before use. This represents the time between connection I/O events. 33 | StaleTimeout time.Duration 34 | } 35 | 36 | // PersistentConn is a net.Conn that lazily closes connections; it invokes a closer callback 37 | // function instead of actually closing the underlying connection. It also augments the net.Conn API 38 | // by providing a Destroy() method that forcefully closes the underlying connection. 39 | type PersistentConn struct { 40 | closer func(destroyed bool) error 41 | destroyed bool 42 | 43 | net.Conn 44 | } 45 | 46 | // NewPersistentConnPool creates a connection pool with the specified dialer factory and 47 | // configuration options. The dialer is a net.Conn factory that describes how a new connection is 48 | // created. 49 | func NewPersistentConnPool(dialer func() (net.Conn, error), cxHook metrics.ConnectionLifecycleHook, opts PersistentConnPoolOpts) *PersistentConnPool { 50 | conns := data.NewMRUQueue(opts.Capacity) 51 | 52 | // The entire pool is initially populated asynchronously with live connections, if possible. 53 | go func() { 54 | for i := 0; i < opts.Capacity; i++ { 55 | dialTimer := lib.NewStopwatch() 56 | conn, err := dialer() 57 | 58 | // It is nonideal, but not necessarily an error, if the pool cannot be 59 | // initially populated to the desired capacity. The size of the pool is 60 | // inherently variable, and pool clients generally degrade gracefully when 61 | // the pool fails to provide a connection. 62 | if err != nil { 63 | cxHook.EmitConnectionError() 64 | } else { 65 | cxHook.EmitConnectionOpen(dialTimer.Elapsed(), conn.RemoteAddr()) 66 | conns.Push(conn) 67 | } 68 | } 69 | }() 70 | 71 | return &PersistentConnPool{ 72 | dialer: dialer, 73 | cxHook: cxHook, 74 | staleTimeout: opts.StaleTimeout, 75 | conns: conns, 76 | } 77 | } 78 | 79 | // Conn returns a single connection. It may be a cached connection that already exists in the pool, 80 | // or it may be a newly created connection in the event that the pool is empty. 81 | func (p *PersistentConnPool) Conn() (*PersistentConn, error) { 82 | value, timestamp, ok := p.conns.Pop() 83 | 84 | // Factory for creating a closer callback that closes the connection if it is destroyed, but 85 | // otherwise returns it to the cached connections pool. 86 | closerFactory := func(conn net.Conn) func(destroyed bool) error { 87 | return func(destroyed bool) error { 88 | if destroyed { 89 | p.cxHook.EmitConnectionClose(conn.RemoteAddr()) 90 | return conn.Close() 91 | } 92 | 93 | return p.put(conn) 94 | } 95 | } 96 | 97 | // A cached connection is available; attempt to use it 98 | if ok { 99 | conn := value.(net.Conn) 100 | 101 | // The connection is not stale; use it 102 | if p.staleTimeout <= 0 || time.Since(timestamp) < p.staleTimeout { 103 | return NewPersistentConn(conn, closerFactory(conn)), nil 104 | } 105 | 106 | // The connection is stale; close it and open a new connection. 107 | // We are not particularly interested in propagating errors that may occur from 108 | // closing the connection, since it is already stale anyways. 109 | p.cxHook.EmitConnectionClose(conn.RemoteAddr()) 110 | go conn.Close() 111 | } 112 | 113 | // A cached connection is not available or stale; create a new one 114 | dialTimer := lib.NewStopwatch() 115 | conn, err := p.dialer() 116 | if err != nil { 117 | p.cxHook.EmitConnectionError() 118 | return nil, err 119 | } 120 | 121 | p.cxHook.EmitConnectionOpen(dialTimer.Elapsed(), conn.RemoteAddr()) 122 | 123 | return NewPersistentConn(conn, closerFactory(conn)), nil 124 | } 125 | 126 | // Size reports the current size of the connection pool. 127 | func (p *PersistentConnPool) Size() int { 128 | return p.conns.Size() 129 | } 130 | 131 | // put attempts to return a connection back to the pool, e.g. when it would otherwise be closed. 132 | // The connection will be reinserted into the pool if there is sufficient capacity; otherwise, the 133 | // connection is simply closed. 134 | func (p *PersistentConnPool) put(conn net.Conn) error { 135 | if ok := p.conns.Push(conn); !ok { 136 | return conn.Close() 137 | } 138 | 139 | return nil 140 | } 141 | 142 | // NewPersistentConn wraps an existing net.Conn with the specified close callback. 143 | func NewPersistentConn(conn net.Conn, closer func(destroyed bool) error) *PersistentConn { 144 | return &PersistentConn{closer: closer, Conn: conn} 145 | } 146 | 147 | // Close will invoke the close callback if the connection has not been destroyed; otherwise, it is 148 | // a noop. The callback is invoked with a single parameter describing whether the connection has 149 | // been marked as destroyed; the interpretation of a destroyed connection is abstracted out to the 150 | // PersistentConn callback supplier. 151 | func (c *PersistentConn) Close() error { 152 | return c.closer(c.destroyed) 153 | } 154 | 155 | // Destroy markes the connection as destroyed and invokes the close callback. 156 | func (c *PersistentConn) Destroy() error { 157 | c.destroyed = true 158 | 159 | return c.Close() 160 | } 161 | 162 | // String implements the Stringer interface for human-consumable representation. 163 | func (c *PersistentConn) String() string { 164 | return fmt.Sprintf("PersistentConn{%s->%s}", c.LocalAddr(), c.RemoteAddr()) 165 | } 166 | -------------------------------------------------------------------------------- /internal/network/server.go: -------------------------------------------------------------------------------- 1 | //go:generate go run golang.org/x/tools/cmd/stringer -type=Transport 2 | 3 | package network 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | "time" 10 | 11 | "dotproxy/internal/metrics" 12 | ) 13 | 14 | // contextKey is a type alias for context keys passed to server handlers. 15 | type contextKey int 16 | 17 | // Transport describes a network transport type. 18 | type Transport int 19 | 20 | // ServerHandler is a common interface that wraps logic for handling incoming connections on any 21 | // network transport. 22 | type ServerHandler interface { 23 | // Handle describes the routine to run when the server establishes a successful connection 24 | // with a client. The passed conn is a net.Conn-implementing TCPConn or UDPConn. 25 | Handle(ctx context.Context, conn net.Conn) error 26 | 27 | // ConsumeError is a callback invoked when the server fails to establish a connection with a 28 | // client, or when the handler returns an error. 29 | ConsumeError(ctx context.Context, err error) 30 | } 31 | 32 | // UDPServer describes a server that listens on a UDP address. 33 | type UDPServer struct { 34 | addr string 35 | opts UDPServerOpts 36 | } 37 | 38 | // UDPServerOpts formalizes UDP server configuration options. 39 | type UDPServerOpts struct { 40 | // MaxConcurrentConnections configures the maximum number of concurrent clients that the 41 | // server is capable of serving. It is generally recommended to set this value to the 42 | // highest number of concurrent connections the server can expect to receive, but it is safe 43 | // to set it lower. 44 | MaxConcurrentConnections int 45 | // ReadTimeout is the maximum amount of time the server will wait to read from a client. 46 | // Note that, since UDP is a connectionless protocol, this timeout value represents the 47 | // duration of time between when the socket begins listening for a connection to when the 48 | // client starts writing data. 49 | ReadTimeout time.Duration 50 | // WriteTimeout is the maximum amount of time the server is allowed to take to write data 51 | // back to a client, after which the server will consider the write to have failed. 52 | WriteTimeout time.Duration 53 | } 54 | 55 | // TCPServer describes a server that listens on a TCP address. 56 | type TCPServer struct { 57 | addr string 58 | cxHook metrics.ConnectionLifecycleHook 59 | opts TCPServerOpts 60 | } 61 | 62 | // TCPServerOpts formalizes TCP server configuration options. 63 | type TCPServerOpts struct { 64 | // ReadTimeout is the maximum amount of time the server will wait to read from a client 65 | // after it has established a connection with the server, after which the server will 66 | // consider the read to have failed. 67 | ReadTimeout time.Duration 68 | // WriteTimeout is the maximum amount of time the server is allowed to take to write to a 69 | // client, after which the server will consider the write to have failed. 70 | WriteTimeout time.Duration 71 | } 72 | 73 | const ( 74 | // TransportContextKey is the name of the context key used to indicate the network transport 75 | // protocol the handler is serving. This is necessary because the handler APIs are 76 | // abstracted to the point that they are inherently agnostic to the client connection's 77 | // underlying transport. 78 | TransportContextKey contextKey = iota 79 | ) 80 | 81 | const ( 82 | // TCP describes a TCP transport. 83 | TCP Transport = iota 84 | // UDP describes a UDP transport. 85 | UDP 86 | ) 87 | 88 | // NewUDPServer creates a UDP server listening on the specified address. 89 | func NewUDPServer(addr string, opts UDPServerOpts) *UDPServer { 90 | // Sane option defaults 91 | if opts.MaxConcurrentConnections <= 0 { 92 | opts.MaxConcurrentConnections = 16 93 | } 94 | 95 | return &UDPServer{addr, opts} 96 | } 97 | 98 | // ListenAndServe starts listening on the UDP address with which the server was configured and 99 | // indefinitely serves connections using the specified handler. It returns an error if it fails to 100 | // bind to the initialized address. 101 | func (s *UDPServer) ListenAndServe(handler ServerHandler) error { 102 | conn, err := net.ListenPacket("udp", s.addr) 103 | if err != nil { 104 | return fmt.Errorf("server: failed to listen on UDP socket: err=%v", err) 105 | } 106 | 107 | ctx := context.WithValue(context.Background(), TransportContextKey, UDP) 108 | 109 | for i := 0; i < s.opts.MaxConcurrentConnections; i++ { 110 | go func() { 111 | for { 112 | udpConn := NewUDPConn(conn, s.opts.ReadTimeout, s.opts.WriteTimeout) 113 | 114 | if err := handler.Handle(ctx, udpConn); err != nil { 115 | handler.ConsumeError(ctx, err) 116 | } 117 | } 118 | }() 119 | } 120 | 121 | return nil 122 | } 123 | 124 | // NewTCPServer creates a TCP server listening on the specified address. 125 | func NewTCPServer(addr string, cxHook metrics.ConnectionLifecycleHook, opts TCPServerOpts) *TCPServer { 126 | return &TCPServer{addr, cxHook, opts} 127 | } 128 | 129 | // ListenAndServe starts listening on the TCP address with which the server was configured and 130 | // indefinitely serves connections using the specified handler. It returns an error if it fails to 131 | //// bind to the initialized address. 132 | func (s *TCPServer) ListenAndServe(handler ServerHandler) error { 133 | ln, err := net.Listen("tcp", s.addr) 134 | if err != nil { 135 | return fmt.Errorf("server: failed to listen on TCP socket: err=%v", err) 136 | } 137 | 138 | ctx := context.WithValue(context.Background(), TransportContextKey, TCP) 139 | 140 | for { 141 | conn, err := ln.Accept() 142 | if err != nil { 143 | s.cxHook.EmitConnectionError() 144 | handler.ConsumeError(ctx, err) 145 | continue 146 | } 147 | 148 | tcpConn := NewTCPConn(conn, s.opts.ReadTimeout, s.opts.WriteTimeout) 149 | s.cxHook.EmitConnectionOpen(0, tcpConn.RemoteAddr()) 150 | 151 | go func() { 152 | defer func() { 153 | s.cxHook.EmitConnectionClose(tcpConn.RemoteAddr()) 154 | tcpConn.Close() 155 | }() 156 | 157 | if err := handler.Handle(ctx, tcpConn); err != nil { 158 | handler.ConsumeError(ctx, err) 159 | } 160 | }() 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /internal/network/sharding.go: -------------------------------------------------------------------------------- 1 | //go:generate go run golang.org/x/tools/cmd/stringer -type=LoadBalancingPolicy 2 | 3 | package network 4 | 5 | import ( 6 | "fmt" 7 | "math/rand" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // LoadBalancingPolicy formalizes the load balancing decision policy to apply when proxying requests 14 | // through a sharded network client. 15 | type LoadBalancingPolicy int 16 | 17 | // ShardedClientFactory is a type alias for a unary constructor function that returns a single 18 | // Client that abstracts operations among several child Clients. 19 | type ShardedClientFactory func([]Client) Client 20 | 21 | // RoundRobinShardedClient shards requests among clients fairly in round-robin order. 22 | type RoundRobinShardedClient struct { 23 | clients []Client 24 | 25 | // Current round robin index (not necessarily async-safe) 26 | rrIdx int 27 | } 28 | 29 | // RandomShardedClient shards requests among clients randomly. 30 | type RandomShardedClient struct { 31 | clients []Client 32 | } 33 | 34 | // HistoricalConnectionsShardedClient directs requests to the client that has, up until the time of 35 | // invocation, served the fewest number of successful connections. It is best used when there is a 36 | // need to ensure that load is distributed to all clients fairly even if one of them has failed. 37 | type HistoricalConnectionsShardedClient struct { 38 | clients []Client 39 | } 40 | 41 | // AvailabilityShardedClient provides connections by dynamically adjusting its active client pool to 42 | // prioritize those clients that are successful in providing new connections. It automatically fails 43 | // over failed client connection requests to healthy clients in the pool, temporarily disabling the 44 | // failed client for future requests with an exponential backoff policy. 45 | type AvailabilityShardedClient struct { 46 | clients []Client 47 | 48 | // Tracks the timestamp at which each client last errored 49 | lastError map[Client]time.Time 50 | // Tracks the current duration of time to wait before a failed connection is once again 51 | // available for use. 52 | errorExpiry map[Client]time.Duration 53 | // Mutex used to protect R/W operations on the state maps. 54 | mutex sync.RWMutex 55 | } 56 | 57 | // FailoverShardedClient provides connections in priority order, serially failing over to the next 58 | // client(s) in the list when the primary is not successful in providing a connection. 59 | type FailoverShardedClient struct { 60 | clients []Client 61 | } 62 | 63 | const ( 64 | // RoundRobin statefully iterates through each client on every connection request. 65 | RoundRobin LoadBalancingPolicy = iota 66 | // Random selects a client at random to provide the connection. 67 | Random 68 | // HistoricalConnections selects the client that has, up until the time of request, 69 | // provided the number of connections. 70 | HistoricalConnections 71 | // Availability randomly selects a client to provide the connection, failing over to another 72 | // client in the event that it fails to do so. The failed client is temporarily pulled out 73 | // of the availability pool to prevent subsequent requests from being directed to the failed 74 | // client. 75 | Availability 76 | // Failover provides connections from multiple clients in serial order, only failing over to 77 | // secondary clients when the primary fails. 78 | Failover 79 | ) 80 | 81 | // NewShardedClient creates a single Client that provides connections from several other Clients 82 | // governed by a load balancing policy. It returns an error if the specified load balancing policy 83 | // has no associated sharded client factory. 84 | func NewShardedClient(clients []Client, lbPolicy LoadBalancingPolicy) (Client, error) { 85 | factories := map[LoadBalancingPolicy]ShardedClientFactory{ 86 | RoundRobin: NewRoundRobinShardedClient, 87 | Random: NewRandomShardedClient, 88 | HistoricalConnections: NewHistoricalConnectionsShardedClient, 89 | Availability: NewAvailabilityShardedClient, 90 | Failover: NewFailoverShardedClient, 91 | } 92 | 93 | factory, ok := factories[lbPolicy] 94 | if !ok { 95 | return nil, fmt.Errorf( 96 | "sharding: no factory configured for load balancing policy: policy=%s", 97 | lbPolicy, 98 | ) 99 | } 100 | 101 | return factory(clients), nil 102 | } 103 | 104 | // NewRoundRobinShardedClient is a client factory for the round robin load balancing policy. 105 | func NewRoundRobinShardedClient(clients []Client) Client { 106 | return &RoundRobinShardedClient{clients: clients} 107 | } 108 | 109 | // Conn retrieves a connection from the next client in the round robin index. 110 | func (c *RoundRobinShardedClient) Conn() (*PersistentConn, error) { 111 | defer func() { 112 | c.rrIdx = (c.rrIdx + 1) % len(c.clients) 113 | }() 114 | 115 | return c.clients[c.rrIdx].Conn() 116 | } 117 | 118 | // Stats aggregates stats from all child clients. 119 | func (c *RoundRobinShardedClient) Stats() Stats { 120 | return aggregateClientsStats(c.clients) 121 | } 122 | 123 | // NewRandomShardedClient is a client factory for the random load balancing policy. 124 | func NewRandomShardedClient(clients []Client) Client { 125 | return &RandomShardedClient{clients} 126 | } 127 | 128 | // Conn selects a client at random to provide the connection. 129 | func (c *RandomShardedClient) Conn() (*PersistentConn, error) { 130 | return c.clients[rand.Intn(len(c.clients))].Conn() 131 | } 132 | 133 | // Stats aggregates stats from all child clients. 134 | func (c *RandomShardedClient) Stats() Stats { 135 | return aggregateClientsStats(c.clients) 136 | } 137 | 138 | // NewHistoricalConnectionsShardedClient is a client factory for the historical connections load 139 | // balancing policy. 140 | func NewHistoricalConnectionsShardedClient(clients []Client) Client { 141 | return &HistoricalConnectionsShardedClient{clients} 142 | } 143 | 144 | // Conn selects the client that has, up until the time of invocation, provided the fewest successful 145 | // connections. 146 | func (c *HistoricalConnectionsShardedClient) Conn() (*PersistentConn, error) { 147 | var client Client 148 | 149 | for _, candidate := range c.clients { 150 | if client == nil || candidate.Stats().SuccessfulConnections < client.Stats().SuccessfulConnections { 151 | client = candidate 152 | } 153 | } 154 | 155 | return client.Conn() 156 | } 157 | 158 | // Stats aggregates stats from all child clients. 159 | func (c *HistoricalConnectionsShardedClient) Stats() Stats { 160 | return aggregateClientsStats(c.clients) 161 | } 162 | 163 | // NewAvailabilityShardedClient is a client factory for the availability load balancing policy. 164 | func NewAvailabilityShardedClient(clients []Client) Client { 165 | lastError := make(map[Client]time.Time) 166 | errorExpiry := make(map[Client]time.Duration) 167 | 168 | for _, client := range clients { 169 | lastError[client] = time.Time{} 170 | errorExpiry[client] = 0 171 | } 172 | 173 | return &AvailabilityShardedClient{ 174 | clients: clients, 175 | lastError: lastError, 176 | errorExpiry: errorExpiry, 177 | } 178 | } 179 | 180 | // Conn attempts to robustly provide a connection from all available client using a failover retry 181 | // mechanism. It is possible for this method to error if the load balancing policy determines that 182 | // there are no live clients eligible for providing a connection. 183 | func (c *AvailabilityShardedClient) Conn() (*PersistentConn, error) { 184 | // Describes the amount of time that must elapse before resetting a client's error expiry 185 | // timer. In other words, this is the minimum amount of time after which a client errors 186 | // that it is permitted to be retried for a live connection. Otherwise, the connection is 187 | // pulled out of the sharding pool for exponentially increasing durations of time. 188 | failedClientExpiry := 30 * time.Second 189 | 190 | client, err := c.selectAvailable() 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | conn, err := client.Conn() 196 | if err != nil { 197 | c.mutex.Lock() 198 | 199 | if c.lastError[client].IsZero() || time.Since(c.lastError[client]) > failedClientExpiry { 200 | // The client has either never errored before, or the last error is too far 201 | // in the past. Start its exponential backoff timer at 100 ms, indicating 202 | // that this client will be marked unavailable for the next 100 ms. 203 | c.errorExpiry[client] = 100 * time.Millisecond 204 | } else { 205 | // The most recent client failure was too recent; double the current expiry 206 | // time. 207 | c.errorExpiry[client] *= 2 208 | } 209 | 210 | c.lastError[client] = time.Now() 211 | 212 | c.mutex.Unlock() 213 | 214 | return c.Conn() 215 | } 216 | 217 | return conn, nil 218 | } 219 | 220 | // Stats aggregates stats from all child clients. 221 | func (c *AvailabilityShardedClient) Stats() Stats { 222 | return aggregateClientsStats(c.clients) 223 | } 224 | 225 | // Select an eligible client at random. This method may error if no clients are available to 226 | // provide connections. 227 | func (c *AvailabilityShardedClient) selectAvailable() (Client, error) { 228 | var eligibleClients []Client 229 | 230 | for _, candidate := range c.clients { 231 | c.mutex.RLock() 232 | lastError := c.lastError[candidate] 233 | expiry := c.errorExpiry[candidate] 234 | c.mutex.RUnlock() 235 | 236 | // The client is considered eligible if it has never errored or if its current 237 | // failure lifetime has expired. 238 | if lastError.IsZero() || time.Since(lastError) > expiry { 239 | eligibleClients = append(eligibleClients, candidate) 240 | } 241 | } 242 | 243 | if len(eligibleClients) == 0 { 244 | return nil, fmt.Errorf("sharding: no live clients are available") 245 | } 246 | 247 | return eligibleClients[rand.Intn(len(eligibleClients))], nil 248 | } 249 | 250 | // NewFailoverShardedClient is a client factory for the failover load balancing policy. 251 | func NewFailoverShardedClient(clients []Client) Client { 252 | return &FailoverShardedClient{clients} 253 | } 254 | 255 | // Conn attempts to provide connections from clients in serial order, failing over to the next 256 | // client on error. 257 | func (c *FailoverShardedClient) Conn() (*PersistentConn, error) { 258 | for _, client := range c.clients { 259 | if conn, err := client.Conn(); err == nil { 260 | return conn, nil 261 | } 262 | } 263 | 264 | return nil, fmt.Errorf("sharding: all clients failed to provide a connection") 265 | } 266 | 267 | // Stats aggregates stats from all child clients. 268 | func (c *FailoverShardedClient) Stats() Stats { 269 | return aggregateClientsStats(c.clients) 270 | } 271 | 272 | // ParseLoadBalancingPolicy parses a LoadBalancingPolicy constant from its stringified 273 | // representation in a case-insensitive manner. 274 | func ParseLoadBalancingPolicy(lbPolicy string) (LoadBalancingPolicy, bool) { 275 | knownLbPolicies := []LoadBalancingPolicy{ 276 | RoundRobin, 277 | Random, 278 | HistoricalConnections, 279 | Availability, 280 | Failover, 281 | } 282 | 283 | for _, knownLbPolicy := range knownLbPolicies { 284 | if strings.ToLower(lbPolicy) == strings.ToLower(knownLbPolicy.String()) { 285 | return knownLbPolicy, true 286 | } 287 | } 288 | 289 | return RoundRobin, false 290 | } 291 | 292 | // aggregateClientsStats creates a single Stats struct from those in multiple Clients. 293 | func aggregateClientsStats(clients []Client) Stats { 294 | var multipleStats []Stats 295 | var aggregatedStats Stats 296 | 297 | for _, client := range clients { 298 | multipleStats = append(multipleStats, client.Stats()) 299 | } 300 | 301 | for _, stats := range multipleStats { 302 | aggregatedStats.SuccessfulConnections += stats.SuccessfulConnections 303 | aggregatedStats.FailedConnections += stats.FailedConnections 304 | } 305 | 306 | return aggregatedStats 307 | } 308 | -------------------------------------------------------------------------------- /internal/protocol/dns_proxy.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "net" 8 | 9 | "github.com/getsentry/raven-go" 10 | "lib.kevinlin.info/aperture/lib" 11 | 12 | "dotproxy/internal/log" 13 | "dotproxy/internal/metrics" 14 | "dotproxy/internal/network" 15 | ) 16 | 17 | // DNSProxyHandler is a semi-DNS-protocol-aware server handler that proxies requests between a 18 | // client and upstream server. 19 | type DNSProxyHandler struct { 20 | Upstream network.Client 21 | ClientCxIOHook metrics.ConnectionIOHook 22 | UpstreamCxIOHook metrics.ConnectionIOHook 23 | ProxyHook metrics.ProxyHook 24 | Logger log.Logger 25 | Opts DNSProxyOpts 26 | } 27 | 28 | // DNSProxyOpts formalizes configuration options for the proxy handler. 29 | type DNSProxyOpts struct { 30 | // MaxUpstreamRetries describes the maximum allowable times the proxy server is permitted to 31 | // retry a request with the upstream server(s). It is recommended to set this to a liberal 32 | // value above 0; since connections are pooled and persisted over a long period of time, it 33 | // is highly likely that any single proxy request will fail (due to a server-side closed 34 | // connection) and will need to be retried with another connection in the pool. 35 | MaxUpstreamRetries int 36 | } 37 | 38 | // ConsumeError simply logs the proxy error. 39 | func (h *DNSProxyHandler) ConsumeError(ctx context.Context, err error) { 40 | h.Logger.Error("%v", err) 41 | h.ProxyHook.EmitError() 42 | 43 | raven.CaptureError(err, map[string]string{ 44 | "transport": ctx.Value(network.TransportContextKey).(network.Transport).String(), 45 | }) 46 | } 47 | 48 | // Handle reads a request from the client connection, writes the request to the upstream connection, 49 | // reads the response from the upstream connection, and finally writes the response back to the 50 | // client. It performs some minimal protocol-aware data shaping and emits metrics along the way. 51 | func (h *DNSProxyHandler) Handle(ctx context.Context, clientConn net.Conn) error { 52 | rttTxTimer := lib.NewStopwatch() 53 | 54 | /* Read the DNS request from the client */ 55 | 56 | clientReq, err := h.clientRead(clientConn) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | h.Logger.Debug( 62 | "dns_proxy: read request from client: request_bytes=%d transport=%s", 63 | len(clientReq), 64 | ctx.Value(network.TransportContextKey), 65 | ) 66 | 67 | if ctx.Value(network.TransportContextKey) == network.UDP { 68 | // Since UDP is connectionless, the initial network read blocks until data is 69 | // available. Reset the RTT timer here to get an approximately correct estimate of 70 | // end-to-end latency. 71 | rttTxTimer = lib.NewStopwatch() 72 | 73 | // By RFC specification, DNS over TCP transports should include a two-octet header 74 | // in the request that denotes the size of the DNS packet. Since this request came 75 | // in on a UDP transport, augment the request payload to conform to standard. 76 | clientHeader := make([]byte, 2) 77 | binary.BigEndian.PutUint16(clientHeader, uint16(len(clientReq))) 78 | clientReq = append(clientHeader, clientReq...) 79 | } 80 | 81 | /* Open a (possibly cached) connection to the upstream and perform a W/R transaction */ 82 | 83 | maxRetries := h.Opts.MaxUpstreamRetries 84 | if maxRetries <= 0 { 85 | maxRetries = 16 86 | } 87 | 88 | upstreamResp, upstreamConn, err := h.proxyUpstream(clientConn, clientReq, maxRetries) 89 | if err != nil { 90 | return err 91 | } 92 | 93 | // Omit the response's size header if the client initially requested a UDP transport 94 | if ctx.Value(network.TransportContextKey) == network.UDP { 95 | upstreamResp = upstreamResp[2:] 96 | } 97 | 98 | /* Write the proxied result back to the client */ 99 | 100 | if err := h.clientWrite(clientConn, upstreamResp); err != nil { 101 | return err 102 | } 103 | 104 | h.Logger.Debug( 105 | "dns_proxy: completed write back to client: rtt=%v transport=%s", 106 | rttTxTimer.Elapsed(), 107 | ctx.Value(network.TransportContextKey), 108 | ) 109 | 110 | /* Clean up and report end-to-end metrics */ 111 | 112 | h.ProxyHook.EmitProcess(clientConn.RemoteAddr(), upstreamConn.RemoteAddr()) 113 | h.ProxyHook.EmitRequestSize(int64(len(clientReq)), clientConn.RemoteAddr()) 114 | h.ProxyHook.EmitResponseSize(int64(len(upstreamResp)), upstreamConn.RemoteAddr()) 115 | h.ProxyHook.EmitRTT( 116 | rttTxTimer.Elapsed(), 117 | clientConn.RemoteAddr(), 118 | upstreamConn.RemoteAddr(), 119 | ) 120 | 121 | return nil 122 | } 123 | 124 | // clientRead reads a request from the client. 125 | func (h *DNSProxyHandler) clientRead(conn net.Conn) ([]byte, error) { 126 | clientReadTimer := lib.NewStopwatch() 127 | clientReq := make([]byte, 1024) // The DNS protocol limits the maximum size of a DNS packet. 128 | 129 | clientReadBytes, err := conn.Read(clientReq) 130 | if err != nil { 131 | h.ClientCxIOHook.EmitReadError(conn.RemoteAddr()) 132 | return nil, fmt.Errorf("dns_proxy: error reading request from client: err=%v", err) 133 | } 134 | 135 | h.ClientCxIOHook.EmitRead(clientReadTimer.Elapsed(), conn.RemoteAddr()) 136 | 137 | // Trim the request buffer to only what the server was able to read 138 | return clientReq[:clientReadBytes], nil 139 | } 140 | 141 | // upstreamTransact performs a write-read transaction with the upstream connection and returns the 142 | // upstream response. 143 | func (h *DNSProxyHandler) upstreamTransact(client net.Conn, upstream *network.PersistentConn, clientReq []byte) ([]byte, error) { 144 | upstreamTxTimer := lib.NewStopwatch() 145 | 146 | /* Proxy the client request to the upstream */ 147 | 148 | upstreamWriteTimer := lib.NewStopwatch() 149 | 150 | upstreamWriteBytes, err := upstream.Write(clientReq) 151 | if err != nil || upstreamWriteBytes != len(clientReq) { 152 | h.UpstreamCxIOHook.EmitWriteError(upstream.RemoteAddr()) 153 | return nil, fmt.Errorf("dns_proxy: error writing to upstream: err=%v", err) 154 | } 155 | 156 | h.UpstreamCxIOHook.EmitWrite(upstreamWriteTimer.Elapsed(), upstream.RemoteAddr()) 157 | 158 | h.Logger.Debug("dns_proxy: wrote request to upstream: request_bytes=%d", upstreamWriteBytes) 159 | 160 | /* Read the response from the upstream */ 161 | 162 | upstreamReadTimer := lib.NewStopwatch() 163 | 164 | // By RFC specification, the server response follows the same format as the TCP request: the 165 | // first two bytes specify the length of the message. 166 | upstreamHeader := make([]byte, 2) 167 | upstreamHeaderBytes, err := upstream.Read(upstreamHeader) 168 | if err != nil || upstreamHeaderBytes != 2 { 169 | h.UpstreamCxIOHook.EmitReadError(upstream.RemoteAddr()) 170 | return nil, fmt.Errorf( 171 | "dns_proxy: error reading header from upstream: err=%v bytes=%d", 172 | err, 173 | upstreamHeaderBytes, 174 | ) 175 | } 176 | 177 | // Parse the alleged size of the remaining response and perform another exactly-sized read. 178 | respSize := binary.BigEndian.Uint16(upstreamHeader) 179 | upstreamResp := make([]byte, respSize) 180 | 181 | h.Logger.Debug("dns_proxy: read upstream header: response_size=%d", respSize) 182 | 183 | upstreamReadBytes, err := upstream.Read(upstreamResp) 184 | if err != nil || upstreamReadBytes != int(respSize) { 185 | h.UpstreamCxIOHook.EmitReadError(upstream.RemoteAddr()) 186 | return nil, fmt.Errorf( 187 | "dns_proxy: error reading full response from upstream: err=%v bytes=%d", 188 | err, 189 | upstreamReadBytes, 190 | ) 191 | } 192 | 193 | h.Logger.Debug("dns_proxy: read upstream response: response_bytes=%d", upstreamReadBytes) 194 | 195 | h.UpstreamCxIOHook.EmitRead(upstreamReadTimer.Elapsed(), upstream.RemoteAddr()) 196 | h.ProxyHook.EmitUpstreamLatency( 197 | upstreamTxTimer.Elapsed(), 198 | client.RemoteAddr(), 199 | upstream.RemoteAddr(), 200 | ) 201 | 202 | return append(upstreamHeader, upstreamResp...), nil 203 | } 204 | 205 | // proxyUpstream opens an upstream connection and performs a write-read transaction with a client 206 | // request, wrapping retry logic. It returns the upstream response, the upstream connection, and 207 | // optionally an error. 208 | func (h *DNSProxyHandler) proxyUpstream(client net.Conn, clientReq []byte, retries int) ([]byte, net.Conn, error) { 209 | upstream, err := h.Upstream.Conn() 210 | if err != nil { 211 | return nil, nil, fmt.Errorf( 212 | "dns_proxy: error opening upstream connection: err=%v", 213 | err, 214 | ) 215 | } 216 | 217 | h.Logger.Debug("dns_proxy: created upstream connection: conn=%v", upstream) 218 | 219 | resp, err := h.upstreamTransact(client, upstream, clientReq) 220 | if err != nil { 221 | // No matter the retry budget, destroy the connection if it fails during I/O 222 | go upstream.Destroy() 223 | 224 | if retries > 0 { 225 | h.UpstreamCxIOHook.EmitRetry(upstream.RemoteAddr()) 226 | h.Logger.Debug( 227 | "dns_proxy: upstream I/O failed; retrying: retry=%d", 228 | retries, 229 | ) 230 | 231 | return h.proxyUpstream(client, clientReq, retries-1) 232 | } 233 | 234 | h.Logger.Debug("dns_proxy: upstream I/O failed; available retries exhausted") 235 | 236 | return nil, nil, err 237 | } 238 | 239 | // Upstream transaction succeeded; schedule the connection for reinsertion into the 240 | // long-lived connection pool 241 | go upstream.Close() 242 | 243 | h.Logger.Debug("dns_proxy: completed upstream proxy: response_bytes=%d", len(resp)) 244 | 245 | return resp, upstream, err 246 | } 247 | 248 | // clientWrite writes data back to the client. 249 | func (h *DNSProxyHandler) clientWrite(conn net.Conn, upstreamResp []byte) error { 250 | clientWriteTimer := lib.NewStopwatch() 251 | clientWriteBytes, err := conn.Write(upstreamResp) 252 | 253 | if err != nil { 254 | h.ClientCxIOHook.EmitWriteError(conn.RemoteAddr()) 255 | return err 256 | } 257 | 258 | if clientWriteBytes != len(upstreamResp) { 259 | h.ClientCxIOHook.EmitWriteError(conn.RemoteAddr()) 260 | return fmt.Errorf( 261 | "dns_proxy: failed writing response bytes to client: expected=%d actual=%d", 262 | len(upstreamResp), 263 | clientWriteBytes, 264 | ) 265 | } 266 | 267 | h.ClientCxIOHook.EmitWrite(clientWriteTimer.Elapsed(), conn.RemoteAddr()) 268 | 269 | return nil 270 | } 271 | -------------------------------------------------------------------------------- /internal/protocol/doc.go: -------------------------------------------------------------------------------- 1 | // Package protocol concerns itself primarily with DNS protocol-specific business logic. It contains 2 | // logic used for understanding the DNS protocol and mediating requests and responses between the 3 | // client and upstream server(s). 4 | package protocol 5 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | // +build tools 2 | 3 | package tools 4 | 5 | import ( 6 | _ "golang.org/x/lint/golint" 7 | _ "golang.org/x/tools/cmd/stringer" 8 | ) 9 | --------------------------------------------------------------------------------