├── .gitignore ├── go.mod ├── run.go ├── LICENSE ├── controller ├── normal.go ├── server.go ├── roundrobin.go ├── regex.go ├── direct.go ├── prewarm.go └── boost.go ├── config ├── setting.json └── setting.go ├── utils └── log.go ├── go.sum ├── README.md └── test └── bench.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | # goland 18 | .idea 19 | 20 | # mac os file 21 | .DS_store 22 | 23 | # go build 24 | go_build_run_go 25 | moto 26 | # moto log 27 | moto.log -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module moto 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/natefinch/lumberjack v2.0.0+incompatible 7 | go.uber.org/zap v1.17.0 8 | ) 9 | 10 | require ( 11 | github.com/BurntSushi/toml v0.3.1 // indirect 12 | github.com/pkg/errors v0.9.1 // indirect 13 | github.com/stretchr/testify v1.7.1 // indirect 14 | go.uber.org/atomic v1.7.0 // indirect 15 | go.uber.org/multierr v1.6.0 // indirect 16 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect 17 | gopkg.in/yaml.v2 v2.4.0 // indirect 18 | gopkg.in/yaml.v3 v3.0.1 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /run.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "moto/config" 7 | "moto/controller" 8 | "moto/utils" 9 | "os" 10 | "sync" 11 | ) 12 | 13 | func main() { 14 | conf := flag.String("config", "", "Path to config file") 15 | flag.Parse() 16 | 17 | // Load config if a path is provided; overrides default and env 18 | if *conf != "" { 19 | if err := config.Reload(*conf); err != nil { 20 | fmt.Printf("failed to load config: %v\n", err) 21 | os.Exit(1) 22 | } 23 | } 24 | 25 | defer utils.Logger.Sync() 26 | 27 | utils.Logger.Info("MOTO 启动...") 28 | // single-sided build: no accelerator init required 29 | wg := &sync.WaitGroup{} 30 | for _, v := range config.GlobalCfg.Rules { 31 | wg.Add(1) 32 | go controller.Listen(v, wg) 33 | } 34 | wg.Wait() 35 | utils.Logger.Info("MOTO 关闭...") 36 | } 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 cppla 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /controller/normal.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "io" 5 | "moto/config" 6 | "moto/utils" 7 | "net" 8 | "time" 9 | 10 | "go.uber.org/zap" 11 | ) 12 | 13 | // HandleNormal 会依次尝试各个目标,并在成功的连接上挂载自适应的单边加速。 14 | func HandleNormal(conn net.Conn, rule *config.Rule) { 15 | defer conn.Close() 16 | 17 | var target net.Conn 18 | //正常模式下挨个连接直到成功连接 19 | for _, v := range rule.Targets { 20 | c, err := outboundDial(v.Address) 21 | if err != nil { 22 | utils.Logger.Error("无法建立连接,尝试下一个目标", 23 | zap.String("ruleName", rule.Name), 24 | zap.String("remoteAddr", conn.RemoteAddr().String()), 25 | zap.String("targetAddr", v.Address)) 26 | continue 27 | } 28 | if tc, ok := c.(*net.TCPConn); ok { 29 | _ = tc.SetNoDelay(true) 30 | _ = tc.SetKeepAlive(true) 31 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 32 | } 33 | target = c 34 | break 35 | } 36 | if target == nil { 37 | utils.Logger.Error("所有目标均连接失败,无法处理连接", 38 | zap.String("ruleName", rule.Name), 39 | zap.String("remoteAddr", conn.RemoteAddr().String())) 40 | return 41 | } 42 | utils.Logger.Debug("建立连接", 43 | zap.String("ruleName", rule.Name), 44 | zap.String("remoteAddr", conn.RemoteAddr().String()), 45 | zap.String("targetAddr", target.RemoteAddr().String())) 46 | 47 | defer target.Close() 48 | 49 | go func() { 50 | io.Copy(conn, target) 51 | conn.Close() 52 | target.Close() 53 | }() 54 | io.Copy(target, conn) 55 | } 56 | -------------------------------------------------------------------------------- /controller/server.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "moto/config" 5 | "moto/utils" 6 | "net" 7 | "strings" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // Listen 根据规则启动 TCP 监听,做基础限流并分发到对应模式。 13 | func Listen(rule *config.Rule, wg *sync.WaitGroup) { 14 | defer wg.Done() 15 | if rule.Prewarm { 16 | initPrewarm(rule) 17 | } 18 | //监听 19 | listener, err := net.Listen("tcp", rule.Listen) 20 | if err != nil { 21 | utils.Logger.Error(rule.Name + " failed to listen at " + rule.Listen) 22 | return 23 | } 24 | utils.Logger.Info(rule.Name + " listing at " + rule.Listen) 25 | for { 26 | //处理客户端连接 27 | conn, err := listener.Accept() 28 | if err != nil { 29 | utils.Logger.Error(rule.Name + " failed to accept at " + rule.Listen) 30 | time.Sleep(time.Second * 1) 31 | continue 32 | } 33 | //判断黑名单 34 | if len(rule.Blacklist) != 0 { 35 | clientIP := conn.RemoteAddr().String() 36 | clientIP = clientIP[0:strings.LastIndex(clientIP, ":")] 37 | if rule.Blacklist[clientIP] { 38 | utils.Logger.Info(rule.Name + " disconnected ip in blacklist: " + clientIP) 39 | conn.Close() 40 | continue 41 | } 42 | } 43 | //选择运行模式 44 | switch rule.Mode { 45 | case "normal": 46 | go HandleNormal(conn, rule) 47 | case "regex": 48 | go HandleRegexp(conn, rule) 49 | case "boost": 50 | go HandleBoost(conn, rule) 51 | case "roundrobin": 52 | go HandleRoundrobin(conn, rule) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /controller/roundrobin.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "io" 5 | "moto/config" 6 | "moto/utils" 7 | "net" 8 | "sync/atomic" 9 | "time" 10 | 11 | "go.uber.org/zap" 12 | ) 13 | 14 | var tcpCounter uint64 15 | 16 | // HandleRoundrobin 顺序轮转目标,失败时回退到 boost 模式。 17 | func HandleRoundrobin(conn net.Conn, rule *config.Rule) { 18 | defer conn.Close() 19 | 20 | index := atomic.AddUint64(&tcpCounter, 1) % uint64(len(rule.Targets)) 21 | if tcpCounter >= 100*uint64(len(rule.Targets)) { 22 | atomic.StoreUint64(&tcpCounter, 1) 23 | } 24 | 25 | v := rule.Targets[index] 26 | 27 | roundrobinBegin := time.Now() 28 | target, err := outboundDial(v.Address) 29 | if err != nil { 30 | utils.Logger.Error("无法建立连接,切换到 boost 模式", 31 | zap.String("ruleName", rule.Name), 32 | zap.String("remoteAddr", conn.RemoteAddr().String()), 33 | zap.String("targetAddr", v.Address), 34 | zap.Int64("failedTime(ms)", time.Since(roundrobinBegin).Milliseconds())) 35 | HandleBoost(conn, rule) 36 | return 37 | } 38 | if tc, ok := target.(*net.TCPConn); ok { 39 | _ = tc.SetNoDelay(true) 40 | _ = tc.SetKeepAlive(true) 41 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 42 | } 43 | utils.Logger.Debug("建立连接", 44 | zap.String("ruleName", rule.Name), 45 | zap.String("remoteAddr", conn.RemoteAddr().String()), 46 | zap.String("targetAddr", target.RemoteAddr().String()), 47 | zap.Int64("roundrobinTime(ms)", time.Since(roundrobinBegin).Milliseconds())) 48 | 49 | defer target.Close() 50 | 51 | go func() { 52 | io.Copy(conn, target) 53 | conn.Close() 54 | target.Close() 55 | }() 56 | io.Copy(target, conn) 57 | } 58 | -------------------------------------------------------------------------------- /config/setting.json: -------------------------------------------------------------------------------- 1 | { 2 | "log": { 3 | "level": "debug", 4 | "path": "./moto.log", 5 | "version": "1.0.1", 6 | "date": "2024-07-23" 7 | }, 8 | "rules": [ 9 | { 10 | "name": "正常模式", 11 | "listen": ":81", 12 | "mode": "normal", 13 | "prewarm": false, 14 | "timeout": 60000, 15 | "blacklist": null, 16 | "targets": [ 17 | { 18 | "address": "www.baidu.com:80" 19 | }, 20 | { 21 | "address": "www.baidu.com:80" 22 | } 23 | ] 24 | }, 25 | { 26 | "name": "正则模式", 27 | "listen": ":82", 28 | "mode": "regex", 29 | "prewarm": false, 30 | "timeout": 60000, 31 | "blacklist": null, 32 | "targets": [ 33 | { 34 | "regexp": "^(GET|POST|HEAD|DELETE|PUT|CONNECT|OPTIONS|TRACE)", 35 | "address": "www.baidu.com:80" 36 | }, 37 | { 38 | "regexp": "^SSH", 39 | "address": "www.baidu.com:22" 40 | } 41 | ] 42 | }, 43 | { 44 | "name": "智能加速", 45 | "listen": ":83", 46 | "mode": "boost", 47 | "prewarm": true, 48 | "timeout": 60000, 49 | "blacklist": null, 50 | "targets": [ 51 | { 52 | "address": "www.baidu.com:80" 53 | }, 54 | { 55 | "address": "www.baidu.com:80" 56 | } 57 | ] 58 | }, 59 | { 60 | "name": "轮询模式", 61 | "listen": ":84", 62 | "mode": "roundrobin", 63 | "prewarm": true, 64 | "timeout": 60000, 65 | "blacklist": null, 66 | "targets": [ 67 | { 68 | "address": "www.baidu.com:80" 69 | }, 70 | { 71 | "address": "www.baidu.com:80" 72 | } 73 | ] 74 | } 75 | ] 76 | } 77 | -------------------------------------------------------------------------------- /controller/regex.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "moto/config" 7 | "moto/utils" 8 | "net" 9 | "time" 10 | 11 | "go.uber.org/zap" 12 | ) 13 | 14 | // HandleRegexp 通过正则检测首包选出目标,再转发后续数据流。 15 | func HandleRegexp(conn net.Conn, rule *config.Rule) { 16 | defer conn.Close() 17 | 18 | //正则模式下需要客户端的第一个数据包判断特征,所以需要设置一个超时 19 | conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(rule.Timeout))) 20 | //获取第一个数据包 21 | firstPacket := new(bytes.Buffer) 22 | if _, err := io.CopyN(firstPacket, conn, 4096); err != nil { 23 | utils.Logger.Error("无法处理连接,读取首包失败", 24 | zap.String("ruleName", rule.Name), 25 | zap.String("remoteAddr", conn.RemoteAddr().String()), 26 | zap.Error(err)) 27 | return 28 | } 29 | 30 | var target net.Conn 31 | //挨个匹配正则 32 | for _, v := range rule.Targets { 33 | if !v.Re.Match(firstPacket.Bytes()) { 34 | continue 35 | } 36 | c, err := outboundDial(v.Address) 37 | if err != nil { 38 | utils.Logger.Error("无法建立连接", 39 | zap.String("ruleName", rule.Name), 40 | zap.String("remoteAddr", conn.RemoteAddr().String()), 41 | zap.String("targetAddr", v.Address)) 42 | continue 43 | } 44 | if tc, ok := c.(*net.TCPConn); ok { 45 | _ = tc.SetNoDelay(true) 46 | _ = tc.SetKeepAlive(true) 47 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 48 | } 49 | target = c 50 | break 51 | } 52 | if target == nil { 53 | utils.Logger.Error("未匹配到任何目标,无法处理连接", 54 | zap.String("ruleName", rule.Name), 55 | zap.String("remoteAddr", conn.RemoteAddr().String())) 56 | return 57 | } 58 | 59 | utils.Logger.Debug("建立连接", 60 | zap.String("ruleName", rule.Name), 61 | zap.String("remoteAddr", conn.RemoteAddr().String()), 62 | zap.String("targetAddr", target.RemoteAddr().String())) 63 | //匹配到了,去除掉刚才设定的超时 64 | conn.SetReadDeadline(time.Time{}) 65 | //把第一个数据包发送给目标 66 | io.Copy(target, firstPacket) 67 | 68 | defer target.Close() 69 | 70 | go func() { 71 | io.Copy(conn, target) 72 | conn.Close() 73 | target.Close() 74 | }() 75 | io.Copy(target, conn) 76 | } 77 | -------------------------------------------------------------------------------- /utils/log.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "moto/config" 5 | "time" 6 | 7 | "github.com/natefinch/lumberjack" 8 | "go.uber.org/zap" 9 | "go.uber.org/zap/zapcore" 10 | ) 11 | 12 | var ( 13 | Logger *zap.Logger 14 | ) 15 | 16 | func init() { 17 | highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 18 | return lvl >= levelMap[config.GlobalCfg.Log.Level] 19 | }) 20 | 21 | //lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 22 | // return lvl >= zapcore.DebugLevel 23 | //}) 24 | 25 | hook := lumberjack.Logger{ 26 | Filename: config.GlobalCfg.Log.Path, 27 | MaxSize: 1024, 28 | MaxBackups: 5, 29 | MaxAge: 30, 30 | Compress: true, 31 | } 32 | 33 | //consoles := zapcore.AddSync(os.Stdout) 34 | files := zapcore.AddSync(&hook) 35 | 36 | encoderConfig := zapcore.EncoderConfig{ 37 | TimeKey: "ts", 38 | LevelKey: "level", 39 | NameKey: "logger", 40 | //CallerKey: "caller", 41 | MessageKey: "msg", 42 | StacktraceKey: "stacktrace", 43 | LineEnding: zapcore.DefaultLineEnding, 44 | EncodeLevel: zapcore.LowercaseLevelEncoder, 45 | //EncodeLevel: zapcore.CapitalColorLevelEncoder, 46 | EncodeTime: TimeEncoder, 47 | EncodeDuration: zapcore.SecondsDurationEncoder, 48 | EncodeCaller: zapcore.ShortCallerEncoder, 49 | } 50 | 51 | //consoleEncoder := zapcore.NewJSONEncoder(encoderConfig) 52 | fileEncoder := zapcore.NewJSONEncoder(encoderConfig) 53 | 54 | core := zapcore.NewTee( 55 | //zapcore.NewCore(consoleEncoder, consoles, lowPriority), 56 | zapcore.NewCore(fileEncoder, files, highPriority), 57 | ) 58 | 59 | Logger = zap.New( 60 | core, 61 | zap.AddCaller(), 62 | zap.Development()) 63 | 64 | } 65 | 66 | var levelMap = map[string]zapcore.Level{ 67 | "debug": zapcore.DebugLevel, 68 | "info": zapcore.InfoLevel, 69 | "warn": zapcore.WarnLevel, 70 | "error": zapcore.ErrorLevel, 71 | "dpanic": zapcore.DPanicLevel, 72 | "panic": zapcore.PanicLevel, 73 | "fatal": zapcore.FatalLevel, 74 | } 75 | 76 | func TimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { 77 | enc.AppendString(t.Format("2006-01-02 15:04:05.000")) 78 | } 79 | 80 | // optional helpers for structured fields (used in some modules) 81 | func ZapString(k, v string) zap.Field { return zap.String(k, v) } 82 | func ZapErr(err error) zap.Field { return zap.Error(err) } 83 | -------------------------------------------------------------------------------- /controller/direct.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/netip" 7 | "time" 8 | ) 9 | 10 | // dialConn 在原始连接基础上附带拨号延迟,供自适应复制逻辑使用。 11 | type dialConn struct { 12 | net.Conn 13 | latency time.Duration 14 | } 15 | 16 | func (d *dialConn) DialLatency() time.Duration { return d.latency } 17 | 18 | // DialFast 实现简化版的 Happy Eyeballs,并记录拨号延迟。 19 | func DialFast(addr string) (net.Conn, error) { 20 | start := time.Now() 21 | host, port, err := net.SplitHostPort(addr) 22 | if err != nil { 23 | c, e := (&net.Dialer{Timeout: 3 * time.Second}).Dial("tcp", addr) 24 | if e != nil { 25 | return nil, e 26 | } 27 | return &dialConn{Conn: c, latency: time.Since(start)}, nil 28 | } 29 | if ip, perr := netip.ParseAddr(host); perr == nil { 30 | target := net.JoinHostPort(ip.String(), port) 31 | c, e := (&net.Dialer{Timeout: 3 * time.Second}).Dial("tcp", target) 32 | if e != nil { 33 | return nil, e 34 | } 35 | return &dialConn{Conn: c, latency: time.Since(start)}, nil 36 | } 37 | ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 38 | defer cancel() 39 | addrs, rerr := net.DefaultResolver.LookupIP(ctx, "ip", host) 40 | if rerr != nil || len(addrs) == 0 { 41 | c, e := (&net.Dialer{Timeout: 3 * time.Second}).Dial("tcp", addr) 42 | if e != nil { 43 | return nil, e 44 | } 45 | return &dialConn{Conn: c, latency: time.Since(start)}, nil 46 | } 47 | type result struct { 48 | c net.Conn 49 | err error 50 | } 51 | resCh := make(chan result, 1) 52 | for i, ip := range addrs { 53 | go func(delay int, ip net.IP) { 54 | if delay > 0 { 55 | select { 56 | case <-time.After(time.Duration(delay) * 50 * time.Millisecond): 57 | case <-ctx.Done(): 58 | return 59 | } 60 | } 61 | d := &net.Dialer{Timeout: 2 * time.Second} 62 | c, e := d.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), port)) 63 | if e == nil { 64 | select { 65 | case resCh <- result{c: c}: 66 | cancel() 67 | default: 68 | _ = c.Close() 69 | } 70 | } 71 | }(i, ip) 72 | } 73 | select { 74 | case r := <-resCh: 75 | if r.err != nil { 76 | return nil, r.err 77 | } 78 | return &dialConn{Conn: r.c, latency: time.Since(start)}, nil 79 | case <-ctx.Done(): 80 | c, e := (&net.Dialer{Timeout: 3 * time.Second}).Dial("tcp", addr) 81 | if e != nil { 82 | return nil, e 83 | } 84 | return &dialConn{Conn: c, latency: time.Since(start)}, nil 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= 2 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/natefinch/lumberjack v2.0.0+incompatible h1:4QJd3OLAMgj7ph+yZTuX13Ld4UpgHp07nNdFX7mqFfM= 7 | github.com/natefinch/lumberjack v2.0.0+incompatible/go.mod h1:Wi9p2TTF5DG5oU+6YfsmYQpsTIOm0B1VNzQg9Mw6nPk= 8 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 9 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 10 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 14 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 15 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 16 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 17 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 18 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 19 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 20 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 21 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 22 | go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= 23 | go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= 26 | gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= 27 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 28 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 29 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 30 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 32 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 33 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Moto 2 | 3 | 端口转发、正则匹配[端口复用]转发、智能加速、轮询加速。TCP转发,零拷贝转发, 单边加速。 4 | high-speed motorcycle,可以上高速的摩托车🏍️~ 5 | 6 | ## 模式 7 | - 普通模式[normal]:逐一连接目标地址,成功为止 8 | - 正则模式[regex]:利用正则匹配第一个数据报文来实现端口复用 9 | - 智能加速[boost]:多线路多TCP主动竞争最优TCP通道,大幅降低网络丢包、中断、切换、出口高低峰的影响! 10 | - 轮询模式[roundrobin]:分散连接到所有目标地址 11 | 12 | 目标为域名时会并发拨号并优先最先连通。 13 | 14 | ## 演示,自动择路 15 | ``` 16 | `work from home(china telecom)`: 17 | {"level":"debug","ts":"2022-06-08 12:17:59.444","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :49751","targetAddr":"47.241.9.9 [新加坡 阿里云] :85","decisionTime(ms)":79} 18 | {"level":"debug","ts":"2022-06-08 12:18:05.050","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :49774","targetAddr":"47.241.9.9 [新加坡 阿里云] :85","decisionTime(ms)":81} 19 | {"level":"debug","ts":"2022-06-08 12:18:05.493","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :49783","targetAddr":"34.124.1.1 [美国 得克萨斯州] :85","decisionTime(ms)":75} 20 | {"level":"debug","ts":"2022-06-08 12:18:05.838","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :49792","targetAddr":"47.241.9.9 [新加坡 阿里云] :85","decisionTime(ms)":84} 21 | {"level":"debug","ts":"2022-06-08 12:18:09.176","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :49810","targetAddr":"34.124.1.1 [美国 得克萨斯州] :85","decisionTime(ms)":81} 22 | 23 | `in office(china unicom)`: 24 | {"level":"debug","ts":"2022-06-09 19:24:43.216","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :63847","targetAddr":"119.28.5.2 [香港 腾讯云] :85","decisionTime(ms)":66} 25 | {"level":"debug","ts":"2022-06-09 19:24:49.412","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :63878","targetAddr":"119.28.5.2 [香港 腾讯云] :85","decisionTime(ms)":49} 26 | {"level":"debug","ts":"2022-06-09 19:27:07.666","msg":"establish connection","ruleName":"智能加速","remoteAddr":"127.0.0.1 [本机地址] :64256","targetAddr":"119.28.5.2 [香港 腾讯云] :85","decisionTime(ms)":55} 27 | ``` 28 | 29 | ## 运行 30 | ```bash 31 | go run ./run.go # 使用默认 config/setting.json 32 | go run ./run.go --config config/setting.json 33 | ``` 34 | 也可通过环境变量:`MOTO_CONFIG=/path/to/your.json`。 35 | 36 | 37 | ## 构建 38 | ```bash 39 | # build 40 | go build ./... 41 | # build for linux 42 | CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo 43 | 44 | # build for macos 45 | CGO_ENABLED=0 GOOS=darwin go build -a -installsuffix cgo 46 | 47 | # build for windows 48 | CGO_ENABLED=0 GOOS=windows go build -a -installsuffix cgo 49 | ``` 50 | 51 | ## 常用正则 52 | |协议|正则表达式| 53 | | --- | ---| 54 | |HTTP|^(GET\|POST\|HEAD\|DELETE\|PUT\|CONNECT\|OPTIONS\|TRACE)| 55 | |SSH|^SSH| 56 | |HTTPS(SSL)|^\x16\x03| 57 | |RDP|^\x03\x00\x00| 58 | |SOCKS5|^\x05| 59 | |HTTP代理|(^CONNECT)\|(Proxy-Connection:)| 60 | 61 | 1、复制到JSON中记得注意特殊符号,例如^\\x16\\x03得改成^\\\\x16\\\\x03** 62 | 2、正则模式的原理是根据客户端建立连接后第一个数据包的特征进行判断是什么协议,该方式不支持连接建立之后服务器主动握手的协议,例如VNC,FTP,MYSQL,被动SSH等。** 63 | 64 | ## 参考 65 | - better way for tcp relay: https://hostloc.com/thread-969397-1-1.html 66 | - switcher: https://github.com/crabkun/switcher 67 | -------------------------------------------------------------------------------- /config/setting.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "regexp" 9 | ) 10 | 11 | // projectConfig 保存从 setting.json 读取的顶层配置。 12 | type projectConfig struct { 13 | Log log `json:"log"` 14 | Rules []*Rule `json:"rules"` 15 | } 16 | 17 | type log struct { 18 | Level string `json:"level"` 19 | Path string `json:"path"` 20 | Version string `json:"version"` 21 | Date string `json:"date"` 22 | } 23 | 24 | // Rule 描述一个监听端口以及接入流量的路由策略。 25 | type Rule struct { 26 | Name string `json:"name"` 27 | Listen string `json:"listen"` 28 | Mode string `json:"mode"` 29 | Prewarm bool `json:"prewarm"` 30 | Targets []*struct { 31 | Regexp string `json:"regexp"` 32 | Re *regexp.Regexp `json:"-"` 33 | Address string `json:"address"` 34 | } `json:"targets"` 35 | Timeout uint64 `json:"timeout"` 36 | Blacklist map[string]bool `json:"blacklist"` 37 | } 38 | 39 | // (单边模式)已移除加速端和丢包自适应的旧配置。 40 | 41 | // GlobalCfg 指向全局生效的配置对象。 42 | var GlobalCfg *projectConfig 43 | 44 | func init() { 45 | // 支持通过环境变量覆盖配置文件路径 46 | path := os.Getenv("MOTO_CONFIG") 47 | if path == "" { 48 | path = "config/setting.json" 49 | } 50 | buf, err := ioutil.ReadFile(path) 51 | if err != nil { 52 | fmt.Printf("failed to load setting.json: %s\n", err.Error()) 53 | } 54 | 55 | if err := json.Unmarshal(buf, &GlobalCfg); err != nil { 56 | fmt.Printf("failed to load setting.json: %s\n", err.Error()) 57 | } 58 | 59 | if len(GlobalCfg.Rules) == 0 { 60 | fmt.Printf("empty rule\n") 61 | } 62 | 63 | for i, v := range GlobalCfg.Rules { 64 | if err := v.verify(); err != nil { 65 | fmt.Printf("verify rule failed at pos %d : %s\n", i, err.Error()) 66 | } 67 | } 68 | } 69 | 70 | // Reload 从指定路径重载配置,并执行默认值填充与校验。 71 | func Reload(path string) error { 72 | buf, err := ioutil.ReadFile(path) 73 | if err != nil { 74 | return err 75 | } 76 | var cfg *projectConfig 77 | if err := json.Unmarshal(buf, &cfg); err != nil { 78 | return err 79 | } 80 | if len(cfg.Rules) == 0 { 81 | fmt.Printf("empty rule\n") 82 | } 83 | for i, v := range cfg.Rules { 84 | if err := v.verify(); err != nil { 85 | fmt.Printf("verify rule failed at pos %d : %s\n", i, err.Error()) 86 | } 87 | } 88 | GlobalCfg = cfg 89 | return nil 90 | } 91 | 92 | // verify 校验规则配置,并在需要时编译正则。 93 | func (c *Rule) verify() error { 94 | if c.Name == "" { 95 | return fmt.Errorf("empty name") 96 | } 97 | if c.Listen == "" { 98 | return fmt.Errorf("invalid listen address") 99 | } 100 | if len(c.Targets) == 0 { 101 | return fmt.Errorf("invalid targets") 102 | } 103 | if c.Mode == "regex" { 104 | if c.Timeout == 0 { 105 | c.Timeout = 500 106 | } 107 | } 108 | for i, v := range c.Targets { 109 | if v.Address == "" { 110 | return fmt.Errorf("invalid address at pos %d", i) 111 | } 112 | if c.Mode == "regex" { 113 | r, err := regexp.Compile(v.Regexp) 114 | if err != nil { 115 | return fmt.Errorf("invalid regexp at pos %d : %s", i, err.Error()) 116 | } 117 | v.Re = r 118 | } 119 | } 120 | return nil 121 | } 122 | -------------------------------------------------------------------------------- /controller/prewarm.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "time" 7 | 8 | "moto/config" 9 | "moto/utils" 10 | 11 | "go.uber.org/zap" 12 | ) 13 | 14 | // 每个目标默认初始预热连接数量。 15 | // 预热配置:默认与 boost 初始规模,以及动态扩容上限。 16 | // prewarmInitialSize: 所有模式统一的初始预热连接数 (每目标地址,原先默认与 boost 已合并) 17 | // prewarmPerTargetMax: 动态扩容后的硬上限,防止无界膨胀 18 | const ( 19 | prewarmInitialSize = 16 20 | prewarmPerTargetMax = 256 21 | ) 22 | 23 | var prewarmPools sync.Map // 映射地址到对应的预热池 24 | 25 | // prewarmPool 维护目标地址对应的一小撮预热 TCP 连接。 26 | type prewarmPool struct { 27 | addr string 28 | desired int 29 | 30 | mu sync.Mutex 31 | idle []net.Conn 32 | warming int 33 | } 34 | 35 | // initPrewarm 会为规则中的每个目标开启后台保温。 36 | func initPrewarm(rule *config.Rule) { 37 | if !rule.Prewarm { 38 | return 39 | } 40 | desired := prewarmInitialSize 41 | for _, target := range rule.Targets { 42 | ensurePrewarmPool(target.Address, desired) 43 | } 44 | } 45 | 46 | func ensurePrewarmPool(addr string, desired int) *prewarmPool { 47 | poolAny, _ := prewarmPools.LoadOrStore(addr, &prewarmPool{addr: addr, desired: desired}) 48 | pool := poolAny.(*prewarmPool) 49 | pool.mu.Lock() 50 | if desired > pool.desired { 51 | pool.desired = desired 52 | } 53 | pool.ensureLocked() 54 | pool.mu.Unlock() 55 | return pool 56 | } 57 | 58 | // ensureLocked 会持续补齐预热连接直到达到期望值。 59 | func (p *prewarmPool) ensureLocked() { 60 | need := p.desired - len(p.idle) - p.warming 61 | if need <= 0 { 62 | return 63 | } 64 | for i := 0; i < need; i++ { 65 | p.warming++ 66 | go p.dialOne() 67 | } 68 | } 69 | 70 | // dialOne 拨号一个连接并加入空闲池。 71 | func (p *prewarmPool) dialOne() { 72 | conn, err := DialFast(p.addr) 73 | if err != nil { 74 | utils.Logger.Warn("预热连接失败", zap.String("target", p.addr), zap.Error(err)) 75 | time.Sleep(500 * time.Millisecond) 76 | p.mu.Lock() 77 | p.warming-- 78 | if p.warming < 0 { 79 | p.warming = 0 80 | } 81 | p.ensureLocked() 82 | p.mu.Unlock() 83 | return 84 | } 85 | if tc, ok := conn.(*net.TCPConn); ok { 86 | _ = tc.SetKeepAlive(true) 87 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 88 | _ = tc.SetNoDelay(true) 89 | } 90 | p.mu.Lock() 91 | p.warming-- 92 | p.idle = append(p.idle, conn) 93 | p.ensureLocked() 94 | p.mu.Unlock() 95 | } 96 | 97 | // acquirePrewarmed 优先从预热池取出可用连接。 98 | func acquirePrewarmed(addr string) (net.Conn, bool) { 99 | poolAny, ok := prewarmPools.Load(addr) 100 | if !ok { 101 | return nil, false 102 | } 103 | pool := poolAny.(*prewarmPool) 104 | pool.mu.Lock() 105 | defer pool.mu.Unlock() 106 | n := len(pool.idle) 107 | if n == 0 { 108 | pool.ensureLocked() 109 | return nil, false 110 | } 111 | conn := pool.idle[n-1] 112 | pool.idle = pool.idle[:n-1] 113 | // 动态扩容逻辑: 114 | // 需求:一旦“剩余预热可用连接” < desired 的 1/4,立即触发再次预热;新增数量 = 当前活跃使用中的连接数 * 2。 115 | // 当前活跃使用中的连接数近似估算:active = desired - idleLen(取出后) - warming 116 | // 然后 desired += active*2 (至少 1),并受 prewarmPerTargetMax 限制。 117 | // 说明:这里不做回缩,保持简单;若未来需要收缩可添加基于空闲率的定时回收策略。 118 | remaining := len(pool.idle) // 取出后剩余 idle 数 119 | if pool.desired > 0 && remaining*4 < pool.desired { // 剩余 < 1/4 触发扩容 120 | oldDesired := pool.desired 121 | active := pool.desired - remaining - pool.warming 122 | if active < 0 { 123 | active = 0 124 | } 125 | growth := active * 2 126 | if growth < 1 { 127 | growth = 1 128 | } 129 | pool.desired += growth 130 | if pool.desired > prewarmPerTargetMax { 131 | pool.desired = prewarmPerTargetMax 132 | } 133 | utils.Logger.Debug("预热动态扩容", 134 | zap.String("target", pool.addr), 135 | zap.Int("remainingIdle", remaining), 136 | zap.Int("activeApprox", active), 137 | zap.Int("warming", pool.warming), 138 | zap.Int("growth", growth), 139 | zap.Int("oldDesired", oldDesired), 140 | zap.Int("newDesired", pool.desired)) 141 | } 142 | pool.ensureLocked() 143 | return conn, true 144 | } 145 | 146 | // outboundDial 先尝试预热池,失败再发起新建连接。 147 | // 之前返回 (conn, usedFlag, error),由于当前不再区分来源,精简为 (conn, error)。 148 | func outboundDial(addr string) (net.Conn, error) { 149 | if conn, ok := acquirePrewarmed(addr); ok { 150 | return conn, nil 151 | } 152 | c, err := DialFast(addr) 153 | if err != nil { 154 | return nil, err 155 | } 156 | return c, nil 157 | } 158 | -------------------------------------------------------------------------------- /controller/boost.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "moto/config" 7 | "moto/utils" 8 | "net" 9 | "sync" 10 | "time" 11 | 12 | "go.uber.org/zap" 13 | ) 14 | 15 | const ( 16 | boostWinnerTTL = 30 * time.Second // 胜出线路缓存时长 17 | boostRevalidateAfter = boostWinnerTTL / 2 18 | ) 19 | 20 | type boostWinnerEntry struct { 21 | addr string 22 | expires time.Time 23 | } 24 | 25 | var boostWinnerCache sync.Map 26 | 27 | const boostWinnerCacheMax = 256 // 防止规则极多导致无限增长 28 | 29 | type dialResult struct { 30 | conn net.Conn 31 | addr string 32 | } 33 | 34 | func loadBoostWinner(ruleName string) (string, bool, time.Time) { 35 | if v, ok := boostWinnerCache.Load(ruleName); ok { 36 | entry := v.(boostWinnerEntry) 37 | if time.Now().Before(entry.expires) { 38 | return entry.addr, true, entry.expires 39 | } 40 | boostWinnerCache.Delete(ruleName) 41 | } 42 | return "", false, time.Time{} 43 | } 44 | 45 | func storeBoostWinner(ruleName, addr string) { 46 | // 简单的 size 控制:超过上限时随机淘汰一个(遍历首个)。 47 | count := 0 48 | boostWinnerCache.Range(func(k, v any) bool { 49 | count++ 50 | if count > boostWinnerCacheMax { 51 | // 淘汰当前这个并停止 52 | boostWinnerCache.Delete(k) 53 | return false 54 | } 55 | return true 56 | }) 57 | boostWinnerCache.Store(ruleName, boostWinnerEntry{addr: addr, expires: time.Now().Add(boostWinnerTTL)}) 58 | } 59 | 60 | // 不再单独提供显式 drop 接口,超时或拨号失败自动失效。 61 | 62 | // lazyRevalidate 在后台重新跑一次竞速,不打断现有请求;若发现更快线路则更新缓存。 63 | func lazyRevalidate(rule *config.Rule) { 64 | // 只有多于一个目标才有意义 65 | if len(rule.Targets) < 2 { 66 | return 67 | } 68 | // 设定一个较短的决策超时,避免后台任务堆积。 69 | ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 70 | defer cancel() 71 | switchBetter := make(chan dialResult, 1) 72 | // 启动并发拨号 73 | for _, v := range rule.Targets { 74 | addr := v.Address 75 | go func(a string) { 76 | if c, err := outboundDial(a); err == nil { 77 | select { 78 | case switchBetter <- dialResult{conn: c, addr: a}: 79 | case <-ctx.Done(): 80 | c.Close() 81 | } 82 | } 83 | }(addr) 84 | } 85 | var best dialResult 86 | select { 87 | case best = <-switchBetter: 88 | cancel() 89 | case <-ctx.Done(): 90 | return 91 | } 92 | if tc, ok := best.conn.(*net.TCPConn); ok { 93 | _ = tc.SetNoDelay(true) 94 | _ = tc.SetKeepAlive(true) 95 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 96 | } 97 | storeBoostWinner(rule.Name, best.addr) 98 | utils.Logger.Debug("懒惰刷新winner", 99 | zap.String("ruleName", rule.Name), 100 | zap.String("targetAddr", best.conn.RemoteAddr().String())) 101 | best.conn.Close() 102 | } 103 | 104 | // HandleBoost 同时发起多路拨号,挑选最先成功的连接并套上单边加速。 105 | func HandleBoost(conn net.Conn, rule *config.Rule) { 106 | defer conn.Close() 107 | 108 | decisionBegin := time.Now() 109 | 110 | if addr, ok, exp := loadBoostWinner(rule.Name); ok { 111 | // 命中缓存后,判断是否需要后台懒惰校验。 112 | var triggerLazy bool 113 | if !exp.IsZero() { 114 | lifeLeft := time.Until(exp) 115 | if lifeLeft < boostRevalidateAfter { 116 | triggerLazy = true 117 | } 118 | } 119 | if cachedConn, err := outboundDial(addr); err == nil { 120 | if tc, ok := cachedConn.(*net.TCPConn); ok { 121 | _ = tc.SetNoDelay(true) 122 | _ = tc.SetKeepAlive(true) 123 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 124 | } 125 | storeBoostWinner(rule.Name, addr) 126 | fields := []zap.Field{ 127 | zap.String("ruleName", rule.Name), 128 | zap.String("remoteAddr", conn.RemoteAddr().String()), 129 | zap.String("targetAddr", cachedConn.RemoteAddr().String()), 130 | zap.Int64("decisionTime(ms)", time.Since(decisionBegin).Milliseconds()), 131 | zap.Bool("boostCacheHit", true), 132 | } 133 | if triggerLazy { 134 | fields = append(fields, zap.Bool("boostLazyRefresh", true)) 135 | } 136 | utils.Logger.Debug("建立连接", fields...) 137 | 138 | if triggerLazy { 139 | go lazyRevalidate(rule) 140 | } 141 | 142 | defer cachedConn.Close() 143 | 144 | go func() { 145 | io.Copy(conn, cachedConn) 146 | conn.Close() 147 | cachedConn.Close() 148 | }() 149 | io.Copy(cachedConn, conn) 150 | return 151 | } 152 | // 缓存线路拨号失败:直接从缓存移除,下次重新竞速 153 | boostWinnerCache.Delete(rule.Name) 154 | } 155 | 156 | // 并发拨号选择最快线路 157 | ctx, cancel := context.WithCancel(context.Background()) 158 | defer cancel() 159 | switchBetter := make(chan dialResult, 1) 160 | for _, v := range rule.Targets { 161 | go func(address string) { 162 | if tryGetQuickConn, err := outboundDial(address); err == nil { 163 | select { 164 | case switchBetter <- dialResult{conn: tryGetQuickConn, addr: address}: 165 | case <-ctx.Done(): 166 | tryGetQuickConn.Close() 167 | } 168 | } 169 | }(v.Address) 170 | } 171 | // 全部连接失败: 所有线路延迟或中断 172 | dtx, dance := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(rule.Timeout)) 173 | defer dance() 174 | 175 | var winner dialResult 176 | select { 177 | case winner = <-switchBetter: 178 | cancel() 179 | case <-dtx.Done(): 180 | utils.Logger.Error("加速决策失败:所有线路均不可用", 181 | zap.String("ruleName", rule.Name)) 182 | return 183 | } 184 | 185 | if tc, ok := winner.conn.(*net.TCPConn); ok { 186 | _ = tc.SetNoDelay(true) 187 | _ = tc.SetKeepAlive(true) 188 | _ = tc.SetKeepAlivePeriod(30 * time.Second) 189 | } 190 | storeBoostWinner(rule.Name, winner.addr) 191 | 192 | utils.Logger.Debug("建立连接", 193 | zap.String("ruleName", rule.Name), 194 | zap.String("remoteAddr", conn.RemoteAddr().String()), 195 | zap.String("targetAddr", winner.conn.RemoteAddr().String()), 196 | zap.Int64("decisionTime(ms)", time.Since(decisionBegin).Milliseconds()), 197 | zap.Bool("boostCacheHit", false)) 198 | 199 | defer winner.conn.Close() 200 | 201 | go func() { 202 | io.Copy(conn, winner.conn) 203 | conn.Close() 204 | winner.conn.Close() 205 | }() 206 | io.Copy(winner.conn, conn) 207 | } 208 | -------------------------------------------------------------------------------- /test/bench.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import asyncio 3 | import argparse 4 | import time 5 | import statistics 6 | import json 7 | import random 8 | import sys 9 | from collections import Counter, defaultdict 10 | from typing import Optional, List, Dict, Any 11 | 12 | PROXY_HOST = "127.0.0.1" 13 | PROXY_PORT = 84 14 | TARGET_HOST = "www.baidu.com" 15 | TARGET_PORT = 80 16 | HTTP_REQ_TEMPLATE = ( 17 | "GET / HTTP/1.1\r\n" 18 | "Host: {host}\r\n" 19 | "User-Agent: prewarm-test/0.1\r\n" 20 | "Accept: */*\r\n" 21 | "Connection: close\r\n" 22 | "\r\n" 23 | ).encode() 24 | 25 | class Result: 26 | __slots__ = ( 27 | "ok","error","connect_ms","first_byte_ms","total_ms","status","phase" 28 | ) 29 | def __init__(self, ok: bool, error: Optional[str], connect_ms: float, 30 | first_byte_ms: float, total_ms: float, status: Optional[int], phase: str): 31 | self.ok = ok 32 | self.error = error 33 | self.connect_ms = connect_ms 34 | self.first_byte_ms = first_byte_ms 35 | self.total_ms = total_ms 36 | self.status = status 37 | self.phase = phase 38 | 39 | async def socks5_http_get(timeout: float, phase: str) -> Result: 40 | start = time.monotonic() 41 | reader = writer = None 42 | try: 43 | # 连接到本地 SOCKS5 44 | conn_begin = time.monotonic() 45 | reader, writer = await asyncio.wait_for( 46 | asyncio.open_connection(PROXY_HOST, PROXY_PORT), 47 | timeout=timeout 48 | ) 49 | # SOCKS5 greeting 50 | writer.write(b"\x05\x01\x00") # VER=5, NMETHODS=1, METHOD=0(no auth) 51 | await writer.drain() 52 | resp = await asyncio.wait_for(reader.readexactly(2), timeout=timeout) 53 | if resp != b"\x05\x00": 54 | raise RuntimeError(f"socks5 greet resp invalid: {resp!r}") 55 | 56 | # CONNECT 请求 57 | host_bytes = TARGET_HOST.encode() 58 | pkt = bytearray() 59 | pkt += b"\x05" # VER 60 | pkt += b"\x01" # CMD=CONNECT 61 | pkt += b"\x00" # RSV 62 | pkt += b"\x03" # ATYP=DOMAIN 63 | pkt += bytes([len(host_bytes)]) 64 | pkt += host_bytes 65 | pkt += TARGET_PORT.to_bytes(2, "big") 66 | writer.write(pkt) 67 | await writer.drain() 68 | # 应答:VER REP RSV ATYP ... 最少 10 字节 (域名长度可能不同) 69 | ver_rep = await asyncio.wait_for(reader.readexactly(4), timeout=timeout) 70 | if len(ver_rep) != 4 or ver_rep[1] != 0x00: 71 | raise RuntimeError(f"socks5 connect failed: {ver_rep!r}") 72 | atyp = ver_rep[3] 73 | if atyp == 1: # IPv4 74 | await asyncio.wait_for(reader.readexactly(4+2), timeout=timeout) 75 | elif atyp == 3: 76 | ln = await asyncio.wait_for(reader.readexactly(1), timeout=timeout) 77 | await asyncio.wait_for(reader.readexactly(ln[0] + 2), timeout=timeout) 78 | elif atyp == 4: # IPv6 79 | await asyncio.wait_for(reader.readexactly(16+2), timeout=timeout) 80 | else: 81 | raise RuntimeError(f"socks5 atyp unsupported: {atyp}") 82 | 83 | connect_done = time.monotonic() 84 | connect_ms = (connect_done - conn_begin) * 1000.0 85 | 86 | # 发起 HTTP 请求 87 | writer.write(HTTP_REQ_TEMPLATE.replace(b"{host}", TARGET_HOST.encode())) 88 | await writer.drain() 89 | 90 | # 首字节 91 | first_chunk = await asyncio.wait_for(reader.read(1), timeout=timeout) 92 | if not first_chunk: 93 | raise RuntimeError("empty first byte") 94 | first_byte_ms = (time.monotonic() - start) * 1000.0 95 | 96 | # 读剩余响应(简单读取到 EOF) 97 | buf = bytearray(first_chunk) 98 | while True: 99 | try: 100 | chunk = await asyncio.wait_for(reader.read(4096), timeout=timeout) 101 | except asyncio.TimeoutError: 102 | raise RuntimeError("read timeout") 103 | if not chunk: 104 | break 105 | buf += chunk 106 | if len(buf) > 64 * 1024: 107 | # 不需要整站大 body,适度截断 108 | break 109 | 110 | total_ms = (time.monotonic() - start) * 1000.0 111 | # 简单解析状态码 112 | status = None 113 | try: 114 | head = bytes(buf.split(b"\r\n", 1)[0]) 115 | if head.startswith(b"HTTP/"): 116 | parts = head.split() 117 | if len(parts) >= 2 and parts[1].isdigit(): 118 | status = int(parts[1]) 119 | except Exception: 120 | pass 121 | 122 | return Result(True, None, connect_ms, first_byte_ms, total_ms, status, phase) 123 | except Exception as e: 124 | total_ms = (time.monotonic() - start) * 1000.0 125 | return Result(False, str(e), 0.0, 0.0, total_ms, None, phase) 126 | finally: 127 | if writer: 128 | try: 129 | writer.close() 130 | await writer.wait_closed() 131 | except Exception: 132 | pass 133 | 134 | def percentiles(values: List[float], ps=(50,90,95,99)) -> Dict[int,float]: 135 | if not values: 136 | return {p: 0.0 for p in ps} 137 | s = sorted(values) 138 | out = {} 139 | n = len(s) 140 | for p in ps: 141 | k = int(round((p/100.0)*(n-1))) 142 | out[p] = s[k] 143 | return out 144 | 145 | async def run_phase(phase_name: str, concurrency: int, total: int, 146 | timeout: float, jitter: float, results: List[Result]): 147 | sem = asyncio.Semaphore(concurrency) 148 | started = 0 149 | 150 | async def worker(idx: int): 151 | nonlocal started 152 | async with sem: 153 | if jitter > 0: 154 | await asyncio.sleep(random.random()*jitter) 155 | res = await socks5_http_get(timeout, phase_name) 156 | results.append(res) 157 | 158 | tasks = [] 159 | for i in range(total): 160 | started += 1 161 | tasks.append(asyncio.create_task(worker(i))) 162 | await asyncio.gather(*tasks) 163 | 164 | def summarize(results: List[Result]): 165 | ok = [r for r in results if r.ok] 166 | fail = [r for r in results if not r.ok] 167 | 168 | connect_ms = [r.connect_ms for r in ok] 169 | fb_ms = [r.first_byte_ms for r in ok] 170 | total_ms = [r.total_ms for r in ok] 171 | 172 | codes = Counter(r.status for r in ok if r.status is not None) 173 | errors = Counter(r.error for r in fail) 174 | 175 | def fmt_dist(c: Counter, top=5): 176 | return ", ".join(f"{k}:{v}" for k,v in c.most_common(top)) or "-" 177 | 178 | def fmt_perc(vals, name): 179 | p = percentiles(vals) 180 | return f"{name} p50={p[50]:.1f} p90={p[90]:.1f} p95={p[95]:.1f} p99={p[99]:.1f}" 181 | 182 | lines = [] 183 | lines.append(f"Total={len(results)} OK={len(ok)} Fail={len(fail)} " 184 | f"SuccessRate={ (len(ok)/len(results)*100 if results else 0):.2f}%") 185 | if ok: 186 | lines.append(fmt_perc(connect_ms, "Connect(ms)")) 187 | lines.append(fmt_perc(fb_ms, "FirstByte(ms)")) 188 | lines.append(fmt_perc(total_ms, "Total(ms)")) 189 | lines.append(f"HTTP Codes: {fmt_dist(codes)}") 190 | if fail: 191 | lines.append(f"Errors: {fmt_dist(errors)}") 192 | # 按 phase 汇总 193 | by_phase = defaultdict(list) 194 | for r in results: 195 | by_phase[r.phase].append(r) 196 | if len(by_phase) > 1: 197 | lines.append("Per-Phase Success:") 198 | for ph, lst in by_phase.items(): 199 | o = sum(1 for x in lst if x.ok) 200 | lines.append(f" {ph}: {o}/{len(lst)} = {o/len(lst)*100:.1f}%") 201 | return "\n".join(lines) 202 | 203 | def parse_args(): 204 | ap = argparse.ArgumentParser(description="SOCKS5 concurrency test for dynamic prewarm observation") 205 | g = ap.add_mutually_exclusive_group(required=True) 206 | g.add_argument("-c","--concurrency", type=int, help="固定并发数") 207 | g.add_argument("-r","--ramp", help="分阶段并发列表, 例如: 50,100,200") 208 | ap.add_argument("-t","--total", type=int, help="总请求数(固定并发模式必须)") 209 | ap.add_argument("--per-stage", type=int, help="每个阶段请求数(ramp 模式必须)") 210 | ap.add_argument("--timeout", type=float, default=5.0, help="单请求超时秒") 211 | ap.add_argument("--jitter", type=float, default=0.0, help="启动抖动最大秒 (0~1 小幅随机延迟)") 212 | ap.add_argument("--save", help="保存所有结果为 JSON 文件") 213 | ap.add_argument("--seed", type=int, help="随机种子") 214 | return ap.parse_args() 215 | 216 | def print_header(): 217 | print("="*70) 218 | print(" SOCKS5 High Concurrency Test (observe server prewarm scaling) ") 219 | print("="*70) 220 | 221 | async def main(): 222 | args = parse_args() 223 | if args.seed is not None: 224 | random.seed(args.seed) 225 | 226 | print_header() 227 | results: List[Result] = [] 228 | t0 = time.monotonic() 229 | 230 | if args.concurrency: 231 | if not args.total: 232 | print("--total 必须指定(固定并发模式)", file=sys.stderr) 233 | sys.exit(1) 234 | print(f"[Phase single] concurrency={args.concurrency} total={args.total}") 235 | await run_phase("phase1", args.concurrency, args.total, 236 | args.timeout, args.jitter, results) 237 | else: 238 | # ramp 模式 239 | stages = [int(x.strip()) for x in args.ramp.split(",") if x.strip()] 240 | if not stages: 241 | print("无效 ramp 列表", file=sys.stderr) 242 | sys.exit(1) 243 | if not args.per_stage: 244 | print("--per-stage 必须指定(ramp 模式)", file=sys.stderr) 245 | sys.exit(1) 246 | for i, c in enumerate(stages, 1): 247 | print(f"[Phase {i}] concurrency={c} total={args.per_stage}") 248 | await run_phase(f"phase{i}", c, args.per_stage, 249 | args.timeout, args.jitter, results) 250 | 251 | elapsed = time.monotonic() - t0 252 | print("\n=== Summary ===") 253 | print(summarize(results)) 254 | print(f"Elapsed: {elapsed:.2f}s Approx QPS: {len(results)/elapsed:.1f}") 255 | 256 | if args.save: 257 | out = [] 258 | for r in results: 259 | out.append({ 260 | "ok": r.ok, 261 | "error": r.error, 262 | "connect_ms": r.connect_ms, 263 | "first_byte_ms": r.first_byte_ms, 264 | "total_ms": r.total_ms, 265 | "status": r.status, 266 | "phase": r.phase 267 | }) 268 | with open(args.save, "w") as f: 269 | json.dump(out, f, ensure_ascii=False, indent=2) 270 | print(f"Saved JSON results -> {args.save}") 271 | 272 | if __name__ == "__main__": 273 | try: 274 | asyncio.run(main()) 275 | except KeyboardInterrupt: 276 | print("\nInterrupted.", file=sys.stderr) 277 | 278 | # 较小并发测试: 279 | # python3 bench.py -c 50 -t 500 280 | # ramp 模式: 281 | # python3 bench.py -r 50,100,200,400 --per-stage 400 --------------------------------------------------------------------------------