├── .github └── workflows │ └── go.yml ├── LICENSE ├── Makefile ├── README.md ├── conn.go ├── filter.go ├── go.mod ├── keepalive.go ├── listener.go ├── override.go ├── protolistener.go ├── proxy.go ├── proxy_test.go ├── tls.go └── tls_test.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: [push] 3 | jobs: 4 | 5 | build: 6 | name: Build 7 | runs-on: ubuntu-latest 8 | steps: 9 | 10 | - name: Set up Go 1.22 11 | uses: actions/setup-go@v1 12 | with: 13 | go-version: 1.22 14 | id: go 15 | 16 | - name: Check out code into the Go module directory 17 | uses: actions/checkout@v1 18 | 19 | - name: Get dependencies 20 | run: | 21 | go get -v -t -d ./... 22 | 23 | - name: Build 24 | run: go build -v . 25 | 26 | - name: Test 27 | run: go test -v . 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mark Karpelès 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | #!/bin/make 2 | GOROOT:=$(shell PATH="/pkg/main/dev-lang.go.dev/bin:$$PATH" go env GOROOT) 3 | GOPATH:=$(shell $(GOROOT)/bin/go env GOPATH) 4 | 5 | .PHONY: test deps 6 | 7 | all: 8 | $(GOPATH)/bin/goimports -w -l . 9 | $(GOROOT)/bin/go build -v 10 | 11 | deps: 12 | $(GOROOT)/bin/go get -v -t . 13 | 14 | test: 15 | $(GOROOT)/bin/go test -v 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagicTLS 2 | 3 | [![Build Status](https://github.com/KarpelesLab/magictls/workflows/Go/badge.svg)](https://github.com/KarpelesLab/magictls/actions) 4 | [![GoDoc](https://godoc.org/github.com/KarpelesLab/magictls?status.svg)](https://godoc.org/github.com/KarpelesLab/magictls) 5 | 6 | A simple Go library that detects protocol automatically: 7 | 8 | * Support for PROXY and PROXYv2 allows detecting the real user's IP when, for example, [using AWS elastic load balancers](https://docs.aws.amazon.com/elasticloadbalancing/latest/classic/enable-proxy-protocol.html). The fact the protocol is detected automatically allows the daemon to work even before ELB is properly configured, and avoid rejecting requests by mistake. 9 | * Automatic TLS support allows using a single port for SSL and non-SSL traffic, and simplifies configuration. 10 | * Allows creating listener for specific TLS negociated protocols, allowing a single port to be used for many things easily. 11 | 12 | This library was used in some of my projects, I've cleaned it up and licensed it under the MIT License since it's small and useful. Pull requests welcome. 13 | 14 | It is written to work with protocols where the client sends the first data, and it expects the client to send at least 16 bytes. This works nicely with HTTP (`GET / HTTP/1.0\r\n` is exactly 16 bytes), SSL, etc. but may not work with protocols such as POP3, IMAP or SMTP where the server is expected to send the first bytes unless TLS is required. In this case using the `ForceTLS` filter only allows to still benefit from the TLS NextProto routing. 15 | 16 | ## Usage 17 | 18 | Use `magictls.Listen()` to create sockets the same way you would use `tls.Listen()`. 19 | 20 | ```go 21 | socket, err := magictls.Listen("tcp", ":8080", tlsConfig) 22 | if err != nil { 23 | ... 24 | } 25 | log.Fatal(http.Serve(socket, handler)) 26 | ``` 27 | 28 | The created listener can receive various configurations. For example if you need to force all connections to be TLS and only want to use PROXY protocol detection: 29 | 30 | ```go 31 | socket, err := magictls.Listen("tcp", ":8443", tlsConfig) 32 | if err != nil { 33 | ... 34 | } 35 | socket.Filters = []magictls.Filter{magictls.DetectProxy, magictls.ForceTLS} 36 | log.Fatal(http.Serve(socket, handler)) 37 | ``` 38 | 39 | It is also possible to implement your own filters. 40 | 41 | ### PROXY protocol allowed IPs 42 | 43 | Depending on your provider, you may need to allow more than the local IPs for PROXY protocol. 44 | 45 | For example Google Cloud's global load balancer [uses a wider range of IPs](https://cloud.google.com/load-balancing/docs/firewall-rules) that may come with PROXY requests and need this to be called: 46 | 47 | ```go 48 | magictls.AddAllowedProxies("35.191.0.0/16", "130.211.0.0/22", "2600:2d00:1:b029::/64", "2600:2d00:1:1::/64") 49 | magictls.AddAllowedProxiesSpf("_cloud-eoips.googleusercontent.com") 50 | ``` 51 | 52 | ### autocert 53 | 54 | This can be used with [autocert](https://godoc.org/golang.org/x/crypto/acme/autocert) too for automatic TLS certificates. Note that in this case you are required to have a listener on port 443. 55 | 56 | ```go 57 | // initialize autocert structure 58 | m := &autocert.Manager{ 59 | Prompt: autocert.AcceptTOS, 60 | HostPolicy: autocert.HostWhitelist("domain", "domain2"), 61 | Cache: autocert.DirCache("/tmp"), // use os.UserCacheDir() to find where to put that 62 | } 63 | // grab autocert TLS config 64 | cfg := m.TLSConfig() 65 | // you may want to add to cfg.NextProtos any protocol you want to handle with ProtoListener. Be careful to not overwrite it. 66 | cfg.NextProtos = append(cfg.NextProtos, "my-proto") 67 | // standard listen 68 | socket, err := magictls.Listen("tcp", ":443", cfg) 69 | if err != nil { 70 | ... 71 | } 72 | ... 73 | ``` 74 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "net" 7 | "time" 8 | ) 9 | 10 | // Conn is used to prepend data to the data stream when we need to 11 | // unread what we've read. It can be used as a net.Conn. 12 | type Conn struct { 13 | conn net.Conn 14 | rbuf []byte 15 | l, r net.Addr 16 | used bool 17 | } 18 | 19 | func (c *Conn) isUsed() bool { 20 | if len(c.rbuf) != 0 { 21 | return true 22 | } 23 | return c.used 24 | } 25 | 26 | func (c *Conn) Read(b []byte) (int, error) { 27 | if ln := len(c.rbuf); ln > 0 { 28 | if len(b) >= ln { 29 | n := copy(b, c.rbuf) 30 | c.rbuf = nil 31 | return n, nil 32 | } 33 | // rbuf did not fit, return as much as we can and keep the rest 34 | n := copy(b, c.rbuf) 35 | c.rbuf = c.rbuf[n:] 36 | return n, nil 37 | } 38 | return c.conn.Read(b) 39 | } 40 | 41 | // PeekMore will perform a single read from the socket, and return the data 42 | // read so far. May return an error if the socket was closed (in which case 43 | // data may still be returned if it was read before). 44 | func (c *Conn) PeekMore(count int) ([]byte, error) { 45 | buf := make([]byte, count) 46 | n, err := c.conn.Read(buf) 47 | if err != nil { 48 | return c.rbuf, err 49 | } 50 | 51 | buf = buf[:n] // cut buf 52 | c.rbuf = append(c.rbuf, buf...) 53 | return c.rbuf, nil 54 | } 55 | 56 | // PeekUntil will block until at least count bytes were read, or an error 57 | // happens. 58 | func (c *Conn) PeekUntil(count int) ([]byte, error) { 59 | for len(c.rbuf) < count { 60 | _, err := c.PeekMore(count - len(c.rbuf)) 61 | if err != nil { 62 | return c.rbuf, err 63 | } 64 | } 65 | 66 | return c.rbuf, nil 67 | } 68 | 69 | // SkipPeek will skip count bytes from the peek buffer, or strip the whole 70 | // buffer if count is larger or equal to the buffer. 71 | func (c *Conn) SkipPeek(count int) { 72 | // skip X bytes from previous peeks 73 | if len(c.rbuf) <= count { 74 | c.rbuf = nil 75 | } else { 76 | c.rbuf = c.rbuf[count:] 77 | } 78 | } 79 | 80 | func (c *Conn) Write(b []byte) (int, error) { 81 | return c.conn.Write(b) 82 | } 83 | 84 | func (c *Conn) Close() error { 85 | return c.conn.Close() 86 | } 87 | 88 | func (c *Conn) LocalAddr() net.Addr { 89 | return c.l 90 | } 91 | 92 | func (c *Conn) RemoteAddr() net.Addr { 93 | return c.r 94 | } 95 | 96 | func (c *Conn) SetLocalAddr(l net.Addr) { 97 | c.l = l 98 | c.used = true 99 | } 100 | 101 | func (c *Conn) SetRemoteAddr(r net.Addr) { 102 | c.r = r 103 | c.used = true 104 | } 105 | 106 | func (c *Conn) SetDeadline(t time.Time) error { 107 | return c.conn.SetDeadline(t) 108 | } 109 | 110 | func (c *Conn) SetReadDeadline(t time.Time) error { 111 | return c.conn.SetReadDeadline(t) 112 | } 113 | 114 | func (c *Conn) SetWriteDeadline(t time.Time) error { 115 | return c.conn.SetWriteDeadline(t) 116 | } 117 | 118 | func (c *Conn) Unwrap() net.Conn { 119 | if c.rbuf != nil { 120 | // can't unwrap yet at this point 121 | return nil 122 | } 123 | return c.conn 124 | } 125 | 126 | // GetTlsConn will attempt to unwrap the given connection in order to locate 127 | // a TLS connection, or return nil if none found. 128 | func GetTlsConn(c net.Conn) *tls.Conn { 129 | for { 130 | switch cv := c.(type) { 131 | case *tls.Conn: 132 | return cv 133 | case interface{ Unwrap() net.Conn }: 134 | c = cv.Unwrap() 135 | default: 136 | return nil 137 | } 138 | } 139 | } 140 | 141 | // HijackedConn allows returning a simple net.Conn from a Conn+ReadWriter as returned by http.Hijacker.Hijack() 142 | func HijackedConn(c net.Conn, io *bufio.ReadWriter, err error) (net.Conn, error) { 143 | if err != nil { 144 | return nil, err 145 | } 146 | ln := io.Reader.Buffered() 147 | if ln == 0 { 148 | // nothing in reader, let's just return c 149 | return c, nil 150 | } 151 | data, err := io.Reader.Peek(ln) // should not fail 152 | if err != nil { 153 | return nil, err 154 | } 155 | return &Conn{conn: c, rbuf: data, l: c.LocalAddr(), r: c.RemoteAddr()}, nil 156 | } 157 | -------------------------------------------------------------------------------- /filter.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | // Filter is a generic magictls filter, to be used when accepting a connection. 4 | // Default filters provided for convenience include DetectProxy, DetectTLS and 5 | // ForceTLS. 6 | type Filter func(conn *Conn, srv *Listener) error 7 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/KarpelesLab/magictls 2 | 3 | go 1.21 4 | -------------------------------------------------------------------------------- /keepalive.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import "time" 4 | 5 | // tcpKeepaliveConn defines methods typically available on TCP connections to enable keepalive 6 | type tcpKeepaliveConn interface { 7 | SetKeepAlive(keepalive bool) error 8 | SetKeepAlivePeriod(d time.Duration) error 9 | } 10 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "io" 7 | "log" 8 | "net" 9 | "os" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | var ErrDuplicateProtocol = errors.New("protocol already has a listener") 15 | 16 | type queuePoint struct { 17 | c net.Conn 18 | e error 19 | } 20 | 21 | // Listener is a stream network listener supporting TLS and 22 | // PROXY protocol automatically. It assumes no matter what the used protocol 23 | // is, at least 16 bytes will always be initially sent (true for HTTP). 24 | type Listener struct { 25 | ports []net.Listener 26 | portsLk sync.Mutex 27 | addr net.Addr 28 | queue chan queuePoint 29 | proto map[string]*protoListener 30 | protoLk sync.RWMutex 31 | 32 | TLSConfig *tls.Config 33 | Filters []Filter 34 | *log.Logger 35 | Timeout time.Duration 36 | 37 | // threads 38 | thCnt uint32 39 | thMax uint32 40 | thCntLock sync.RWMutex 41 | thCntCond sync.Cond 42 | } 43 | 44 | // Listen creates a hybrid TCP/TLS listener accepting connections on the given 45 | // network address using net.Listen. The configuration config must be non-nil 46 | // and must include at least one certificate or else set GetCertificate. If 47 | // not, then only PROXY protocol support will be available. 48 | // 49 | // If the connection uses TLS protocol, then Accept() returned net.Conn will 50 | // actually be a tls.Conn object. 51 | func Listen(network, laddr string, config *tls.Config) (*Listener, error) { 52 | r := ListenNull() 53 | r.TLSConfig = config 54 | 55 | if err := r.Listen(network, laddr); err != nil { 56 | return nil, err 57 | } 58 | 59 | return r, nil 60 | } 61 | 62 | // ListenNull creates a listener that is not actually listening to anything, 63 | // but can be used to push connections via PushConn. This can be useful to use 64 | // a http.Server with custom listeners. 65 | func ListenNull() *Listener { 66 | return &Listener{ 67 | queue: make(chan queuePoint, 8), 68 | proto: make(map[string]*protoListener), 69 | Filters: []Filter{DetectProxy, DetectTLS}, 70 | Timeout: 15 * time.Second, 71 | thMax: 64, 72 | } 73 | } 74 | 75 | // Listen makes the given listener listen on an extra port. Each listener will 76 | // spawn a new goroutine. 77 | func (r *Listener) Listen(network, laddr string) error { 78 | return r.ListenFilter(network, laddr, nil) 79 | } 80 | 81 | // ListenFilter listens on a given port with the selected filters used instead 82 | // of the default ones. 83 | func (r *Listener) ListenFilter(network, laddr string, filters []Filter) error { 84 | port, err := net.Listen(network, laddr) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | if r.addr == nil { 90 | r.addr = port.Addr() 91 | } 92 | 93 | r.portsLk.Lock() 94 | defer r.portsLk.Unlock() 95 | 96 | r.ports = append(r.ports, port) 97 | 98 | go r.listenLoop(port, filters) 99 | return nil 100 | } 101 | 102 | // ProtoListener returns a net.Listener that will receive connections for which 103 | // TLS is enabled and the specified protocol(s) have been negociated between 104 | // client and server. 105 | func (r *Listener) ProtoListener(proto ...string) (net.Listener, error) { 106 | r.protoLk.Lock() 107 | defer r.protoLk.Unlock() 108 | 109 | // check if none of proto are taken 110 | for _, pr := range proto { 111 | if _, found := r.proto[pr]; found { 112 | return nil, ErrDuplicateProtocol 113 | } 114 | } 115 | 116 | // create listener, register 117 | l := &protoListener{ 118 | proto: proto, 119 | queue: make(chan *queuePoint, 8), 120 | parent: r, 121 | } 122 | 123 | for _, pr := range proto { 124 | r.proto[pr] = l 125 | } 126 | 127 | return l, nil 128 | } 129 | 130 | // SetThreads sets the number of threads (goroutines) magictls will spawn in 131 | // parallel when handling incoming connections. Note that once a connection 132 | // leaves Accept() it is not tracked anymore. 133 | // Filters will however run in parallel for those connections, meaning that 134 | // one connection's handshake taking time will not block other connections. 135 | func (r *Listener) SetThreads(count uint32) { 136 | r.thCntLock.Lock() 137 | defer r.thCntLock.Unlock() 138 | 139 | r.thMax = count 140 | } 141 | 142 | // GetRunningThreads returns the current number of running threads. 143 | func (r *Listener) GetRunningThreads() uint32 { 144 | r.thCntLock.RLock() 145 | defer r.thCntLock.RUnlock() 146 | 147 | return r.thCnt 148 | } 149 | 150 | // Accept blocks until a connection is available, then return said connection 151 | // or an error if the listener was closed. 152 | func (r *Listener) Accept() (net.Conn, error) { 153 | // TODO implement timeouts? 154 | p, ok := <-r.queue 155 | if !ok { 156 | return nil, io.EOF 157 | } 158 | 159 | return p.c, p.e 160 | } 161 | 162 | // processFilters is run in a thread and will execute filters as needed. 163 | func (r *Listener) processFilters(c net.Conn, filters []Filter) { 164 | defer func() { 165 | r.thCntLock.Lock() 166 | r.thCnt -= 1 167 | r.thCntLock.Unlock() 168 | }() 169 | 170 | cw := &Conn{ 171 | conn: c, 172 | l: c.LocalAddr(), 173 | r: c.RemoteAddr(), 174 | } 175 | 176 | var ( 177 | tlsconn *tls.Conn 178 | negociatedProtocol string 179 | ) 180 | 181 | if filters == nil { 182 | filters = r.Filters 183 | } 184 | 185 | // for each filter 186 | for _, f := range filters { 187 | cw.SetReadDeadline(time.Now().Add(r.Timeout)) 188 | err := f(cw, r) 189 | if err != nil { 190 | if err == io.EOF { 191 | // ignore EOF errors, those are typically not important 192 | continue 193 | } 194 | if errors.Is(err, os.ErrDeadlineExceeded) { 195 | // timeout reached for running this filter 196 | continue 197 | } 198 | if ov, ok := err.(*Override); ok { 199 | if ov.Conn != nil { 200 | if t, ok := ov.Conn.(*tls.Conn); ok { 201 | // keep this tls connection nearby 202 | tlsconn = t 203 | } 204 | // perform override 205 | cw = &Conn{ 206 | conn: ov.Conn, 207 | l: ov.Conn.LocalAddr(), 208 | r: ov.Conn.RemoteAddr(), 209 | } 210 | } 211 | if ov.Protocol != "" { 212 | negociatedProtocol = ov.Protocol 213 | } 214 | continue 215 | } 216 | 217 | // For now we ignore all filter errors 218 | if r.Logger != nil { 219 | r.Logger.Printf("filter error on new connection: %s", err) 220 | } 221 | cw.Close() 222 | return 223 | } 224 | } 225 | cw.SetReadDeadline(time.Time{}) // disable any timeout 226 | 227 | var final net.Conn 228 | final = cw 229 | if !cw.isUsed() { 230 | // skip cw 231 | final = cw.conn 232 | } 233 | 234 | if tlsconn != nil && negociatedProtocol == "" { 235 | // special case: this is a tls socket. Check NegotiatedProtocol 236 | negociatedProtocol = tlsconn.ConnectionState().NegotiatedProtocol 237 | } 238 | 239 | if negociatedProtocol != "" { 240 | // grab lock 241 | r.protoLk.RLock() 242 | v, ok := r.proto[negociatedProtocol] 243 | r.protoLk.RUnlock() 244 | 245 | if ok { 246 | // send value 247 | v.queue <- &queuePoint{c: final, e: nil} 248 | return 249 | } 250 | } 251 | r.queue <- queuePoint{c: final} 252 | } 253 | 254 | // Close() closes the socket. 255 | func (r *Listener) Close() error { 256 | r.portsLk.Lock() 257 | defer r.portsLk.Unlock() 258 | 259 | for n, port := range r.ports { 260 | if err := port.Close(); err != nil { 261 | r.ports = r.ports[n:] // drop any port that was successfully closed 262 | return err 263 | } 264 | } 265 | r.ports = nil 266 | return nil 267 | } 268 | 269 | // Addr returns the address the socket is currently listening on, or nil for 270 | // null listeners. 271 | func (r *Listener) Addr() net.Addr { 272 | return r.addr 273 | } 274 | 275 | func (r *Listener) listenLoop(port net.Listener, filterOverride []Filter) { 276 | var tempDelay time.Duration // how long to sleep on accept failure 277 | for { 278 | c, err := port.Accept() 279 | if err != nil { 280 | // check for temporary error 281 | if ne, ok := err.(net.Error); ok && ne.Temporary() { 282 | if tempDelay == 0 { 283 | tempDelay = 5 * time.Millisecond 284 | } else { 285 | tempDelay *= 2 286 | } 287 | if max := 1 * time.Second; tempDelay > max { 288 | tempDelay = max 289 | } 290 | time.Sleep(tempDelay) 291 | continue 292 | } 293 | 294 | // send error & close 295 | r.queue <- queuePoint{e: err} 296 | close(r.queue) 297 | return 298 | } else { 299 | // enable tcp keepalive 300 | if kc, ok := c.(tcpKeepaliveConn); ok { 301 | kc.SetKeepAlive(true) 302 | kc.SetKeepAlivePeriod(3 * time.Minute) 303 | } 304 | 305 | r.HandleConn(c, filterOverride) 306 | } 307 | } 308 | } 309 | 310 | // PushConn allows pushing an existing connection to the queue as if it had 311 | // just been accepted by the server. No auto-detection will be performed. 312 | func (r *Listener) PushConn(c net.Conn) { 313 | r.queue <- queuePoint{c: c} 314 | } 315 | 316 | // HandleConn will run detection on a given incoming connection and attempt to 317 | // find if it should parse any kind of PROXY headers, or TLS handshake/etc. 318 | func (r *Listener) HandleConn(c net.Conn, filterOverride []Filter) { 319 | r.thCntLock.Lock() 320 | if r.thCnt >= r.thMax { 321 | // out of luck 322 | r.thCntLock.Unlock() 323 | c.Close() 324 | return 325 | } 326 | r.thCnt += 1 327 | r.thCntLock.Unlock() 328 | 329 | go r.processFilters(c, filterOverride) 330 | } 331 | 332 | func (p *Listener) String() string { 333 | return p.addr.String() 334 | } 335 | -------------------------------------------------------------------------------- /override.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import "net" 4 | 5 | // special type returned as error by filters to return a different Conn 6 | type Override struct { 7 | Conn net.Conn 8 | Protocol string 9 | } 10 | 11 | func (o *Override) Error() string { 12 | return "Connection override required" 13 | } 14 | -------------------------------------------------------------------------------- /protolistener.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "io" 5 | "net" 6 | ) 7 | 8 | type protoListener struct { 9 | proto []string 10 | queue chan *queuePoint 11 | parent *Listener 12 | } 13 | 14 | func (p *protoListener) Accept() (net.Conn, error) { 15 | pc, ok := <-p.queue 16 | if !ok { 17 | return nil, io.EOF 18 | } 19 | 20 | return pc.c, pc.e 21 | } 22 | 23 | func (p *protoListener) Addr() net.Addr { 24 | return p.parent.Addr() 25 | } 26 | 27 | func (p *protoListener) Close() error { 28 | if p.queue == nil { 29 | return nil 30 | } 31 | 32 | // remove self 33 | p.parent.protoLk.Lock() 34 | defer p.parent.protoLk.Unlock() 35 | 36 | for _, proto := range p.proto { 37 | delete(p.parent.proto, proto) 38 | } 39 | 40 | // we can close p.queue here because we have the lock 41 | close(p.queue) 42 | p.queue = nil 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "log" 8 | "net" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | var allowedProxyIps []*net.IPNet 14 | 15 | func init() { 16 | SetAllowedProxies("127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fd00::/8") 17 | } 18 | 19 | type proxyError struct { 20 | version int 21 | msg string 22 | } 23 | 24 | func (p *proxyError) Error() string { 25 | return fmt.Sprintf("Error in PROXYv%d protocol: %s", p.version, p.msg) 26 | } 27 | 28 | // SetAllowedProxies allows modifying the list of IP addresses allowed to use 29 | // proxy protocol. Any host matching a CIDR listed in here will be trusted to 30 | // provide the client's real IP. 31 | // 32 | // By default all local IPs are allowed as these cannot appear on Internet. 33 | // 34 | // SetAllowedProxies("127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fd00::/8") 35 | func SetAllowedProxies(cidrs ...string) error { 36 | allowedProxyIps = nil 37 | 38 | return AddAllowedProxies(cidrs...) 39 | } 40 | 41 | // AddAllowedProxies adds to the list of allowed proxies 42 | func AddAllowedProxies(cidrs ...string) error { 43 | allowed := allowedProxyIps 44 | 45 | for _, s := range cidrs { 46 | _, ipn, err := net.ParseCIDR(s) 47 | if err != nil { 48 | return err 49 | } 50 | allowed = append(allowed, ipn) 51 | } 52 | 53 | allowedProxyIps = allowed 54 | return nil 55 | } 56 | 57 | // AddAllowedProxiesSpf will perform TXT lookup on the given hosts and add 58 | // those IPs as allowed proxies. This is only performed once and may need to 59 | // be refreshed from times to times. 60 | func AddAllowedProxiesSpf(spfhosts ...string) error { 61 | // example: magictls.AddAllowedProxiesSpf("_cloud-eoips.googleusercontent.com") 62 | var cidrs []string 63 | 64 | for _, host := range spfhosts { 65 | txtrecords, err := net.LookupTXT(host) 66 | if err != nil { 67 | return err 68 | } 69 | for _, rec := range txtrecords { 70 | // typical response: v=spf1 ip4: ~all 71 | recData := strings.Fields(rec) 72 | for _, rec := range recData { 73 | if val, ok := strings.CutPrefix(rec, "ip4:"); ok { 74 | if strings.IndexByte(val, '/') == -1 { 75 | // no / found, this is a host and not a cidr 76 | val += "/32" 77 | } 78 | cidrs = append(cidrs, val) 79 | } else if val, ok = strings.CutPrefix(rec, "ip6:"); ok { 80 | if strings.IndexByte(val, '/') == -1 { 81 | // no / found, this is a host and not a cidr 82 | val += "/128" 83 | } 84 | cidrs = append(cidrs, val) 85 | } 86 | } 87 | } 88 | } 89 | 90 | return AddAllowedProxies(cidrs...) 91 | } 92 | 93 | // DetectProxy is a magictls filter that will detect proxy protocol headers 94 | // (both versions) and update local/remote addr based on these if the 95 | // source is an allowed proxy (see SetAllowedProxies). 96 | func DetectProxy(cw *Conn, srv *Listener) error { 97 | proxyAllow := false 98 | 99 | switch ipaddr := cw.r.(type) { 100 | case *net.TCPAddr: 101 | for _, n := range allowedProxyIps { 102 | if n.Contains(ipaddr.IP) { 103 | proxyAllow = true 104 | break 105 | } 106 | } 107 | case *net.IPAddr: 108 | for _, n := range allowedProxyIps { 109 | if n.Contains(ipaddr.IP) { 110 | proxyAllow = true 111 | break 112 | } 113 | } 114 | } 115 | 116 | if !proxyAllow { 117 | return nil 118 | } 119 | 120 | buf, err := cw.PeekUntil(16) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | // detect proxy 126 | if bytes.Compare(buf[:12], []byte{0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a}) == 0 { 127 | // proxy protocol v2 128 | var verCmd, fam uint8 129 | verCmd = buf[12] 130 | fam = buf[13] 131 | ln := binary.BigEndian.Uint16(buf[14:16]) 132 | var d []byte 133 | if ln > 0 { 134 | tmp, err := cw.PeekUntil(16 + int(ln)) 135 | if err != nil { 136 | log.Printf("magictls: failed to read proxy v2 data") 137 | return err 138 | } 139 | d = tmp[16:] 140 | } 141 | if err := parseProxyV2Data(cw, verCmd, fam, d); err != nil { 142 | return err 143 | } 144 | cw.SkipPeek(16 + int(ln)) 145 | return nil 146 | } else if bytes.Compare(buf[:6], []byte("PROXY ")) == 0 { 147 | // proxy protocol v1 148 | var pos int 149 | 150 | for { 151 | buf, err = cw.PeekMore(128) // max proxy line length is 107 bytes in theory 152 | if err != nil { 153 | log.Printf("magictls: failed to read full line of proxy protocol") 154 | return err 155 | } 156 | 157 | pos = bytes.IndexByte(buf, '\n') 158 | if pos > 0 { 159 | break 160 | } 161 | if len(buf) > 128 { 162 | log.Printf("magictls: got proxy protocol intro but line is too long, ignoring") 163 | return nil 164 | } 165 | } 166 | 167 | err := parseProxyLine(cw, buf[:pos]) 168 | if err != nil { 169 | return err 170 | } 171 | 172 | cw.SkipPeek(pos + 1) 173 | } 174 | return nil 175 | } 176 | 177 | func parseProxyLine(c *Conn, buf []byte) error { 178 | if buf[len(buf)-1] == '\r' { 179 | buf = buf[:len(buf)-1] 180 | } 181 | 182 | s := bytes.Split(buf, []byte{' '}) 183 | if bytes.Compare(s[0], []byte("PROXY")) != 0 { 184 | return &proxyError{version: 1, msg: "invalid proxy line provided"} 185 | } 186 | 187 | // see: magictls://www.haproxy.org/download/1.5/doc/proxy-protocol.txt 188 | switch string(s[1]) { 189 | case "UNKNOWN": 190 | return nil // do nothing 191 | case "TCP4", "TCP6": 192 | if len(s) < 6 { 193 | return &proxyError{version: 1, msg: "not enough parameters for TCP PROXY"} 194 | } 195 | rPort, _ := strconv.Atoi(string(s[4])) 196 | lPort, _ := strconv.Atoi(string(s[5])) 197 | c.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP(string(s[2])), Port: rPort}) 198 | c.SetLocalAddr(&net.TCPAddr{IP: net.ParseIP(string(s[3])), Port: lPort}) 199 | return nil 200 | default: 201 | return &proxyError{version: 1, msg: "invalid proxy transport provided"} 202 | } 203 | } 204 | 205 | func parseProxyV2Data(c *Conn, verCmd, fam uint8, d []byte) error { 206 | if verCmd>>4&0xf != 0x2 { 207 | return &proxyError{version: 2, msg: "uynsupported header version"} 208 | } 209 | switch verCmd & 0xf { 210 | case 0x0: // LOCAL (health check, etc) 211 | return nil 212 | case 0x1: // PROXY 213 | break 214 | default: 215 | return &proxyError{version: 2, msg: "unsupported proxy type"} 216 | } 217 | 218 | switch fam >> 4 & 0xf { 219 | case 0x0: // UNSPEC 220 | return nil 221 | case 0x1, 0x2: // AF_INET, AF_INET6 222 | break 223 | case 0x3: // AF_UNIX 224 | return nil 225 | default: 226 | return &proxyError{version: 2, msg: "unsupported address family"} 227 | } 228 | 229 | switch fam & 0xf { 230 | case 0x0: // UNSPEC 231 | return nil 232 | case 0x1, 0x2: // STREAM, DGRAM 233 | break 234 | default: 235 | return &proxyError{version: 2, msg: "unsupported protocol"} 236 | } 237 | 238 | // sanitarization done, let's parse data 239 | b := bytes.NewBuffer(d) 240 | var rPort, lPort uint16 241 | 242 | switch fam >> 4 & 0xf { 243 | case 0x1: // AF_INET 244 | if len(d) < 12 { 245 | return &proxyError{version: 2, msg: "not enough data for ipv4"} 246 | } 247 | rip := make([]byte, 4) 248 | lip := make([]byte, 4) 249 | binary.Read(b, binary.BigEndian, rip) 250 | binary.Read(b, binary.BigEndian, lip) 251 | binary.Read(b, binary.BigEndian, &rPort) 252 | binary.Read(b, binary.BigEndian, &lPort) 253 | 254 | c.SetRemoteAddr(&net.TCPAddr{IP: rip, Port: int(rPort)}) 255 | c.SetLocalAddr(&net.TCPAddr{IP: lip, Port: int(lPort)}) 256 | case 0x2: // AF_INET6 257 | if len(d) < 36 { 258 | return &proxyError{version: 2, msg: "not enough data for ipv6"} 259 | } 260 | rip := make([]byte, 16) 261 | lip := make([]byte, 16) 262 | binary.Read(b, binary.BigEndian, rip) 263 | binary.Read(b, binary.BigEndian, lip) 264 | binary.Read(b, binary.BigEndian, &rPort) 265 | binary.Read(b, binary.BigEndian, &lPort) 266 | 267 | c.SetRemoteAddr(&net.TCPAddr{IP: rip, Port: int(rPort)}) 268 | c.SetLocalAddr(&net.TCPAddr{IP: lip, Port: int(lPort)}) 269 | } 270 | return nil 271 | } 272 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "testing" 10 | ) 11 | 12 | const testStr = "This is a test string. It could really be anything as long as it's unique and longer than 128 bytes. The proxy protocol detector expects at least 128 bytes of input data from the remote peer in order to know we have a proxy thing, but we can add a timeout to that eventually..." 13 | 14 | func send(w io.Writer, data ...string) error { 15 | for _, s := range data { 16 | b := []byte(s) 17 | l := len(b) 18 | binary.Write(w, binary.BigEndian, int32(l)) 19 | w.Write(b) 20 | } 21 | return nil 22 | } 23 | 24 | func readOne(r io.Reader) string { 25 | var l int32 26 | err := binary.Read(r, binary.BigEndian, &l) 27 | if err != nil { 28 | return "" 29 | } 30 | b := make([]byte, l) 31 | io.ReadFull(r, b) 32 | return string(b) 33 | } 34 | 35 | type proxyV1test struct { 36 | proxy string 37 | expect string 38 | } 39 | 40 | func TestProxy(t *testing.T) { 41 | rdy := make(chan int) 42 | go testProxySrv(rdy) 43 | port := <-rdy 44 | 45 | t.Logf("running tests on port %d", port) 46 | 47 | tests := []proxyV1test{ 48 | proxyV1test{"", ""}, 49 | proxyV1test{"PROXY UNKNOWN\n", ""}, 50 | proxyV1test{"PROXY UNKNOWN\r\n", ""}, 51 | proxyV1test{"PROXY TCP4 1.1.1.1 2.2.2.2 123 456\n", "1.1.1.1:123"}, 52 | } 53 | 54 | for _, test := range tests { 55 | c, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) 56 | if err != nil { 57 | t.Errorf("failed test %+v: %s", test, err) 58 | continue 59 | } 60 | 61 | if test.proxy != "" { 62 | c.Write([]byte(test.proxy)) 63 | } 64 | 65 | send(c, testStr) 66 | 67 | l1 := readOne(c) 68 | l2 := readOne(c) 69 | c.Close() 70 | 71 | if test.expect != "" && l1 != test.expect { 72 | t.Errorf("expected %s but got %s", test.expect, l1) 73 | } 74 | 75 | if l2 != testStr { 76 | t.Errorf("failed, expected string but got %s / %s", l1, l2) 77 | } 78 | } 79 | } 80 | 81 | func testProxySrv(rdy chan int) { 82 | l, err := Listen("tcp", "127.0.0.1:0", nil) 83 | if err != nil { 84 | // shouldn't happen 85 | panic(err) 86 | } 87 | 88 | // return port 89 | rdy <- l.Addr().(*net.TCPAddr).Port 90 | 91 | // read loop (default handler) 92 | for { 93 | c, err := l.Accept() 94 | if err != nil { 95 | log.Printf("accept error: %s", err) 96 | return 97 | } 98 | 99 | data := readOne(c) 100 | send(c, c.RemoteAddr().String(), data) 101 | c.Close() 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /tls.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import "crypto/tls" 4 | 5 | // DetectTLS is a magictls filter that will attempt to detect if the connection 6 | // is a TLS client. This best works with protocols where the first byte is 7 | // expected to be an ASCII character, such as HTTP. This will not work well if 8 | // the client is not sending the first message. 9 | func DetectTLS(conn *Conn, srv *Listener) error { 10 | buf, err := conn.PeekUntil(1) 11 | if err != nil { 12 | return err 13 | } 14 | 15 | // perform auto-detection 16 | if buf[0]&0x80 == 0x80 { 17 | // SSLv2, probably. At least, not HTTP 18 | cs := tls.Server(conn, srv.TLSConfig) 19 | if err = cs.Handshake(); err != nil { 20 | // note: at this point we lost data, connection should be closed 21 | conn.Close() 22 | return err 23 | } 24 | return &Override{Conn: cs} 25 | } 26 | if buf[0] == 0x16 { 27 | // SSLv3, TLS 28 | cs := tls.Server(conn, srv.TLSConfig) 29 | if err = cs.Handshake(); err != nil { 30 | // note: at this point we lost data, connection should be closed 31 | conn.Close() 32 | return err 33 | } 34 | return &Override{Conn: cs} 35 | } 36 | 37 | // probably not tls 38 | return nil 39 | } 40 | 41 | // ForceTLS is a magictls filter that will engage TLS mode. 42 | func ForceTLS(conn *Conn, srv *Listener) error { 43 | cs := tls.Server(conn, srv.TLSConfig) 44 | if err := cs.Handshake(); err != nil { 45 | return err 46 | } 47 | return &Override{Conn: cs} 48 | } 49 | -------------------------------------------------------------------------------- /tls_test.go: -------------------------------------------------------------------------------- 1 | package magictls 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/rand" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "fmt" 11 | "io/ioutil" 12 | "log" 13 | "math/big" 14 | "net" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | var ( 20 | testP = x509.NewCertPool() 21 | pk, _ = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 22 | ca *x509.Certificate 23 | ) 24 | 25 | func init() { 26 | // initialize some basic stuff for testing 27 | catpl := &x509.Certificate{ 28 | SerialNumber: big.NewInt(1), 29 | Issuer: pkix.Name{CommonName: "localhost"}, 30 | Subject: pkix.Name{CommonName: "localhost"}, 31 | BasicConstraintsValid: true, 32 | IsCA: true, 33 | NotBefore: time.Now(), 34 | NotAfter: time.Now().Add(24 * time.Hour), 35 | DNSNames: []string{"localhost"}, 36 | } 37 | 38 | caBin, err := x509.CreateCertificate(rand.Reader, catpl, catpl, pk.Public(), pk) 39 | if err != nil { 40 | panic(err) 41 | } 42 | ca, err = x509.ParseCertificate(caBin) 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | testP.AddCert(ca) 48 | } 49 | 50 | func TestTLS(t *testing.T) { 51 | rdy := make(chan int) 52 | 53 | go testSrv(rdy) 54 | port := <-rdy 55 | 56 | t.Logf("running tests on port %d", port) 57 | 58 | c, err := tls.Dial("tcp", fmt.Sprintf("localhost:%d", port), &tls.Config{RootCAs: testP}) 59 | if err != nil { 60 | t.Errorf("failed test: %s", err) 61 | return 62 | } 63 | 64 | buf, err := ioutil.ReadAll(c) 65 | c.Close() 66 | if err != nil { 67 | t.Errorf("failed read: %s", err) 68 | return 69 | } 70 | 71 | if string(buf) != "hello world" { 72 | t.Errorf("invalid response: %s", buf) 73 | return 74 | } 75 | 76 | // test with "A" 77 | c, err = tls.Dial("tcp", fmt.Sprintf("localhost:%d", port), &tls.Config{RootCAs: testP, NextProtos: []string{"a"}}) 78 | if err != nil { 79 | t.Errorf("failed test: %s", err) 80 | return 81 | } 82 | 83 | buf, err = ioutil.ReadAll(c) 84 | c.Close() 85 | if err != nil { 86 | t.Errorf("failed read: %s", err) 87 | return 88 | } 89 | 90 | if string(buf) != "hello from A" { 91 | t.Errorf("invalid response: %s", buf) 92 | return 93 | } 94 | 95 | // test with "B" and proxy protocol 96 | ctcp, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) 97 | if err != nil { 98 | t.Errorf("failed connect: %s", err) 99 | return 100 | } 101 | 102 | // send proxy line 103 | fmt.Fprintf(ctcp, "PROXY TCP4 10.0.0.1 10.0.0.2 123 456\r\n") 104 | 105 | // initialize tls 106 | c = tls.Client(ctcp, &tls.Config{RootCAs: testP, ServerName: "localhost", NextProtos: []string{"b"}}) 107 | 108 | buf, err = ioutil.ReadAll(c) 109 | c.Close() 110 | if err != nil { 111 | t.Errorf("failed read: %s", err) 112 | return 113 | } 114 | 115 | if string(buf) != "IP = 10.0.0.1:123" { 116 | t.Errorf("invalid response: %s", buf) 117 | return 118 | } 119 | 120 | // test with "B" and proxyv2 protocol 121 | ctcp, err = net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) 122 | if err != nil { 123 | t.Errorf("failed connect: %s", err) 124 | return 125 | } 126 | 127 | // send proxyv2 prefix 128 | ctcp.Write([]byte{0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a}) 129 | // send proxyv2 info 130 | ctcp.Write([]byte{0x21, 0x11, 0, 12, 10, 0, 0, 1, 10, 0, 0, 2, 0, 123, 1, 200}) 131 | 132 | // initialize tls 133 | c = tls.Client(ctcp, &tls.Config{RootCAs: testP, ServerName: "localhost", NextProtos: []string{"b"}}) 134 | 135 | buf, err = ioutil.ReadAll(c) 136 | c.Close() 137 | if err != nil { 138 | t.Errorf("failed read: %s", err) 139 | return 140 | } 141 | 142 | if string(buf) != "IP = 10.0.0.1:123" { 143 | t.Errorf("invalid response: %s", buf) 144 | return 145 | } 146 | } 147 | 148 | func testSrv(rdy chan int) { 149 | cfg := &tls.Config{ 150 | RootCAs: testP, 151 | Certificates: []tls.Certificate{ 152 | tls.Certificate{ 153 | Certificate: [][]byte{ca.Raw}, 154 | PrivateKey: pk, 155 | Leaf: ca, 156 | }, 157 | }, 158 | NextProtos: []string{"a", "b", "c", "d", "e", "f"}, 159 | } 160 | 161 | l, err := Listen("tcp", "127.0.0.1:0", cfg) 162 | if err != nil { 163 | // shouldn't happen 164 | panic(err) 165 | } 166 | 167 | // return port 168 | rdy <- l.Addr().(*net.TCPAddr).Port 169 | 170 | if sub, err := l.ProtoListener("a"); err == nil { 171 | go handleA(sub) 172 | } 173 | if sub, err := l.ProtoListener("b"); err == nil { 174 | go handleB(sub) 175 | } 176 | 177 | // read loop (default handler) 178 | for { 179 | c, err := l.Accept() 180 | if err != nil { 181 | log.Printf("accept error: %s", err) 182 | return 183 | } 184 | 185 | c.Write([]byte("hello world")) 186 | c.Close() 187 | } 188 | } 189 | 190 | func handleA(l net.Listener) { 191 | for { 192 | c, err := l.Accept() 193 | if err != nil { 194 | log.Printf("accept error: %s", err) 195 | return 196 | } 197 | 198 | fmt.Fprintf(c, "hello from A") 199 | c.Close() 200 | } 201 | } 202 | 203 | func handleB(l net.Listener) { 204 | for { 205 | c, err := l.Accept() 206 | if err != nil { 207 | log.Printf("accept error: %s", err) 208 | return 209 | } 210 | 211 | fmt.Fprintf(c, "IP = %s", c.RemoteAddr()) 212 | c.Close() 213 | } 214 | } 215 | --------------------------------------------------------------------------------