├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── go.mod ├── main.go ├── main_test.go └── profiling.go /.gitignore: -------------------------------------------------------------------------------- 1 | json-tcp-lb 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:latest as builder 2 | 3 | WORKDIR /src/json-tcp-lb 4 | COPY go.mod ./ 5 | #noexternal deps yet 6 | #RUN go mod download 7 | 8 | COPY *.go ./ 9 | 10 | RUN go test ./... 11 | RUN CGO_ENABLED=0 go build -ldflags="-w -s" -o /usr/bin/json-tcp-lb 12 | 13 | #### 14 | 15 | FROM alpine 16 | COPY --from=builder /usr/bin/json-tcp-lb /usr/bin/json-tcp-lb 17 | ENV USER=jsonlb 18 | ENV UID=10001 19 | RUN adduser \ 20 | --disabled-password \ 21 | --gecos "" \ 22 | --home "/nonexistent" \ 23 | --shell "/sbin/nologin" \ 24 | --no-create-home \ 25 | --uid "${UID}" \ 26 | "${USER}" 27 | 28 | USER $USER:$USER 29 | 30 | ENTRYPOINT ["/usr/bin/json-tcp-lb"] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Corelight, Inc. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | (1) Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | (2) Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in 12 | the documentation and/or other materials provided with the 13 | distribution. 14 | 15 | (3) Neither the name of Corelight nor the names of any contributors 16 | may be used to endorse or promote products derived from this 17 | software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # json tcp lb 2 | 3 | [![Docker Automated Build](https://img.shields.io/docker/cloud/automated/corelight/json-tcp-lb.svg)](https://cloud.docker.com/repository/docker/corelight/json-tcp-lb/builds) 4 | [![Docker Build Status](https://img.shields.io/docker/cloud/build/corelight/json-tcp-lb.svg)](https://cloud.docker.com/repository/docker/corelight/json-tcp-lb/builds) 5 | 6 | 7 | This is a simple line based tcp load balancing proxy. It is designed to work with newline 8 | delimited json, but will work with any line based protocol. 9 | 10 | This is different from a basic TCP proxy in that it will load balance data in a 11 | single connection across multiple destinations. 12 | 13 | ## Features 14 | 15 | * Load balancing to multiple connections across multiple targets. 16 | * Failed transmissions will be retried to avoid ever losing data. 17 | * Target failover and failback. 18 | 19 | ## Implementation 20 | 21 | * The proxy will start up N worker `connections` to each `target`. 22 | * The proxy will read data from the incoming connection into a 16KB buffer. 23 | * The buffer will be split cleanly on a newline boundary, or combined with additional data until at least one newline is seen. 24 | * The buffer containing one or more lines is places onto a channel and will be pulled by a worker and transmitted to a target. 25 | * If any of the worker connections fail, it will attempt to connect to a random target instead. 26 | * Every 5 minutes it will attempt to reconnect to its original target. 27 | 28 | ## Usage 29 | 30 | Usage of ./json-tcp-lb: 31 | -addr string 32 | Address to listen on (default "0.0.0.0") 33 | -connections int 34 | Number of outbound connections to make to each target (default 4) 35 | -port int 36 | Port to listen on (default 9000) 37 | -target string 38 | Address to proxy to. separate multiple with comma (default "127.0.0.1:9999") 39 | -tls-cert string 40 | TLS Certificate PEM file. Configuring this enables TLS 41 | -tls-key string 42 | TLS Certificate Key PEM file 43 | -tls-target 44 | Connect to the targets using TLS 45 | -tls-target-skip-verify 46 | Accepts any certificate presented by the target 47 | 48 | ## TLS 49 | 50 | TLS can be used on the incoming connections or the outbound target connections, or both. 51 | 52 | To listen using TLS provide the `tls-cert` and `tls-key` options. 53 | 54 | To connect to targets using TLS toggle the `tls-target` option. 55 | 56 | ## Container Usage 57 | 58 | The `Dockerfile` can be used to build the tcp lb in a container, then run like the following example: 59 | ```bash 60 | docker build -t json-tcp-lb . 61 | docker run -p ${LISTEN_PORT}:${LISTEN_PORT} -target ${TARGET_STRING} json-tcp-lb 62 | ``` 63 | 64 | ## Alternatives 65 | 66 | I'm not aware of any simple alternatives. This is similar to something like 67 | gRPC load balancing across a single http/2 session in something like Envoy. 68 | It's likely possible to add 'newline delimited' data as a codec in Envoy or 69 | another load balancer. 70 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/corelight/json-tcp-lb 2 | 3 | go 1.16 4 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "errors" 8 | "flag" 9 | "fmt" 10 | "log" 11 | "math/rand" 12 | "net" 13 | "os" 14 | "os/signal" 15 | "strings" 16 | "sync" 17 | "syscall" 18 | "time" 19 | ) 20 | 21 | const bufferSize int = 16384 22 | 23 | type Config struct { 24 | // Basic Settings 25 | Addr string 26 | Port int 27 | Connections int 28 | Targets []string 29 | 30 | // Listen using TLS 31 | CertFile string 32 | KeyFile string 33 | 34 | // Connect using TLS 35 | TargetTLS bool 36 | TargetTLSSkipVerify bool 37 | } 38 | 39 | var bufPool = sync.Pool{ 40 | New: func() interface{} { 41 | buf := new(bytes.Buffer) 42 | buf.Grow(bufferSize * 2) 43 | return buf 44 | }, 45 | } 46 | 47 | // sleep returns nil after the specified duration or error if interrupted. 48 | // Modified slightly from 49 | // from https://stackoverflow.com/questions/55135239/how-can-i-sleep-with-responsive-context-cancelation 50 | func sleep(ctx context.Context, d time.Duration) error { 51 | t := time.NewTimer(d) 52 | select { 53 | case <-ctx.Done(): 54 | t.Stop() 55 | return fmt.Errorf("Interrupted: %w", ctx.Err()) 56 | case <-t.C: 57 | } 58 | return nil 59 | } 60 | 61 | func receive(conn net.Conn, out chan *bytes.Buffer) { 62 | log.Printf("New connection from %s", conn.RemoteAddr()) 63 | defer conn.Close() 64 | 65 | buf := make([]byte, bufferSize) 66 | var stringbuf *bytes.Buffer 67 | stringbuf = bufPool.Get().(*bytes.Buffer) 68 | for { 69 | n, err := conn.Read(buf) 70 | if err != nil { 71 | if errors.Is(err, net.ErrClosed) { 72 | log.Printf("Closed inbound connection from %s", conn.RemoteAddr()) 73 | } else { 74 | log.Printf("Error reading from %s: %v", conn.RemoteAddr(), err) 75 | } 76 | break 77 | } 78 | 79 | lastNewlineIndex := bytes.LastIndexByte(buf[:n], byte('\n')) 80 | if lastNewlineIndex != -1 { 81 | //Newline, truncate and send 82 | stringbuf.Write(buf[:lastNewlineIndex+1]) 83 | out <- stringbuf 84 | stringbuf = bufPool.Get().(*bytes.Buffer) 85 | stringbuf.Write(buf[lastNewlineIndex+1 : n]) 86 | } else { 87 | //No Newline, append to buffer 88 | stringbuf.Write(buf[:n]) 89 | } 90 | } 91 | if stringbuf.Len() > 0 { 92 | out <- stringbuf 93 | } 94 | } 95 | 96 | type Worker struct { 97 | id int 98 | targets []string //The list of all available targets 99 | defaultTarget string //The default target this worker should be using 100 | curTarget string //The currently used target 101 | conn net.Conn 102 | lastReconnect time.Time 103 | cfg Config 104 | } 105 | 106 | func (w Worker) String() string { 107 | return fmt.Sprintf("worker-%02d", w.id) 108 | } 109 | 110 | func (w Worker) isConnectedToPrimary() bool { 111 | return w.curTarget == w.defaultTarget 112 | } 113 | func (w *Worker) Connect(ctx context.Context, target string) (net.Conn, error) { 114 | var conn net.Conn 115 | var err error 116 | if !w.cfg.TargetTLS { 117 | conn, err = net.DialTimeout("tcp", target, 5*time.Second) 118 | } else { 119 | conf := &tls.Config{} 120 | if w.cfg.TargetTLSSkipVerify { 121 | conf.InsecureSkipVerify = true 122 | } 123 | conn, err = tls.DialWithDialer(&net.Dialer{Timeout: 10 * time.Second}, "tcp", target, conf) 124 | } 125 | return conn, err 126 | } 127 | 128 | // ConnectWithRetries tries to connect to a target with exponential backoff 129 | func (w *Worker) ConnectWithRetries(ctx context.Context) error { 130 | rand.Seed(time.Now().UnixNano()) 131 | delay := 2 * time.Second 132 | w.curTarget = w.defaultTarget 133 | for { 134 | conn, err := w.Connect(ctx, w.curTarget) 135 | //log.Printf("Worker %d: Opening connection to %v", w.id, w.target) 136 | if err == nil { 137 | w.Close() 138 | log.Printf("Worker %d: connected to %s", w.id, w.curTarget) 139 | w.conn = conn 140 | w.lastReconnect = time.Now() 141 | return nil 142 | } 143 | log.Printf("Worker %d: Unable connect to %s: %v", w.id, w.curTarget, err) 144 | //The context is done 145 | if ctx.Err() != nil { 146 | return err 147 | } 148 | sleep(ctx, delay) 149 | delay *= 2 150 | if delay > 30*time.Second { 151 | delay = 30 * time.Second 152 | } 153 | //After a failure, move onto a random target 154 | w.curTarget = w.targets[rand.Intn(len(w.targets))] 155 | } 156 | } 157 | 158 | func (w *Worker) ConnectIfNeeded(ctx context.Context) error { 159 | if w.conn == nil { 160 | return w.ConnectWithRetries(ctx) 161 | } 162 | //If not connected to the desired target, try reconnecting if it's been 5 minutes 163 | if !w.isConnectedToPrimary() && time.Since(w.lastReconnect) > 5*time.Second { 164 | log.Printf("Worker %d: attempting to reconnect to primary target", w.id) 165 | return w.ConnectWithRetries(ctx) 166 | } 167 | return nil 168 | } 169 | func (w *Worker) Close() { 170 | if w.conn != nil { 171 | w.conn.Close() 172 | w.conn = nil 173 | } 174 | } 175 | 176 | func (w *Worker) Reconnect(ctx context.Context) { 177 | w.Close() 178 | w.ConnectWithRetries(ctx) 179 | } 180 | func (w *Worker) Write(b []byte) (int, error) { 181 | w.conn.SetDeadline(time.Now().Add(30 * time.Second)) 182 | n, err := w.conn.Write(b) 183 | return n, err 184 | } 185 | func (w *Worker) WriteWithRetries(ctx context.Context, b []byte) (int, error) { 186 | for { 187 | w.ConnectIfNeeded(ctx) 188 | n, err := w.Write(b) 189 | if err == nil { 190 | return n, err 191 | } 192 | log.Printf("Worker %d: Error writing to %s: %v. n=%d, len=%d", w.id, w.curTarget, err, n, len(b)) 193 | w.Close() 194 | } 195 | } 196 | 197 | func transmit(ctx context.Context, cfg Config, worker int, outputChan chan *bytes.Buffer, target int) { 198 | var b *bytes.Buffer 199 | 200 | w := &Worker{ 201 | id: worker, 202 | cfg: cfg, 203 | targets: cfg.Targets, 204 | defaultTarget: cfg.Targets[target], 205 | } 206 | err := w.ConnectWithRetries(ctx) 207 | //Only happens if we are exiting during startup 208 | if err != nil { 209 | return 210 | } 211 | var exit bool 212 | 213 | doneChan := ctx.Done() 214 | 215 | idleCount := 0 216 | timer := time.NewTicker(1 * time.Second) 217 | defer timer.Stop() 218 | 219 | for { 220 | select { 221 | case <-timer.C: 222 | idleCount++ 223 | //Exit if we are done and have not received any logs to write in 5 ticks. 224 | if exit && idleCount >= 5 { 225 | w.Close() 226 | return 227 | } 228 | case <-doneChan: 229 | log.Printf("Worker %d: draining records and exiting...", worker) 230 | exit = true 231 | doneChan = nil 232 | case b = <-outputChan: 233 | idleCount = 0 234 | //This will retry forever and will not fail 235 | w.WriteWithRetries(context.TODO(), b.Bytes()) 236 | //Message succesfully sent.. but... 237 | //Only return small buffers to the pool 238 | if b.Cap() <= 1024*1024 { 239 | b.Reset() 240 | bufPool.Put(b) 241 | } 242 | } 243 | } 244 | } 245 | func proxy(ctx context.Context, l net.Listener, cfg Config) error { 246 | numTargets := len(cfg.Targets) 247 | outputChan := make(chan *bytes.Buffer, cfg.Connections*numTargets*2) 248 | var wg sync.WaitGroup 249 | for i := 0; i < cfg.Connections*numTargets; i++ { 250 | wg.Add(1) 251 | go func(idx int) { 252 | targetIdx := idx % numTargets 253 | transmit(ctx, cfg, idx+1, outputChan, targetIdx) 254 | log.Printf("Worker %d done", idx+1) 255 | wg.Done() 256 | }(i) 257 | } 258 | go func() { 259 | <-ctx.Done() 260 | l.Close() 261 | }() 262 | var err error 263 | for { 264 | conn, err := l.Accept() 265 | if err != nil { 266 | break 267 | } 268 | go func() { 269 | <-ctx.Done() 270 | conn.Close() 271 | }() 272 | go receive(conn, outputChan) 273 | } 274 | //Wait for all workers to exit 275 | wg.Wait() 276 | return err 277 | } 278 | 279 | func listen(cfg Config) (net.Listener, error) { 280 | bind := fmt.Sprintf("%s:%d", cfg.Addr, cfg.Port) 281 | 282 | var l net.Listener 283 | var err error 284 | 285 | if cfg.CertFile == "" || cfg.KeyFile == "" { 286 | log.Printf("Listening on %s", bind) 287 | l, err = net.Listen("tcp", bind) 288 | } else { 289 | cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) 290 | if err != nil { 291 | log.Fatal(err) 292 | } 293 | config := &tls.Config{Certificates: []tls.Certificate{cert}} 294 | 295 | log.Printf("listening on %s using TLS", bind) 296 | l, err = tls.Listen("tcp", bind, config) 297 | } 298 | return l, err 299 | } 300 | 301 | func listenAndProxy(cfg Config) error { 302 | ctx, cancel := context.WithCancel(context.Background()) 303 | l, err := listen(cfg) 304 | if err != nil { 305 | return err 306 | } 307 | defer cancel() 308 | sigs := make(chan os.Signal, 1) 309 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 310 | go func() { 311 | sig := <-sigs 312 | log.Printf("Received signal %s, exiting", sig) 313 | cancel() 314 | }() 315 | 316 | return proxy(ctx, l, cfg) 317 | } 318 | 319 | func main() { 320 | var cfg Config 321 | var targets string 322 | flag.StringVar(&cfg.Addr, "addr", "0.0.0.0", "Address to listen on") 323 | flag.IntVar(&cfg.Port, "port", 9000, "Port to listen on") 324 | flag.StringVar(&targets, "target", "127.0.0.1:9999", "Address to proxy to. separate multiple with comma") 325 | flag.BoolVar(&cfg.TargetTLS, "tls-target", false, "Connect to the targets using TLS") 326 | flag.BoolVar(&cfg.TargetTLSSkipVerify, "tls-target-skip-verify", false, "Accepts any certificate presented by the target") 327 | flag.IntVar(&cfg.Connections, "connections", 4, "Number of outbound connections to make to each target") 328 | flag.StringVar(&cfg.CertFile, "tls-cert", "", "TLS Certificate PEM file. Configuring this enables TLS") 329 | flag.StringVar(&cfg.KeyFile, "tls-key", "", "TLS Certificate Key PEM file") 330 | flag.Parse() 331 | cfg.Targets = strings.Split(targets, ",") 332 | err := listenAndProxy(cfg) 333 | if err != nil { 334 | log.Fatal(err) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/rand" 7 | "crypto/rsa" 8 | "crypto/tls" 9 | "crypto/x509" 10 | "crypto/x509/pkix" 11 | "encoding/pem" 12 | "fmt" 13 | "io" 14 | "math/big" 15 | "net" 16 | "os" 17 | "path/filepath" 18 | "sync" 19 | "testing" 20 | "time" 21 | ) 22 | 23 | // generateTestCerts generates a self signed key pair 24 | func generateTestCerts(t *testing.T) (string, string) { 25 | // Mostly from https://go.dev/src/crypto/tls/generate_cert.go 26 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | notBefore := time.Now().Add(-5 * time.Minute) 31 | notAfter := notBefore.Add(10 * time.Minute) 32 | serialNumber := big.NewInt(42) 33 | 34 | keyUsage := x509.KeyUsageDigitalSignature 35 | keyUsage |= x509.KeyUsageKeyEncipherment 36 | 37 | template := x509.Certificate{ 38 | SerialNumber: serialNumber, 39 | Subject: pkix.Name{ 40 | Organization: []string{"Acme Co"}, 41 | }, 42 | NotBefore: notBefore, 43 | NotAfter: notAfter, 44 | 45 | KeyUsage: keyUsage, 46 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 47 | BasicConstraintsValid: true, 48 | } 49 | template.IPAddresses = append(template.IPAddresses, net.ParseIP("127.0.0.1")) 50 | 51 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | 56 | certDir := t.TempDir() 57 | 58 | certPath := filepath.Join(certDir, "cert.pem") 59 | keyPath := filepath.Join(certDir, "key.pem") 60 | 61 | certOut, err := os.Create(certPath) 62 | if err != nil { 63 | t.Fatalf("Failed to open %v for writing: %v", certPath, err) 64 | } 65 | if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 66 | t.Fatalf("Failed to write data to %v: %v", certPath, err) 67 | } 68 | if err := certOut.Close(); err != nil { 69 | t.Fatalf("Error closing %v: %v", certPath, err) 70 | } 71 | t.Logf("wrote %v\n", certPath) 72 | 73 | keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 74 | if err != nil { 75 | t.Fatalf("Failed to open %v for writing: %v", keyPath, err) 76 | } 77 | privBytes, err := x509.MarshalPKCS8PrivateKey(priv) 78 | if err != nil { 79 | t.Fatalf("Unable to marshal private key: %v", err) 80 | } 81 | if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { 82 | t.Fatalf("Failed to write data to %v: %v", keyPath, err) 83 | } 84 | if err := keyOut.Close(); err != nil { 85 | t.Fatalf("Error closing %v: %v", keyPath, err) 86 | } 87 | t.Logf("wrote %v", keyPath) 88 | 89 | return certPath, keyPath 90 | } 91 | 92 | func listenTestConnection(t *testing.T, l net.Listener, connections int) (int, error) { 93 | resultChan := make(chan int, connections) 94 | 95 | var lines int 96 | for i := 0; i < connections; i++ { 97 | conn, err := l.Accept() 98 | if err != nil { 99 | return 0, err 100 | } 101 | t.Logf("Accepted connection %d of %d", i+1, connections) 102 | go handleTestConnection(t, conn, resultChan) 103 | } 104 | for i := 0; i < connections; i++ { 105 | lines += <-resultChan 106 | } 107 | return lines, nil 108 | } 109 | 110 | func handleTestConnection(t *testing.T, conn net.Conn, resultChan chan int) { 111 | var lines int 112 | t.Logf("New connection from %s", conn.RemoteAddr()) 113 | defer conn.Close() 114 | buf := make([]byte, 4096) 115 | for { 116 | n, err := conn.Read(buf) 117 | if err != nil { 118 | if err != io.EOF { 119 | t.Logf("Error reading: %v", err) 120 | } 121 | break 122 | } 123 | lines += bytes.Count(buf[:n], []byte("\n")) 124 | } 125 | t.Logf("Connection got %d lines", lines) 126 | resultChan <- lines 127 | } 128 | 129 | const logLine = `{"_path":"conn","_system_name":"HQ","_write_ts":"2018-11-28T04:50:48.848281Z","ts":"2018-11-28T04:50:38.834880Z","uid":"CX6jut3BmNwFdYkgrk","id.orig_h":"fc00::165","id.orig_p":44206,"id.resp_h":"fc00::1","id.resp_p":53,"proto":"udp","service":"dns","duration":0.004537,"orig_bytes":55,"resp_bytes":55,"conn_state":"SF","local_orig":false,"local_resp":false,"missed_bytes":0,"history":"Dd","orig_pkts":1,"orig_ip_bytes":103,"resp_pkts":1,"resp_ip_bytes":103,"tunnel_parents":[],"corelight_shunted":false,"orig_l2_addr":"ac:1f:6b:00:81:9a","resp_l2_addr":"b4:75:0e:08:08:c1"}` 130 | 131 | func connectToProxy(cfg Config) (net.Conn, error) { 132 | var conn net.Conn 133 | var err error 134 | target := fmt.Sprintf("127.0.0.1:%d", cfg.Port) 135 | if cfg.CertFile == "" || cfg.KeyFile == "" { 136 | conn, err = net.DialTimeout("tcp", target, 5*time.Second) 137 | } else { 138 | //TODO: fixme. this needs to be tested properly 139 | conf := &tls.Config{ 140 | InsecureSkipVerify: true, 141 | } 142 | conn, err = tls.DialWithDialer(&net.Dialer{Timeout: 10 * time.Second}, "tcp", target, conf) 143 | } 144 | return conn, err 145 | } 146 | 147 | func spew(t *testing.T, cfg Config, lines int) error { 148 | line := []byte(logLine + "\n") 149 | conn, err := connectToProxy(cfg) 150 | if err != nil { 151 | return fmt.Errorf("Spew: %w", err) 152 | } 153 | defer conn.Close() 154 | conn.SetDeadline(time.Now().Add(5 * time.Second)) 155 | t.Logf("Spewing %d lines to port %d", lines, cfg.Port) 156 | for i := 0; i < lines; i++ { 157 | _, err := conn.Write(line) 158 | if err != nil { 159 | return fmt.Errorf("Spew failed: %w", err) 160 | } 161 | } 162 | return nil 163 | } 164 | 165 | func TestDirect(t *testing.T) { 166 | connections := 8 167 | linesPerConnnection := 10000 168 | expected := connections * linesPerConnnection 169 | l, err := net.Listen("tcp", ":0") 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | port := l.Addr().(*net.TCPAddr).Port 174 | t.Logf("Listening for tests on %d", port) 175 | defer l.Close() 176 | cfg := Config{ 177 | Port: port, 178 | } 179 | for i := 0; i < connections; i++ { 180 | go spew(t, cfg, linesPerConnnection) 181 | } 182 | lines, err := listenTestConnection(t, l, connections) 183 | if err != nil { 184 | t.Fatal(err) 185 | } 186 | t.Logf("Got %d lines total", lines) 187 | if lines != expected { 188 | t.Errorf("Expected %d lines, got %d", expected, lines) 189 | } 190 | } 191 | 192 | func testProxy(t *testing.T, cfg Config) { 193 | connections := 8 194 | linesPerConnnection := 10000 195 | expected := connections * linesPerConnnection 196 | 197 | // Common settings for all tests 198 | cfg.Connections = connections 199 | cfg.Addr = "" 200 | cfg.Port = 0 //Select dynamically 201 | 202 | // Setup downstream listener with matching TLS setting 203 | var targetCfg Config 204 | if cfg.TargetTLS { 205 | cert, key := generateTestCerts(t) 206 | targetCfg.CertFile = cert 207 | targetCfg.KeyFile = key 208 | } 209 | targetListener, err := listen(targetCfg) 210 | if err != nil { 211 | t.Fatal(err) 212 | } 213 | port := targetListener.Addr().(*net.TCPAddr).Port 214 | t.Logf("Target listening on %d", port) 215 | defer targetListener.Close() 216 | target := fmt.Sprintf("localhost:%d", port) 217 | cfg.Targets = []string{target} 218 | 219 | // Setup proxy listener 220 | ctx, cancel := context.WithCancel(context.Background()) 221 | proxyListener, err := listen(cfg) 222 | proxyPort := proxyListener.Addr().(*net.TCPAddr).Port 223 | t.Logf("Proxy listening on %d", proxyPort) 224 | cfg.Port = proxyPort 225 | defer proxyListener.Close() 226 | // 227 | 228 | go proxy(ctx, proxyListener, cfg) 229 | 230 | // Spew everything and then close the proxy listener 231 | go func() { 232 | var wg sync.WaitGroup 233 | for i := 0; i < connections; i++ { 234 | wg.Add(1) 235 | go func() { 236 | err := spew(t, cfg, linesPerConnnection) 237 | if err == nil { 238 | t.Logf("Spew done") 239 | } else { 240 | t.Logf("spew: %v", err) 241 | } 242 | wg.Done() 243 | }() 244 | } 245 | wg.Wait() 246 | t.Logf("All spew done") 247 | time.Sleep(1 * time.Second) 248 | cancel() 249 | }() 250 | 251 | lines, err := listenTestConnection(t, targetListener, connections) 252 | if err != nil { 253 | t.Fatal(err) 254 | } 255 | t.Logf("Got %d lines total", lines) 256 | if lines != expected { 257 | t.Errorf("Expected %d lines, got %d", expected, lines) 258 | } 259 | } 260 | 261 | func TestProxyPlainPlain(t *testing.T) { 262 | t.Parallel() 263 | // All default settings 264 | cfg := Config{} 265 | testProxy(t, cfg) 266 | } 267 | 268 | func TestProxyTLSPlain(t *testing.T) { 269 | t.Parallel() 270 | cert, key := generateTestCerts(t) 271 | cfg := Config{ 272 | CertFile: cert, 273 | KeyFile: key, 274 | } 275 | testProxy(t, cfg) 276 | } 277 | 278 | func TestProxyPlainTLS(t *testing.T) { 279 | t.Parallel() 280 | cfg := Config{ 281 | TargetTLS: true, 282 | TargetTLSSkipVerify: true, 283 | } 284 | testProxy(t, cfg) 285 | } 286 | 287 | func TestProxyTLSTLS(t *testing.T) { 288 | t.Parallel() 289 | cert, key := generateTestCerts(t) 290 | cfg := Config{ 291 | CertFile: cert, 292 | KeyFile: key, 293 | TargetTLS: true, 294 | TargetTLSSkipVerify: true, 295 | } 296 | testProxy(t, cfg) 297 | } 298 | -------------------------------------------------------------------------------- /profiling.go: -------------------------------------------------------------------------------- 1 | // +build prof 2 | 3 | package main 4 | 5 | import ( 6 | "log" 7 | "net/http" 8 | _ "net/http/pprof" 9 | ) 10 | 11 | func init() { 12 | go func() { 13 | log.Println("Listening on localhost:6060 for pprof") 14 | log.Println(http.ListenAndServe("localhost:6060", nil)) 15 | }() 16 | } 17 | --------------------------------------------------------------------------------