├── .gitignore ├── linux ├── aes256cbc.h ├── Makefile ├── README.md ├── gateway_test_server.go ├── gateway_test_client.go ├── aes256cbc.c ├── aes256cbc_test.c ├── base64.h ├── base64.c └── gateway.c ├── copy_slow.go ├── listen_win.go ├── copy_fast.go ├── listen_unix.go ├── .travis.yml ├── README.md ├── main.go └── main_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.cov 2 | *.pid -------------------------------------------------------------------------------- /linux/aes256cbc.h: -------------------------------------------------------------------------------- 1 | #ifndef AES256CBC_H_ 2 | #define AES256CBC_H_ 3 | 4 | int aes256cbc_decrypt(unsigned char *passphrase, unsigned char *buf); 5 | 6 | #endif -------------------------------------------------------------------------------- /copy_slow.go: -------------------------------------------------------------------------------- 1 | // +build !go1.5 2 | 3 | package main 4 | 5 | import "io" 6 | 7 | func copy(dst io.WriteCloser, src io.ReadCloser) { 8 | io.Copy(dst, src) 9 | } 10 | -------------------------------------------------------------------------------- /listen_win.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package main 4 | 5 | import "net" 6 | 7 | func listen() (net.Listener, error) { 8 | return net.Listen("tcp", cfgGatewayAddr) 9 | } 10 | -------------------------------------------------------------------------------- /linux/Makefile: -------------------------------------------------------------------------------- 1 | all: gateway 2 | 3 | gateway: gateway.c aes256cbc.h aes256cbc.c base64.h base64.c 4 | gcc -D_GNU_SOURCE -pthread -O3 -o gateway gateway.c aes256cbc.c base64.c -lssl -lcrypto 5 | 6 | clean: 7 | rm gateway -------------------------------------------------------------------------------- /linux/README.md: -------------------------------------------------------------------------------- 1 | 针对Linux系统制作的epoll + splice的gateway,和Go语言版功能基本一致,用来验证zero copy的效率。 2 | 3 | 建议线上系统还是使用Go语言版,会比较稳定可靠,也方便二次开发。 4 | 5 | 注意:每个客户端连接将会产生六个文件句柄,所以请确保文件句柄数量限制够大。 6 | 7 | TODO: 8 | 9 | * 连接后端超时 10 | * 连接存活检查 -------------------------------------------------------------------------------- /copy_fast.go: -------------------------------------------------------------------------------- 1 | // +build go1.5 2 | 3 | package main 4 | 5 | import "io" 6 | 7 | func copy(dst io.WriteCloser, src io.ReadCloser) { 8 | b := copyBufPool.Get().(*[]byte) 9 | buf := *b 10 | io.CopyBuffer(dst, src, buf) 11 | copyBufPool.Put(b) 12 | } 13 | -------------------------------------------------------------------------------- /listen_unix.go: -------------------------------------------------------------------------------- 1 | // +build linux darwin dragonfly freebsd netbsd openbsd 2 | 3 | package main 4 | 5 | import ( 6 | "net" 7 | 8 | "github.com/funny/reuseport" 9 | ) 10 | 11 | func listen() (net.Listener, error) { 12 | if cfgReusePort { 13 | return reuseport.NewReusablePortListener("tcp", cfgGatewayAddr) 14 | } 15 | return net.Listen("tcp", cfgGatewayAddr) 16 | } 17 | -------------------------------------------------------------------------------- /linux/gateway_test_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | ) 8 | 9 | func main() { 10 | lsn, err := net.Listen("tcp", "0.0.0.0:10010") 11 | if err != nil { 12 | log.Fatalf("Listen failed: %s", err) 13 | } 14 | for { 15 | conn, err := lsn.Accept() 16 | if err != nil { 17 | log.Fatalf("Accept failed: %s", err) 18 | } 19 | log.Printf("New client: %s", conn.RemoteAddr()) 20 | 21 | go func() { 22 | defer conn.Close() 23 | io.Copy(conn, conn) 24 | log.Print("Closed") 25 | }() 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.4.3 5 | - 1.5.2 6 | - tip 7 | 8 | before_install: 9 | - uname -a 10 | - go get golang.org/x/tools/cmd/vet 11 | - go get golang.org/x/tools/cmd/cover 12 | - go get github.com/golang/lint/golint 13 | - go get github.com/mattn/goveralls 14 | - go get github.com/funny/utest 15 | - go get github.com/funny/crypto/aes256cbc 16 | - go get github.com/funny/reuseport 17 | 18 | install: 19 | - go get -d -v . && go build -v . 20 | 21 | script: 22 | - go vet -x . 23 | - $HOME/gopath/bin/golint . 24 | - go test -v -covermode=count -coverprofile=profile.cov 25 | 26 | after_script: 27 | - $HOME/gopath/bin/goveralls -coverprofile=profile.cov -service=travis-ci 28 | - go test -benchmem -bench . 29 | -------------------------------------------------------------------------------- /linux/gateway_test_client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "log" 7 | "math/rand" 8 | "net" 9 | "time" 10 | ) 11 | 12 | func main() { 13 | conn, err := net.Dial("tcp", "127.0.0.1:10080") 14 | if err != nil { 15 | log.Fatal(err) 16 | } 17 | 18 | conn.Write([]byte("U2FsdGVkX1+JXKDI/2wFpglXX2zzASqnKhqAiM6GvoI=\n")) 19 | code := make([]byte, 3) 20 | _, err = io.ReadFull(conn, code) 21 | if err != nil { 22 | log.Fatal(err) 23 | } 24 | if !bytes.Equal(code, []byte("200")) { 25 | log.Fatal() 26 | } 27 | 28 | t1 := time.Now() 29 | for i := 0; i < 100000; i++ { 30 | //println(i) 31 | b1 := RandBytes(256) 32 | b2 := make([]byte, len(b1)) 33 | 34 | _, err := conn.Write(b1) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | _, err = io.ReadFull(conn, b2) 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | 44 | if !bytes.Equal(b1, b2) { 45 | log.Fatal() 46 | } 47 | } 48 | log.Println("Finish:", time.Since(t1).String()) 49 | } 50 | 51 | func RandBytes(n int) []byte { 52 | n = rand.Intn(n) + 1 53 | b := make([]byte, n) 54 | for i := 0; i < n; i++ { 55 | b[i] = byte(rand.Intn(255)) 56 | } 57 | return b 58 | } 59 | -------------------------------------------------------------------------------- /linux/aes256cbc.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "base64.h" 4 | 5 | int aes256cbc_decrypt(unsigned char *passphrase, unsigned char *buf) { 6 | unsigned char key[EVP_MAX_KEY_LENGTH]; 7 | unsigned char iv[EVP_MAX_IV_LENGTH]; 8 | 9 | const EVP_CIPHER *cipher = EVP_aes_256_cbc(); 10 | 11 | int len = base64_decode(buf, buf); 12 | if (len < 16 /* AES Block Size */) { 13 | return 0; 14 | } 15 | 16 | if (!strncmp(buf, "Slated__", 8)) { 17 | return 0; 18 | } 19 | 20 | if (!EVP_BytesToKey(cipher, EVP_md5(), 21 | buf + 8 /* skip "Slated__" */, 22 | passphrase, strlen(passphrase), 1, 23 | key, iv)) { 24 | return 0; 25 | } 26 | 27 | EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); 28 | if (!ctx) { 29 | return 0; 30 | } 31 | 32 | if (1 != EVP_DecryptInit_ex(ctx, cipher, NULL, key, iv)) { 33 | EVP_CIPHER_CTX_free(ctx); 34 | return 0; 35 | } 36 | 37 | int newlen = 0; 38 | if (1 != EVP_DecryptUpdate(ctx, buf, &newlen, 39 | buf + 16 /* skip slat header */, len - 16)) { 40 | EVP_CIPHER_CTX_free(ctx); 41 | return 0; 42 | } 43 | len = newlen; 44 | if (1 != EVP_DecryptFinal_ex(ctx, buf + newlen, &newlen)) { 45 | EVP_CIPHER_CTX_free(ctx); 46 | return 0; 47 | } 48 | 49 | EVP_CIPHER_CTX_free(ctx); 50 | len += newlen; 51 | buf[len] = '\0'; 52 | return len; 53 | } -------------------------------------------------------------------------------- /linux/aes256cbc_test.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "base64.h" 6 | #include "aes256cbc.h" 7 | 8 | char *password = "p0S8rX680*48"; 9 | char *encrypted = "U2FsdGVkX1+JXKDI/2wFpglXX2zzASqnKhqAiM6GvoI="; 10 | 11 | int decrypt(char *buf, int len, char *key, char *iv) { 12 | EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); 13 | if (!ctx) { 14 | return 0; 15 | } 16 | 17 | char *plain = malloc(128); 18 | 19 | if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv)) { 20 | EVP_CIPHER_CTX_free(ctx); 21 | return 0; 22 | } 23 | 24 | int newlen = 0; 25 | if (1 != EVP_DecryptUpdate(ctx, plain, &newlen, buf, len)) { 26 | EVP_CIPHER_CTX_free(ctx); 27 | return 0; 28 | } 29 | len = newlen; 30 | if (1 != EVP_DecryptFinal_ex(ctx, plain + newlen, &newlen)) { 31 | ERR_print_errors_fp(stderr); 32 | EVP_CIPHER_CTX_free(ctx); 33 | return 0; 34 | } 35 | 36 | EVP_CIPHER_CTX_free(ctx); 37 | 38 | len += newlen; 39 | *(plain + len) = '\0'; 40 | printf("%s\n", plain); 41 | return len; 42 | } 43 | 44 | int main(int argc, char *argv[]) 45 | { 46 | unsigned char *plain = malloc(base64_decode_len(encrypted)); 47 | int len = base64_decode(plain, encrypted); 48 | 49 | unsigned char salt[9]; 50 | memcpy(salt, plain + 8, 8); 51 | salt[8] = 0; 52 | 53 | const EVP_CIPHER *cipher = EVP_aes_256_cbc(); 54 | unsigned char key[EVP_MAX_KEY_LENGTH], iv[EVP_MAX_IV_LENGTH]; 55 | 56 | if (!EVP_BytesToKey(cipher, EVP_md5(), plain + 8, password, strlen(password), 1, key, iv)) { 57 | fprintf(stderr, "EVP_BytesToKey failed\n"); 58 | return 1; 59 | } 60 | 61 | int i; 62 | printf("Len: %d\n", len); 63 | printf("Header: "); for(i=0; i<8; ++i) { printf("%c", plain[i]); } printf("\n"); 64 | printf("Text: "); for(i=0; ikey_len; ++i) { printf("%02X", key[i]); } printf("\n"); 67 | printf("IV: "); for(i=0; iiv_len; ++i) { printf("%02X", iv[i]); } printf("\n"); 68 | 69 | len = decrypt(plain + 16, len - 16, key, iv); 70 | printf("Text: "); for(i=0; i. 79 | * 80 | */ 81 | 82 | 83 | 84 | #ifndef _BASE64_H_ 85 | #define _BASE64_H_ 86 | 87 | #ifdef __cplusplus 88 | extern "C" { 89 | #endif 90 | 91 | int base64_encode_len(int len); 92 | int base64_encode(char * coded_dst, const char *plain_src,int len_plain_src); 93 | 94 | int base64_decode_len(const char * coded_src); 95 | int base64_decode(char * plain_dst, const char *coded_src); 96 | 97 | #ifdef __cplusplus 98 | } 99 | #endif 100 | 101 | #endif //_BASE64_H_ 102 | 103 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "fmt" 7 | "io/ioutil" 8 | "log" 9 | "net" 10 | "net/http" 11 | _ "net/http/pprof" 12 | "os" 13 | "os/signal" 14 | "runtime/debug" 15 | "strconv" 16 | "sync" 17 | "syscall" 18 | "time" 19 | 20 | "github.com/funny/crypto/aes256cbc" 21 | ) 22 | 23 | const miniBufferSize = 1024 24 | 25 | var ( 26 | configed = false 27 | cfgSecret []byte 28 | cfgGatewayAddr = "0.0.0.0:0" 29 | cfgPprofAddr = "" 30 | cfgReusePort = false 31 | cfgDialRetry = uint(1) 32 | cfgDialTimeout = uint(3) 33 | cfgBufferSize = uint(16 * 1024) 34 | 35 | codeOK = []byte("200") 36 | codeBadReq = []byte("400") 37 | codeBadAddr = []byte("401") 38 | codeDialErr = []byte("502") 39 | codeDialTimeout = []byte("504") 40 | 41 | isTest bool 42 | handshakeBufPool sync.Pool 43 | copyBufPool sync.Pool 44 | ) 45 | 46 | func init() { 47 | var secret string 48 | flag.StringVar(&secret, "secret", "", "The passphrase used to decrypt target server address") 49 | flag.StringVar(&cfgGatewayAddr, "addr", cfgGatewayAddr, "Network address for gateway") 50 | flag.StringVar(&cfgPprofAddr, "pprof", cfgPprofAddr, "Network address for net/http/pprof") 51 | flag.BoolVar(&cfgReusePort, "reuse", cfgReusePort, "Enable reuse port feature") 52 | flag.UintVar(&cfgDialRetry, "retry", cfgDialRetry, "Retry times when dial to target server timeout") 53 | flag.UintVar(&cfgDialTimeout, "timeout", cfgDialTimeout, "Timeout seconds when dial to targer server") 54 | flag.UintVar(&cfgBufferSize, "buffer", cfgBufferSize, "Buffer size for io.CopyBuffer()") 55 | flag.Parse() 56 | 57 | cfgSecret = []byte(secret) 58 | 59 | cfgDialTimeout = uint(time.Second) * cfgDialTimeout 60 | 61 | handshakeBufPool.New = func() interface{} { 62 | buf := make([]byte, 64 /* longest crypted address */ +1 /* \n */) 63 | return &buf 64 | } 65 | 66 | copyBufPool.New = func() interface{} { 67 | buf := make([]byte, cfgBufferSize) 68 | return &buf 69 | } 70 | } 71 | 72 | func main() { 73 | if len(cfgSecret) == 0 { 74 | fatal("Missing passphrase") 75 | return 76 | } 77 | 78 | if cfgPprofAddr != "" { 79 | listener, err := net.Listen("tcp", cfgPprofAddr) 80 | if err != nil { 81 | fatalf("Setup pprof failed: %s", err) 82 | } 83 | cfgPprofAddr = listener.Addr().String() 84 | go http.Serve(listener, nil) 85 | } else { 86 | cfgPprofAddr = "disable" 87 | } 88 | 89 | pid := syscall.Getpid() 90 | if err := ioutil.WriteFile("gateway.pid", []byte(strconv.Itoa(pid)), 0644); err != nil { 91 | fatalf("Can't write pid file: %s", err) 92 | } 93 | defer os.Remove("gateway.pid") 94 | 95 | start() 96 | 97 | printf(`Gateway running 98 | Address: %s 99 | Reuse port: %v 100 | Dial retry: %d 101 | Dial timeout: %s 102 | Buffer size: %d 103 | Passphrase: %s 104 | Profiling: %s 105 | Process ID: %d`, 106 | cfgGatewayAddr, 107 | cfgReusePort, 108 | cfgDialRetry, 109 | time.Duration(cfgDialTimeout), 110 | cfgBufferSize, 111 | cfgSecret, 112 | cfgPprofAddr, 113 | pid) 114 | 115 | exitChan := make(chan os.Signal, 1) 116 | signal.Notify(exitChan, syscall.SIGTERM) 117 | signal.Notify(exitChan, syscall.SIGINT) 118 | <-exitChan 119 | printf("Gateway killed") 120 | } 121 | 122 | func fatal(t string) { 123 | if !isTest { 124 | log.Fatal(t) 125 | } 126 | panic(t) 127 | } 128 | 129 | func fatalf(t string, args ...interface{}) { 130 | if !isTest { 131 | log.Fatalf(t, args...) 132 | } 133 | panic(fmt.Sprintf(t, args...)) 134 | } 135 | 136 | func printf(t string, args ...interface{}) { 137 | if !isTest { 138 | log.Printf(t, args...) 139 | } 140 | } 141 | 142 | func start() { 143 | listener, err := listen() 144 | if err != nil { 145 | fatalf("Setup listener failed: %s", err) 146 | } 147 | cfgGatewayAddr = listener.Addr().String() 148 | go loop(listener) 149 | } 150 | 151 | func loop(listener net.Listener) { 152 | defer listener.Close() 153 | for { 154 | conn, err := accept(listener) 155 | if err != nil { 156 | fatalf("Gateway accept failed: %s", err) 157 | return 158 | } 159 | go handle(conn) 160 | } 161 | } 162 | 163 | func accept(listener net.Listener) (net.Conn, error) { 164 | var tempDelay time.Duration 165 | for { 166 | conn, err := listener.Accept() 167 | if err != nil { 168 | if ne, ok := err.(net.Error); ok && ne.Temporary() { 169 | if tempDelay == 0 { 170 | tempDelay = 5 * time.Millisecond 171 | } else { 172 | tempDelay *= 2 173 | } 174 | if max := 1 * time.Second; tempDelay > max { 175 | tempDelay = max 176 | } 177 | time.Sleep(tempDelay) 178 | continue 179 | } 180 | return nil, err 181 | } 182 | tempDelay = 0 183 | return conn, nil 184 | } 185 | } 186 | 187 | func handle(conn net.Conn) { 188 | defer func() { 189 | conn.Close() 190 | if err := recover(); err != nil { 191 | printf("panic: %v\n\n%s", err, debug.Stack()) 192 | } 193 | }() 194 | 195 | agent := handshake(conn) 196 | if agent == nil { 197 | return 198 | } 199 | defer agent.Close() 200 | 201 | go func() { 202 | defer func() { 203 | agent.Close() 204 | conn.Close() 205 | if err := recover(); err != nil { 206 | printf("panic: %v\n\n%s", err, debug.Stack()) 207 | } 208 | }() 209 | copy(conn, agent) 210 | }() 211 | copy(agent, conn) 212 | } 213 | 214 | func handshake(conn net.Conn) (agent net.Conn) { 215 | var b = handshakeBufPool.Get().(*[]byte) 216 | buf := *b 217 | defer handshakeBufPool.Put(b) 218 | 219 | // read and decrypt target server address 220 | var err error 221 | var addr, remain []byte 222 | for n, nn := 0, 0; n < len(buf); n += nn { 223 | nn, err = conn.Read(buf[n:]) 224 | if err != nil { 225 | conn.Write(codeBadReq) 226 | return 227 | } 228 | if i := bytes.IndexByte(buf[n:n+nn], '\n'); i >= 0 { 229 | if addr, err = aes256cbc.DecryptBase64(cfgSecret, buf[:n+i]); err != nil { 230 | conn.Write(codeBadAddr) 231 | return nil 232 | } 233 | remain = buf[n+i+1 : n+nn] 234 | break 235 | } 236 | } 237 | if addr == nil { 238 | conn.Write(codeBadReq) 239 | return nil 240 | } 241 | 242 | // dial to target server 243 | for i := uint(0); i < cfgDialRetry; i++ { 244 | agent, err = net.DialTimeout("tcp", string(addr), time.Duration(cfgDialTimeout)) 245 | if err == nil { 246 | break 247 | } 248 | if ne, ok := err.(net.Error); ok && ne.Timeout() { 249 | continue 250 | } 251 | conn.Write(codeDialErr) 252 | return nil 253 | } 254 | if err != nil { 255 | conn.Write(codeDialTimeout) 256 | return nil 257 | } 258 | 259 | // send succeed code 260 | if _, err = conn.Write(codeOK); err != nil { 261 | agent.Close() 262 | return nil 263 | } 264 | 265 | // send remainder data in buffer 266 | if len(remain) > 0 { 267 | if _, err = agent.Write(remain); err != nil { 268 | agent.Close() 269 | return nil 270 | } 271 | } 272 | return 273 | } 274 | -------------------------------------------------------------------------------- /linux/base64.c: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2003 Apple Computer, Inc. All rights reserved. 3 | * 4 | * @APPLE_LICENSE_HEADER_START@ 5 | * 6 | * Copyright (c) 1999-2003 Apple Computer, Inc. All Rights Reserved. 7 | * 8 | * This file contains Original Code and/or Modifications of Original Code 9 | * as defined in and that are subject to the Apple Public Source License 10 | * Version 2.0 (the 'License'). You may not use this file except in 11 | * compliance with the License. Please obtain a copy of the License at 12 | * http://www.opensource.apple.com/apsl/ and read it before using this 13 | * file. 14 | * 15 | * The Original Code and all software distributed under the License are 16 | * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER 17 | * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES, 18 | * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY, 19 | * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT. 20 | * Please see the License for the specific language governing rights and 21 | * limitations under the License. 22 | * 23 | * @APPLE_LICENSE_HEADER_END@ 24 | */ 25 | /* ==================================================================== 26 | * Copyright (c) 1995-1999 The Apache Group. All rights reserved. 27 | * 28 | * Redistribution and use in source and binary forms, with or without 29 | * modification, are permitted provided that the following conditions 30 | * are met: 31 | * 32 | * 1. Redistributions of source code must retain the above copyright 33 | * notice, this list of conditions and the following disclaimer. 34 | * 35 | * 2. Redistributions in binary form must reproduce the above copyright 36 | * notice, this list of conditions and the following disclaimer in 37 | * the documentation and/or other materials provided with the 38 | * distribution. 39 | * 40 | * 3. All advertising materials mentioning features or use of this 41 | * software must display the following acknowledgment: 42 | * "This product includes software developed by the Apache Group 43 | * for use in the Apache HTTP server project (http://www.apache.org/)." 44 | * 45 | * 4. The names "Apache Server" and "Apache Group" must not be used to 46 | * endorse or promote products derived from this software without 47 | * prior written permission. For written permission, please contact 48 | * apache@apache.org. 49 | * 50 | * 5. Products derived from this software may not be called "Apache" 51 | * nor may "Apache" appear in their names without prior written 52 | * permission of the Apache Group. 53 | * 54 | * 6. Redistributions of any form whatsoever must retain the following 55 | * acknowledgment: 56 | * "This product includes software developed by the Apache Group 57 | * for use in the Apache HTTP server project (http://www.apache.org/)." 58 | * 59 | * THIS SOFTWARE IS PROVIDED BY THE APACHE GROUP ``AS IS'' AND ANY 60 | * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 61 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 62 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE APACHE GROUP OR 63 | * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 64 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 65 | * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 66 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 67 | * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 68 | * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 69 | * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 70 | * OF THE POSSIBILITY OF SUCH DAMAGE. 71 | * ==================================================================== 72 | * 73 | * This software consists of voluntary contributions made by many 74 | * individuals on behalf of the Apache Group and was originally based 75 | * on public domain software written at the National Center for 76 | * Supercomputing Applications, University of Illinois, Urbana-Champaign. 77 | * For more information on the Apache Group and the Apache HTTP server 78 | * project, please see . 79 | * 80 | */ 81 | 82 | /* Base64 encoder/decoder. Originally Apache file ap_base64.c 83 | */ 84 | 85 | #include 86 | 87 | #include "base64.h" 88 | 89 | /* aaaack but it's fast and const should make it shared text page. */ 90 | static const unsigned char pr2six[256] = 91 | { 92 | /* ASCII table */ 93 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 94 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 95 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, 96 | 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, 97 | 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 98 | 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64, 99 | 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 100 | 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, 101 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 102 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 103 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 104 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 105 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 106 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 107 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 108 | 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 109 | }; 110 | 111 | int base64_decode_len(const char *bufcoded) 112 | { 113 | int nbytesdecoded; 114 | register const unsigned char *bufin; 115 | register int nprbytes; 116 | 117 | bufin = (const unsigned char *) bufcoded; 118 | while (pr2six[*(bufin++)] <= 63); 119 | 120 | nprbytes = (bufin - (const unsigned char *) bufcoded) - 1; 121 | nbytesdecoded = ((nprbytes + 3) / 4) * 3; 122 | 123 | return nbytesdecoded + 1; 124 | } 125 | 126 | int base64_decode(char *bufplain, const char *bufcoded) 127 | { 128 | int nbytesdecoded; 129 | register const unsigned char *bufin; 130 | register unsigned char *bufout; 131 | register int nprbytes; 132 | 133 | bufin = (const unsigned char *) bufcoded; 134 | while (pr2six[*(bufin++)] <= 63); 135 | nprbytes = (bufin - (const unsigned char *) bufcoded) - 1; 136 | nbytesdecoded = ((nprbytes + 3) / 4) * 3; 137 | 138 | bufout = (unsigned char *) bufplain; 139 | bufin = (const unsigned char *) bufcoded; 140 | 141 | while (nprbytes > 4) { 142 | *(bufout++) = 143 | (unsigned char) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4); 144 | *(bufout++) = 145 | (unsigned char) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2); 146 | *(bufout++) = 147 | (unsigned char) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]); 148 | bufin += 4; 149 | nprbytes -= 4; 150 | } 151 | 152 | /* Note: (nprbytes == 1) would be an error, so just ingore that case */ 153 | if (nprbytes > 1) { 154 | *(bufout++) = 155 | (unsigned char) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4); 156 | } 157 | if (nprbytes > 2) { 158 | *(bufout++) = 159 | (unsigned char) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2); 160 | } 161 | if (nprbytes > 3) { 162 | *(bufout++) = 163 | (unsigned char) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]); 164 | } 165 | 166 | *(bufout++) = '\0'; 167 | nbytesdecoded -= (4 - nprbytes) & 3; 168 | return nbytesdecoded; 169 | } 170 | 171 | static const char basis_64[] = 172 | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 173 | 174 | int base64_encode_len(int len) 175 | { 176 | return ((len + 2) / 3 * 4) + 1; 177 | } 178 | 179 | int base64_encode(char *encoded, const char *string, int len) 180 | { 181 | int i; 182 | char *p; 183 | 184 | p = encoded; 185 | for (i = 0; i < len - 2; i += 3) { 186 | *p++ = basis_64[(string[i] >> 2) & 0x3F]; 187 | *p++ = basis_64[((string[i] & 0x3) << 4) | 188 | ((int) (string[i + 1] & 0xF0) >> 4)]; 189 | *p++ = basis_64[((string[i + 1] & 0xF) << 2) | 190 | ((int) (string[i + 2] & 0xC0) >> 6)]; 191 | *p++ = basis_64[string[i + 2] & 0x3F]; 192 | } 193 | if (i < len) { 194 | *p++ = basis_64[(string[i] >> 2) & 0x3F]; 195 | if (i == (len - 1)) { 196 | *p++ = basis_64[((string[i] & 0x3) << 4)]; 197 | *p++ = '='; 198 | } 199 | else { 200 | *p++ = basis_64[((string[i] & 0x3) << 4) | 201 | ((int) (string[i + 1] & 0xF0) >> 4)]; 202 | *p++ = basis_64[((string[i + 1] & 0xF) << 2)]; 203 | } 204 | *p++ = '='; 205 | } 206 | 207 | *p++ = '\0'; 208 | return p - encoded; 209 | } -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "math/rand" 6 | "net" 7 | "strings" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/funny/crypto/aes256cbc" 13 | "github.com/funny/utest" 14 | ) 15 | 16 | func init() { 17 | isTest = true 18 | cfgSecret = []byte("test") 19 | go main() 20 | time.Sleep(time.Second * 2) 21 | } 22 | 23 | func RandBytes(n int) []byte { 24 | n = rand.Intn(n) + 1 25 | b := make([]byte, n) 26 | for i := 0; i < n; i++ { 27 | b[i] = byte(rand.Intn(255)) 28 | } 29 | return b 30 | } 31 | 32 | func Test_Fatals(t *testing.T) { 33 | // missing passphrase 34 | oldSecret := cfgSecret 35 | defer func() { 36 | cfgSecret = oldSecret 37 | }() 38 | cfgSecret = nil 39 | func() { 40 | defer func() { 41 | err := recover() 42 | utest.NotNilNow(t, err) 43 | utest.Assert(t, strings.Contains(err.(string), "Missing passphrase")) 44 | }() 45 | main() 46 | }() 47 | cfgSecret = oldSecret 48 | 49 | // bad pprof address 50 | cfgPprofAddr = "xxoo" 51 | func() { 52 | defer func() { 53 | err := recover() 54 | utest.NotNilNow(t, err) 55 | utest.Assert(t, strings.Contains(err.(string), "Setup pprof failed")) 56 | }() 57 | main() 58 | }() 59 | cfgPprofAddr = "0.0.0.0:0" 60 | 61 | // bad gateway address 62 | oldAddr := cfgGatewayAddr 63 | defer func() { 64 | cfgGatewayAddr = oldAddr 65 | }() 66 | cfgGatewayAddr = "abc" 67 | func() { 68 | defer func() { 69 | err := recover() 70 | utest.NotNilNow(t, err) 71 | utest.Assert(t, strings.Contains(err.(string), "Setup listener failed")) 72 | }() 73 | main() 74 | }() 75 | 76 | // bad gateway address with reuse port 77 | cfgReusePort = true 78 | func() { 79 | defer func() { 80 | err := recover() 81 | utest.NotNilNow(t, err) 82 | utest.Assert(t, strings.Contains(err.(string), "Setup listener failed")) 83 | }() 84 | start() 85 | }() 86 | } 87 | 88 | func Test_BadReq1(t *testing.T) { 89 | conn, err := net.Dial("tcp", cfgGatewayAddr) 90 | utest.IsNilNow(t, err) 91 | defer conn.Close() 92 | 93 | err = conn.(*net.TCPConn).CloseWrite() 94 | utest.IsNilNow(t, err) 95 | 96 | code := make([]byte, 3) 97 | _, err = io.ReadFull(conn, code) 98 | utest.IsNilNow(t, err) 99 | utest.EqualNow(t, string(code), string(codeBadReq)) 100 | } 101 | 102 | func Test_BadReq2(t *testing.T) { 103 | conn, err := net.Dial("tcp", cfgGatewayAddr) 104 | utest.IsNilNow(t, err) 105 | defer conn.Close() 106 | 107 | _, err = conn.Write([]byte("abc")) 108 | utest.IsNilNow(t, err) 109 | 110 | err = conn.(*net.TCPConn).CloseWrite() 111 | utest.IsNilNow(t, err) 112 | 113 | code := make([]byte, 3) 114 | _, err = io.ReadFull(conn, code) 115 | utest.IsNilNow(t, err) 116 | utest.EqualNow(t, string(code), string(codeBadReq)) 117 | } 118 | 119 | func Test_BadReq3(t *testing.T) { 120 | conn, err := net.Dial("tcp", cfgGatewayAddr) 121 | utest.IsNilNow(t, err) 122 | defer conn.Close() 123 | 124 | _, err = conn.Write(make([]byte, 128)) 125 | utest.IsNilNow(t, err) 126 | code := make([]byte, 3) 127 | _, err = io.ReadFull(conn, code) 128 | utest.IsNilNow(t, err) 129 | utest.EqualNow(t, string(code), string(codeBadReq)) 130 | } 131 | 132 | func Test_BadAddr(t *testing.T) { 133 | conn, err := net.Dial("tcp", cfgGatewayAddr) 134 | utest.IsNilNow(t, err) 135 | defer conn.Close() 136 | 137 | _, err = conn.Write([]byte("abc\n")) 138 | utest.IsNilNow(t, err) 139 | code := make([]byte, 3) 140 | _, err = io.ReadFull(conn, code) 141 | utest.IsNilNow(t, err) 142 | utest.EqualNow(t, string(code), string(codeBadAddr)) 143 | } 144 | 145 | func Test_CodeDialErr(t *testing.T) { 146 | conn, err := net.Dial("tcp", cfgGatewayAddr) 147 | utest.IsNilNow(t, err) 148 | defer conn.Close() 149 | 150 | encryptedAddr, err := aes256cbc.EncryptString("test", "0.0.0.0:0") 151 | utest.IsNilNow(t, err) 152 | 153 | _, err = conn.Write([]byte(encryptedAddr)) 154 | utest.IsNilNow(t, err) 155 | _, err = conn.Write([]byte("\n")) 156 | utest.IsNilNow(t, err) 157 | 158 | code := make([]byte, 3) 159 | _, err = io.ReadFull(conn, code) 160 | utest.IsNilNow(t, err) 161 | utest.EqualNow(t, string(code), string(codeDialErr)) 162 | } 163 | 164 | func Test_DialTimeout(t *testing.T) { 165 | oldTimeout := cfgDialTimeout 166 | cfgDialTimeout = 10 167 | defer func() { 168 | cfgDialTimeout = oldTimeout 169 | }() 170 | 171 | listener, err := net.Listen("tcp", "0.0.0.0:0") 172 | utest.IsNilNow(t, err) 173 | defer listener.Close() 174 | 175 | conn, err := net.Dial("tcp", cfgGatewayAddr) 176 | utest.IsNilNow(t, err) 177 | defer conn.Close() 178 | 179 | encryptedAddr, err := aes256cbc.EncryptString("test", listener.Addr().String()) 180 | utest.IsNilNow(t, err) 181 | 182 | _, err = conn.Write([]byte(encryptedAddr)) 183 | utest.IsNilNow(t, err) 184 | _, err = conn.Write([]byte("\n")) 185 | utest.IsNilNow(t, err) 186 | 187 | code := make([]byte, 3) 188 | _, err = io.ReadFull(conn, code) 189 | utest.IsNilNow(t, err) 190 | utest.EqualNow(t, string(code), string(codeDialTimeout)) 191 | } 192 | 193 | func Test_OK(t *testing.T) { 194 | listener, err := net.Listen("tcp", "0.0.0.0:0") 195 | utest.IsNilNow(t, err) 196 | defer listener.Close() 197 | 198 | conn, err := net.Dial("tcp", cfgGatewayAddr) 199 | utest.IsNilNow(t, err) 200 | defer conn.Close() 201 | 202 | encryptedAddr, err := aes256cbc.EncryptString(string(cfgSecret), listener.Addr().String()) 203 | utest.IsNilNow(t, err) 204 | 205 | _, err = conn.Write([]byte(encryptedAddr)) 206 | utest.IsNilNow(t, err) 207 | _, err = conn.Write([]byte("\n")) 208 | utest.IsNilNow(t, err) 209 | 210 | code := make([]byte, 3) 211 | _, err = io.ReadFull(conn, code) 212 | utest.IsNilNow(t, err) 213 | utest.EqualNow(t, string(code), string(codeOK)) 214 | } 215 | 216 | type TestError struct { 217 | timeout bool 218 | temporary bool 219 | } 220 | 221 | func (e TestError) Error() string { 222 | return "This is test error" 223 | } 224 | 225 | func (e TestError) Timeout() bool { 226 | return e.timeout 227 | } 228 | 229 | func (e TestError) Temporary() bool { 230 | return e.temporary 231 | } 232 | 233 | type TestListener struct { 234 | n int 235 | err TestError 236 | } 237 | 238 | func (l *TestListener) Accept() (net.Conn, error) { 239 | if l.n == -1 { 240 | return nil, l.err 241 | } 242 | if l.n == 0 { 243 | return &net.TCPConn{}, nil 244 | } 245 | l.n-- 246 | return nil, l.err 247 | } 248 | 249 | func (l *TestListener) Close() error { 250 | return nil 251 | } 252 | 253 | func (l *TestListener) Addr() net.Addr { 254 | return nil 255 | } 256 | 257 | func Test_Accept(t *testing.T) { 258 | _, err := accept(&TestListener{ 259 | 9, TestError{false, true}, 260 | }) 261 | utest.IsNilNow(t, err) 262 | 263 | _, err = accept(&TestListener{ 264 | -1, TestError{true, false}, 265 | }) 266 | utest.NotNilNow(t, err) 267 | 268 | func() { 269 | defer func() { 270 | err := recover() 271 | utest.NotNilNow(t, err) 272 | utest.Assert(t, strings.Contains(err.(string), "Gateway accept failed")) 273 | }() 274 | loop(&TestListener{ 275 | -1, TestError{true, false}, 276 | }) 277 | }() 278 | } 279 | 280 | type TestReadWriteCloser struct { 281 | closed bool 282 | } 283 | 284 | func (t *TestReadWriteCloser) Write(_ []byte) (int, error) { 285 | panic("just panic") 286 | } 287 | 288 | func (t *TestReadWriteCloser) Read(_ []byte) (int, error) { 289 | panic("just panic") 290 | } 291 | 292 | func (t *TestReadWriteCloser) Close() error { 293 | t.closed = true 294 | return nil 295 | } 296 | 297 | func Test_Transfer(t *testing.T) { 298 | listener, err := net.Listen("tcp", "0.0.0.0:0") 299 | utest.IsNilNow(t, err) 300 | defer listener.Close() 301 | go func() { 302 | for { 303 | conn, err := listener.Accept() 304 | if err != nil { 305 | continue 306 | } 307 | go func() { 308 | defer conn.Close() 309 | io.Copy(conn, conn) 310 | }() 311 | } 312 | }() 313 | 314 | for i := 0; i < 20; i++ { 315 | conn, err := net.Dial("tcp", cfgGatewayAddr) 316 | utest.IsNilNow(t, err) 317 | defer conn.Close() 318 | 319 | encryptedAddr, err := aes256cbc.EncryptString(string(cfgSecret), listener.Addr().String()) 320 | utest.IsNilNow(t, err) 321 | 322 | _, err = conn.Write([]byte(encryptedAddr)) 323 | utest.IsNilNow(t, err) 324 | _, err = conn.Write([]byte("\nabc")) 325 | utest.IsNilNow(t, err) 326 | 327 | code := make([]byte, 6) 328 | _, err = io.ReadFull(conn, code) 329 | utest.IsNilNow(t, err) 330 | utest.EqualNow(t, string(code[:3]), string(codeOK)) 331 | utest.EqualNow(t, string(code[3:]), "abc") 332 | 333 | for j := 0; j < 10000; j++ { 334 | b1 := RandBytes(256) 335 | _, err = conn.Write(b1) 336 | utest.IsNilNow(t, err) 337 | 338 | b2 := make([]byte, len(b1)) 339 | _, err = io.ReadFull(conn, b2) 340 | utest.IsNilNow(t, err) 341 | 342 | utest.EqualNow(t, b1, b2) 343 | } 344 | } 345 | } 346 | 347 | var testBufPool1 = sync.Pool{ 348 | New: func() interface{} { 349 | return make([]byte, 64) 350 | }, 351 | } 352 | 353 | var testBufPool2 = sync.Pool{ 354 | New: func() interface{} { 355 | buf := make([]byte, 64) 356 | return &buf 357 | }, 358 | } 359 | 360 | func Benchmark_BufPool1(b *testing.B) { 361 | var buf []byte 362 | for i := 0; i < b.N; i++ { 363 | buf = testBufPool1.Get().([]byte) 364 | testBufPool1.Put(buf) 365 | } 366 | _ = buf 367 | } 368 | 369 | func Benchmark_BufPool2(b *testing.B) { 370 | var buf []byte 371 | for i := 0; i < b.N; i++ { 372 | b := testBufPool2.Get().(*[]byte) 373 | buf = *b 374 | testBufPool2.Put(b) 375 | } 376 | _ = buf 377 | } 378 | -------------------------------------------------------------------------------- /linux/gateway.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "aes256cbc.h" 18 | 19 | #define PIPE_SIZE 32768 // splice pipe buffer size 20 | #define HS_BUFF_SIZE 64 // handshake buffer size 21 | #define MAX_EVENTS 20 // max epoll event for each epoll frame 22 | 23 | const char *HS_DONE = "200"; // handshake finished 24 | const char *HS_WAIT_DONE = "200"; // handshake request processed but waiting connect to backend 25 | const char *HS_BAD_REQ = "400"; // incomplete handshake request 26 | const char *HS_BAD_ADDR = "401"; // decrypt backend address failed or parse failed 27 | const char *HS_DIAL_ERR = "500"; // can't connect to backend 28 | const char *HS_DIAL_TIMEOUT = "504"; // connect to backend timeout 29 | 30 | // catch SIGTERM 31 | int gw_stop_flag = 0; 32 | void gw_stop(int signal_id) { 33 | gw_stop_flag = 1; 34 | } 35 | 36 | // handshake state 37 | struct gw_hs_state { 38 | char buf[HS_BUFF_SIZE + 1]; 39 | const char *code; 40 | int readed; 41 | int writed; 42 | }; 43 | 44 | struct gw_conn { 45 | int fd; 46 | int pipe[2]; 47 | int events; 48 | int buffered; 49 | int deleted; 50 | struct gw_conn *other; 51 | struct gw_conn **del_poll; 52 | struct gw_hs_state *hs_state; 53 | struct gw_conn *prev; 54 | struct gw_conn *next; 55 | }; 56 | 57 | // all connections 58 | struct gw_conn *gw_conn_list; 59 | 60 | static struct gw_conn * 61 | gw_add_conn(int pd, int fd, struct gw_conn *del_poll[]) { 62 | int flags = fcntl(fd, F_GETFL, 0); 63 | if (flags < 0) { 64 | fprintf(stderr, "Can't get socket flag - %s\n", strerror(errno)); 65 | return NULL; 66 | } 67 | if (fcntl(fd, F_SETFL, flags|O_NONBLOCK) != 0) { 68 | fprintf(stderr, "Can't set O_NONBLOCK flag - %s\n", strerror(errno)); 69 | return NULL; 70 | } 71 | int opt = PIPE_SIZE / 2; 72 | if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &opt, sizeof(opt)) != 0) { 73 | fprintf(stderr, "Can't set socket receive buffer - %s\n", strerror(errno)); 74 | return NULL; 75 | } 76 | opt = PIPE_SIZE; 77 | if (setsockopt (fd, SOL_SOCKET, SO_SNDBUF, &opt, sizeof(opt)) != 0) { 78 | fprintf(stderr, "Can't set socket send buffer - %s\n", strerror(errno)); 79 | return NULL; 80 | } 81 | opt = 1; 82 | if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) != 0) { 83 | fprintf(stderr, "Can't set TCP_NODELAY flag - %s\n", strerror(errno)); 84 | return NULL; 85 | } 86 | 87 | struct gw_conn *conn = (struct gw_conn *)calloc(1, sizeof(struct gw_conn)); 88 | if (conn == NULL) { 89 | fprintf(stderr, "Can't alloc memory for gw_conn - %s\n", strerror(errno)); 90 | return NULL; 91 | } 92 | if (pipe2(conn->pipe, O_NONBLOCK) != 0) { 93 | fprintf(stderr, "Can't create pipe - %s\n", strerror(errno)); 94 | free(conn); 95 | return NULL; 96 | } 97 | if (fcntl(conn->pipe[1], F_SETPIPE_SZ, PIPE_SIZE) != PIPE_SIZE) { 98 | fprintf(stderr, "Can't set pipe buffer size - %s\n", strerror(errno)); 99 | goto FAIL; 100 | } 101 | conn->fd = fd; 102 | conn->del_poll = del_poll; 103 | 104 | struct epoll_event event; 105 | event.data.ptr = conn; 106 | event.events = EPOLLIN | EPOLLOUT | EPOLLRDHUP | EPOLLET; 107 | if (epoll_ctl(pd, EPOLL_CTL_ADD, fd, &event) != 0) { 108 | fprintf(stderr, "Can't add socket into epoll - %s\n", strerror(errno)); 109 | goto FAIL; 110 | } 111 | return conn; 112 | 113 | FAIL: 114 | close(conn->pipe[0]); 115 | close(conn->pipe[1]); 116 | free(conn); 117 | return NULL; 118 | } 119 | 120 | static void 121 | gw_del_conn(struct gw_conn *conn) { 122 | if (conn->deleted == 1) 123 | return; 124 | 125 | struct gw_conn **del_poll = conn->del_poll; 126 | 127 | conn->deleted = 1; 128 | for (int i = 0; i < MAX_EVENTS; i ++) { 129 | if (del_poll[i] == NULL) { 130 | del_poll[i] = conn; 131 | break; 132 | } 133 | } 134 | 135 | // never happens? 136 | if (conn->other->deleted == 1) 137 | return; 138 | 139 | conn->other->deleted = 1; 140 | for (int i = 0; i < MAX_EVENTS; i ++) { 141 | if (del_poll[i] == NULL) { 142 | del_poll[i] = conn->other; 143 | break; 144 | } 145 | } 146 | } 147 | 148 | static void 149 | gw_free_conn(int pd, struct gw_conn *conn) { 150 | if (conn->hs_state != NULL) { 151 | free(conn->hs_state); 152 | } 153 | 154 | struct epoll_event event; 155 | epoll_ctl(pd, EPOLL_CTL_DEL, conn->fd, &event); 156 | 157 | close(conn->fd); 158 | close(conn->pipe[0]); 159 | close(conn->pipe[1]); 160 | free(conn); 161 | } 162 | 163 | static void 164 | gw_clean_del_poll(int pd, struct gw_conn *del_poll[]) { 165 | for (int i = 0; i < MAX_EVENTS; i ++) { 166 | if (del_poll[i] == NULL) 167 | break; 168 | gw_free_conn(pd, del_poll[i]); 169 | del_poll[i] = NULL; 170 | } 171 | } 172 | 173 | static int 174 | gw_listen(char *addr) { 175 | int lsn = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK , 0); 176 | if (lsn < 0) { 177 | fprintf(stderr, "Can't create listener - %s\n", strerror(errno)); 178 | goto FAIL; 179 | } 180 | 181 | int opt = 1; 182 | if (setsockopt(lsn, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) != 0) { 183 | fprintf(stderr, "Can't set SO_REUSEADDR on listener - %s\n", strerror(errno)); 184 | goto FAIL; 185 | } 186 | 187 | // parse address 188 | char *clone = strchr(addr, ':'); 189 | if (clone == NULL) { 190 | fprintf(stderr, "Can't parse address - %s\n", addr); 191 | goto FAIL; 192 | } 193 | addr[clone - addr] = '\0'; 194 | struct sockaddr_in lsn_addr; 195 | lsn_addr.sin_family = AF_INET; 196 | lsn_addr.sin_port = htons(atoi(clone + 1)); 197 | lsn_addr.sin_addr.s_addr = inet_addr(addr); 198 | socklen_t addr_len = sizeof(struct sockaddr_in); 199 | 200 | if (bind(lsn, (struct sockaddr *)&lsn_addr, addr_len) != 0) { 201 | fprintf(stderr, "Can't bind address %s:%s - %s\n", addr, (clone + 1), strerror(errno)); 202 | goto FAIL; 203 | } 204 | 205 | if (listen(lsn, 128) != 0) { 206 | fprintf(stderr, "Can't listen - %s\n", strerror(errno)); 207 | goto FAIL; 208 | } 209 | 210 | if (getsockname(lsn, (struct sockaddr *)&lsn_addr, &addr_len) != 0) { 211 | fprintf(stderr, "Can't get listener address - %s\n", strerror(errno)); 212 | goto FAIL; 213 | } 214 | 215 | fprintf(stderr, "Setup proxy at %s:%d\n", inet_ntoa(lsn_addr.sin_addr), ntohs(lsn_addr.sin_port)); 216 | return lsn; 217 | 218 | FAIL: 219 | close(lsn); 220 | return -1; 221 | } 222 | 223 | static int 224 | gw_accept(int pd, int lsn, struct gw_conn *del_poll[]) { 225 | struct sockaddr_in addr; 226 | socklen_t addr_len = sizeof(struct sockaddr_in); 227 | for (;;) { 228 | int fd = accept(lsn, &addr, &addr_len); 229 | if (fd < 0) { 230 | if (errno == EAGAIN || errno == EWOULDBLOCK) 231 | break; 232 | return -1; 233 | } 234 | struct gw_conn *conn = gw_add_conn(pd, fd, del_poll); 235 | if (conn == NULL) { 236 | close(fd); 237 | continue; 238 | } 239 | // setup handshake state 240 | conn->hs_state = (struct gw_hs_state *)calloc( 241 | 1, sizeof(struct gw_hs_state) 242 | ); 243 | if (conn->hs_state == NULL) { 244 | fprintf(stderr, "Can't malloc memory for gw_hs_state - %s", strerror(errno)); 245 | gw_del_conn(conn); 246 | continue; 247 | } 248 | } 249 | return 0; 250 | } 251 | 252 | static void 253 | gw_handshake_in(int pd, struct gw_conn *conn, char *secret) { 254 | if (!(conn->events & EPOLLIN)) 255 | return; 256 | 257 | if (conn->hs_state->code != NULL) 258 | return; 259 | 260 | struct gw_hs_state *state = conn->hs_state; 261 | 262 | // read AES256-CBC encrypted address from client side. 263 | int begin = state->readed; 264 | while (state->readed < HS_BUFF_SIZE) { 265 | int n = read(conn->fd, state->buf, HS_BUFF_SIZE - state->readed); 266 | if (n == 0) { 267 | state->code = HS_BAD_REQ; 268 | return; 269 | } 270 | if (n < 0) { 271 | if (errno == EAGAIN || errno == EWOULDBLOCK) { 272 | conn->events &= ~EPOLLIN; 273 | break; 274 | } 275 | state->code = HS_BAD_REQ; 276 | return; 277 | } 278 | state->readed += n; 279 | } 280 | 281 | // not change 282 | if (begin == state->readed) 283 | return; 284 | 285 | // make buffer like a c string and search '\n' 286 | state->buf[state->readed] = '\0'; 287 | if (strchr(state->buf + begin, '\n') == NULL) { 288 | if (state->readed == HS_BUFF_SIZE - 1) { 289 | state->code = HS_BAD_REQ; 290 | } 291 | return; 292 | } 293 | 294 | // decrypt the backend address 295 | if (aes256cbc_decrypt(secret, state->buf) == 0) { 296 | state->code = HS_BAD_ADDR; 297 | return; 298 | } 299 | 300 | // create a nonblock TCP socket 301 | int fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); 302 | if (fd < 0) { 303 | state->code = HS_DIAL_ERR; 304 | return; 305 | } 306 | 307 | // parse address 308 | char *clone = strchr(state->buf, ':'); 309 | if (clone == NULL) { 310 | state->code = HS_BAD_ADDR; 311 | return; 312 | } 313 | state->buf[clone - state->buf] = '\0'; 314 | struct sockaddr_in addr; 315 | addr.sin_family = AF_INET; 316 | addr.sin_port = htons(atoi(clone + 1)); 317 | addr.sin_addr.s_addr = inet_addr(state->buf); 318 | socklen_t addr_len = sizeof(struct sockaddr_in); 319 | 320 | // connect to backend 321 | int err = connect(fd, &addr, addr_len); 322 | if (err < 0 && errno != EINPROGRESS) { 323 | fprintf(stderr, "Can't connect to backend %s:%s - %s\n", state->buf, clone + 1, strerror(errno)); 324 | close(fd); 325 | state->code = HS_DIAL_ERR; 326 | return; 327 | } 328 | 329 | // setup backend connection 330 | conn->other = gw_add_conn(pd, fd, conn->del_poll); 331 | if (conn->other == NULL) { 332 | close(fd); 333 | state->code = HS_DIAL_ERR; 334 | return; 335 | } 336 | conn->other->other = conn; 337 | 338 | // succeed code 339 | state->code = err == 0 ? HS_DONE : HS_WAIT_DONE; 340 | } 341 | 342 | static void 343 | gw_handshake_out(struct gw_conn *conn) { 344 | struct gw_hs_state *state = conn->hs_state; 345 | 346 | if (!(conn->events & EPOLLOUT)) 347 | return; 348 | 349 | if (state->code == NULL) 350 | return; 351 | 352 | while (state->writed < 3) { 353 | int n = write(conn->fd, state->code + state->writed, 3 - state->writed); 354 | if (n == 0) 355 | return; 356 | if (n < 0) { 357 | if (errno == EAGAIN || errno == EWOULDBLOCK) { 358 | conn->events &= ~EPOLLOUT; 359 | return; 360 | } 361 | gw_del_conn(conn); 362 | return; 363 | } 364 | state->writed += n; 365 | } 366 | 367 | if (state->code == HS_DONE) { 368 | free(state); 369 | conn->hs_state = NULL; 370 | } else { 371 | gw_del_conn(conn); 372 | } 373 | } 374 | 375 | static int 376 | gw_splice_in(struct gw_conn *conn) { 377 | while (conn->buffered < PIPE_SIZE) { 378 | int n = splice( 379 | conn->fd, NULL, 380 | conn->pipe[1], NULL, 381 | PIPE_SIZE - conn->buffered, 382 | SPLICE_F_MOVE | SPLICE_F_NONBLOCK 383 | ); 384 | if (n == 0) { 385 | return -1; 386 | } 387 | if (n < 0) { 388 | if (errno == EAGAIN || errno == EWOULDBLOCK) { 389 | conn->events &= ~EPOLLIN; 390 | break; 391 | } 392 | return -1; 393 | } 394 | conn->buffered += n; 395 | } 396 | return 0; 397 | } 398 | 399 | static int 400 | gw_splice_out(struct gw_conn *conn) { 401 | while (conn->other->buffered > 0) { 402 | int n = splice( 403 | conn->other->pipe[0], NULL, 404 | conn->fd, NULL, 405 | conn->other->buffered, 406 | SPLICE_F_MOVE | SPLICE_F_NONBLOCK 407 | ); 408 | if (n == 0) { 409 | break; 410 | } 411 | if (n < 0) { 412 | if (errno == EAGAIN || errno == EWOULDBLOCK) { 413 | conn->events &= ~EPOLLOUT; 414 | break; 415 | } 416 | return -1; 417 | } 418 | conn->other->buffered -= n; 419 | } 420 | return 0; 421 | } 422 | 423 | static void 424 | gw_loop(int pd, int lsn, char *secret) { 425 | // Add the connections that want to close into del_poll and close them at the end of event frame. 426 | // Because the connections may be referenced in current event frame. 427 | struct gw_conn *del_poll[MAX_EVENTS]; 428 | bzero(del_poll, sizeof(struct gw_conn *) * MAX_EVENTS); 429 | 430 | struct epoll_event readys[MAX_EVENTS]; 431 | for (;;) { 432 | int rc = epoll_wait(pd, readys, MAX_EVENTS, -1); 433 | if (rc < 0) { 434 | if (errno == EINTR && !gw_stop_flag) { 435 | continue; 436 | } 437 | break; 438 | } 439 | for (register int i = 0; i < rc; i ++) { 440 | // is listener? 441 | if (readys[i].data.ptr == NULL) { 442 | if (readys[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { 443 | // TODO: close connections. 444 | fprintf(stderr, "listener closed\n"); 445 | return; 446 | } 447 | 448 | if (gw_accept(pd, lsn, del_poll) != 0) { 449 | // TODO: close connections. 450 | fprintf(stderr, "listener closed\n"); 451 | return; 452 | } 453 | continue; 454 | } 455 | 456 | // save events for EPOLLET 457 | struct gw_conn *conn = (struct gw_conn *)readys[i].data.ptr; 458 | conn->events |= readys[i].events; 459 | 460 | // deleted by previous event? 461 | if (conn->deleted == 1) 462 | continue; 463 | 464 | // error happens? 465 | if (conn->events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { 466 | // client waiting handshake code? 467 | if (conn->other != NULL 468 | && conn->other->hs_state != NULL 469 | && conn->other->hs_state->code == HS_WAIT_DONE) { 470 | conn->other->hs_state->code = HS_DIAL_ERR; 471 | gw_handshake_out(conn->other); 472 | } else { 473 | gw_del_conn(conn); 474 | } 475 | continue; 476 | } 477 | 478 | // doing handshake? 479 | // TODO: read/write timeout 480 | if (conn->hs_state != NULL) { 481 | gw_handshake_in(pd, conn, secret); 482 | gw_handshake_out(conn); 483 | if (conn->hs_state != NULL) 484 | continue; 485 | } 486 | 487 | struct gw_conn *other = conn->other; 488 | 489 | // client waiting handshake code? 490 | if (other->hs_state != NULL && other->hs_state->code == HS_WAIT_DONE) { 491 | other->hs_state->code = HS_DONE; 492 | gw_handshake_out(other); 493 | } 494 | 495 | // splice in 496 | if (conn->events & EPOLLIN) { 497 | if (gw_splice_in(conn) != 0) { 498 | gw_del_conn(conn); 499 | continue; 500 | } 501 | // because EPOLLET 502 | if (other->events & EPOLLOUT) { 503 | if (gw_splice_out(other) != 0) { 504 | gw_del_conn(other); 505 | continue; 506 | } 507 | } 508 | } 509 | 510 | // splice out 511 | if (conn->events & EPOLLOUT) { 512 | if (gw_splice_out(conn) != 0) { 513 | gw_del_conn(conn); 514 | continue; 515 | } 516 | // because EPOLLET 517 | if (other->events & EPOLLIN) { 518 | if (gw_splice_in(other) != 0) { 519 | gw_del_conn(other); 520 | continue; 521 | } 522 | } 523 | } 524 | } 525 | // close bad connections 526 | gw_clean_del_poll(pd, del_poll); 527 | } 528 | } 529 | 530 | int 531 | main(int argc, char *argv[]) { 532 | // the passphrase for AES256-CBC decrypt 533 | char *secret = getenv("GW_SECRET"); 534 | if (secret == NULL) { 535 | fprintf(stderr, "Missing GW_SECRET environment variable\n"); 536 | return 1; 537 | } 538 | 539 | int ret = 1; 540 | 541 | // create a nonblock listener 542 | char *addr = getenv("GW_ADDR"); 543 | if (addr == NULL) { 544 | addr = strdup("0.0.0.0:0"); 545 | } 546 | int lsn = gw_listen(addr); 547 | if (lsn < 0) { 548 | goto END; 549 | } 550 | 551 | int pd = epoll_create(10); 552 | if (pd < 0) { 553 | fprintf(stderr, "Can't create epoll - %s\n", strerror(errno)); 554 | goto END; 555 | } 556 | 557 | // listener event 558 | struct epoll_event event; 559 | event.data.ptr = NULL; 560 | event.events = EPOLLIN | EPOLLET; 561 | if (epoll_ctl(pd, EPOLL_CTL_ADD, lsn, &event) != 0) { 562 | fprintf(stderr, "Can't add listener into epoll - %s\n", strerror(errno)); 563 | goto END; 564 | } 565 | 566 | FILE *pid_file = fopen("gateway.pid", "w"); 567 | if (pid_file == NULL || fprintf(pid_file, "%d", getpid()) < 0) { 568 | fprintf(stderr, "Can't record process ID - %s\n", strerror(errno)); 569 | goto END; 570 | } 571 | fclose(pid_file); 572 | 573 | // catch SIGTERM 574 | struct sigaction sa; 575 | memset(&sa, 0, sizeof(struct sigaction *)); 576 | sa.sa_handler = gw_stop; 577 | sa.sa_flags = 0; 578 | sigemptyset (&(sa.sa_mask)); 579 | if (sigaction(SIGTERM, &sa, NULL) != 0) { 580 | fprintf(stderr, "Can't catch SIGTERM signal - %s\n", strerror(errno)); 581 | goto END; 582 | } 583 | 584 | // event loop 585 | ret = 0; 586 | fprintf(stderr, "Getway running, pid = %d\n", getpid()); 587 | gw_loop(pd, lsn, secret); 588 | fprintf(stderr, "Getway killed\n"); 589 | 590 | END: 591 | // dispose things 592 | if (addr) free(addr); 593 | if (lsn > 0) close(lsn); 594 | if (pd > 0) close(pd); 595 | if (pid_file) remove("gateway.pid"); 596 | return ret; 597 | } --------------------------------------------------------------------------------