├── .gitignore ├── Dockerfile ├── Makefile ├── README.md ├── anyproxy.go ├── conf ├── router.yaml └── tcpcopy.yaml ├── config └── config.go ├── crypto └── aes.go ├── examples ├── chan │ └── send.go ├── docker_pull.png ├── https_capture.png ├── https_flow.jpg ├── nginx_vhost.conf ├── parseurl │ └── test.go ├── stress │ └── tcpcopy.go ├── telnet │ ├── main.go │ └── wireshark.png └── websocket │ ├── client.go │ ├── conn.go │ ├── hub.go │ └── main.go ├── go.mod ├── go.sum ├── grace ├── autoinc │ └── autoinc.go ├── conn.go ├── grace.go └── server.go ├── logging ├── logger.go └── timewriter.go ├── nat ├── bridge.go ├── bridge_hub.go ├── client.go ├── client_hub.go ├── conn.go ├── handler.go └── message.go ├── proto ├── client.go ├── http.go ├── http │ └── header.go ├── keep.go ├── request.go ├── server.go ├── socks5.go ├── stats │ ├── counter.go │ └── stats.go ├── stream.go ├── stream_addr.go ├── stream_windows.go ├── tcp │ └── reader.go ├── tcpcopy.go ├── text │ └── reader.go ├── tunnel.go └── websocket.go ├── scripts ├── build.sh ├── setup.sh └── win-start.bat └── utils ├── cache └── cache.go ├── conf ├── config.go ├── path.go └── router.go ├── daemon ├── daemon.go ├── daemon_fork.go └── daemon_windows.go ├── help └── help.go ├── tools └── string.go └── trace └── trace.go /.gitignore: -------------------------------------------------------------------------------- 1 | .history 2 | .vscode 3 | dist/ 4 | logs/ 5 | anyproxy 6 | anyproxy-alpine 7 | anyproxy-darwin 8 | anyproxy-windows.exe 9 | tunneld 10 | tunneld-alpine 11 | tunnel/logs/ 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 参考 https://studygolang.com/articles/26823 2 | FROM golang:1.13.11-alpine AS builder 3 | 4 | WORKDIR /go/src/github.com/keminar/anyproxy 5 | # Go 版本>=1.13 设置GOPROXY 6 | RUN go env -w GOPROXY=https://goproxy.cn,direct 7 | COPY go.mod . 8 | COPY go.sum . 9 | # 缓存下载依赖包 10 | RUN go mod download 11 | 12 | COPY . . 13 | 14 | RUN go build -o /go/bin/anyproxy anyproxy.go 15 | RUN go build -o /go/bin/tunneld tunnel/tunneld.go 16 | 17 | # debian比centos和golang镜像更小 18 | # FROM debian:9 AS final 19 | # alpine 镜像是最小的,大部分时间也够用 20 | FROM alpine:3.11 AS final 21 | 22 | WORKDIR /go 23 | COPY --from=builder /go/bin/anyproxy /go/bin/anyproxy 24 | COPY --from=builder /go/bin/tunneld /go/bin/tunneld 25 | 26 | # 避免使用container的用户root 27 | RUN adduser -u 1000 -D appuser 28 | RUN mkdir logs/ && chown appuser logs/ 29 | 30 | USER appuser 31 | 32 | # CMD 和 ENTRYPOINT 区别用法参考 https://blog.csdn.net/u010900754/article/details/78526443 33 | ENTRYPOINT [ "/go/bin/anyproxy" ] -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | linux: 3 | bash ./scripts/build.sh linux 4 | mac: 5 | bash ./scripts/build.sh mac 6 | windows: 7 | bash ./scripts/build.sh windows 8 | alpine: 9 | bash ./scripts/build.sh alpine 10 | all: 11 | bash ./scripts/build.sh all 12 | clean: 13 | rm -rf dist/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Any Proxy 2 | 3 | anyproxy 是一个部署在Linux系统上的tcp流转发器,可以将收到的请求按域名划分链路发本地请求或者转到下一级代理。可以代替Proxifier做Linux下的客户端, 也可以配合Proxifier当它的服务端。经过跨平台编译,如果只做网络包的转发可以在windows等平台使用。 4 | 5 | [下载二进制包](http://cloudme.io/) 6 | 7 | tunneld 是一个anyproxy的服务端模式,带密钥验证,部署在服务器上接收anyproxy的请求,并代理发出请求或是转到下一个tunneld。用于跨内网访问资源使用。非anyproxy请求一概拒绝处理 8 | 9 | # 路由支持 10 | 11 | ``` 12 | +----------+ +----------+ +----------+ 13 | | Computer | <==> | anyproxy | <==> | Internet | 14 | +----------+ +----------+ +----------+ 15 | 16 | # or 17 | +----------+ +----------+ +---------+ +----------+ 18 | | Computer | <==> | anyproxy | <==> | tunneld | <==> | Internet | 19 | +----------+ +----------+ +---------+ +----------+ 20 | 21 | # or 22 | +----------+ +----------+ +---------+ +----------+ 23 | | Computer | <==> | anyproxy | <==> | socks5 | <==> | Internet | 24 | +----------+ +----------+ +---------+ +----------+ 25 | 26 | # or 27 | +----------+ +----------+ +---------+ +---------+ +----------+ 28 | | Computer | <==> | anyproxy | <==> | tunneld | <==> | socks5 | <==> | Internet | 29 | +----------+ +----------+ +---------+ +---------+ +----------+ 30 | 31 | # or 32 | +----------+ +---------+ +-----------+ ws +-----------+ +---------+ 33 | | Computer | <==> | Nginx A | <==> | anyproxy S| <==> | anyproxy C| <==> | Nginx B | 34 | +----------+ +---------+ +-----------+ +-----------+ +---------+ 35 | ``` 36 | 37 | # 使用案例 38 | > 案例1:解决Docker pull官方镜像的问题 39 | 40 | `使用iptables将本用户下tcp流转到anyproxy,再进行docker pull操作` 41 | 42 | > 案例2: 解决相同域名访问网站不同测试环境的问题 43 | 44 | `本地通过内网 anyproxy 代理上网,遇到测试服务器域名则跳到外网tunneld转发,网站的nginx根据来源IP进行转发到特定测试环境(有几个环境就需要有几个tunneld服务且IP要不同)` 45 | 46 | > 案例3: 解决HTTPS抓包问题 47 | 48 | `本地将https请求到服务器,服务器解证书后增加特定头部转到anyproxy websocket服务端,本地另起一个anyproxy的websocket客户端接收并将http请求转发到Charles` 49 | 50 | > 案例4: 解决内网tcp端口给外网访问 51 | 52 | `假如本机是192网段,容器内是10网段,在本机启动一个程序监听本机端口同时桥接到容器内的应用的端口,这样就可以通过本机端口访问容器内的tcp服务(配置项是tcpcopy)` 53 | 54 | # 源码编译 55 | 56 | > 安装Go环境并设置GOPROXY 57 | 58 | Go环境安装比较简单,这里不做介绍,GOPROXY对不同版本有些差异,设置方法如下 59 | ``` 60 | # Go 版本>=1.11 61 | export GOPROXY=https://goproxy.cn 62 | # Go 版本>=1.13  63 | go env -w GOPROXY=https://goproxy.cn,direct 64 | ``` 65 | 66 | > 下载编译 67 | ``` 68 | git clone https://github.com/keminar/anyproxy.git 69 | cd anyproxy 70 | make all 71 | ``` 72 | 73 | > 本机启动 74 | 75 | ``` 76 | # 示例1. 以anyproxy用户启动 77 | sudo -u anyproxy ./anyproxy 78 | 79 | # 示例2. 以后台进程方式运行 80 | ./anyproxy -daemon 81 | 82 | # 示例3. 启动tunneld 83 | ./anyproxy -mode tunnel 84 | 85 | # 示例4. 启动anyproxy并将请求转给tunneld 86 | ./anyproxy -p 'tunnel://127.0.0.1:3001' 87 | 88 | # 示例5. 启动anyproxy并将请求转给socks5 89 | ./anyproxy -p 'socks5://127.0.0.1:10000' 90 | 91 | # 示例6. 端口转发 92 | ./anyproxy -c conf/tcpcopy.yaml 93 | 94 | # 其它帮助 95 | ./anyproxy -h 96 | ``` 97 | 98 | 注:因为本地iptables转发是Linux功能,所以windows系统使用时精简掉了此部分功能 99 | 100 | > 平滑重启 101 | 102 | ``` 103 | # 首先查到进程pid,然后发送HUP信号 104 | kill -HUP pid 105 | ``` 106 | 107 | 108 | > 使用Docker 109 | 110 | ``` 111 | # 构建 112 | docker build -t anyproxy:latest . 113 | # 运行 114 | docker run anyproxy:latest 115 | # 开放端口并带参数运行 116 | docker run -p 3000:3000 anyproxy:latest -p '127.0.0.1:3001' 117 | ``` 118 | 119 | # 代理设置 120 | 121 | * 防火墙全局代理 122 | 123 | ``` 124 | #添加一个不可以登录的用户 125 | sudo useradd -M -s /sbin/nologin anyproxy 126 | # uid为anyproxy的tcp请求不转发,并用anyproxy用户启动anyproxy程序 127 | sudo iptables -t nat -A OUTPUT -p tcp -m owner --uid-owner anyproxy -j RETURN 128 | sudo -u anyproxy ./anyproxy -daemon 129 | # 指定root账号本地请求不走代理 130 | sudo iptables -t nat -A OUTPUT -p tcp -d 192.168.0.0/16 -m owner --uid-owner 0 -j RETURN 131 | sudo iptables -t nat -A OUTPUT -p tcp -d 172.17.0.0/16 -m owner --uid-owner 0 -j RETURN 132 | # 指定root账号的http/https请求走代理 133 | sudo iptables -t nat -A OUTPUT -p tcp -m multiport --dport 80,443 -m owner --uid-owner 0 -j REDIRECT --to-port 3000 134 | ``` 135 | 136 | > 如果删除全局代理 137 | ``` 138 | # 查看当前规则 139 | sudo iptables -t nat -L -n --line-number 140 | 141 | # 输出 142 | ...以上省略 143 | Chain OUTPUT (policy ACCEPT) 144 | num target prot opt source destination 145 | 1 RETURN tcp -- 0.0.0.0/0 0.0.0.0/0 owner UID match 1004 146 | 2 REDIRECT tcp -- 0.0.0.0/0 0.0.0.0/0 redir ports 3000 147 | ...以下省略 148 | 149 | # 按顺序依次为OUTPUT的第一条规则,和第二条规则 150 | # 假如想删除net的OUTPUT的第2条规则 151 | sudo iptables -t nat -D OUTPUT 2 152 | ``` 153 | * 浏览器 [Chrome设置](https://zhidao.baidu.com/question/204679423955769445.html) 154 | * 手机端 [苹果](https://jingyan.baidu.com/article/84b4f565add95060f7da3271.html) [安卓](https://jingyan.baidu.com/article/219f4bf7ff97e6de442d38c8.html) 155 | 156 | # Todo 157 | 158 | > ~~划线~~ 部分为已实现功能 159 | * ~~可将请求转发到Tunnel服务~~ 160 | * ~~对域名支持加Host绑定~~ 161 | * ~~对域名配置请求出口~~ 162 | * ~~增加全局默认出口配置~~ 163 | * ~~配置文件支持~~ 164 | * ~~服务间通信增加token验证可配~~ 165 | * ~~日志信息完善~~ 166 | * ~~DNS解析增加cache~~ 167 | * ~~自动路由模式下可设置检测时间和cache~~ 168 | * ~~可以自定义代理server,如果不可用则用全局的~~ 169 | * ~~server多级转发~~ 170 | * ~~加域名黑名单功能,不给请求~~ 171 | * ~~支持转发到socket5服务~~ 172 | * ~~支持HTTP/1.1 keep-alive 一外链接多次请求不同域名~~ 173 | * ~~修复iptables转发后百度贴吧无法访问的问题~~ 174 | * ~~支持windows平台使用~~ 175 | * ~~通过websocket实现内网穿透(必须为http的非CONNECT请求)~~ 176 | * ~~订阅增加邮箱标识,用于辨别在线用户~~ 177 | * ~~与Tunnel功能合并,使用mode区分~~ 178 | * ~~启用ws-listen后的平滑重启问题~~ 179 | * ~~监听配置文件变化重新加载路由~~ 180 | * ~~支持proxy时转换端口号~~ 181 | * ~~支持tcpcopy模式,用此转发连mysql~~ 182 | * ~~支持socks5协议接入~~ 183 | * ~~统计上下行流量~~ 184 | * ~~修复不支持http upgrade socket的问题~~ 185 | * TCP 增加更多协议解析支持,如rtmp,ftp, socks5, https(SNI)等 186 | * tunel token支持按host配置 187 | 188 | # 感谢 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /anyproxy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net/http" 9 | _ "net/http/pprof" 10 | "os" 11 | "time" 12 | 13 | "github.com/keminar/anyproxy/config" 14 | "github.com/keminar/anyproxy/grace" 15 | "github.com/keminar/anyproxy/logging" 16 | "github.com/keminar/anyproxy/nat" 17 | "github.com/keminar/anyproxy/proto" 18 | "github.com/keminar/anyproxy/utils/conf" 19 | "github.com/keminar/anyproxy/utils/daemon" 20 | "github.com/keminar/anyproxy/utils/help" 21 | "github.com/keminar/anyproxy/utils/tools" 22 | ) 23 | 24 | var ( 25 | gListenAddrPort string 26 | gProxyServerSpec string 27 | gConfigFile string 28 | gWebsocketListen string 29 | gWebsocketConn string 30 | gMode string 31 | gHelp bool 32 | gDebug int 33 | gPprof string 34 | gVersion bool 35 | ) 36 | 37 | func init() { 38 | flag.Usage = help.Usage 39 | flag.StringVar(&gListenAddrPort, "l", "", "listen address of socks5 and http proxy") 40 | flag.StringVar(&gProxyServerSpec, "p", "", "Proxy servers to use") 41 | flag.StringVar(&gConfigFile, "c", "", "Config file path, default is router.yaml") 42 | flag.StringVar(&gWebsocketListen, "ws-listen", "", "Websocket address and port to listen on") 43 | flag.StringVar(&gWebsocketConn, "ws-connect", "", "Websocket Address and port to connect") 44 | flag.StringVar(&gMode, "mode", "", "Run mode(proxy, tunnel). proxy mode default") 45 | flag.IntVar(&gDebug, "debug", 0, "debug mode (0, 1, 2, 3)") 46 | flag.StringVar(&gPprof, "pprof", "", "pprof port, disable if empty") 47 | flag.BoolVar(&gVersion, "v", false, "Show build version") 48 | flag.BoolVar(&gHelp, "h", false, "This usage message") 49 | } 50 | 51 | func main() { 52 | flag.Parse() 53 | if gHelp { 54 | flag.Usage() 55 | return 56 | } 57 | if gVersion { 58 | help.ShowVersion() 59 | return 60 | } 61 | 62 | config.SetDebugLevel(gDebug) 63 | conf.LoadAllConfig(gConfigFile) 64 | 65 | // 检查配置是否存在 66 | if conf.RouterConfig == nil { 67 | time.Sleep(60 * time.Second) 68 | os.Exit(2) 69 | } 70 | 71 | cmdName := "anyproxy" 72 | defLogDir := fmt.Sprintf("%s%s%s%s", conf.AppPath, string(os.PathSeparator), "logs", string(os.PathSeparator)) 73 | logDir := config.IfEmptyThen(conf.RouterConfig.Log.Dir, defLogDir, "") 74 | if _, err := os.Stat(logDir); err != nil { 75 | log.Println(err) 76 | time.Sleep(60 * time.Second) 77 | os.Exit(2) 78 | } 79 | 80 | envRunMode := fmt.Sprintf("%s_run_mode", cmdName) 81 | fd := logging.ErrlogFd(logDir, cmdName) 82 | // 是否后台运行 83 | daemon.Daemonize(envRunMode, fd) 84 | 85 | gListenAddrPort = config.IfEmptyThen(gListenAddrPort, conf.RouterConfig.Listen, ":3000") 86 | gListenAddrPort = tools.FillPort(gListenAddrPort) 87 | config.SetListenPort(gListenAddrPort) 88 | 89 | var writer io.Writer 90 | // 前台执行,daemon运行根据环境变量识别 91 | if daemon.IsForeground(envRunMode) { 92 | // 同时输出到日志和标准输出 93 | writer = io.Writer(os.Stdout) 94 | } 95 | 96 | logging.SetDefaultLogger(logDir, fmt.Sprintf("%s.%d", cmdName, config.ListenPort), true, 3, writer) 97 | // 设置代理 98 | gProxyServerSpec = config.IfEmptyThen(gProxyServerSpec, conf.RouterConfig.Default.Proxy, "") 99 | config.SetProxyServer(gProxyServerSpec) 100 | 101 | // 调试模式 102 | if len(gPprof) > 0 { 103 | go func() { 104 | gPprof = tools.FillPort(gPprof) 105 | //浏览器访问: http://:5001/debug/pprof/ 106 | log.Println("Starting pprof debug server ...") 107 | // 这里不要使用log.Fatal会在平滑重启时导致进程退出 108 | // 因为http server现在没办法一次平滑重启,会报端口冲突,可以通过多次重试来启动pprof 109 | for i := 0; i < 10; i++ { 110 | log.Println(http.ListenAndServe(gPprof, nil)) 111 | time.Sleep(10 * time.Second) 112 | } 113 | }() 114 | } 115 | 116 | // websocket 服务端 117 | gWebsocketListen = config.IfEmptyThen(gWebsocketListen, conf.RouterConfig.Websocket.Listen, "") 118 | if gWebsocketListen != "" { 119 | gWebsocketListen = tools.FillPort(gWebsocketListen) 120 | go nat.NewServer(&gWebsocketListen) 121 | } 122 | // websocket 客户端 123 | gWebsocketConn = config.IfEmptyThen(gWebsocketConn, conf.RouterConfig.Websocket.Connect, "") 124 | if gWebsocketConn != "" { 125 | gWebsocketConn = tools.FillPort(gWebsocketConn) 126 | go nat.ConnectServer(&gWebsocketConn) 127 | } 128 | 129 | // tcp 同是监听IPv4 和 IPv6 130 | // tcp4 仅监听使用IPv4 131 | // tcp6 仅监听使用IPv6 132 | network := "tcp" 133 | if conf.RouterConfig.Network != "" { 134 | network = conf.RouterConfig.Network 135 | } 136 | // 运行模式 137 | if gMode == "tunnel" { 138 | server := grace.NewServer(gListenAddrPort, proto.ServerHandler, network) 139 | server.ListenAndServe() 140 | } else { 141 | server := grace.NewServer(gListenAddrPort, proto.ClientHandler, network) 142 | server.ListenAndServe() 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /conf/router.yaml: -------------------------------------------------------------------------------- 1 | # 监听端口IP, 优先级低于启动传参 2 | listen: 3 | # 日志目录 4 | log: 5 | dir: ./logs/ 6 | # 监听配置文件变化 7 | watcher: true 8 | # anyproxy 和 tunnel通信密钥, 必须16位长度 9 | token: anyproxyproxyany 10 | # 可访问的客户端IP,为空不限制 11 | allowIP: 12 | # - 172.17.0.12 13 | 14 | # http非CONNECT请求首行域名处理 15 | firstLine: 16 | #是否带Host, on带,off不带,默认带 17 | host: on 18 | #按域名配带Host,on带,off不带,其他用默认 19 | #一般vue本地项目要把域名配置为off 20 | #注意:域名和端口中间的冒号改为点,如localhost:5173配置为localhost.5173 21 | custom: 22 | localhost.5173: off 23 | 24 | # 设置此项,进入tcpCopy模式,则配置中的hosts域名代理不再生效 25 | # tcpcopy模式下allowIP设置有效 26 | tcpcopy: 27 | # enable: false 28 | # ip: 127.0.0.1 29 | # port: 3306 30 | 31 | # 默认操作,可热加载 32 | default: 33 | # 使用的DNS服务器 local 当前环境, remote远程, 仅当target使用remote有效 34 | dns: local 35 | # 默认环境,local 当前环境, remote 远程, deny 禁止 36 | # auto根据dial选择,local dial失败则remote 37 | target: auto 38 | # tcp 请求环境,local 当前环境, remote 远程, deny 禁止 39 | tcpTarget: remote 40 | # 默认域名比对方案,contain 包含,equal 完全相等, preg 正则 41 | match: equal 42 | # 全局代理服务器, 优先级低于启动传参 43 | proxy: 44 | 45 | # 域名,可热加载 46 | hosts: 47 | - name: github 48 | # contain 包含,equal 完全相等, preg 正则 49 | match: contain 50 | # 参考全局target 51 | # 如果有用proxy自定义代理可用,target强制当remote使用,proxy代理不可用,target按原逻辑处理 52 | target: remote 53 | # 参考全局localDns 54 | dns: remote 55 | # 支持 http:// , tunnel:// , socks5:// 三种协议,默认 tunnel:// 56 | #proxy: http://127.0.0.1:8888 57 | # 支持多代理,支持忽略全局代理并执行 last 或 deny 2种逻辑 58 | proxy: http://127.0.0.1:8888, http://127.0.0.1:7777 last 59 | - name: golang.org 60 | match: contain 61 | target: auto 62 | dns: remote 63 | - name: www.baidu.com 64 | match: equal 65 | target: auto 66 | - name: google 67 | match: contain 68 | target: deny 69 | - name: dev.example.com 70 | ip: 127.0.0.1 71 | port: 72 | - from: 80 73 | to: 88 74 | allowIP: 75 | # - 172.17.0.12 76 | 77 | #websocket配置 78 | #对于服务端需要配置 listen, user, pass 三个参数 79 | #对于客户端未配置connect / user / email 都不发起连接 80 | websocket: 81 | # 监听端口 82 | listen: 83 | # ip 端口 84 | connect: 85 | # connect 域名 86 | host: 87 | # 用户名 88 | user: 89 | # 密码 90 | pass: 91 | # Email用于定位用户,不鉴权 92 | email: 93 | # 订阅头部信息 94 | subscribe: 95 | - key: 96 | val: -------------------------------------------------------------------------------- /conf/tcpcopy.yaml: -------------------------------------------------------------------------------- 1 | watcher: true 2 | listen: 192.168.1.2:3306 3 | allowIP: 4 | - 192.168.1.2 5 | tcpcopy: 6 | enable: true 7 | ip: 10.0.0.2 8 | port: 3306 9 | 10 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "log" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/keminar/anyproxy/utils/tools" 9 | ) 10 | 11 | // ProxyScheme 协议,目前支持 tunnel为自定义协议,socks5,http为标准协议 12 | var ProxyScheme string = "tunnel" 13 | 14 | // ProxyServer 代理服务器 15 | var ProxyServer string 16 | 17 | // ProxyPort 代理端口 18 | var ProxyPort uint16 19 | 20 | // TimeFormat 格式化时间 21 | var TimeFormat string = "2006-01-02 15:04:05" 22 | 23 | // DebugLevel 调试级别 24 | var DebugLevel int 25 | 26 | // ListenPort 监听端口 27 | var ListenPort uint16 28 | 29 | const ( 30 | // LevelShort 简短格式 31 | LevelShort int = iota 32 | // LevelLong 长格式日志 33 | LevelLong 34 | // LevelDebug 长日志 + 更多日志 35 | LevelDebug 36 | // LevelDebugBody 打印body 37 | LevelDebugBody 38 | ) 39 | 40 | // SetProxyServer 设置代理服务器 41 | func SetProxyServer(gProxyServerSpec string) { 42 | if gProxyServerSpec == "" { 43 | return 44 | } 45 | 46 | // 先检查协议 47 | tmp := strings.Split(gProxyServerSpec, "://") 48 | if len(tmp) == 2 { 49 | ProxyScheme = tmp[0] 50 | gProxyServerSpec = tmp[1] 51 | } 52 | // 检查端口,和上面的顺序不能反 53 | tmp = strings.Split(gProxyServerSpec, ":") 54 | if len(tmp) == 2 { 55 | portInt, err := strconv.Atoi(tmp[1]) 56 | if err == nil { 57 | ProxyServer = tmp[0] 58 | ProxyPort = uint16(portInt) 59 | log.Printf("Proxy server is %s://%s:%d\n", ProxyScheme, ProxyServer, ProxyPort) 60 | } else { 61 | log.Printf("Set proxy port err %s\n", err.Error()) 62 | } 63 | } 64 | } 65 | 66 | // SetDebugLevel 调试级别 67 | func SetDebugLevel(gDebug int) { 68 | DebugLevel = gDebug 69 | } 70 | 71 | // SetListenPort 端口 72 | func SetListenPort(gListenAddrPort string) { 73 | intStr := tools.GetPort(gListenAddrPort) 74 | intNum, err := strconv.Atoi(intStr) 75 | if err != nil { 76 | log.Printf("SetListenPort err %s\n", err.Error()) 77 | } 78 | ListenPort = uint16(intNum) 79 | } 80 | 81 | // IfEmptyThen 取值 82 | func IfEmptyThen(str string, str2 string, str3 string) string { 83 | if str == "" { 84 | if str2 == "" { 85 | return str3 86 | } 87 | return str2 88 | } 89 | return str 90 | } 91 | -------------------------------------------------------------------------------- /crypto/aes.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | ) 8 | 9 | // 填充数据 10 | func padding(src []byte, blockSize int) []byte { 11 | padNum := blockSize - len(src)%blockSize 12 | pad := bytes.Repeat([]byte{byte(padNum)}, padNum) 13 | return append(src, pad...) 14 | } 15 | 16 | // 去掉填充数据 17 | func unpadding(src []byte) []byte { 18 | n := len(src) 19 | unPadNum := int(src[n-1]) 20 | return src[:n-unPadNum] 21 | } 22 | 23 | // EncryptAES 加密 24 | func EncryptAES(src []byte, key []byte) ([]byte, error) { 25 | block, err := aes.NewCipher(key) 26 | if err != nil { 27 | return nil, err 28 | } 29 | src = padding(src, block.BlockSize()) 30 | blockMode := cipher.NewCBCEncrypter(block, key) 31 | blockMode.CryptBlocks(src, src) 32 | return src, nil 33 | } 34 | 35 | // DecryptAES 解密 36 | func DecryptAES(src []byte, key []byte) ([]byte, error) { 37 | block, err := aes.NewCipher(key) 38 | if err != nil { 39 | return nil, err 40 | } 41 | blockMode := cipher.NewCBCDecrypter(block, key) 42 | blockMode.CryptBlocks(src, src) 43 | src = unpadding(src) 44 | return src, nil 45 | } 46 | -------------------------------------------------------------------------------- /examples/chan/send.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "log" 4 | 5 | func main() { 6 | log.Println("test1============") 7 | test1() 8 | log.Println("test2============") 9 | test2() 10 | log.Println("test3============") 11 | test3() 12 | log.Println("test4============") 13 | test4() 14 | log.Println("test5============") 15 | test5() 16 | log.Println("test6============") 17 | test6() 18 | 19 | //结论 ok 为判断通道是否关闭, default为判断通道是否放满或者无数据时都会调用 20 | } 21 | 22 | func test1() { 23 | send := make(chan int) 24 | close(send) 25 | 26 | select { 27 | case t := <-send: //不用ok 28 | log.Println(t) //被执行 29 | default: 30 | log.Println("default") 31 | } 32 | } 33 | 34 | func test2() { 35 | send := make(chan int) 36 | close(send) 37 | 38 | select { 39 | case t, ok := <-send: // 使用ok 40 | log.Println(t, ok) //被执行,且ok为false 41 | default: 42 | log.Println("default") 43 | } 44 | } 45 | 46 | func test3() { 47 | send := make(chan int) 48 | close(send) 49 | 50 | select { 51 | case t, ok := <-send: // 使用ok 52 | log.Println(t, ok) //被执行,且ok为false 53 | } 54 | } 55 | 56 | func test4() { 57 | send := make(chan int) 58 | go func() { 59 | // 无close 60 | for i := 0; i < 10; i++ { 61 | send <- i 62 | } 63 | }() 64 | 65 | for i := 0; i < 20; i++ { 66 | select { 67 | case t, ok := <-send: 68 | log.Println(t, ok) 69 | default: 70 | log.Println("send is full or send is empty") //部分被执行 71 | } 72 | } 73 | } 74 | 75 | func test5() { 76 | send := make(chan int) 77 | go func() { 78 | for i := 0; i < 10; i++ { 79 | send <- i 80 | } 81 | }() 82 | 83 | for i := 0; i < 10; i++ { 84 | select { 85 | case t, ok := <-send: 86 | log.Println(t, ok) //全部执行 87 | } 88 | } 89 | } 90 | 91 | func test6() { 92 | send := make(chan int, 1) 93 | for i := 0; i < 5; i++ { 94 | select { 95 | case send <- i: 96 | log.Println(i) 97 | default: 98 | log.Println("send is full") //部分被执行 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /examples/docker_pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keminar/anyproxy/8fb25d1512710ad5411e29e774677f3cb04cc9b0/examples/docker_pull.png -------------------------------------------------------------------------------- /examples/https_capture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keminar/anyproxy/8fb25d1512710ad5411e29e774677f3cb04cc9b0/examples/https_capture.png -------------------------------------------------------------------------------- /examples/https_flow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keminar/anyproxy/8fb25d1512710ad5411e29e774677f3cb04cc9b0/examples/https_flow.jpg -------------------------------------------------------------------------------- /examples/nginx_vhost.conf: -------------------------------------------------------------------------------- 1 | map $http_upgrade $connection_upgrade { 2 | default upgrade; 3 | '' close; 4 | } 5 | 6 | server { 7 | listen 80; 8 | server_name ws.example.com; 9 | 10 | default_type application/octet-stream; 11 | 12 | sendfile on; 13 | tcp_nopush on; 14 | tcp_nodelay on; 15 | gzip on; 16 | gzip_min_length 1000; 17 | gzip_proxied any; 18 | 19 | proxy_next_upstream error; 20 | 21 | location / { 22 | include proxy.conf; 23 | proxy_pass http://127.0.0.1:3002; 24 | keepalive_timeout 65; 25 | proxy_http_version 1.1; 26 | proxy_set_header X-Scheme $scheme; 27 | proxy_set_header Host $http_host; 28 | proxy_set_header Upgrade $http_upgrade; 29 | proxy_set_header Connection $connection_upgrade; 30 | } 31 | access_log logs/n_$HOST.log; 32 | } 33 | 34 | server { 35 | listen 443; 36 | server_name 3.1415.tech; 37 | root /var/www/1415; 38 | 39 | ssl on; 40 | ssl_certificate "/etc/nginx/cert/3.1415.tech.pem"; 41 | ssl_certificate_key "/etc/nginx/cert/3.1415.tech.key"; 42 | 43 | ssl_session_timeout 5m; 44 | ssl_ciphers ECDHE-RSA-AES128-GCM-SHA256:ECDHE:ECDH:AES:HIGH:!NULL:!aNULL:!MD5:!ADH:!RC4; 45 | ssl_protocols TLSv1 TLSv1.1 TLSv1.2; 46 | ssl_prefer_server_ciphers on; 47 | 48 | location / { 49 | proxy_set_header Host 'data.1415.tech'; 50 | proxy_set_header Anyproxy-Action "websocket"; 51 | proxy_pass http://127.0.0.1:3001; 52 | } 53 | } -------------------------------------------------------------------------------- /examples/parseurl/test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | ) 7 | 8 | func main() { 9 | rawurl := "http://www.example.com:443" 10 | xx, _ := url.ParseRequestURI(rawurl) 11 | fmt.Println(xx.Host, xx.Port()) 12 | fmt.Println(xx.String()) 13 | 14 | rawurl = "http://www.example.com:80" 15 | xx, _ = url.ParseRequestURI(rawurl) 16 | fmt.Println(xx.Host, xx.Port()) 17 | fmt.Println(xx.String()) 18 | 19 | rawurl = "http:///test.html" 20 | xx, _ = url.ParseRequestURI(rawurl) 21 | fmt.Println("test", xx.Host, xx.Port()) 22 | } 23 | -------------------------------------------------------------------------------- /examples/stress/tcpcopy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "strings" 10 | "sync" 11 | "time" 12 | 13 | "github.com/keminar/anyproxy/proto/tcp" 14 | ) 15 | 16 | /** 17 | * 用于对tcp服务器的压力测试, 外部请求到本服务的监听端口 18 | * 本服务把流量复制N份与目标服务器交互。同时只将一份的返回数据返回给客户端 19 | * 编译:CGO_ENABLED=0 go build -o /tmp/tcpcopy tcpcopy.go 20 | * 21 | * 注:在测试某Tcp代理程序时 ,部署方式 curl->tcpcopy->某代理->nginx 22 | * 用curl进行测试时可能是因为http_proxy会有Proxy-Connection: keep-alive,造成1000个并发总有几个请求会卡住, 23 | * 排查tcpcopy有发送closeWrite但是nginx没有收到,所以直到Nginx超时退出则结束。 24 | * 后换用socks5协议测试,测试中虽然也有个别请求与结果长度不符,但不会卡住请求。不确认是不是某代理有bug 25 | * 示例如下 26 | * 运行:./tcpcopy -listen 0.0.0.0:10010 -server 127.0.0.1:58813 -num 5000 -debug 1 -ignoreDog 27 | * 其中58813为另一个程序的socks5代理入口,并且可以访问本地80端口 28 | * curl --socks5 '127.0.0.1:10010' http://127.0.0.1/test.html 29 | */ 30 | 31 | // 定义命令行参数对应的变量 32 | var listen = flag.String("listen", ":6000", "本地监听端口") 33 | var server = flag.String("server", ":8000", "目标服务器") 34 | var num = flag.Int("num", 1, "压力测试数") 35 | var connTimeout = flag.Int("connTimeout", 5, "连接目标服务器超时") 36 | var writeTimeout = flag.Int("writeTimeout", 0, "向目标服务器读写超时") 37 | var mustLen = flag.Int("mustLen", 0, "目标服务器返回长度,非此长度的在debug 1以上时输出最后一段内容, 同时会有一个错误计数器") 38 | var panicLen = flag.Int("panicLen", 0, "目标服务器返回长度,非此长度的在debug 1以上时显示异常并退出,优先级大于mustLen") 39 | var ignore = flag.Bool("ignoreDog", false, "忽略127.0.0.1来的请求,防止看门狗的请求被复制") 40 | var debug = flag.Int("debug", 0, "调试日志级别") 41 | 42 | const ( 43 | OUT_NONE = iota 44 | OUT_INFO 45 | OUT_DEBUG 46 | ) 47 | 48 | // tcp 压力测试 49 | func main() { 50 | // 把用户传递的命令行参数解析为对应变量的值 51 | flag.Parse() 52 | if *num <= 0 { 53 | *num = 1 54 | } 55 | 56 | fmt.Println("本地监听", *listen) 57 | fmt.Println("压测目标", *server) 58 | fmt.Println("压测连接数", *num) 59 | fmt.Println("防看门狗", *ignore) 60 | if *panicLen > 0 { 61 | fmt.Println("程序panic长度", *panicLen) 62 | } 63 | 64 | err := accept() 65 | fmt.Println(err) 66 | } 67 | 68 | func accept() (err error) { 69 | var lnaddr *net.TCPAddr 70 | lnaddr, err = net.ResolveTCPAddr("tcp", *listen) 71 | if err != nil { 72 | err = fmt.Errorf("net.Listen error: %v", err) 73 | return 74 | } 75 | 76 | ln, err := net.ListenTCP("tcp", lnaddr) 77 | if err != nil { 78 | err = fmt.Errorf("net.Listen error: %v", err) 79 | return 80 | } 81 | 82 | for { 83 | var rw *net.TCPConn 84 | rw, err = ln.AcceptTCP() 85 | if err != nil { 86 | if strings.Contains(err.Error(), "use of closed network connection") { 87 | return 88 | } 89 | if ne, ok := err.(net.Error); ok && ne.Temporary() { 90 | tempDelay := 5 * time.Millisecond 91 | log.Printf("Accept error: %v; retrying in %v\n", err, tempDelay) 92 | time.Sleep(tempDelay) 93 | continue 94 | } 95 | return 96 | } 97 | // 忽略看门狗程序搔扰 98 | if *ignore && strings.Contains(rw.RemoteAddr().String(), "127.0.0.1:") { 99 | rw.Close() 100 | continue 101 | } 102 | go conn(rw) 103 | } 104 | } 105 | 106 | func conn(rwc *net.TCPConn) { 107 | if *debug >= OUT_INFO { 108 | log.Println("accecpt connection") 109 | } 110 | 111 | defer func() { 112 | rwc.Close() 113 | }() 114 | connTimeout := time.Duration(*connTimeout) * time.Second 115 | tunnel := newTunnel(rwc) 116 | for i := 0; i < *num; i++ { 117 | var conn net.Conn 118 | conn, err := net.DialTimeout("tcp", *server, connTimeout) 119 | if err != nil { 120 | log.Println("connect", i, err) 121 | return 122 | } 123 | if *writeTimeout > 0 { 124 | conn.SetDeadline(time.Now().Add(time.Duration(*writeTimeout) * time.Second)) 125 | } 126 | tunnel.addTarget(conn.(*net.TCPConn)) 127 | } 128 | tunnel.transfer() 129 | } 130 | 131 | const ( 132 | stateNew int = iota 133 | stateActive 134 | stateClosed 135 | stateIdle 136 | ) 137 | 138 | type target struct { 139 | conn *net.TCPConn 140 | reader *tcp.Reader 141 | } 142 | 143 | // 转发实体 144 | type tunnel struct { 145 | clientConn *net.TCPConn 146 | clientReader *tcp.Reader 147 | targets []target 148 | 149 | curState int 150 | 151 | readSize int64 152 | writeSize int64 153 | 154 | buf []byte 155 | } 156 | 157 | // newTunnel 实例 158 | func newTunnel(client *net.TCPConn) *tunnel { 159 | s := &tunnel{ 160 | clientConn: client, 161 | clientReader: tcp.NewReader(client), 162 | } 163 | return s 164 | } 165 | 166 | // 添加连接 167 | func (s *tunnel) addTarget(conn *net.TCPConn) { 168 | s.targets = append(s.targets, target{conn, tcp.NewReader(conn)}) 169 | } 170 | 171 | // transfer 交换数据 172 | func (s *tunnel) transfer() { 173 | s.curState = stateActive 174 | done := make(chan struct{}) 175 | 176 | //发送请求 177 | go func() { 178 | defer func() { 179 | close(done) 180 | }() 181 | //不能和外层共用err 182 | var err error 183 | var closeWrite int64 184 | s.readSize, closeWrite, err = s.copyBuffer(s.clientReader, "request") 185 | s.logCopyErr("read from request", err) 186 | if *debug >= OUT_INFO { 187 | // 用fmt方便tee到另一个文件日志查看 188 | fmt.Println("request body size", s.readSize, "send closeWrite", closeWrite) 189 | } 190 | }() 191 | 192 | var errLenNum = 0 193 | // 加锁防止顺序错乱 194 | var wg sync.WaitGroup 195 | wg.Add(len(s.targets)) 196 | go func() { // 丢弃其它服务器返回的内容 197 | size := 4 * 1024 198 | for i, t := range s.targets { 199 | if i > 0 { 200 | go func(i int, t target) { 201 | defer func() { 202 | wg.Done() 203 | }() 204 | buf := make([]byte, size) 205 | var c int64 206 | var last string 207 | c = 0 208 | for { 209 | var lastlast string 210 | lastlast = last 211 | nr, er := t.reader.Read(buf) 212 | if nr > 0 { 213 | if *mustLen > 0 || *panicLen > 0 { 214 | last = string(buf[0:nr]) 215 | //fmt.Println("testlast", last) 216 | } 217 | c += int64(nr) 218 | } 219 | if er != nil { 220 | if *debug >= OUT_INFO { 221 | if *panicLen > 0 && c != int64(*panicLen) { 222 | panic(lastlast + string(buf[0:nr])) 223 | } else if *mustLen > 0 && c != int64(*mustLen) { //不为指定大小的结果,输出上一次的值 224 | fmt.Println("reader", i, "#lastlast#", lastlast, "#this#", string(buf[0:nr])) 225 | errLenNum++ 226 | } 227 | s.logReaderClosed("reader closed", i, c, er) 228 | } else if *mustLen > 0 && c != int64(*mustLen) { 229 | errLenNum++ 230 | } 231 | return 232 | } 233 | } 234 | }(i, t) 235 | } else { 236 | wg.Done() 237 | } 238 | } 239 | }() 240 | 241 | var err error 242 | //取返回结果 243 | s.writeSize, _, err = s.copyBuffer(s.targets[0].reader, "server") 244 | wg.Wait() 245 | <-done 246 | // 不管是不是正常结束,只要server结束了,函数就会返回,然后底层会自动断开与client的连接 247 | s.logReaderClosed("reader closed", 0, s.writeSize, err) 248 | 249 | if *mustLen > 0 && s.writeSize != int64(*mustLen) { 250 | errLenNum++ 251 | } 252 | if errLenNum > 0 { 253 | log.Println("read content len error num", errLenNum) 254 | } 255 | } 256 | 257 | // copyBuffer 传输数据 258 | func (s *tunnel) copyBuffer(src *tcp.Reader, srcname string) (written int64, closeWrite int64, err error) { 259 | //如果设置过大会耗内存高,4k比较合理 260 | size := 4 * 1024 261 | buf := make([]byte, size) 262 | i := 0 263 | for { 264 | i++ 265 | nr, er := src.Read(buf) 266 | if nr > 0 { 267 | var nw int 268 | var ew error 269 | if srcname == "request" { 270 | nw, ew = s.targets[0].conn.Write(buf[0:nr]) 271 | if *debug >= OUT_DEBUG { 272 | s.logReaderClosed("real request", 0, int64(nw), ew) 273 | } 274 | // 加锁防止顺序错乱 275 | var wg sync.WaitGroup 276 | wg.Add(len(s.targets)) 277 | go func() { //同步发到多个连接 278 | for tk, tv := range s.targets { 279 | if tk > 0 { 280 | nx, ex := tv.conn.Write(buf[0:nr]) 281 | if *debug >= OUT_DEBUG { 282 | s.logReaderClosed("copy request", tk, int64(nx), ex) 283 | } 284 | } 285 | wg.Done() 286 | } 287 | }() 288 | wg.Wait() 289 | } else { 290 | nw, ew = s.clientConn.Write(buf[0:nr]) 291 | //打印测试服务端返回值 292 | //fmt.Println(string(buf[0:nr])) 293 | } 294 | if nw > 0 { 295 | written += int64(nw) 296 | } 297 | if ew != nil { 298 | err = fmt.Errorf("id#1 %s", ew.Error()) 299 | break 300 | } 301 | if nr != nw { 302 | err = fmt.Errorf("id#2 %s", io.ErrShortWrite.Error()) 303 | break 304 | } 305 | } 306 | if er != nil { 307 | if er != io.EOF { 308 | err = fmt.Errorf("id#3 %s", er.Error()) 309 | } else { 310 | s.logCopyErr(srcname+" read", er) 311 | if srcname == "server" { 312 | if s.curState != stateClosed { 313 | // 如果非客户端导致的服务端关闭,则关闭客户端读 314 | // Notice: 如果只是CloseRead(),当在windows上执行时,且是做为订阅端从服务器收到请求再转到charles 315 | // 等服务时,当请求的地址返回足够长的内容时会触发卡住问题。 316 | // 流程如 curl -> anyproxy(server) -> ws -> anyproxy(windows) -> charles 317 | // 用Close()可以解决卡住,不过客户端会收到use of closed network connection的错误提醒 318 | s.clientConn.Close() 319 | } 320 | } 321 | } 322 | 323 | if srcname == "request" { 324 | // 客户端已经主动发送了EOF断开, 读取不到内容也算正常 325 | if strings.Contains(er.Error(), "use of closed network connection") { 326 | err = nil 327 | } 328 | // 当客户端断开或出错了,服务端也不用再读了,可以关闭,解决读Server卡住不能到EOF的问题 329 | for tk, tv := range s.targets { 330 | cerr := tv.conn.CloseWrite() 331 | if cerr != nil { //调试时改为panic也行 332 | log.Println("closeWrite", cerr.Error()) 333 | } else { 334 | closeWrite++ 335 | if *debug >= OUT_DEBUG { 336 | log.Println("closeWrite", tk) 337 | } 338 | } 339 | } 340 | s.curState = stateClosed 341 | } 342 | break 343 | } 344 | } 345 | return written, closeWrite, err 346 | } 347 | 348 | // 错误日志 349 | func (s *tunnel) logCopyErr(name string, err error) { 350 | if err == nil || err == io.EOF { 351 | return 352 | } 353 | log.Println(name, err.Error()) 354 | } 355 | 356 | // 读取字节日志 357 | func (s *tunnel) logReaderClosed(msg string, i int, c int64, err error) { 358 | if err != nil && err != io.EOF { 359 | log.Println(msg, i, "size", c, "error", err.Error()) 360 | } else { 361 | log.Println(msg, i, "size", c) 362 | } 363 | } 364 | -------------------------------------------------------------------------------- /examples/telnet/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | func main() { 10 | http.HandleFunc("/sleep", SleepHandler) 11 | http.HandleFunc("/", HelloHandler) 12 | http.ListenAndServe("0.0.0.0:8880", nil) 13 | } 14 | 15 | func HelloHandler(w http.ResponseWriter, r *http.Request) { 16 | fmt.Println(time.Now().Unix(), r.Host+r.URL.RequestURI()) 17 | fmt.Fprintf(w, "Hello! %s%s\n", r.Host, r.URL.RequestURI()) 18 | fmt.Println(time.Now().Unix(), "hello end") 19 | } 20 | 21 | func SleepHandler(w http.ResponseWriter, r *http.Request) { 22 | fmt.Println(time.Now().Unix(), r.Host+r.URL.RequestURI()) 23 | time.Sleep(time.Duration(20) * time.Second) 24 | fmt.Fprintf(w, "Sleep! %s%s\n", r.Host, r.URL.RequestURI()) 25 | fmt.Println(time.Now().Unix(), "sleep end") 26 | } 27 | 28 | /* 29 | 请求端日志 30 | $ telnet 127.0.0.1 8880 31 | Trying 127.0.0.1... 32 | Connected to 127.0.0.1. 33 | Escape character is '^]'. 34 | GET /sleep HTTP/1.1 35 | HOST: www.example.com 36 | 37 | GET / HTTP/1.1 38 | HOST: www.aaa.com 39 | 40 | HTTP/1.1 200 OK 41 | Date: Sat, 22 Aug 2020 07:59:49 GMT 42 | Content-Length: 29 43 | Content-Type: text/plain; charset=utf-8 44 | 45 | Sleep! www.example.com/sleep 46 | HTTP/1.1 200 OK 47 | Date: Sat, 22 Aug 2020 07:59:49 GMT 48 | Content-Length: 20 49 | Content-Type: text/plain; charset=utf-8 50 | 51 | Hello! www.aaa.com/ 52 | */ 53 | 54 | /* 55 | 服务端日志 56 | $ go run main.go 57 | 1598108632 www.example.com/sleep 58 | 1598108662 sleep end 59 | 1598108662 www.aaa.com/ 60 | 1598108662 hello end 61 | */ 62 | 63 | // HTTP层面 64 | // 结论: 从请求端看,第二个www.aaa.com的请求响应一定在第一个www.example.com响应后面 65 | // 从接收端看,第二个www.aaa.com的请求接收在第一个www.example.com响应返回后 66 | 67 | // 再结合wireshark.png 68 | // 从wireshark看,第二个请求在第一个请求响应前发出了,为了测试准确性,通过将请求端换另一台机器, 69 | // 在接收端抓包观察,接收端依然及时收到了,说明发送端没缓存处理。接收端还要再测试是在哪变更的顺序 70 | 71 | //将代码运行在anyproxy后面,在copyBuffer函数增加日志 72 | /* 73 | 请求端 74 | $ telnet 127.0.0.1 4000 75 | Trying 127.0.0.1... 76 | Connected to 127.0.0.1. 77 | Escape character is '^]'. 78 | GET /sleep HTTP/1.1 79 | HOST: 127.0.0.1:8880 80 | 81 | GET / HTTP/1.1 82 | HOST: www.aaa.com 83 | 84 | HTTP/1.1 200 OK 85 | Date: Sat, 22 Aug 2020 15:44:12 GMT 86 | Content-Length: 28 87 | Content-Type: text/plain; charset=utf-8 88 | 89 | Sleep! 127.0.0.1:8880/sleep 90 | Connection closed by foreign host. 91 | */ 92 | 93 | /* 94 | 接收端 95 | 以下日志有删减 96 | $ go run anyproxy.go -l :4000 -d 2 97 | grace/server.go:215: Listening for connections on [::]:4000, pid=1188 98 | proto/client.go:22: ID #1, remoteAddr:127.0.0.1:6657 99 | ID #1, GET /sleep HTTP/1.1 100 | ID #1, Host = [127.0.0.1:8880] 101 | ID #1, 102 | proto/tunnel.go:62: ID #1, receive from client, n=1 103 | 104 | 1598111028138113800 GET / HTTP/1.1 105 | 106 | proto/tunnel.go:83: ID #1, receive from server, n=1, data len: 145 107 | ID #1, 1598111052420328000 HTTP/1.1 200 OK 108 | Date: Sat, 22 Aug 2020 15:44:12 GMT 109 | Content-Length: 28 110 | Content-Type: text/plain; charset=utf-8 111 | 112 | Sleep! 127.0.0.1:8880/sleep 113 | 114 | */ 115 | 116 | // 结论:TCP层面是没有处理顺序的。就是说第二个发送是有及时收到的。但是HTTP层面第二个发送是推迟到第一个请求返回后收到 117 | // 所以是HTTP层面有做一些操作。那么在tcp的copyBuffer函数就不能根据第二个请求是否到达来判断第一个请求是否完成并关闭第一个请求的返回。 118 | -------------------------------------------------------------------------------- /examples/telnet/wireshark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keminar/anyproxy/8fb25d1512710ad5411e29e774677f3cb04cc9b0/examples/telnet/wireshark.png -------------------------------------------------------------------------------- /examples/websocket/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // +build ignore 6 | 7 | package main 8 | 9 | import ( 10 | "flag" 11 | "log" 12 | "net/url" 13 | "os" 14 | "os/signal" 15 | "time" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | var addr = flag.String("addr", "localhost:8080", "http service address") 21 | 22 | func main() { 23 | flag.Parse() 24 | log.SetFlags(0) 25 | 26 | interrupt := make(chan os.Signal, 1) 27 | signal.Notify(interrupt, os.Interrupt) 28 | 29 | u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"} 30 | log.Printf("connecting to %s", u.String()) 31 | 32 | c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 33 | if err != nil { 34 | log.Fatal("dial:", err) 35 | } 36 | defer c.Close() 37 | 38 | done := make(chan struct{}) 39 | 40 | go func() { 41 | defer close(done) 42 | for { 43 | _, message, err := c.ReadMessage() 44 | if err != nil { 45 | log.Println("read:", err) 46 | return 47 | } 48 | log.Printf("recv: %s", message) 49 | } 50 | }() 51 | 52 | for { 53 | select { 54 | case <-done: 55 | return 56 | case <-interrupt: 57 | log.Println("interrupt") 58 | 59 | // Cleanly close the connection by sending a close message and then 60 | // waiting (with timeout) for the server to close the connection. 61 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 62 | if err != nil { 63 | log.Println("write close:", err) 64 | return 65 | } 66 | select { 67 | case <-done: 68 | case <-time.After(time.Second): 69 | } 70 | return 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /examples/websocket/conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "net/http" 10 | "time" 11 | 12 | "github.com/gorilla/websocket" 13 | ) 14 | 15 | const ( 16 | // Time allowed to write a message to the peer. 17 | writeWait = 10 * time.Second 18 | 19 | // Time allowed to read the next pong message from the peer. 20 | pongWait = 60 * time.Second 21 | 22 | // Send pings to peer with this period. Must be less than pongWait. 23 | pingPeriod = (pongWait * 9) / 10 24 | 25 | // Maximum message size allowed from peer. 26 | maxMessageSize = 512 27 | ) 28 | 29 | var ( 30 | newline = []byte{'\n'} 31 | space = []byte{' '} 32 | ) 33 | 34 | var upgrader = websocket.Upgrader{ 35 | ReadBufferSize: 1024, 36 | WriteBufferSize: 1024, 37 | } 38 | 39 | // Client is a middleman between the websocket connection and the hub. 40 | type Client struct { 41 | hub *Hub 42 | 43 | // The websocket connection. 44 | conn *websocket.Conn 45 | 46 | // Buffered channel of outbound messages. 47 | send chan []byte 48 | } 49 | 50 | // writePump pumps messages from the hub to the websocket connection. 51 | // 52 | // A goroutine running writePump is started for each connection. The 53 | // application ensures that there is at most one writer to a connection by 54 | // executing all writes from this goroutine. 55 | func (c *Client) writePump() { 56 | ticker := time.NewTicker(pingPeriod) 57 | defer func() { 58 | ticker.Stop() 59 | c.conn.Close() 60 | }() 61 | for { 62 | select { 63 | case message, ok := <-c.send: 64 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 65 | if !ok { 66 | // The hub closed the channel. 67 | c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 68 | return 69 | } 70 | 71 | w, err := c.conn.NextWriter(websocket.BinaryMessage) 72 | if err != nil { 73 | return 74 | } 75 | w.Write(message) 76 | if err := w.Close(); err != nil { 77 | return 78 | } 79 | case <-ticker.C: 80 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 81 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 82 | return 83 | } 84 | } 85 | } 86 | } 87 | 88 | // serveWs handles websocket requests from the peer. 89 | func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { 90 | conn, err := upgrader.Upgrade(w, r, nil) 91 | if err != nil { 92 | log.Println(err) 93 | return 94 | } 95 | client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} 96 | client.hub.register <- client 97 | 98 | // Allow collection of memory referenced by the caller by doing all work in 99 | // new goroutines. 100 | go client.writePump() 101 | } 102 | -------------------------------------------------------------------------------- /examples/websocket/hub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | // Hub maintains the set of active clients and broadcasts messages to the 8 | // clients. 9 | type Hub struct { 10 | // Registered clients. 11 | clients map[*Client]bool 12 | 13 | // Inbound messages from the clients. 14 | broadcast chan []byte 15 | 16 | // Register requests from the clients. 17 | register chan *Client 18 | 19 | // Unregister requests from clients. 20 | unregister chan *Client 21 | } 22 | 23 | func newHub() *Hub { 24 | return &Hub{ 25 | broadcast: make(chan []byte), 26 | register: make(chan *Client), 27 | unregister: make(chan *Client), 28 | clients: make(map[*Client]bool), 29 | } 30 | } 31 | 32 | func (h *Hub) run() { 33 | for { 34 | select { 35 | case client := <-h.register: 36 | h.clients[client] = true 37 | case client := <-h.unregister: 38 | if _, ok := h.clients[client]; ok { 39 | delete(h.clients, client) 40 | close(client.send) 41 | } 42 | case message := <-h.broadcast: 43 | //fmt.Println(string(message)) 44 | for client := range h.clients { 45 | select { 46 | case client.send <- message: 47 | default: 48 | close(client.send) 49 | delete(h.clients, client) 50 | } 51 | } 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /examples/websocket/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "fmt" 10 | "log" 11 | "net/http" 12 | "time" 13 | ) 14 | 15 | var addr = flag.String("addr", ":8080", "http service address") 16 | 17 | func main() { 18 | flag.Parse() 19 | hub := newHub() 20 | go hub.run() 21 | go func() { 22 | // 发布消息 23 | for { 24 | hub.broadcast <- []byte(fmt.Sprintf("test - %d", time.Now().Second())) 25 | time.Sleep(time.Duration(2) * time.Second) 26 | } 27 | }() 28 | http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { 29 | serveWs(hub, w, r) 30 | }) 31 | err := http.ListenAndServe(*addr, nil) 32 | if err != nil { 33 | log.Fatal("ListenAndServe: ", err) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/keminar/anyproxy 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.23.7 6 | 7 | require ( 8 | github.com/fsnotify/fsnotify v1.4.9 9 | github.com/gorilla/websocket v1.4.2 10 | golang.org/x/net v0.36.0 11 | gopkg.in/yaml.v2 v2.3.0 12 | ) 13 | 14 | require golang.org/x/sys v0.30.0 // indirect 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= 2 | github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= 3 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 4 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 5 | golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= 6 | golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= 7 | golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 8 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 9 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= 13 | gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 14 | -------------------------------------------------------------------------------- /grace/autoinc/autoinc.go: -------------------------------------------------------------------------------- 1 | //autoInc.go 2 | 3 | package autoinc 4 | 5 | // UintMax 最大值 6 | const UintMax = ^uint(0) 7 | 8 | // AutoInc 自增 9 | type AutoInc struct { 10 | start, step uint 11 | queue chan uint 12 | } 13 | 14 | // New 实例化 15 | func New(start, step uint) (ai *AutoInc) { 16 | ai = &AutoInc{ 17 | start: start, 18 | step: step, 19 | queue: make(chan uint, 4), 20 | } 21 | 22 | go ai.process() 23 | return 24 | } 25 | 26 | // 产生id 27 | func (ai *AutoInc) process() { 28 | defer func() { recover() }() 29 | for i := ai.start; ; i = i + ai.step { 30 | if i > UintMax { 31 | // reset 32 | i = ai.start 33 | } 34 | ai.queue <- i 35 | } 36 | } 37 | 38 | // ID 取id 39 | func (ai *AutoInc) ID() uint { 40 | return <-ai.queue 41 | } 42 | -------------------------------------------------------------------------------- /grace/conn.go: -------------------------------------------------------------------------------- 1 | package grace 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net" 8 | "runtime" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/keminar/anyproxy/grace/autoinc" 13 | ) 14 | 15 | // AutoInc 自增 16 | var autoInc *autoinc.AutoInc 17 | 18 | func init() { 19 | autoInc = autoinc.New(1, 1) 20 | } 21 | 22 | // A conn represents the server side of an HTTP connection. 23 | type conn struct { 24 | // server is the server on which the connection arrived. 25 | // Immutable; never nil. 26 | server *Server 27 | 28 | // cancelCtx cancels the connection-level context. 29 | cancelCtx context.CancelFunc 30 | 31 | // rwc is the underlying network connection. 32 | rwc *net.TCPConn 33 | 34 | traceID uint 35 | // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously 36 | // inside the Listener's Accept goroutine, as some implementations block. 37 | // It is populated immediately inside the (*conn).serve goroutine. 38 | // This is the value of a Handler's (*Request).RemoteAddr. 39 | remoteAddr string 40 | 41 | curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) 42 | 43 | startTime int64 44 | } 45 | 46 | // Serve a new connection. 47 | func (c *conn) serve(ctx context.Context) { 48 | c.traceID = autoInc.ID() 49 | c.startTime = time.Now().UnixNano() 50 | c.remoteAddr = c.rwc.RemoteAddr().String() 51 | //addr, ok := ctx.Value(grace.LocalAddrContextKey).(net.Addr) 52 | ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) 53 | //traceID, ok := ctx.Value(grace.TraceIDContextKey).(uint) 54 | ctx = context.WithValue(ctx, TraceIDContextKey, c.traceID) 55 | defer func() { 56 | if err := recover(); err != nil { 57 | const size = 64 << 10 58 | buf := make([]byte, size) 59 | buf = buf[:runtime.Stack(buf, false)] 60 | log.Printf("%s panic serving %v: %v\n%s", traceID(c.traceID), c.remoteAddr, err, buf) 61 | } 62 | c.close() 63 | c.setState(c.rwc, StateClosed) 64 | log.Println(traceID(c.traceID), "closed") 65 | }() 66 | ctx, cancelCtx := context.WithCancel(ctx) 67 | c.cancelCtx = cancelCtx 68 | defer cancelCtx() 69 | 70 | c.setState(c.rwc, StateActive) 71 | handler := c.server.Handler 72 | err := handler(ctx, c.rwc) 73 | if err != nil { 74 | log.Printf("%s conn handler %v: %v\n", traceID(c.traceID), c.remoteAddr, err) 75 | } 76 | } 77 | 78 | // TraceID 日志ID 79 | func traceID(id uint) string { 80 | return fmt.Sprintf("ID #%d,", id) 81 | } 82 | 83 | // A ConnState represents the state of a client connection to a server. 84 | // It's used by the optional Server.ConnState hook. 85 | type ConnState int 86 | 87 | const ( 88 | // StateNew represents a new connection that is expected to 89 | // send a request immediately. Connections begin at this 90 | // state and then transition to either StateActive or 91 | // StateClosed. 92 | StateNew ConnState = iota 93 | 94 | // StateActive represents a connection that has read 1 or more 95 | // bytes of a request. The Server.ConnState hook for 96 | // StateActive fires before the request has entered a handler 97 | // and doesn't fire again until the request has been 98 | // handled. After the request is handled, the state 99 | // transitions to StateClosed, StateHijacked, or StateIdle. 100 | // For HTTP/2, StateActive fires on the transition from zero 101 | // to one active request, and only transitions away once all 102 | // active requests are complete. That means that ConnState 103 | // cannot be used to do per-request work; ConnState only notes 104 | // the overall state of the connection. 105 | StateActive 106 | 107 | // StateIdle represents a connection that has finished 108 | // handling a request and is in the keep-alive state, waiting 109 | // for a new request. Connections transition from StateIdle 110 | // to either StateActive or StateClosed. 111 | StateIdle 112 | 113 | // StateHijacked represents a hijacked connection. 114 | // This is a terminal state. It does not transition to StateClosed. 115 | StateHijacked 116 | 117 | // StateClosed represents a closed connection. 118 | // This is a terminal state. Hijacked connections do not 119 | // transition to StateClosed. 120 | StateClosed 121 | ) 122 | 123 | var stateName = map[ConnState]string{ 124 | StateNew: "new", 125 | StateActive: "active", 126 | StateIdle: "idle", 127 | StateHijacked: "hijacked", 128 | StateClosed: "closed", 129 | } 130 | 131 | func (c ConnState) String() string { 132 | return stateName[c] 133 | } 134 | 135 | func (c *conn) setState(nc net.Conn, state ConnState) { 136 | srv := c.server 137 | switch state { 138 | case StateNew: 139 | srv.trackConn(c, true) 140 | case StateHijacked, StateClosed: 141 | srv.trackConn(c, false) 142 | } 143 | if state > 0xff || state < 0 { 144 | panic("internal error") 145 | } 146 | packedState := uint64(time.Now().Unix()<<8) | uint64(state) 147 | atomic.StoreUint64(&c.curState.atomic, packedState) 148 | } 149 | 150 | func (c *conn) getState() (state ConnState, unixSec int64) { 151 | packedState := atomic.LoadUint64(&c.curState.atomic) 152 | return ConnState(packedState & 0xff), int64(packedState >> 8) 153 | } 154 | 155 | // Close the connection. 156 | func (c *conn) close() { 157 | c.rwc.Close() 158 | } 159 | -------------------------------------------------------------------------------- /grace/grace.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 beego Author. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package grace use to hot reload 16 | // Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ 17 | 18 | // examples 19 | // package main 20 | // 21 | // import ( 22 | // "flag" 23 | // "net" 24 | // 25 | // "github.com/keminar/anyproxy/grace" 26 | // ) 27 | // 28 | // var appHandler func(conn net.Conn) error 29 | // 30 | // func main() { 31 | // flag.Parse() 32 | // 33 | // server := grace.NewServer(":3000", appHandler) 34 | // server.ListenAndServe() 35 | // } 36 | 37 | package grace 38 | 39 | import ( 40 | "flag" 41 | "os" 42 | "strings" 43 | "sync" 44 | "syscall" 45 | ) 46 | 47 | const ( 48 | // PreSignal is the position to add filter before signal 49 | PreSignal = iota 50 | // PostSignal is the position to add filter after signal 51 | PostSignal 52 | // StateInit represent the application inited 53 | StateInit 54 | // StateRunning represent the application is running 55 | StateRunning 56 | // StateShuttingDown represent the application is shutting down 57 | StateShuttingDown 58 | // StateTerminate represent the application is killed 59 | StateTerminate 60 | ) 61 | 62 | var ( 63 | regLock *sync.Mutex 64 | runningServers map[string]*Server 65 | runningServersOrder []string 66 | socketPtrOffsetMap map[string]uint 67 | runningServersForked bool 68 | 69 | isChild bool 70 | socketOrder string 71 | 72 | hookableSignals []os.Signal 73 | ) 74 | 75 | func init() { 76 | flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") 77 | flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") 78 | 79 | regLock = &sync.Mutex{} 80 | runningServers = make(map[string]*Server) 81 | runningServersOrder = []string{} 82 | socketPtrOffsetMap = make(map[string]uint) 83 | 84 | hookableSignals = []os.Signal{ 85 | syscall.SIGHUP, 86 | syscall.SIGINT, 87 | syscall.SIGTERM, 88 | } 89 | } 90 | 91 | // NewServer returns a new graceServer. 92 | func NewServer(addr string, handler ConnHandler, network string) (srv *Server) { 93 | regLock.Lock() 94 | defer regLock.Unlock() 95 | 96 | if !flag.Parsed() { 97 | flag.Parse() 98 | } 99 | if len(socketOrder) > 0 { 100 | for i, addr := range strings.Split(socketOrder, ",") { 101 | socketPtrOffsetMap[addr] = uint(i) 102 | } 103 | } else { 104 | socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) 105 | } 106 | 107 | srv = &Server{ 108 | Addr: addr, 109 | Handler: handler, 110 | sigChan: make(chan os.Signal), 111 | isChild: isChild, 112 | SignalHooks: map[int]map[os.Signal][]func(){ 113 | PreSignal: { 114 | syscall.SIGHUP: {}, 115 | syscall.SIGINT: {}, 116 | syscall.SIGTERM: {}, 117 | }, 118 | PostSignal: { 119 | syscall.SIGHUP: {}, 120 | syscall.SIGINT: {}, 121 | syscall.SIGTERM: {}, 122 | }, 123 | }, 124 | state: StateInit, 125 | Network: network, 126 | terminalChan: make(chan error), //no cache channel 127 | } 128 | 129 | runningServersOrder = append(runningServersOrder, addr) 130 | runningServers[addr] = srv 131 | return srv 132 | } 133 | -------------------------------------------------------------------------------- /grace/server.go: -------------------------------------------------------------------------------- 1 | package grace 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net" 9 | "os" 10 | "os/exec" 11 | "os/signal" 12 | "strings" 13 | "sync" 14 | "syscall" 15 | "time" 16 | ) 17 | 18 | //ConnHandler connection handler definition 19 | type ConnHandler func(ctx context.Context, conn *net.TCPConn) error 20 | 21 | //ErrReloadClose reload graceful 22 | var ErrReloadClose = errors.New("reload graceful") 23 | 24 | //TermTimeout 平滑重启主进程保持秒数 25 | var TermTimeout = 10 26 | 27 | // Server embedded http.Server 28 | type Server struct { 29 | Addr string 30 | Handler ConnHandler 31 | ln *net.TCPListener 32 | SignalHooks map[int]map[os.Signal][]func() 33 | sigChan chan os.Signal 34 | isChild bool 35 | state uint8 36 | Network string 37 | terminalChan chan error 38 | 39 | mu sync.Mutex 40 | activeConn map[*conn]struct{} 41 | } 42 | 43 | // Serve accepts incoming connections on the Listener l, 44 | // creating a new service goroutine for each. 45 | // The service goroutines read requests and then call srv.Handler to reply to them. 46 | func (srv *Server) Serve() (err error) { 47 | srv.state = StateRunning 48 | defer func() { srv.state = StateTerminate }() 49 | 50 | // 主动重启导致的错误为ErrReloadClose 51 | if err = srv.serve(); err != nil && err != ErrReloadClose { 52 | log.Println(syscall.Getpid(), "Server.Serve() error:", err) 53 | return err 54 | } 55 | 56 | log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") 57 | // wait for Shutdown to return 58 | return <-srv.terminalChan 59 | } 60 | 61 | // contextKey is a value for use with context.WithValue. It's used as 62 | // a pointer so it fits in an interface{} without allocation. 63 | type contextKey struct { 64 | name string 65 | } 66 | 67 | func (k *contextKey) String() string { return "tcp context value " + k.name } 68 | 69 | var ( 70 | // ServerContextKey is a context key. It can be used in HTTP 71 | // handlers with context.WithValue to access the server that 72 | // started the handler. The associated value will be of 73 | // type *Server. 74 | ServerContextKey = &contextKey{"server"} 75 | 76 | // LocalAddrContextKey is a context key. It can be used in 77 | // HTTP handlers with context.WithValue to access the local 78 | // address the connection arrived on. 79 | // The associated value will be of type net.Addr. 80 | LocalAddrContextKey = &contextKey{"local-addr"} 81 | 82 | // TraceIDContextKey traceID 83 | TraceIDContextKey = &contextKey{"traceID"} 84 | ) 85 | 86 | func (srv *Server) serve() (err error) { 87 | var tempDelay time.Duration 88 | 89 | baseCtx := context.Background() // base is always background, per Issue 16220 90 | 91 | //srv, ok := ctx.Value(grace.ServerContextKey).(*grace.Server) 92 | ctx := context.WithValue(baseCtx, ServerContextKey, srv) 93 | for { 94 | var rw *net.TCPConn 95 | rw, err = srv.ln.AcceptTCP() 96 | if err != nil { 97 | // 主动重启服务 98 | if srv.state == StateShuttingDown && strings.Contains(err.Error(), "use of closed network connection") { 99 | return ErrReloadClose 100 | } 101 | if ne, ok := err.(net.Error); ok && ne.Temporary() { 102 | if tempDelay == 0 { 103 | tempDelay = 5 * time.Millisecond 104 | } else { 105 | tempDelay *= 2 106 | } 107 | if max := 1 * time.Second; tempDelay > max { 108 | tempDelay = max 109 | } 110 | log.Printf("Accept error: %v; retrying in %v\n", err, tempDelay) 111 | time.Sleep(tempDelay) 112 | continue 113 | } 114 | return err 115 | } 116 | tempDelay = 0 117 | c := srv.newConn(rw) 118 | c.setState(c.rwc, StateNew) // before Serve can return 119 | go c.serve(ctx) 120 | } 121 | } 122 | 123 | // closeIdleConns closes all idle connections and reports whether the 124 | // server is quiescent. 125 | func (srv *Server) closeIdleConns() bool { 126 | srv.mu.Lock() 127 | defer srv.mu.Unlock() 128 | 129 | quiescent := true 130 | for c := range srv.activeConn { 131 | st, unixSec := c.getState() 132 | // Issue 22682: treat StateNew connections as if 133 | // they're idle if we haven't read the first request's 134 | // header in over 5 seconds. 135 | if st == StateNew && unixSec < time.Now().Unix()-5 { 136 | st = StateIdle 137 | } 138 | if st != StateIdle || unixSec == 0 { 139 | // Assume unixSec == 0 means it's a very new 140 | // connection, without state set yet. 141 | quiescent = false 142 | continue 143 | } 144 | c.rwc.Close() 145 | delete(srv.activeConn, c) 146 | } 147 | return quiescent 148 | } 149 | 150 | // GetConns 获取所有连接数 151 | func (srv *Server) GetConns() int { 152 | return len(srv.activeConn) 153 | } 154 | 155 | // GetConnRange 输出全部连接 156 | func (srv *Server) GetConnRange(f func(ID uint, startTime int64, remoteAddr string)) { 157 | srv.mu.Lock() 158 | defer srv.mu.Unlock() 159 | for c := range srv.activeConn { 160 | f(c.traceID, c.startTime, c.remoteAddr) 161 | } 162 | } 163 | 164 | // 统计连接数 165 | func (srv *Server) trackConn(c *conn, add bool) { 166 | srv.mu.Lock() 167 | defer srv.mu.Unlock() 168 | if srv.activeConn == nil { 169 | srv.activeConn = make(map[*conn]struct{}) 170 | } 171 | if add { 172 | srv.activeConn[c] = struct{}{} 173 | } else { 174 | delete(srv.activeConn, c) 175 | } 176 | } 177 | 178 | func (srv *Server) newConn(rwc *net.TCPConn) *conn { 179 | c := &conn{ 180 | server: srv, 181 | rwc: rwc, 182 | } 183 | return c 184 | } 185 | 186 | // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve 187 | // to handle requests on incoming connections. If srv.Addr is blank, ":http" is 188 | // used. 189 | func (srv *Server) ListenAndServe() (err error) { 190 | addr := srv.Addr 191 | if addr == "" { 192 | addr = ":3000" 193 | } 194 | 195 | go srv.handleSignals() 196 | 197 | srv.ln, err = srv.getListener(addr) 198 | if err != nil { 199 | log.Println(os.Getpid(), err) 200 | return err 201 | } 202 | 203 | ppid := os.Getppid() 204 | if srv.isChild && ppid != 1 { //增加一个安全检查 205 | process, err := os.FindProcess(ppid) 206 | if err != nil { 207 | log.Println(os.Getpid(), err) 208 | return err 209 | } 210 | err = process.Signal(syscall.SIGTERM) 211 | if err != nil { 212 | return err 213 | } 214 | } 215 | 216 | log.Println(fmt.Sprintf("Listening for connections on %v, pid=%d", srv.ln.Addr(), os.Getpid())) 217 | 218 | return srv.Serve() 219 | } 220 | 221 | // getListener either opens a new socket to listen on, or takes the acceptor socket 222 | // it got passed when restarted. 223 | func (srv *Server) getListener(laddr string) (l *net.TCPListener, err error) { 224 | if srv.isChild { 225 | var ptrOffset uint 226 | if len(socketPtrOffsetMap) > 0 { 227 | ptrOffset = socketPtrOffsetMap[laddr] 228 | log.Println(os.Getpid(), "laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) 229 | } 230 | 231 | f := os.NewFile(uintptr(3+ptrOffset), "") 232 | 233 | var ln net.Listener 234 | ln, err = net.FileListener(f) 235 | if err != nil { 236 | err = fmt.Errorf("net.FileListener error: %v", err) 237 | return 238 | } 239 | l = ln.(*net.TCPListener) 240 | } else { 241 | var lnaddr *net.TCPAddr 242 | lnaddr, err = net.ResolveTCPAddr(srv.Network, laddr) 243 | if err != nil { 244 | err = fmt.Errorf("net.Listen error: %v", err) 245 | return 246 | } 247 | 248 | l, err = net.ListenTCP(srv.Network, lnaddr) 249 | if err != nil { 250 | err = fmt.Errorf("net.Listen error: %v", err) 251 | return 252 | } 253 | } 254 | return 255 | } 256 | 257 | // handleSignals listens for os Signals and calls any hooked in function that the 258 | // user had registered with the signal. 259 | func (srv *Server) handleSignals() { 260 | var sig os.Signal 261 | 262 | signal.Notify( 263 | srv.sigChan, 264 | hookableSignals..., 265 | ) 266 | 267 | pid := syscall.Getpid() 268 | for { 269 | sig = <-srv.sigChan 270 | srv.signalHooks(PreSignal, sig) 271 | switch sig { 272 | case syscall.SIGHUP: 273 | log.Println(pid, "Received SIGHUP. forking.") 274 | err := srv.fork() 275 | if err != nil { 276 | log.Println("Fork err:", err) 277 | } 278 | case syscall.SIGINT: 279 | log.Println(pid, "Received SIGINT.") 280 | // ctrl+c无等待时间 281 | srv.shutdown(0) 282 | case syscall.SIGTERM: 283 | log.Println(pid, "Received SIGTERM.") 284 | srv.shutdown(TermTimeout) 285 | default: 286 | log.Printf("Received %v: nothing i care about...\n", sig) 287 | } 288 | srv.signalHooks(PostSignal, sig) 289 | } 290 | } 291 | 292 | // 处理默认消息之外的钩子 293 | func (srv *Server) signalHooks(ppFlag int, sig os.Signal) { 294 | if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { 295 | return 296 | } 297 | for _, f := range srv.SignalHooks[ppFlag][sig] { 298 | f() 299 | } 300 | } 301 | 302 | // shutdown closes the listener so that no new connections are accepted. it also 303 | // starts a goroutine that will serverTimeout (stop all running requests) the server 304 | // after DefaultTimeout. 305 | func (srv *Server) shutdown(timeout int) { 306 | if srv.state != StateRunning { 307 | return 308 | } 309 | 310 | srv.state = StateShuttingDown 311 | // listen close就不能accept新的链接,已接收的链接不受影响 312 | // 关闭已连接的是用tcpConn.Close(), 为了简单下面是用超时来等待处理 313 | srv.ln.Close() 314 | 315 | if timeout > 0 { 316 | log.Println(syscall.Getpid(), fmt.Sprintf("Waiting max %d second for connections to finish...", timeout)) 317 | // 等一定时间让已接收的请求处理一下,如果还处理不完就强制关闭了 318 | after := time.After(time.Duration(timeout) * time.Second) 319 | 320 | var shutdownPollInterval = 500 * time.Millisecond 321 | ticker := time.NewTicker(shutdownPollInterval) 322 | defer ticker.Stop() 323 | for { 324 | srv.closeIdleConns() 325 | if len(srv.activeConn) == 0 { 326 | break 327 | } 328 | force := false 329 | select { 330 | case <-after: 331 | // 这里加break没用,只会跳一层select ,所以加一个变量 332 | force = true 333 | case <-ticker.C: 334 | } 335 | if force { 336 | break 337 | } 338 | } 339 | } 340 | 341 | srv.terminalChan <- nil 342 | } 343 | 344 | func (srv *Server) fork() (err error) { 345 | regLock.Lock() 346 | defer regLock.Unlock() 347 | if runningServersForked { 348 | return 349 | } 350 | runningServersForked = true 351 | 352 | var files = make([]*os.File, len(runningServers)) 353 | var orderArgs = make([]string, len(runningServers)) 354 | for _, srvPtr := range runningServers { 355 | f, _ := srvPtr.ln.File() 356 | files[socketPtrOffsetMap[srvPtr.Addr]] = f 357 | orderArgs[socketPtrOffsetMap[srvPtr.Addr]] = srvPtr.Addr 358 | } 359 | 360 | //log.Println(files) 361 | path := os.Args[0] 362 | var args []string 363 | if len(os.Args) > 1 { 364 | for _, arg := range os.Args[1:] { 365 | if arg == "-graceful" { 366 | break 367 | } 368 | args = append(args, arg) 369 | } 370 | } 371 | args = append(args, "-graceful") 372 | if len(runningServers) > 1 { 373 | args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) 374 | log.Println(args) 375 | } 376 | cmd := exec.Command(path, args...) 377 | cmd.Stdout = os.Stdout 378 | cmd.Stderr = os.Stderr 379 | cmd.ExtraFiles = files 380 | err = cmd.Start() 381 | if err != nil { 382 | log.Fatalf("Restart: Failed to launch, error: %v", err) 383 | } 384 | 385 | return 386 | } 387 | 388 | // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal. 389 | func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) { 390 | if ppFlag != PreSignal && ppFlag != PostSignal { 391 | err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal") 392 | return 393 | } 394 | for _, s := range hookableSignals { 395 | if s == sig { 396 | srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f) 397 | return 398 | } 399 | } 400 | err = fmt.Errorf("Signal '%v' is not supported", sig) 401 | return 402 | } 403 | -------------------------------------------------------------------------------- /logging/logger.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | 10 | "github.com/keminar/anyproxy/config" 11 | ) 12 | 13 | // SetDefaultLogger 设置日志 14 | func SetDefaultLogger(dir, prefix string, compress bool, reserveDay int, w io.Writer) { 15 | timeWriter := &TimeWriter{ 16 | Dir: dir, 17 | Prefix: prefix, 18 | Compress: compress, 19 | ReserveDay: reserveDay, 20 | } 21 | // 同时输出到日志和标准输出 22 | writers := []io.Writer{ 23 | timeWriter, 24 | } 25 | if w != nil { 26 | writers = append(writers, w) 27 | } 28 | log.SetOutput(io.MultiWriter(writers...)) 29 | switch config.DebugLevel { 30 | case config.LevelLong: 31 | log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds) 32 | case config.LevelDebug: 33 | log.SetFlags(log.Llongfile | log.Ldate | log.Lmicroseconds) 34 | case config.LevelDebugBody: 35 | log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds) 36 | default: 37 | log.SetFlags(log.Lshortfile | log.LstdFlags) 38 | } 39 | } 40 | 41 | // ErrlogFd 标准输出错误输出文件 42 | func ErrlogFd(logDir string, cmdName string) *os.File { 43 | if _, err := os.Stat(logDir); os.IsNotExist(err) { 44 | err = os.Mkdir(logDir, 0777) 45 | if err != nil { 46 | log.Fatalln("logs dir create error", err.Error()) 47 | } 48 | } 49 | errFile := filepath.Join(logDir, fmt.Sprintf("%s.err.log", cmdName)) 50 | fd, err := os.OpenFile(errFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0664) 51 | if err != nil { 52 | //报错并退出 53 | log.Fatalln("open log file error", err.Error()) 54 | } 55 | return fd 56 | } 57 | -------------------------------------------------------------------------------- /logging/timewriter.go: -------------------------------------------------------------------------------- 1 | //TimeWriter implements io.Writer to roll daily and comporess log file time 2 | //Clone from https://github.com/longbozhan/timewriter 3 | 4 | package logging 5 | 6 | import ( 7 | "compress/gzip" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "io/ioutil" 12 | "os" 13 | "path/filepath" 14 | "sort" 15 | "strconv" 16 | "strings" 17 | "sync" 18 | "time" 19 | ) 20 | 21 | // io.WriteCloser 22 | var _ io.WriteCloser = (*TimeWriter)(nil) 23 | 24 | const ( 25 | compressSuffix = ".gz" 26 | timeFormat = "2006-01-02 15:04:05" 27 | ) 28 | 29 | // TimeWriter 实体 30 | type TimeWriter struct { 31 | Dir string 32 | Prefix string 33 | Compress bool 34 | ReserveDay int 35 | 36 | curFilename string 37 | file *os.File 38 | mu sync.Mutex 39 | startMill sync.Once 40 | millCh chan bool 41 | } 42 | 43 | func (l *TimeWriter) Write(p []byte) (n int, err error) { 44 | l.mu.Lock() 45 | defer l.mu.Unlock() 46 | 47 | if l.file == nil { 48 | if err = l.openExistingOrNew(len(p)); err != nil { 49 | fmt.Printf("write fail, msg(%s)\n", err) 50 | return 0, err 51 | } 52 | } 53 | 54 | if l.curFilename != l.filename() { 55 | l.rotate() 56 | } 57 | 58 | n, err = l.file.Write(p) 59 | 60 | return n, err 61 | } 62 | 63 | // Close 关闭 64 | func (l *TimeWriter) Close() error { 65 | l.mu.Lock() 66 | defer l.mu.Unlock() 67 | return l.close() 68 | } 69 | 70 | // Rotate 初始 71 | func (l *TimeWriter) Rotate() error { 72 | l.mu.Lock() 73 | defer l.mu.Unlock() 74 | return l.rotate() 75 | } 76 | 77 | func (l *TimeWriter) close() error { 78 | if l.file == nil { 79 | return nil 80 | } 81 | err := l.file.Close() 82 | l.file = nil 83 | return err 84 | } 85 | 86 | func (l *TimeWriter) rotate() error { 87 | if err := l.close(); err != nil { 88 | return err 89 | } 90 | if err := l.openNew(); err != nil { 91 | return err 92 | } 93 | 94 | l.mill() 95 | return nil 96 | } 97 | 98 | func (l *TimeWriter) oldLogFiles() ([]logInfo, error) { 99 | files, err := ioutil.ReadDir(l.Dir) 100 | if err != nil { 101 | return nil, fmt.Errorf("can't read log file directory: %s", err) 102 | } 103 | logFiles := []logInfo{} 104 | 105 | prefix, ext := l.prefixAndExt() 106 | 107 | for _, f := range files { 108 | if f.IsDir() { 109 | continue 110 | } 111 | if f.Name() == filepath.Base(l.curFilename) { 112 | continue 113 | } 114 | if t, err := l.timeFromName(f.Name(), prefix, ext); err == nil { 115 | logFiles = append(logFiles, logInfo{t, f}) 116 | continue 117 | } else { 118 | fmt.Printf("err1(%s)\n", err) 119 | } 120 | if t, err := l.timeFromName(f.Name(), prefix, ext+compressSuffix); err == nil { 121 | logFiles = append(logFiles, logInfo{t, f}) 122 | continue 123 | } else { 124 | fmt.Printf("err2(%s)\n", err) 125 | } 126 | } 127 | 128 | sort.Sort(byFormatTime(logFiles)) 129 | 130 | return logFiles, nil 131 | } 132 | 133 | func (l *TimeWriter) timeFromName(filename, prefix, ext string) (time.Time, error) { 134 | if !strings.HasPrefix(filename, prefix) { 135 | return time.Time{}, errors.New("mismatched prefix") 136 | } 137 | if !strings.HasSuffix(filename, ext) { 138 | return time.Time{}, errors.New("mismatched extension") 139 | } 140 | ts := filename[len(prefix) : len(filename)-len(ext)] 141 | if len(ts) != 8 { 142 | return time.Time{}, errors.New("mismatched date") 143 | } 144 | if year, err := strconv.ParseInt(ts[0:4], 10, 64); err != nil { 145 | return time.Time{}, err 146 | } else if month, _ := strconv.ParseInt(ts[4:6], 10, 64); err != nil { 147 | return time.Time{}, err 148 | } else if day, _ := strconv.ParseInt(ts[6:8], 10, 64); err != nil { 149 | return time.Time{}, err 150 | } else { 151 | timeStr := fmt.Sprintf("%04d-%02d-%02d 00:00:00", year, month, day) 152 | if location, err := time.LoadLocation("Local"); err != nil { 153 | return time.Time{}, err 154 | } else if t, err := time.ParseInLocation(timeFormat, timeStr, location); err != nil { 155 | return time.Time{}, err 156 | } else { 157 | return t, nil 158 | } 159 | } 160 | 161 | } 162 | 163 | func (l *TimeWriter) prefixAndExt() (prefix, ext string) { 164 | filename := filepath.Base(l.filename()) 165 | ext = filepath.Ext(filename) 166 | prefix = filename[:len(filename)-len(ext)-8] 167 | return prefix, ext 168 | } 169 | 170 | func (l *TimeWriter) millRunOnce() error { 171 | if l.ReserveDay == 0 && !l.Compress { 172 | return nil 173 | } 174 | 175 | files, err := l.oldLogFiles() 176 | if err != nil { 177 | return err 178 | } 179 | 180 | var compress, remove []logInfo 181 | 182 | if l.ReserveDay > 0 { 183 | diff := time.Duration(int64(24*time.Hour) * int64(l.ReserveDay)) 184 | cutoff := time.Now().Add(-1 * diff) 185 | 186 | var remaining []logInfo 187 | for _, f := range files { 188 | if f.timestamp.Before(cutoff) { 189 | remove = append(remove, f) 190 | } else { 191 | remaining = append(remaining, f) 192 | } 193 | } 194 | 195 | files = remaining 196 | } 197 | 198 | if l.Compress { 199 | for _, f := range files { 200 | if !strings.HasSuffix(f.Name(), compressSuffix) { 201 | compress = append(compress, f) 202 | } 203 | } 204 | } 205 | 206 | for _, f := range remove { 207 | errRemove := os.Remove(filepath.Join(l.Dir, f.Name())) 208 | if err == nil && errRemove != nil { 209 | err = errRemove 210 | } 211 | } 212 | for _, f := range compress { 213 | fn := filepath.Join(l.Dir, f.Name()) 214 | errCompress := compressLogFile(fn, fn+compressSuffix) 215 | if err == nil && errCompress != nil { 216 | err = errCompress 217 | } 218 | } 219 | 220 | return err 221 | } 222 | 223 | func (l *TimeWriter) millRun() { 224 | for range l.millCh { 225 | _ = l.millRunOnce() 226 | } 227 | } 228 | 229 | func (l *TimeWriter) mill() { 230 | l.startMill.Do(func() { 231 | l.millCh = make(chan bool, 1) 232 | go l.millRun() 233 | }) 234 | select { 235 | case l.millCh <- true: 236 | default: 237 | } 238 | } 239 | 240 | func (l *TimeWriter) openNew() error { 241 | name := l.filename() 242 | err := os.MkdirAll(l.Dir, 0744) 243 | if err != nil { 244 | return fmt.Errorf("can't make directories for new logfile: %s", err) 245 | } 246 | 247 | mode := os.FileMode(0644) 248 | 249 | f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) 250 | if err != nil { 251 | return fmt.Errorf("can't open new logfile: %s", err) 252 | } 253 | l.curFilename = name 254 | l.file = f 255 | return nil 256 | } 257 | 258 | func (l *TimeWriter) openExistingOrNew(writeLen int) error { 259 | 260 | filename := l.filename() 261 | if _, err := os.Stat(filename); os.IsNotExist(err) { 262 | return l.openNew() 263 | } else if err != nil { 264 | return fmt.Errorf("error getting log file info: %s", err) 265 | } 266 | 267 | file, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY, 0644) 268 | if err != nil { 269 | return l.openNew() 270 | } 271 | l.curFilename = filename 272 | l.file = file 273 | return nil 274 | } 275 | 276 | func (l *TimeWriter) filename() string { 277 | year, month, day := time.Now().Date() 278 | date := fmt.Sprintf("%04d%02d%02d", year, month, day) 279 | name := fmt.Sprintf("%s.%s.log", l.Prefix, date) 280 | if l.Dir != "" { 281 | return filepath.Join(l.Dir, name) 282 | } 283 | return filepath.Join(os.TempDir(), name) 284 | } 285 | 286 | func compressLogFile(src, dst string) (err error) { 287 | f, err := os.Open(src) 288 | if err != nil { 289 | return fmt.Errorf("failed to open log file: %v", err) 290 | } 291 | defer f.Close() 292 | 293 | fi, err := os.Stat(src) 294 | if err != nil { 295 | return fmt.Errorf("failed to stat log file: %v", err) 296 | } 297 | 298 | gzf, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, fi.Mode()) 299 | if err != nil { 300 | return fmt.Errorf("failed to open compressed log file: %v", err) 301 | } 302 | defer gzf.Close() 303 | 304 | gz := gzip.NewWriter(gzf) 305 | 306 | defer func() { 307 | if err != nil { 308 | os.Remove(dst) 309 | err = fmt.Errorf("failed to compress log file: %v", err) 310 | } 311 | }() 312 | 313 | if _, err := io.Copy(gz, f); err != nil { 314 | return err 315 | } 316 | if err := gz.Close(); err != nil { 317 | return err 318 | } 319 | if err := gzf.Close(); err != nil { 320 | return err 321 | } 322 | 323 | if err := f.Close(); err != nil { 324 | return err 325 | } 326 | if err := os.Remove(src); err != nil { 327 | return err 328 | } 329 | 330 | return nil 331 | } 332 | 333 | type logInfo struct { 334 | timestamp time.Time 335 | os.FileInfo 336 | } 337 | 338 | // byFormatTime sorts by newest time formatted in the name. 339 | type byFormatTime []logInfo 340 | 341 | func (b byFormatTime) Less(i, j int) bool { 342 | return b[i].timestamp.After(b[j].timestamp) 343 | } 344 | 345 | func (b byFormatTime) Swap(i, j int) { 346 | b[i], b[j] = b[j], b[i] 347 | } 348 | 349 | func (b byFormatTime) Len() int { 350 | return len(b) 351 | } 352 | -------------------------------------------------------------------------------- /nat/bridge.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | 8 | "github.com/keminar/anyproxy/config" 9 | "github.com/keminar/anyproxy/utils/trace" 10 | ) 11 | 12 | // Bridge 桥接 13 | type Bridge struct { 14 | bridgeHub *BridgeHub 15 | client *Client 16 | 17 | reqID uint //请求id 18 | conn *net.TCPConn 19 | 20 | // Buffered channel of outbound messages. 21 | send chan []byte 22 | } 23 | 24 | // Unregister 包外面调用取消注册 25 | func (b *Bridge) Unregister() { 26 | b.bridgeHub.unregister <- b 27 | } 28 | 29 | // 向websocket hub写数据 30 | func (b *Bridge) Write(p []byte) (n int, err error) { 31 | // 先把p拷贝一份,否则会被外面的CopyBuffer再次修改,因为是引入传递 32 | body := make([]byte, len(p)) 33 | copy(body, p) 34 | msg := &Message{ID: b.reqID, Body: body} 35 | 36 | if config.DebugLevel >= config.LevelDebugBody { 37 | md5Val, _ := md5Byte(msg.Body) 38 | log.Println("nat_debug_write_chan", msg.ID, md5Val) 39 | } 40 | 41 | cmsg := &CMessage{client: b.client, message: msg} 42 | b.client.hub.broadcast <- cmsg 43 | return len(p), nil 44 | } 45 | 46 | // Open 通知websocket 创建连接 47 | func (b *Bridge) Open() { 48 | msg := &Message{ID: b.reqID, Method: METHOD_CREATE} 49 | //b.client.send <- msg //注意:不能直接写send会与close有并发安全冲突 50 | cmsg := &CMessage{client: b.client, message: msg} 51 | b.client.hub.broadcast <- cmsg 52 | } 53 | 54 | // CloseWrite 通知tcp关闭连接 55 | func (b *Bridge) CloseWrite() { 56 | msg := &Message{ID: b.reqID, Method: METHOD_CLOSE} 57 | cmsg := &CMessage{client: b.client, message: msg} 58 | b.client.hub.broadcast <- cmsg 59 | } 60 | 61 | // WritePump 从websocket hub读数据写到请求http端 62 | func (b *Bridge) WritePump() (written int64, err error) { 63 | defer func() { 64 | b.conn.CloseWrite() 65 | if config.DebugLevel >= config.LevelDebug { 66 | log.Println("net_debug_write_proxy_close") 67 | } 68 | }() 69 | for { 70 | select { 71 | case message, ok := <-b.send: //ok为判断channel是否关闭 72 | if !ok { 73 | if config.DebugLevel >= config.LevelDebug { 74 | log.Println("nat_debug_bridge_send_chan_closed") 75 | } 76 | return 77 | } 78 | var nw int 79 | nw, err = b.conn.Write(message) 80 | if config.DebugLevel >= config.LevelDebugBody { 81 | md5Val, _ := md5Byte(message) 82 | log.Println("nat_debug_write_proxy", md5Val, err, "\n", string(message)) 83 | } 84 | if err != nil { 85 | return 86 | } 87 | written += int64(nw) 88 | } 89 | } 90 | } 91 | 92 | // CopyBuffer 传输数据 93 | func (b *Bridge) CopyBuffer(dst io.Writer, src io.Reader, srcname string) (written int64, err error) { 94 | //如果设置过大会耗内存高,4k比较合理 95 | size := 4 * 1024 96 | buf := make([]byte, size) 97 | i := 0 98 | for { 99 | i++ 100 | if config.DebugLevel >= config.LevelDebug { 101 | log.Printf("%s bridge of %s proxy, n=%d\n", trace.ID(b.reqID), srcname, i) 102 | } 103 | nr, er := src.Read(buf) 104 | if nr > 0 { 105 | if config.DebugLevel >= config.LevelDebugBody { 106 | md5Val, _ := md5Byte(buf[0:nr]) 107 | log.Println("net_debug_copy_buffer", trace.ID(b.reqID), srcname, i, nr, md5Val) 108 | } 109 | nw, ew := dst.Write(buf[0:nr]) 110 | if nw > 0 { 111 | written += int64(nw) 112 | } 113 | if ew != nil { 114 | err = ew 115 | break 116 | } 117 | if nr != nw { 118 | err = io.ErrShortWrite 119 | break 120 | } 121 | } 122 | if er != nil { 123 | if er != io.EOF { 124 | err = er 125 | } 126 | if config.DebugLevel >= config.LevelDebug { 127 | log.Println("nat_debug_read_error", srcname, er) 128 | } 129 | break 130 | } 131 | 132 | } 133 | return written, err 134 | } 135 | -------------------------------------------------------------------------------- /nat/bridge_hub.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "log" 5 | "net" 6 | 7 | "github.com/keminar/anyproxy/config" 8 | ) 9 | 10 | // BridgeHub 桥接组 11 | type BridgeHub struct { 12 | // Registered clients. 13 | bridges map[*Bridge]bool 14 | 15 | // Inbound messages from the clients. 16 | broadcast chan *Message 17 | 18 | // Register requests from the clients. 19 | register chan *Bridge 20 | 21 | // Unregister requests from clients. 22 | unregister chan *Bridge 23 | } 24 | 25 | func newBridgeHub() *BridgeHub { 26 | // 无缓冲通道,保证并发安全 27 | return &BridgeHub{ 28 | broadcast: make(chan *Message), 29 | register: make(chan *Bridge), 30 | unregister: make(chan *Bridge), 31 | bridges: make(map[*Bridge]bool), 32 | } 33 | } 34 | 35 | func (h *BridgeHub) run() { 36 | for { 37 | select { 38 | case bridge := <-h.register: 39 | h.bridges[bridge] = true 40 | case bridge := <-h.unregister: 41 | if _, ok := h.bridges[bridge]; ok { 42 | delete(h.bridges, bridge) 43 | close(bridge.send) 44 | } 45 | case message := <-h.broadcast: 46 | if config.DebugLevel >= config.LevelDebug { 47 | log.Println("bridge nums", len(h.bridges)) 48 | } 49 | if config.DebugLevel >= config.LevelDebugBody { 50 | md5Val, _ := md5Byte(message.Body) 51 | log.Println("nat_debug_write_bridge_hub", message.ID, message.Method, md5Val) 52 | } 53 | Exit: 54 | for bridge := range h.bridges { 55 | if bridge.reqID != message.ID { 56 | continue 57 | } 58 | if message.Method == METHOD_CLOSE { 59 | close(bridge.send) 60 | delete(h.bridges, bridge) 61 | break Exit 62 | } 63 | select { 64 | case bridge.send <- message.Body: 65 | break Exit 66 | default: // 当send chan写不进时会走进default,防止某一个send卡着影响整个系统 67 | log.Println("net_bridge_send_chan_full", message.ID) 68 | close(bridge.send) 69 | delete(h.bridges, bridge) 70 | } 71 | } 72 | } 73 | } 74 | } 75 | 76 | // Register 注册 77 | func (h *BridgeHub) Register(c *Client, ID uint, conn *net.TCPConn) *Bridge { 78 | b := &Bridge{bridgeHub: h, reqID: ID, conn: conn, send: make(chan []byte, 100), client: c} 79 | h.register <- b 80 | return b 81 | } 82 | -------------------------------------------------------------------------------- /nat/client.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | "time" 8 | 9 | "github.com/gorilla/websocket" 10 | "github.com/keminar/anyproxy/config" 11 | "github.com/keminar/anyproxy/utils/trace" 12 | ) 13 | 14 | var interruptClose bool 15 | 16 | // Client is a middleman between the websocket connection and the hub. 17 | type Client struct { 18 | hub *Hub 19 | 20 | // The websocket connection. 21 | conn *websocket.Conn 22 | 23 | // Buffered channel of outbound messages. 24 | send chan *Message 25 | 26 | // 用户 27 | User string 28 | 29 | Email string 30 | 31 | // 订阅特征 32 | Subscribe []SubscribeMessage 33 | } 34 | 35 | // 写数据到websocket的对端 36 | func (c *Client) writePump() { 37 | ticker := time.NewTicker(pingPeriod) 38 | defer func() { 39 | ticker.Stop() 40 | c.conn.Close() 41 | }() 42 | for { 43 | select { 44 | case message, ok := <-c.send: //ok为判断channel是否关闭 45 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 46 | if !ok { 47 | log.Println("nat_debug_client_send_chan_close") 48 | // The hub closed the channel. 49 | c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 50 | return 51 | } 52 | 53 | w, err := c.conn.NextWriter(websocket.BinaryMessage) 54 | if err != nil { 55 | return 56 | } 57 | 58 | if config.DebugLevel >= config.LevelDebugBody { 59 | md5Val, _ := md5Byte(message.Body) 60 | log.Println("nat_debug_write_websocket", message.ID, message.Method, md5Val, "\n", string(message.Body)) 61 | } 62 | msgByte, _ := message.encode() 63 | w.Write(msgByte) 64 | if err := w.Close(); err != nil { 65 | return 66 | } 67 | case <-ticker.C: 68 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 69 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 70 | return 71 | } 72 | } 73 | } 74 | } 75 | 76 | // 服务器从websocket的客户端读取数据 77 | func (c *Client) serverReadPump() { 78 | defer func() { 79 | c.hub.unregister <- c 80 | c.conn.Close() 81 | }() 82 | c.conn.SetReadDeadline(time.Now().Add(pongWait)) 83 | c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 84 | for { 85 | _, p, err := c.conn.ReadMessage() 86 | if err != nil { 87 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 88 | log.Printf("nat_debug_read_message_error: %v", err) 89 | } 90 | break 91 | } 92 | msg, err := decodeMessage(p) 93 | if err != nil { 94 | log.Printf("nat_debug_decode_message_error: %v", err) 95 | break 96 | } 97 | if config.DebugLevel >= config.LevelDebugBody { 98 | md5Val, _ := md5Byte(msg.Body) 99 | log.Println("nat_debug_read_from_websocket", msg.ID, msg.Method, md5Val) 100 | } 101 | ServerBridge.broadcast <- msg 102 | } 103 | } 104 | 105 | // 本地从websocket服务端取数据 106 | func (c *Client) localReadPump() { 107 | for { 108 | _, p, err := c.conn.ReadMessage() 109 | if err != nil { 110 | log.Println("nat_local_debug_read_error", err.Error()) 111 | return 112 | } 113 | 114 | msg, err := decodeMessage(p) 115 | if err != nil { 116 | log.Println("nat_local_debug_decode_error", err.Error()) 117 | return 118 | } 119 | if config.DebugLevel >= config.LevelDebugBody { 120 | md5Val, _ := md5Byte(msg.Body) 121 | log.Println("nat_local_read_from_websocket_message", msg.ID, msg.Method, md5Val) 122 | } 123 | 124 | if msg.Method == METHOD_CREATE { 125 | proxConn := dialProxy() //创建本地与本地代理端口之间的连接 126 | b := LocalBridge.Register(c, msg.ID, proxConn.(*net.TCPConn)) 127 | go func() { 128 | written, err := b.WritePump() 129 | logCopyErr(trace.ID(msg.ID), "nat_local_debug websocket->local", err) 130 | if config.DebugLevel >= config.LevelDebug { 131 | log.Println(trace.ID(msg.ID), "nat debug response size", written) 132 | } 133 | }() 134 | 135 | // 从tcp返回数据到ws 136 | go func() { 137 | defer b.Unregister() 138 | readSize, err := b.CopyBuffer(b, proxConn, "local") 139 | logCopyErr(trace.ID(msg.ID), "nat_local_debug local->websocket", err) 140 | if config.DebugLevel >= config.LevelDebug { 141 | log.Println(trace.ID(msg.ID), "nat debug request body size", readSize) 142 | } 143 | b.CloseWrite() 144 | }() 145 | } else { 146 | LocalBridge.broadcast <- msg 147 | } 148 | } 149 | } 150 | 151 | func logCopyErr(traceID, name string, err error) { 152 | if err == nil { 153 | return 154 | } 155 | if config.DebugLevel >= config.LevelLong { 156 | log.Println(traceID, name, err.Error()) 157 | } else if err != io.EOF { 158 | log.Println(traceID, name, err.Error()) 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /nat/client_hub.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "log" 5 | 6 | "github.com/keminar/anyproxy/config" 7 | "github.com/keminar/anyproxy/proto/http" 8 | ) 9 | 10 | // Hub maintains the set of active clients and broadcasts messages to the 11 | // clients. 12 | type Hub struct { 13 | // Registered clients. 14 | clients map[*Client]bool 15 | 16 | // Inbound messages from the clients. 17 | broadcast chan *CMessage 18 | 19 | // Register requests from the clients. 20 | register chan *Client 21 | 22 | // Unregister requests from clients. 23 | unregister chan *Client 24 | } 25 | 26 | func newHub() *Hub { 27 | // 无缓冲通道,保证并发安全 28 | return &Hub{ 29 | broadcast: make(chan *CMessage), 30 | register: make(chan *Client), 31 | unregister: make(chan *Client), 32 | clients: make(map[*Client]bool), 33 | } 34 | } 35 | 36 | func (h *Hub) run() { 37 | for { 38 | select { 39 | case client := <-h.register: 40 | h.clients[client] = true 41 | case client := <-h.unregister: 42 | if _, ok := h.clients[client]; ok { 43 | close(client.send) 44 | delete(h.clients, client) 45 | log.Printf("client email %s disconnected, total client nums %d\n", client.Email, len(h.clients)) 46 | } 47 | case cmessage := <-h.broadcast: 48 | if config.DebugLevel >= config.LevelDebug { 49 | log.Println("client nums", len(h.clients)) 50 | } 51 | if config.DebugLevel >= config.LevelDebugBody { 52 | md5Val, _ := md5Byte(cmessage.message.Body) 53 | log.Println("nat_debug_write_client_hub", cmessage.message.ID, cmessage.message.Method, md5Val) 54 | } 55 | // 使用broadcast 无缓冲且不会关闭解决并发问题 56 | // 如果在外部直接写client.send,会与close()有并发安全冲突 57 | Exit: 58 | for client := range h.clients { 59 | if client != cmessage.client { 60 | continue 61 | } 62 | select { 63 | case client.send <- cmessage.message: 64 | break Exit 65 | default: // 当send chan写不进时会走进default,防止某一个send卡着影响整个系统 66 | close(client.send) 67 | delete(h.clients, client) 68 | log.Printf("net_client_send_chan_full, client email %s disconnected\n", client.Email) 69 | } 70 | } 71 | } 72 | } 73 | } 74 | 75 | // GetClient 获取某一个订阅者 76 | func (h *Hub) GetClient(header http.Header) *Client { 77 | for client := range h.clients { 78 | for _, s := range client.Subscribe { 79 | val := header.Get(s.Key) 80 | //log.Println("debug", client.Email, s.Key, s.Val, val) 81 | if val != "" && val == s.Val { 82 | return client 83 | } 84 | } 85 | } 86 | return nil 87 | } 88 | -------------------------------------------------------------------------------- /nat/conn.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | "net/http" 8 | "strings" 9 | "time" 10 | 11 | "github.com/gorilla/websocket" 12 | "github.com/keminar/anyproxy/utils/conf" 13 | "github.com/keminar/anyproxy/utils/tools" 14 | ) 15 | 16 | const ( 17 | // Time allowed to write a message to the peer. 18 | writeWait = 10 * time.Second 19 | 20 | // Time allowed to read the next pong message from the peer. 21 | pongWait = 60 * time.Second 22 | 23 | // Send pings to peer with this period. Must be less than pongWait. 24 | pingPeriod = (pongWait * 9) / 10 25 | ) 26 | 27 | var ( 28 | newline = []byte{'\n'} 29 | space = []byte{' '} 30 | ) 31 | 32 | var upgrader = websocket.Upgrader{ 33 | ReadBufferSize: 1024, 34 | WriteBufferSize: 1024, 35 | } 36 | 37 | // ServerHub 服务端的ws链接信息 38 | var ServerHub *Hub 39 | 40 | // ServerBridge 服务端的http与ws链接 41 | var ServerBridge *BridgeHub 42 | 43 | // serverStart 是否开启服务 44 | var serverStart = false 45 | 46 | // Eable 检查是否可以发送nat请求 47 | func Eable() bool { 48 | if !serverStart { 49 | return false 50 | } 51 | if len(ServerHub.clients) == 0 { 52 | return false 53 | } 54 | return true 55 | } 56 | 57 | // NewServer 开启服务 58 | func NewServer(addr *string) { 59 | ServerHub = newHub() 60 | go ServerHub.run() 61 | ServerBridge = newBridgeHub() 62 | go ServerBridge.run() 63 | serverStart = true 64 | 65 | http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { 66 | serveWs(ServerHub, w, r) 67 | }) 68 | 69 | log.Printf("Listening for websocket connections on %s\n", *addr) 70 | 71 | // 延迟启动 72 | time.Sleep(2 * time.Second) 73 | for i := 0; i < 1000; i++ { 74 | // 副服务,出错不退出并定时重试。方便主服务做平滑重启 75 | err := http.ListenAndServe(*addr, nil) 76 | if err != nil { 77 | log.Printf("ListenAndServe: num=%d, err=%v ,retry\n", i, err) 78 | } 79 | time.Sleep(10 * time.Second) 80 | } 81 | } 82 | 83 | // serveWs handles websocket requests from the peer. 84 | func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { 85 | conn, err := upgrader.Upgrade(w, r, nil) 86 | if err != nil { 87 | log.Println("serveWs", err) 88 | return 89 | } 90 | 91 | // 认证 92 | var user AuthMessage 93 | err = conn.ReadJSON(&user) 94 | if err != nil { 95 | // 客户端没配置user, email会主动断开 96 | log.Println("serveWs", "maybe client close", err) 97 | return 98 | } 99 | if user.Email == "" { // 增强验证 100 | log.Println("serveWs", "client email is empty") 101 | conn.WriteMessage(websocket.TextMessage, []byte("email error")) 102 | return 103 | } 104 | xtime := time.Now().Unix() 105 | if xtime-user.Xtime > 300 { 106 | log.Printf("serveWs client email %s ignore, xtime is error\n", user.Email) 107 | conn.WriteMessage(websocket.TextMessage, []byte("xtime err, please check local time")) 108 | return 109 | } 110 | if user.User != conf.RouterConfig.Websocket.User { 111 | log.Printf("serveWs client email %s ignore, user is error\n", user.Email) 112 | conn.WriteMessage(websocket.TextMessage, []byte("user err")) 113 | return 114 | } 115 | 116 | token, err := tools.Md5Str(fmt.Sprintf("%s|%s|%d", user.User, conf.RouterConfig.Websocket.Pass, user.Xtime)) 117 | if err != nil || user.Token != token { 118 | log.Printf("serveWs client email %s ignore, token is error\n", user.Email) 119 | conn.WriteMessage(websocket.TextMessage, []byte("token err")) 120 | return 121 | } 122 | conn.WriteMessage(websocket.TextMessage, []byte("ok")) 123 | 124 | // 订阅 125 | var tmpSub []SubscribeMessage 126 | err = conn.ReadJSON(&tmpSub) 127 | if err != nil { 128 | log.Printf("serveWs client email %s ignore, %v\n", user.Email, err) 129 | return 130 | } 131 | var subscribe []SubscribeMessage 132 | for _, sub := range tmpSub { 133 | if sub.Key != "" && sub.Val != "" { 134 | subscribe = append(subscribe, sub) 135 | } 136 | } 137 | if len(subscribe) == 0 { 138 | log.Printf("serveWs client email %s ignore, subscribe is empty\n", user.Email) 139 | conn.WriteMessage(websocket.TextMessage, []byte("subscribe empty err")) 140 | return 141 | } 142 | conn.WriteMessage(websocket.TextMessage, []byte("ok")) 143 | 144 | clientNum := len(hub.clients) 145 | // 注册连接 146 | client := &Client{hub: hub, conn: conn, send: make(chan *Message, SEND_CHAN_LEN), User: user.User, Email: user.Email, Subscribe: subscribe} 147 | client.hub.register <- client 148 | clientNum++ //这里不用len计算是因为chan异步不确认谁先执行 149 | 150 | remote := getIPAdress(r, []string{"X-Real-IP"}) 151 | log.Printf("serveWs client email %s ip %s connected, subscribe %v, total client nums %d\n", user.Email, remote, subscribe, clientNum) 152 | 153 | go client.writePump() 154 | go client.serverReadPump() 155 | } 156 | 157 | // getIPAdress 客户端IP 158 | func getIPAdress(req *http.Request, head []string) string { 159 | var ipAddress string 160 | // X-Forwarded-For容易被伪造,最好不用 161 | if len(head) == 0 { 162 | head = []string{"X-Real-IP"} 163 | } 164 | for _, h := range head { 165 | for _, ip := range strings.Split(req.Header.Get(h), ",") { 166 | ip = strings.TrimSpace(ip) 167 | realIP := net.ParseIP(ip) 168 | if realIP != nil { 169 | ipAddress = ip 170 | } 171 | } 172 | } 173 | if len(ipAddress) == 0 { 174 | ipAddress, _, _ = net.SplitHostPort(req.RemoteAddr) 175 | } 176 | return ipAddress 177 | } 178 | -------------------------------------------------------------------------------- /nat/handler.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "crypto/md5" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "os" 13 | "os/signal" 14 | "strings" 15 | "time" 16 | 17 | "github.com/keminar/anyproxy/utils/conf" 18 | "github.com/keminar/anyproxy/utils/tools" 19 | 20 | "github.com/gorilla/websocket" 21 | "github.com/keminar/anyproxy/config" 22 | ) 23 | 24 | // ClientHub 客户端的ws信息 25 | var ClientHub *Hub 26 | 27 | // LocalBridge 客户端的ws与http关系 28 | var LocalBridge *BridgeHub 29 | 30 | var tempDelay time.Duration 31 | 32 | // ConnectServer 连接到websocket服务 33 | func ConnectServer(addr *string) { 34 | if conf.RouterConfig.Websocket.User == "" || conf.RouterConfig.Websocket.Email == "" { 35 | log.Println("ws user or email empty, donot connect") 36 | return 37 | } 38 | interruptClose = false 39 | interrupt := make(chan os.Signal, 1) 40 | signal.Notify(interrupt, os.Interrupt) 41 | 42 | ClientHub = newHub() 43 | go ClientHub.run() 44 | LocalBridge = newBridgeHub() 45 | go LocalBridge.run() 46 | 47 | addrs := strings.Split(*addr, "://") 48 | if addrs[0] == "ws" && len(addrs) == 2 { 49 | *addr = addrs[1] 50 | } 51 | for { 52 | connect(addr, interrupt) 53 | if interruptClose { 54 | break 55 | } 56 | } 57 | } 58 | 59 | // 连接本地Proxy服务 60 | func dialProxy() net.Conn { 61 | connTimeout := time.Duration(5) * time.Second 62 | var err error 63 | localProxy := fmt.Sprintf("%s:%d", "127.0.0.1", config.ListenPort) 64 | proxyConn, err := net.DialTimeout("tcp", localProxy, connTimeout) 65 | if err != nil { 66 | log.Println("dial local proxy", err) 67 | } 68 | log.Printf("local websocket connecting to %s", localProxy) 69 | return proxyConn 70 | } 71 | 72 | // 认证连接并交换数据 73 | func connect(addr *string, interrupt chan os.Signal) { 74 | u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"} 75 | log.Printf("connecting to %s", u.String()) 76 | 77 | h := http.Header{} 78 | if conf.RouterConfig.Websocket.Host != "" { 79 | h.Add("Host", conf.RouterConfig.Websocket.Host) 80 | } 81 | c, _, err := websocket.DefaultDialer.Dial(u.String(), h) 82 | if err != nil { 83 | log.Println("dial:", err) 84 | time.Sleep(time.Duration(3) * time.Second) 85 | return 86 | } 87 | defer c.Close() 88 | 89 | w := newClientHandler(c) 90 | err = w.auth(conf.RouterConfig.Websocket.User, conf.RouterConfig.Websocket.Pass, conf.RouterConfig.Websocket.Email) 91 | if err != nil { 92 | log.Println("auth:", err) 93 | 94 | if tempDelay == 0 { 95 | tempDelay = 3 * time.Second 96 | } else { 97 | tempDelay *= 2 98 | } 99 | if max := 1 * time.Minute; tempDelay > max { 100 | tempDelay = max 101 | } 102 | time.Sleep(tempDelay) 103 | return 104 | } 105 | tempDelay = 0 106 | err = w.subscribe(conf.RouterConfig.Websocket.Subscribe) 107 | if err != nil { 108 | log.Println("subscribe:", err) 109 | time.Sleep(time.Duration(3) * time.Second) 110 | return 111 | } 112 | log.Println("websocket auth and subscribe ok") 113 | 114 | client := &Client{hub: ClientHub, conn: c, send: make(chan *Message, SEND_CHAN_LEN)} 115 | client.hub.register <- client 116 | defer func() { 117 | client.hub.unregister <- client 118 | }() 119 | 120 | go client.writePump() 121 | done := make(chan struct{}) 122 | go func() { //客户端的client.readRump 123 | defer close(done) 124 | client.localReadPump() 125 | }() 126 | 127 | for { 128 | select { 129 | case <-done: 130 | return 131 | case <-interrupt: 132 | log.Println("interrupt") 133 | interruptClose = true 134 | 135 | // Cleanly close the connection by sending a close message and then 136 | // waiting (with timeout) for the server to close the connection. 137 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 138 | if err != nil { 139 | log.Println("write close:", err) 140 | return 141 | } 142 | select { 143 | case <-done: 144 | case <-time.After(time.Second): 145 | } 146 | return 147 | } 148 | } 149 | } 150 | 151 | // ClientHandler 认证助手 152 | type ClientHandler struct { 153 | c *websocket.Conn 154 | } 155 | 156 | func newClientHandler(c *websocket.Conn) *ClientHandler { 157 | return &ClientHandler{c: c} 158 | } 159 | 160 | // auth 认证 161 | func (h *ClientHandler) auth(user string, pass string, email string) error { 162 | xtime := time.Now().Unix() 163 | token, err := tools.Md5Str(fmt.Sprintf("%s|%s|%d", user, pass, xtime)) 164 | if err != nil { 165 | return err 166 | } 167 | msg := AuthMessage{User: user, Token: token, Xtime: xtime, Email: email} 168 | return h.ask(&msg) 169 | } 170 | 171 | // subscribe 订阅 172 | func (h *ClientHandler) subscribe(sub []conf.Subscribe) error { 173 | msg := []SubscribeMessage{} 174 | for _, s := range sub { 175 | msg = append(msg, SubscribeMessage{Key: s.Key, Val: s.Val}) 176 | } 177 | return h.ask(&msg) 178 | } 179 | 180 | func (h *ClientHandler) ask(v interface{}) error { 181 | err := h.c.WriteJSON(v) 182 | if err != nil { 183 | return err 184 | } 185 | ticker := time.NewTicker(3 * time.Second) 186 | defer func() { 187 | ticker.Stop() 188 | }() 189 | 190 | send := make(chan []byte) 191 | go func() { 192 | defer close(send) 193 | _, message, _ := h.c.ReadMessage() 194 | send <- message 195 | }() 196 | select { 197 | case message, ok := <-send: //ok为判断channel是否关闭 198 | if !ok { 199 | return errors.New("fail") 200 | } 201 | if string(message) != "ok" { 202 | return errors.New("fail, " + string(message)) 203 | } 204 | case <-ticker.C: 205 | return errors.New("timeout") 206 | } 207 | return nil 208 | } 209 | 210 | // md5 211 | func md5Byte(data []byte) (string, error) { 212 | h := md5.New() 213 | h.Write(data) 214 | cipherStr := h.Sum(nil) 215 | return hex.EncodeToString(cipherStr), nil 216 | } 217 | -------------------------------------------------------------------------------- /nat/message.go: -------------------------------------------------------------------------------- 1 | package nat 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | ) 7 | 8 | // METHOD_CREATE 创建连接命令 9 | const METHOD_CREATE = "create" 10 | 11 | // METHOD_CLOSE 关闭连接命令 12 | const METHOD_CLOSE = "close" 13 | 14 | // SEND_CHAN_LEN 发送通道长度 15 | const SEND_CHAN_LEN = 200 16 | 17 | // AuthMessage 认证 18 | type AuthMessage struct { 19 | User string 20 | Email string 21 | Token string 22 | Xtime int64 23 | } 24 | 25 | // SubscribeMessage 订阅 26 | type SubscribeMessage struct { 27 | Key string 28 | Val string 29 | } 30 | 31 | // Message 普通消息体 32 | type Message struct { 33 | ID uint 34 | Method string 35 | Body []byte 36 | } 37 | 38 | // CMessage 普通消息体的复合类型,标记要向哪个Client发送 39 | type CMessage struct { 40 | client *Client 41 | message *Message 42 | } 43 | 44 | // 转成二进制 45 | func (m *Message) encode() ([]byte, error) { 46 | var buf bytes.Buffer 47 | enc := gob.NewEncoder(&buf) 48 | err := enc.Encode(*m) 49 | return buf.Bytes(), err 50 | } 51 | 52 | // 转成struct 53 | func decodeMessage(data []byte) (*Message, error) { 54 | var buf bytes.Buffer 55 | var m Message 56 | _, err := buf.Write(data) 57 | if err != nil { 58 | return &m, err 59 | } 60 | dec := gob.NewDecoder(&buf) 61 | err = dec.Decode(&m) 62 | return &m, err 63 | } 64 | -------------------------------------------------------------------------------- /proto/client.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "net" 8 | 9 | "github.com/keminar/anyproxy/utils/trace" 10 | ) 11 | 12 | // ClientHandler 客户端处理 13 | func ClientHandler(ctx context.Context, tcpConn *net.TCPConn) error { 14 | req := NewRequest(ctx, tcpConn) 15 | 16 | // test if the underlying fd is nil 17 | remoteAddr := tcpConn.RemoteAddr() 18 | if remoteAddr == nil { 19 | log.Println(trace.ID(req.ID), "ClientHandler(): oops, clientConn.fd is nil!") 20 | return errors.New("clientConn.fd is nil") 21 | } 22 | log.Println(trace.ID(req.ID), "remoteAddr:"+remoteAddr.String()) 23 | 24 | ok, err := req.ReadRequest("client") 25 | if err != nil && ok == false { 26 | log.Println("req err", err.Error()) 27 | return err 28 | } 29 | return req.Stream.response() 30 | } 31 | -------------------------------------------------------------------------------- /proto/http.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "fmt" 7 | "log" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/keminar/anyproxy/config" 13 | "github.com/keminar/anyproxy/crypto" 14 | "github.com/keminar/anyproxy/nat" 15 | "github.com/keminar/anyproxy/proto/http" 16 | "github.com/keminar/anyproxy/proto/text" 17 | "github.com/keminar/anyproxy/utils/conf" 18 | "github.com/keminar/anyproxy/utils/trace" 19 | ) 20 | 21 | // badRequestError is a literal string (used by in the server in HTML, 22 | // unescaped) to tell the user why their request was bad. It should 23 | // be plain text without user info or other embedded errors. 24 | type badRequestError string 25 | 26 | func (e badRequestError) Error() string { return "Bad Request: " + string(e) } 27 | 28 | type httpStream struct { 29 | req *Request 30 | Method string // http请求方法 31 | RequestURI string //读求原值,非解密值 32 | URL *url.URL //http请求地址信息 33 | Proto string //形如 http/1.0 或 http/1.1 34 | Host string //域名含端口 35 | Header http.Header //http请求头部 36 | FirstLine string //第一行字串 37 | BodyBuf []byte 38 | clientUnRead int 39 | tp *text.Reader 40 | } 41 | 42 | func newHTTPStream(req *Request) *httpStream { 43 | c := &httpStream{ 44 | req: req, 45 | } 46 | return c 47 | } 48 | 49 | // 检查是不是HTTP请求 50 | func (that *httpStream) validHead() bool { 51 | if that.req.reader.Buffered() < 8 { 52 | return false 53 | } 54 | tmpBuf, err := that.req.reader.Peek(8) 55 | if err != nil { 56 | return false 57 | } 58 | // 解析方法名 59 | s1 := bytes.IndexByte(tmpBuf, ' ') 60 | if s1 < 0 { 61 | return false 62 | } 63 | that.Method = strings.ToUpper(string(tmpBuf[:s1])) 64 | 65 | isHTTP := false 66 | allMethods := []string{"CONNECT", "OPTIONS", "DELETE", "TRACE", "POST", "HEAD", "GET", "PUT"} 67 | for _, one := range allMethods { 68 | if one == that.Method { 69 | isHTTP = true 70 | } 71 | } 72 | if isHTTP { 73 | return that.readFistLine() 74 | } 75 | return false 76 | } 77 | 78 | // 会在keep.go调用,所以要独立出来 79 | func (that *httpStream) readFistLine() bool { 80 | var err error 81 | // 下面是http的内容了,用封装的reader比较好按行取内容 82 | that.tp = text.NewReader(that.req.reader) 83 | // First line: GET /index.html HTTP/1.0 84 | if that.FirstLine, err = that.tp.ReadLine(true); err != nil { 85 | return false 86 | } 87 | 88 | var ok bool 89 | that.RequestURI, that.Proto, ok = parseRequestLine(that.FirstLine) 90 | if !ok { 91 | // 格式非http请求, 报错 92 | return false 93 | } 94 | if that.Proto != "HTTP/1.0" && that.Proto != "HTTP/1.1" { 95 | return false 96 | } 97 | return true 98 | } 99 | 100 | func (that *httpStream) readRequest(from string) (canProxy bool, err error) { 101 | rawurl := that.RequestURI 102 | if that.Method == "CONNECT" && from == "server" { 103 | key := []byte(getToken()) 104 | x1, err := base64.StdEncoding.DecodeString(that.RequestURI) 105 | if err != nil { 106 | return false, err 107 | } 108 | if len(x1) > 0 { 109 | x2, err := crypto.DecryptAES(x1, key) 110 | if err != nil { 111 | return false, err 112 | } 113 | rawurl = string(x2) 114 | } 115 | } 116 | if config.DebugLevel >= config.LevelDebug { 117 | log.Println(trace.ID(that.req.ID), "rawurl:", rawurl) 118 | } 119 | justAuthority := that.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/") 120 | addedScheme := false 121 | if justAuthority { 122 | //CONNECT是http的,如果RequestURI不是/开头,则为域名且不带http://, 这里补上 123 | rawurl = "http://" + rawurl 124 | addedScheme = true 125 | } 126 | 127 | if that.URL, err = url.ParseRequestURI(rawurl); err != nil { 128 | return false, err 129 | } 130 | 131 | // 读取http的头部信息 132 | // Subsequent lines: Key: value. 133 | that.Header, err = that.tp.ReadHeader() 134 | if err != nil { 135 | return false, err 136 | } 137 | 138 | // 首先header里的host可能会没传,有遇到taobao的个别CONNECT请求,所以优先使用that.URL.Host, 但这个也可能没传域名,比如 GET /test HTTP/1.1 这些情况都用原FirstLine值 139 | // 另外如果全信that.URL.Host,当手机代理走电脑再走iptables代理后访问百度贴吧有遇到首行中的域名被变成了ip请求会403。所以头部host也要看,当不一致时将FirstLine更新 140 | that.Host = that.URL.Host 141 | if that.URL.Host == "" { 142 | that.Host = that.Header.Get("Host") 143 | //在通过本地websocket接收服务端透传http请求再转发到charles遇到 144 | //如果遇到 Charles proxy malformed request url error 145 | //解决方法:在Charles 的 proxy 菜单下的 Proxy Settings. 开启选项 enable Transparent HTTP proxying. 146 | // 或者开启本软件的首行增加域名配置(firstLine.custom) 147 | that.URL.Host = that.Host 148 | that.URL.Scheme = "http" 149 | } else if that.Header.Get("Host") != "" { 150 | if that.Header.Get("Host") != that.URL.Host { 151 | if config.DebugLevel >= config.LevelDebug { 152 | fmt.Println(trace.ID(that.req.ID), "headerHost:", that.Header.Get("Host"), "urlHost:", that.URL.Host) 153 | fmt.Println(trace.ID(that.req.ID), "firstLine:", that.FirstLine) 154 | } 155 | // 有些header里的域名没带端口,拼接上端口 156 | that.Host = that.Header.Get("Host") 157 | if strings.Contains(that.URL.Host, ":") && !strings.Contains(that.Host, ":") { 158 | that.Host += ":" + that.URL.Port() 159 | } 160 | // 赋值回URL来生成RequestURI 161 | that.URL.Host = that.Host 162 | } 163 | } 164 | 165 | if addedScheme { 166 | // 去掉拼的http://, CONNECT请求首行单独处理 167 | that.RequestURI = that.URL.String() 168 | that.FirstLine = fmt.Sprintf("%s %s %s", that.Method, that.RequestURI[7:], that.Proto) 169 | } else { 170 | // 在代理部分vue本地开发环境时,有些用到websocket技术的请求首行带域名反而会404 171 | // 这种情况可以把域名配置在自定义配置里去掉首行域名 172 | if strings.ToLower(that.URL.Scheme) == "http" && firstLineHost(that.URL.Host) == "off" { 173 | if config.DebugLevel >= config.LevelDebug { 174 | log.Println(trace.ID(that.req.ID), "firstline host removed") 175 | } 176 | that.URL.Scheme = "" 177 | that.URL.Host = "" 178 | } 179 | that.RequestURI = that.URL.String() 180 | that.FirstLine = fmt.Sprintf("%s %s %s", that.Method, that.RequestURI, that.Proto) 181 | } 182 | 183 | that.readBody() 184 | that.getNameIPPort() 185 | 186 | //debug 187 | if config.DebugLevel >= config.LevelDebug { 188 | fmt.Println(trace.ID(that.req.ID), that.FirstLine) 189 | for k, v := range that.Header { 190 | fmt.Println(trace.ID(that.req.ID), k, "=", v) 191 | } 192 | fmt.Println(trace.ID(that.req.ID), string(that.BodyBuf)) 193 | } 194 | return true, nil 195 | } 196 | 197 | func firstLineHost(host string) string { 198 | host = strings.ReplaceAll(host, ":", ".") 199 | if val, ok := conf.RouterConfig.FirstLine.Custom[host]; ok { 200 | return val 201 | } 202 | if conf.RouterConfig.FirstLine.Host == "off" { 203 | return "off" 204 | } 205 | return "on" 206 | } 207 | 208 | func (that *httpStream) readBody() { 209 | that.clientUnRead = -1 210 | if that.Method == "CONNECT" { 211 | // 多层代理按长连接处理 212 | that.BodyBuf = that.req.reader.UnreadBuf(-1) 213 | return 214 | } 215 | if that.Proto == "HTTP/1.1" { 216 | //websocket 按长连接处理 217 | if test, ok := that.Header["Connection"]; ok && test[0] == "Upgrade" { 218 | that.BodyBuf = that.req.reader.UnreadBuf(-1) 219 | return 220 | } 221 | //todo chunk的暂没处理支持, 按长连接处理 222 | if _, ok := that.Header["Transfer-Encoding"]; ok { 223 | that.BodyBuf = that.req.reader.UnreadBuf(-1) 224 | return 225 | } 226 | // 主要处理IE复用链接请求不同域名的问题 227 | if contentLen, ok := that.Header["Content-Length"]; ok { 228 | if bodyLen, err := parseContentLength(contentLen[0]); err == nil { 229 | that.BodyBuf = that.req.reader.UnreadBuf(int(bodyLen)) 230 | that.clientUnRead = int(bodyLen) - len(that.BodyBuf) 231 | return 232 | } 233 | } 234 | //默认没有body,不需要读了,返回 235 | that.clientUnRead = 0 236 | return 237 | } 238 | // 其它按长连接处理 239 | that.BodyBuf = that.req.reader.UnreadBuf(-1) 240 | return 241 | } 242 | 243 | // getNameIPPort 分析请求目标 244 | func (that *httpStream) getNameIPPort() { 245 | splitStr := strings.Split(that.Host, ":") 246 | that.req.DstName = splitStr[0] 247 | if len(splitStr) == 2 { 248 | // 优先Host中的端口 249 | c, _ := strconv.ParseUint(splitStr[1], 0, 16) 250 | that.req.DstPort = uint16(c) 251 | if that.req.DstPort > 0 { 252 | return 253 | } 254 | } 255 | 256 | c, _ := strconv.ParseUint(that.URL.Port(), 0, 16) 257 | that.req.DstPort = uint16(c) 258 | if that.req.DstPort == 0 { 259 | if that.URL.Scheme == "https" { 260 | that.req.DstPort = 443 261 | } else { 262 | that.req.DstPort = 80 263 | } 264 | } 265 | } 266 | 267 | // Request 请求地址 268 | func (that *httpStream) Request() string { 269 | if that.RequestURI[0] == '/' { 270 | return that.Host + that.RequestURI 271 | } 272 | return that.RequestURI 273 | } 274 | 275 | // badRequest 400响应 276 | func (that *httpStream) badRequest(err error) { 277 | 278 | const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" 279 | 280 | publicErr := "400 Bad Request" 281 | if err != nil { 282 | publicErr = "400 Bad Request" + ": " + err.Error() 283 | } 284 | 285 | fmt.Fprintf(that.req.conn, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) 286 | } 287 | 288 | func (that *httpStream) response() error { 289 | specialHeader := "Anyproxy-Action" 290 | if config.DebugLevel >= config.LevelDebug { 291 | log.Println(trace.ID(that.req.ID), "nat server status:", nat.Eable(), ",special header:", that.Header.Get(specialHeader)) 292 | } 293 | if that.Method != "CONNECT" && nat.Eable() { //CONNECT 请求不支持ws转发 294 | if that.Header.Get(specialHeader) == "websocket" { 295 | that.Header.Del(specialHeader) 296 | tunnel := newWsTunnel(that.req, that.Header) 297 | if tunnel.getTarget(that.req.DstName) { 298 | // 先将请求头部发出 299 | tunnel.buffer.Write([]byte(fmt.Sprintf("%s\r\n", that.FirstLine))) 300 | that.Header.Write(tunnel.buffer) 301 | tunnel.buffer.Write([]byte("\r\n")) 302 | // 多读取的body部分 303 | tunnel.buffer.Write(that.BodyBuf) 304 | ok := tunnel.transfer() 305 | if ok { 306 | that.showIP("WS") 307 | return nil 308 | } 309 | // 请求不成,则走普通转发 310 | } 311 | } 312 | } 313 | tunnel := newTunnel(that.req) 314 | if that.Method == "CONNECT" { 315 | that.showIP("CONNECT") 316 | err := tunnel.handshake(protoHTTPS, that.req.DstName, "", that.req.DstPort) 317 | if err != nil { 318 | log.Println(trace.ID(that.req.ID), "handshake err", err.Error()) 319 | return err 320 | } 321 | // 遇到过后端连不上先输出established导致某手机app闪退 322 | // 所以要在能连上后端的情况下再输出 323 | _, err = that.req.conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) 324 | if err != nil { 325 | log.Println(trace.ID(that.req.ID), "write err", err.Error()) 326 | return err 327 | } 328 | tunnel.transfer(-1) 329 | } else { 330 | that.showIP("HTTP") 331 | err := tunnel.handshake(protoHTTP, that.req.DstName, "", that.req.DstPort) 332 | if err != nil { 333 | log.Println(trace.ID(that.req.ID), "handshake err", err.Error()) 334 | return err 335 | } 336 | 337 | // 先将请求头部发出 338 | tunnel.Write([]byte(fmt.Sprintf("%s\r\n", that.FirstLine))) 339 | that.Header.Write(tunnel) 340 | tunnel.Write([]byte("\r\n")) 341 | // 多读取的body部分 342 | tunnel.Write(that.BodyBuf) 343 | 344 | tunnel.transfer(that.clientUnRead) 345 | } 346 | return nil 347 | } 348 | 349 | func (that *httpStream) showIP(method string) { 350 | if method == "CONNECT" { 351 | log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", method, that.req.conn.RemoteAddr().String(), that.req.DstName, that.req.DstPort)) 352 | } else { 353 | log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s", method, that.req.conn.RemoteAddr().String(), that.Request())) 354 | } 355 | } 356 | 357 | // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. 358 | func parseRequestLine(line string) (requestURI, proto string, ok bool) { 359 | s1 := strings.Index(line, " ") 360 | s2 := strings.Index(line[s1+1:], " ") 361 | if s1 < 0 || s2 < 0 { 362 | return 363 | } 364 | s2 += s1 + 1 365 | return line[s1+1 : s2], line[s2+1:], true 366 | } 367 | 368 | // parseContentLength trims whitespace from s and returns -1 if no value 369 | // is set, or the value if it's >= 0. 370 | func parseContentLength(cl string) (int64, error) { 371 | cl = strings.TrimSpace(cl) 372 | if cl == "" { 373 | return -1, nil 374 | } 375 | n, err := strconv.ParseInt(cl, 10, 64) 376 | if err != nil || n < 0 { 377 | return 0, fmt.Errorf("bad Content-Length %s", cl) 378 | } 379 | return n, nil 380 | } 381 | -------------------------------------------------------------------------------- /proto/http/header.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "io" 5 | "net/http/httptrace" 6 | "net/textproto" 7 | "sort" 8 | "strings" 9 | "sync" 10 | ) 11 | 12 | // A Header represents a MIME-style header mapping 13 | // keys to sets of values. 14 | type Header map[string][]string 15 | 16 | // Add adds the key, value pair to the header. 17 | // It appends to any existing values associated with key. 18 | func (h Header) Add(key, value string) { 19 | key = CanonicalMIMEHeaderKey(key) 20 | h[key] = append(h[key], value) 21 | } 22 | 23 | // Set sets the header entries associated with key to 24 | // the single element value. It replaces any existing 25 | // values associated with key. 26 | func (h Header) Set(key, value string) { 27 | h[CanonicalMIMEHeaderKey(key)] = []string{value} 28 | } 29 | 30 | // Get gets the first value associated with the given key. 31 | // It is case insensitive; CanonicalMIMEHeaderKey is used 32 | // to canonicalize the provided key. 33 | // If there are no values associated with the key, Get returns "". 34 | // To access multiple values of a key, or to use non-canonical keys, 35 | // access the map directly. 36 | func (h Header) Get(key string) string { 37 | if h == nil { 38 | return "" 39 | } 40 | v := h[CanonicalMIMEHeaderKey(key)] 41 | if len(v) == 0 { 42 | return "" 43 | } 44 | return v[0] 45 | } 46 | 47 | // Del deletes the values associated with key. 48 | func (h Header) Del(key string) { 49 | delete(h, CanonicalMIMEHeaderKey(key)) 50 | } 51 | 52 | // Write writes a header in wire format. 53 | func (h Header) Write(w io.Writer) error { 54 | return h.write(w, nil) 55 | } 56 | 57 | func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error { 58 | return h.writeSubset(w, nil, trace) 59 | } 60 | 61 | // CanonicalMIMEHeaderKey 转换常见的头部为固定格式,其它不变 62 | // 已符合规则的和包含特殊字符的不转 63 | func CanonicalMIMEHeaderKey(s string) string { 64 | // Quick check for canonical encoding. 65 | upper := true 66 | for i := 0; i < len(s); i++ { 67 | c := s[i] 68 | if !validHeaderFieldByte(c) { 69 | return s 70 | } 71 | if upper && 'a' <= c && c <= 'z' { 72 | return canonicalMIMEHeaderKey([]byte(s)) 73 | } 74 | if !upper && 'A' <= c && c <= 'Z' { 75 | return canonicalMIMEHeaderKey([]byte(s)) 76 | } 77 | //遇到中划线切换大小写? 78 | upper = c == '-' 79 | } 80 | return s 81 | } 82 | 83 | const toLower = 'a' - 'A' 84 | 85 | // validHeaderFieldByte reports whether b is a valid byte in a header 86 | // field name. RFC 7230 says: 87 | // header-field = field-name ":" OWS field-value OWS 88 | // field-name = token 89 | // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / 90 | // "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA 91 | // token = 1*tchar 92 | func validHeaderFieldByte(b byte) bool { 93 | return int(b) < len(isTokenTable) && isTokenTable[b] 94 | } 95 | 96 | // canonicalMIMEHeaderKey 转换常见的头部为固定格式,其它不变 97 | // 包含特殊字符的不转,已符合规则的也会产生临时变量 98 | func canonicalMIMEHeaderKey(a []byte) string { 99 | // See if a looks like a header key. If not, return it unchanged. 100 | for _, c := range a { 101 | if validHeaderFieldByte(c) { 102 | continue 103 | } 104 | // Don't canonicalize. 105 | return string(a) 106 | } 107 | 108 | upper := true 109 | tmp := make([]byte, len(a)) 110 | for i, c := range a { 111 | // Canonicalize: first letter upper case 112 | // and upper case after each dash. 113 | // (Host, User-Agent, If-Modified-Since). 114 | // MIME headers are ASCII only, so no Unicode issues. 115 | if upper && 'a' <= c && c <= 'z' { 116 | c -= toLower 117 | } else if !upper && 'A' <= c && c <= 'Z' { 118 | c += toLower 119 | } 120 | tmp[i] = c 121 | upper = c == '-' // for next time 122 | } 123 | // The compiler recognizes m[string(byteSlice)] as a special 124 | // case, so a copy of a's bytes into a new string does not 125 | // happen in this map lookup: 126 | if v := commonHeader[string(tmp)]; v != "" { 127 | return v 128 | } 129 | // 除了commonHeader之外,原来长什么样还返回什么样 130 | return string(a) 131 | } 132 | 133 | // commonHeader interns common header strings. 134 | var commonHeader = make(map[string]string) 135 | 136 | func init() { 137 | for _, v := range []string{ 138 | "Accept", 139 | "Accept-Charset", 140 | "Accept-Encoding", 141 | "Accept-Language", 142 | "Accept-Ranges", 143 | "Cache-Control", 144 | "Cc", 145 | "Connection", 146 | "Content-Id", 147 | "Content-Language", 148 | "Content-Length", 149 | "Content-Transfer-Encoding", 150 | "Content-Type", 151 | "Cookie", 152 | "Date", 153 | "Dkim-Signature", 154 | "Etag", 155 | "Expires", 156 | "From", 157 | "Host", 158 | "If-Modified-Since", 159 | "If-None-Match", 160 | "In-Reply-To", 161 | "Last-Modified", 162 | "Location", 163 | "Message-Id", 164 | "Mime-Version", 165 | "Pragma", 166 | "Received", 167 | "Return-Path", 168 | "Server", 169 | "Set-Cookie", 170 | "Subject", 171 | "To", 172 | "User-Agent", 173 | "Via", 174 | "X-Forwarded-For", 175 | "X-Imforwards", 176 | "X-Powered-By", 177 | } { 178 | commonHeader[v] = v 179 | } 180 | } 181 | 182 | // isTokenTable is a copy of net/http/lex.go's isTokenTable. 183 | // See https://httpwg.github.io/specs/rfc7230.html#rule.token.separators 184 | var isTokenTable = [127]bool{ 185 | '!': true, 186 | '#': true, 187 | '$': true, 188 | '%': true, 189 | '&': true, 190 | '\'': true, 191 | '*': true, 192 | '+': true, 193 | '-': true, 194 | '.': true, 195 | '0': true, 196 | '1': true, 197 | '2': true, 198 | '3': true, 199 | '4': true, 200 | '5': true, 201 | '6': true, 202 | '7': true, 203 | '8': true, 204 | '9': true, 205 | 'A': true, 206 | 'B': true, 207 | 'C': true, 208 | 'D': true, 209 | 'E': true, 210 | 'F': true, 211 | 'G': true, 212 | 'H': true, 213 | 'I': true, 214 | 'J': true, 215 | 'K': true, 216 | 'L': true, 217 | 'M': true, 218 | 'N': true, 219 | 'O': true, 220 | 'P': true, 221 | 'Q': true, 222 | 'R': true, 223 | 'S': true, 224 | 'T': true, 225 | 'U': true, 226 | 'W': true, 227 | 'V': true, 228 | 'X': true, 229 | 'Y': true, 230 | 'Z': true, 231 | '^': true, 232 | '_': true, 233 | '`': true, 234 | 'a': true, 235 | 'b': true, 236 | 'c': true, 237 | 'd': true, 238 | 'e': true, 239 | 'f': true, 240 | 'g': true, 241 | 'h': true, 242 | 'i': true, 243 | 'j': true, 244 | 'k': true, 245 | 'l': true, 246 | 'm': true, 247 | 'n': true, 248 | 'o': true, 249 | 'p': true, 250 | 'q': true, 251 | 'r': true, 252 | 's': true, 253 | 't': true, 254 | 'u': true, 255 | 'v': true, 256 | 'w': true, 257 | 'x': true, 258 | 'y': true, 259 | 'z': true, 260 | '|': true, 261 | '~': true, 262 | } 263 | 264 | var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") 265 | 266 | // stringWriter implements WriteString on a Writer. 267 | type stringWriter struct { 268 | w io.Writer 269 | } 270 | 271 | func (w stringWriter) WriteString(s string) (n int, err error) { 272 | return w.w.Write([]byte(s)) 273 | } 274 | 275 | type keyValues struct { 276 | key string 277 | values []string 278 | } 279 | 280 | // A headerSorter implements sort.Interface by sorting a []keyValues 281 | // by key. It's used as a pointer, so it can fit in a sort.Interface 282 | // interface value without allocation. 283 | type headerSorter struct { 284 | kvs []keyValues 285 | } 286 | 287 | func (s *headerSorter) Len() int { return len(s.kvs) } 288 | func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } 289 | func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } 290 | 291 | var headerSorterPool = sync.Pool{ 292 | New: func() interface{} { return new(headerSorter) }, 293 | } 294 | 295 | // sortedKeyValues returns h's keys sorted in the returned kvs 296 | // slice. The headerSorter used to sort is also returned, for possible 297 | // return to headerSorterCache. 298 | func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { 299 | hs = headerSorterPool.Get().(*headerSorter) 300 | if cap(hs.kvs) < len(h) { 301 | hs.kvs = make([]keyValues, 0, len(h)) 302 | } 303 | kvs = hs.kvs[:0] 304 | for k, vv := range h { 305 | if !exclude[k] { 306 | kvs = append(kvs, keyValues{k, vv}) 307 | } 308 | } 309 | hs.kvs = kvs 310 | sort.Sort(hs) 311 | return kvs, hs 312 | } 313 | 314 | // WriteSubset writes a header in wire format. 315 | // If exclude is not nil, keys where exclude[key] == true are not written. 316 | func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { 317 | return h.writeSubset(w, exclude, nil) 318 | } 319 | 320 | func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error { 321 | ws, ok := w.(io.StringWriter) 322 | if !ok { 323 | ws = stringWriter{w} 324 | } 325 | kvs, sorter := h.sortedKeyValues(exclude) 326 | var formattedVals []string 327 | for _, kv := range kvs { 328 | for _, v := range kv.values { 329 | v = headerNewlineToSpace.Replace(v) 330 | v = textproto.TrimString(v) 331 | for _, s := range []string{kv.key, ": ", v, "\r\n"} { 332 | if _, err := ws.WriteString(s); err != nil { 333 | headerSorterPool.Put(sorter) 334 | return err 335 | } 336 | } 337 | if trace != nil && trace.WroteHeaderField != nil { 338 | formattedVals = append(formattedVals, v) 339 | } 340 | } 341 | if trace != nil && trace.WroteHeaderField != nil { 342 | trace.WroteHeaderField(kv.key, formattedVals) 343 | formattedVals = nil 344 | } 345 | } 346 | headerSorterPool.Put(sorter) 347 | return nil 348 | } 349 | -------------------------------------------------------------------------------- /proto/keep.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net" 9 | 10 | "github.com/keminar/anyproxy/utils/trace" 11 | ) 12 | 13 | // KeepHandler HTTP/1.1复用处理 14 | func KeepHandler(ctx context.Context, tcpConn *net.TCPConn, buf []byte) error { 15 | req := NewRequestWithBuf(ctx, tcpConn, buf) 16 | 17 | // test if the underlying fd is nil 18 | remoteAddr := tcpConn.RemoteAddr() 19 | if remoteAddr == nil { 20 | log.Println(trace.ID(req.ID), "ClientHandler(): oops, clientConn.fd is nil!") 21 | return errors.New("clientConn.fd is nil") 22 | } 23 | // 日志方便查询有走到keep.go的记录 24 | log.Println(trace.ID(req.ID), "remoteAddr:"+remoteAddr.String()) 25 | 26 | // 和client.go统一代码好维护 27 | ok, err := req.ReadRequest("client") 28 | if err != nil && !ok { 29 | log.Println(trace.ID(req.ID), "req err", err.Error()) 30 | return err 31 | } 32 | // 增加一个协议判断的日志 33 | if req.Proto != "http" { 34 | err = fmt.Errorf("is not http request %s: %s", req.Proto, string(buf)) 35 | log.Println(trace.ID(req.ID), err.Error()) 36 | return err 37 | } 38 | return req.Stream.response() 39 | } 40 | 41 | // 预判后面的包是不是http链接 42 | func isKeepAliveHttp(ctx context.Context, tcpConn *net.TCPConn, buf []byte) bool { 43 | req := NewRequestWithBuf(ctx, tcpConn, buf) 44 | req.ReadRequest("client") 45 | return req.Proto == "http" 46 | } 47 | -------------------------------------------------------------------------------- /proto/request.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/keminar/anyproxy/utils/conf" 8 | 9 | "github.com/keminar/anyproxy/grace" 10 | "github.com/keminar/anyproxy/proto/tcp" 11 | ) 12 | 13 | // AesToken 加密密钥, 必须16位长度 14 | var AesToken = "anyproxyproxyany" 15 | 16 | // Request 请求类 17 | type Request struct { 18 | ID uint 19 | ctx context.Context 20 | conn *net.TCPConn 21 | reader *tcp.Reader 22 | Proto string //http 23 | 24 | Stream stream 25 | DstName string //目标域名 26 | DstIP string //目标ip 27 | DstPort uint16 //目标端口 28 | } 29 | 30 | // NewRequest 请求类 31 | func NewRequest(ctx context.Context, conn *net.TCPConn) *Request { 32 | // 取traceID 33 | traceID, _ := ctx.Value(grace.TraceIDContextKey).(uint) 34 | c := &Request{ 35 | ctx: ctx, 36 | ID: traceID, 37 | conn: conn, 38 | reader: tcp.NewReader(conn), 39 | } 40 | return c 41 | } 42 | 43 | // NewRequestWithBuf 请求类,前带buf内容 44 | func NewRequestWithBuf(ctx context.Context, conn *net.TCPConn, buf []byte) *Request { 45 | // 取traceID 46 | traceID, _ := ctx.Value(grace.TraceIDContextKey).(uint) 47 | c := &Request{ 48 | ctx: ctx, 49 | ID: traceID, 50 | conn: conn, 51 | reader: tcp.NewReaderWithBuf(conn, buf), 52 | } 53 | return c 54 | } 55 | 56 | // ReadRequest 分析请求内容 57 | func (that *Request) ReadRequest(from string) (canProxy bool, err error) { 58 | //如果启用了tcpcopy 且目标地址也有配置,则进行tcpcopy转发 59 | if conf.RouterConfig.TcpCopy.Enable { 60 | if conf.RouterConfig.TcpCopy.IP != "" && conf.RouterConfig.TcpCopy.Port > 0 { 61 | s := newTCPCopy(that) 62 | that.Proto = "tcp" 63 | that.Stream = s 64 | return s.readRequest(from) 65 | } 66 | } 67 | _, err = that.reader.Peek(1) 68 | if err != nil { 69 | return false, err 70 | } 71 | 72 | var s stream 73 | protos := []string{"http", "socks5"} 74 | for _, v := range protos { 75 | switch v { 76 | case "http": 77 | s = newHTTPStream(that) 78 | if s.validHead() { 79 | that.Proto = v 80 | break 81 | } 82 | case "socks5": 83 | s = newSocks5Stream(that) 84 | if s.validHead() { 85 | that.Proto = v 86 | break 87 | } 88 | } 89 | if that.Proto != "" { 90 | break 91 | } 92 | } 93 | if that.Proto == "" { 94 | s = newTCPStream(that) 95 | that.Proto = "tcp" 96 | } 97 | that.Stream = s 98 | return s.readRequest(from) 99 | } 100 | 101 | // 加密Token 102 | func getToken() string { 103 | if conf.RouterConfig.Token == "" { 104 | return AesToken 105 | } 106 | return conf.RouterConfig.Token 107 | } 108 | -------------------------------------------------------------------------------- /proto/server.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "net" 8 | 9 | "github.com/keminar/anyproxy/utils/trace" 10 | ) 11 | 12 | // ServerHandler 服务端处理 13 | func ServerHandler(ctx context.Context, tcpConn *net.TCPConn) error { 14 | req := NewRequest(ctx, tcpConn) 15 | 16 | // test if the underlying fd is nil 17 | remoteAddr := tcpConn.RemoteAddr() 18 | if remoteAddr == nil { 19 | log.Println(trace.ID(req.ID), "ClientHandler(): oops, clientConn.fd is nil!") 20 | return errors.New("clientConn.fd is nil") 21 | } 22 | log.Println(trace.ID(req.ID), "remoteAddr:"+remoteAddr.String()) 23 | 24 | ok, err := req.ReadRequest("server") 25 | if err != nil && ok == false { 26 | log.Println("req err", err.Error()) 27 | return err 28 | } 29 | 30 | // server 只支持通过client/server和server连接,后续还要加安全密钥检查 31 | if req.Proto != "http" { 32 | return errors.New("Not http method") 33 | } 34 | stream, ok := req.Stream.(*httpStream) 35 | if !ok || stream.Method != "CONNECT" { 36 | return errors.New("Not CONNECT method") 37 | } 38 | return req.Stream.response() 39 | } 40 | -------------------------------------------------------------------------------- /proto/socks5.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net" 11 | "strconv" 12 | 13 | "github.com/keminar/anyproxy/utils/trace" 14 | ) 15 | 16 | type socks5Stream struct { 17 | req *Request 18 | } 19 | 20 | func newSocks5Stream(req *Request) *socks5Stream { 21 | c := &socks5Stream{ 22 | req: req, 23 | } 24 | return c 25 | } 26 | 27 | func (that *socks5Stream) validHead() bool { 28 | if that.req.reader.Buffered() < 2 { 29 | return false 30 | } 31 | 32 | tmpBuf, err := that.req.reader.Peek(2) 33 | if err != nil { 34 | return false 35 | } 36 | 37 | isSocks5 := len(tmpBuf) >= 2 && tmpBuf[0] == 0x05 38 | if isSocks5 { 39 | // 如果是SOCKS5则把已读信息从缓存区释放掉 40 | that.req.reader.UnreadBuf(-1) 41 | } 42 | return isSocks5 43 | } 44 | 45 | func (that *socks5Stream) readRequest(from string) (canProxy bool, err error) { 46 | if err = that.ParseHeader(); err != nil { 47 | return false, err 48 | } 49 | return true, nil 50 | } 51 | 52 | func (that *socks5Stream) response() error { 53 | tunnel := newTunnel(that.req) 54 | 55 | var err error 56 | // 发送socks5应答 57 | _, err = that.req.conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 58 | if err != nil { 59 | log.Println(trace.ID(that.req.ID), "write err", err.Error()) 60 | return err 61 | } 62 | 63 | that.showIP() 64 | err = tunnel.handshake(protoTCP, that.req.DstName, that.req.DstIP, that.req.DstPort) 65 | if err != nil { 66 | log.Println(trace.ID(that.req.ID), "handshake err", err.Error()) 67 | return err 68 | } 69 | 70 | tunnel.transfer(-1) 71 | return nil 72 | } 73 | 74 | func (that *socks5Stream) showIP() { 75 | if that.req.DstName != "" { 76 | log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstName, that.req.DstPort)) 77 | } else { 78 | log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", "Socks5", that.req.conn.RemoteAddr().String(), that.req.DstIP, that.req.DstPort)) 79 | } 80 | } 81 | 82 | // parsing socks5 header, and return address and parsing error 83 | func (that *socks5Stream) ParseHeader() error { 84 | // response to socks5 client 85 | // see rfc 1982 for more details (https://tools.ietf.org/html/rfc1928) 86 | n, err := that.req.conn.Write([]byte{0x05, 0x00}) // version and no authentication required 87 | if err != nil { 88 | return err 89 | } 90 | 91 | // step2: process client Requests and does Reply 92 | /** 93 | +----+-----+-------+------+----------+----------+ 94 | |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | 95 | +----+-----+-------+------+----------+----------+ 96 | | 1 | 1 | X'00' | 1 | Variable | 2 | 97 | +----+-----+-------+------+----------+----------+ 98 | */ 99 | var buffer [1024]byte 100 | n, err = that.req.reader.Read(buffer[:]) 101 | if err != nil { 102 | return err 103 | } 104 | if n < 6 { 105 | return errors.New("not a socks protocol") 106 | } 107 | 108 | switch buffer[3] { 109 | case 0x01: 110 | // ipv4 address 111 | ipv4 := make([]byte, 4) 112 | if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv4, len(ipv4)); err != nil { 113 | return err 114 | } 115 | //fmt.Println(1) 116 | that.req.DstIP = net.IP(ipv4).String() 117 | case 0x04: 118 | // ipv6 119 | ipv6 := make([]byte, 16) 120 | if _, err := io.ReadAtLeast(bytes.NewReader(buffer[4:]), ipv6, len(ipv6)); err != nil { 121 | return err 122 | } 123 | that.req.DstIP = net.IP(ipv6).String() 124 | case 0x03: 125 | // domain 126 | addrLen := int(buffer[4]) 127 | domain := make([]byte, addrLen) 128 | if _, err := io.ReadAtLeast(bytes.NewReader(buffer[5:]), domain, addrLen); err != nil { 129 | return err 130 | } 131 | //fmt.Println(2) 132 | that.req.DstName = string(domain) 133 | } 134 | 135 | port := make([]byte, 2) 136 | err = binary.Read(bytes.NewReader(buffer[n-2:n]), binary.BigEndian, &port) 137 | if err != nil { 138 | return err 139 | } 140 | 141 | portStr := strconv.Itoa((int(port[0]) << 8) | int(port[1])) 142 | c, err := strconv.ParseUint(portStr, 0, 16) 143 | if err != nil { 144 | return err 145 | } 146 | that.req.DstPort = uint16(c) 147 | return nil 148 | } 149 | -------------------------------------------------------------------------------- /proto/stats/counter.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "log" 5 | "runtime" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | ) 10 | 11 | type Counter struct { 12 | access sync.RWMutex 13 | name string 14 | active int64 // 活跃时间, 判断计数器是否可以清理 15 | minute int // 打印日志时间, 当前分钟数不再打印 16 | value int64 17 | } 18 | 19 | func (c *Counter) Add(delta int64) int64 { 20 | defer func() { 21 | if err := recover(); err != nil { 22 | const size = 32 << 10 23 | buf := make([]byte, size) 24 | buf = buf[:runtime.Stack(buf, false)] 25 | log.Printf("panic stats: %v\n%s", err, buf) 26 | } 27 | }() 28 | c.access.Lock() 29 | defer c.access.Unlock() 30 | tmp := atomic.AddInt64(&c.value, delta) 31 | 32 | now := time.Now().Minute() 33 | if now != c.minute { 34 | // 打印上一分钟的上行下行字节数 35 | if tmp > 1e6 { 36 | log.Println(c.name, tmp/1e6, "MB") 37 | } else if tmp > 1e3 { 38 | log.Println(c.name, tmp/1e3, "KB") 39 | } else { 40 | log.Println(c.name, tmp, "Bytes") 41 | } 42 | c.minute = now 43 | c.active = time.Now().Unix() 44 | tmp = atomic.SwapInt64(&c.value, 0) 45 | } 46 | return tmp 47 | } 48 | -------------------------------------------------------------------------------- /proto/stats/stats.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type Manager struct { 10 | access sync.RWMutex 11 | counters map[string]*Counter 12 | } 13 | 14 | func NewManager() *Manager { 15 | m := &Manager{ 16 | counters: make(map[string]*Counter), 17 | } 18 | return m 19 | } 20 | 21 | func (m *Manager) RegisterCounter(name string) *Counter { 22 | m.access.Lock() 23 | defer m.access.Unlock() 24 | 25 | if _, found := m.counters[name]; found { 26 | m.counters[name].active = time.Now().Unix() 27 | return m.counters[name] 28 | } 29 | c := new(Counter) 30 | c.name = name 31 | m.counters[name] = c 32 | return c 33 | } 34 | 35 | func (m *Manager) UnregisterCounter() { 36 | m.access.Lock() 37 | defer m.access.Unlock() 38 | 39 | now := time.Now().Unix() 40 | 41 | for _, v := range m.counters { 42 | if now-v.active > 300 { 43 | delete(m.counters, v.name) 44 | } 45 | } 46 | log.Println("stats links:", len(m.counters)) 47 | } 48 | -------------------------------------------------------------------------------- /proto/stream.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | 8 | "github.com/keminar/anyproxy/proto/tcp" 9 | "github.com/keminar/anyproxy/utils/trace" 10 | ) 11 | 12 | const SO_ORIGINAL_DST = 80 13 | 14 | type stream interface { 15 | validHead() bool 16 | readRequest(from string) (canProxy bool, err error) 17 | response() error 18 | } 19 | 20 | type tcpStream struct { 21 | req *Request 22 | } 23 | 24 | func newTCPStream(req *Request) *tcpStream { 25 | c := &tcpStream{ 26 | req: req, 27 | } 28 | return c 29 | } 30 | 31 | func (that *tcpStream) validHead() bool { 32 | return true 33 | } 34 | func (that *tcpStream) readRequest(from string) (canProxy bool, err error) { 35 | return true, nil 36 | } 37 | 38 | // 处理iptables转发的流量 39 | func (that *tcpStream) response() error { 40 | tunnel := newTunnel(that.req) 41 | var err error 42 | var newTCPConn *net.TCPConn 43 | that.req.DstIP, that.req.DstPort, newTCPConn, err = GetOriginalDstAddr(that.req.conn) 44 | if err != nil { 45 | log.Println(trace.ID(that.req.ID), "GetOriginalDstAddr err", err.Error()) 46 | return err 47 | } 48 | defer newTCPConn.Close() 49 | 50 | that.showIP("TCP") 51 | err = tunnel.handshake(protoTCP, "", that.req.DstIP, uint16(that.req.DstPort)) 52 | if err != nil { 53 | log.Println(trace.ID(that.req.ID), "dail err", err.Error()) 54 | return err 55 | } 56 | 57 | // 将前面读的字节补上 58 | tmpBuf := that.req.reader.UnreadBuf(-1) 59 | tunnel.Write(tmpBuf) 60 | // 切换为新连接 61 | reader := tcp.NewReader(newTCPConn) 62 | that.req.reader = reader 63 | that.req.conn = newTCPConn 64 | tunnel.transfer(-1) 65 | return nil 66 | } 67 | 68 | func (that *tcpStream) showIP(method string) { 69 | log.Println(trace.ID(that.req.ID), fmt.Sprintf("%s %s -> %s:%d", method, that.req.conn.RemoteAddr().String(), that.req.DstIP, that.req.DstPort)) 70 | } 71 | -------------------------------------------------------------------------------- /proto/stream_addr.go: -------------------------------------------------------------------------------- 1 | // 条件编译 https://segmentfault.com/a/1190000017846997 2 | 3 | // +build !windows 4 | 5 | package proto 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "net" 11 | "strings" 12 | "syscall" 13 | 14 | "github.com/keminar/anyproxy/config" 15 | ) 16 | 17 | // GetOriginalDstAddr 目标 18 | func GetOriginalDstAddr(tcpConn *net.TCPConn) (dstIP string, dstPort uint16, newTCPConn *net.TCPConn, err error) { 19 | if tcpConn == nil { 20 | err = errors.New("ERR: tcpConn is nil") 21 | return 22 | } 23 | 24 | // test if the underlying fd is nil 25 | if tcpConn.RemoteAddr() == nil { 26 | err = errors.New("ERR: clientConn.fd is nil") 27 | return 28 | } 29 | 30 | srcipport := fmt.Sprintf("%v", tcpConn.RemoteAddr()) 31 | 32 | newTCPConn = nil 33 | // connection => file, will make a copy 34 | // 会使得连接变成阻塞模式,需要自己手动 close 原来的 tcp 连接 35 | tcpConnFile, err := tcpConn.File() 36 | if err != nil { 37 | err = fmt.Errorf("GETORIGINALDST|%v->?->FAILEDTOBEDETERMINED|ERR: %v", srcipport, err) 38 | return 39 | } 40 | // 旧链接关闭 41 | tcpConn.Close() 42 | // 文件句柄关闭 43 | defer tcpConnFile.Close() 44 | 45 | mreq, err := syscall.GetsockoptIPv6Mreq(int(tcpConnFile.Fd()), syscall.IPPROTO_IP, SO_ORIGINAL_DST) 46 | if err != nil { 47 | err = fmt.Errorf("GETORIGINALDST|%v->?->FAILEDTOBEDETERMINED|ERR: getsocketopt(SO_ORIGINAL_DST) failed: %v", srcipport, err) 48 | return 49 | } 50 | 51 | // 开新连接 52 | newConn, err := net.FileConn(tcpConnFile) 53 | if err != nil { 54 | err = fmt.Errorf("GETORIGINALDST|%v->?->%v|ERR: could not create a FileConn from clientConnFile=%+v: %v", srcipport, mreq, tcpConnFile, err) 55 | return 56 | } 57 | if _, ok := newConn.(*net.TCPConn); ok { 58 | newTCPConn = newConn.(*net.TCPConn) 59 | 60 | // only support ipv4 61 | dstIP = net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7]).String() 62 | dstPort = uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3]) 63 | 64 | ipArr := strings.Split(srcipport, ":") 65 | // 来源和目标地址是同一个ip,且目标端口和本服务是同一个端口 66 | if ipArr[0] == dstIP && dstPort == config.ListenPort { 67 | err = fmt.Errorf("may be loop call: %s=>%s:%d", srcipport, dstIP, dstPort) 68 | } 69 | return 70 | } 71 | err = fmt.Errorf("GETORIGINALDST|%v|ERR: newConn is not a *net.TCPConn, instead it is: %T (%v)", srcipport, newConn, newConn) 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /proto/stream_windows.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | // GetOriginalDstAddr 目标 9 | func GetOriginalDstAddr(tcpConn *net.TCPConn) (dstIP string, dstPort uint16, newTCPConn *net.TCPConn, err error) { 10 | err = errors.New("ERR: windows can not work") 11 | return 12 | } 13 | -------------------------------------------------------------------------------- /proto/tcp/reader.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | ) 8 | 9 | const ( 10 | defaultBufSize = 4096 11 | minReadBufferSize = 16 12 | maxConsecutiveEmptyReads = 100 13 | ) 14 | 15 | var ( 16 | errNegativeRead = errors.New("tcpReader: reader returned negative count from Read") 17 | //ErrBufferFull Buffer is full 18 | ErrBufferFull = errors.New("tcpReader: buffer full") 19 | //ErrNegativeCount negative count 20 | ErrNegativeCount = errors.New("tcpReader: negative count") 21 | ) 22 | 23 | // A Reader implements convenience methods for reading requests 24 | // or responses from a text protocol network connection. 25 | type Reader struct { 26 | buf []byte 27 | rd io.Reader 28 | r, w int // buf read and write positions 29 | err error 30 | lastByte int // last byte read for UnreadByte; -1 means invalid 31 | } 32 | 33 | // NewReader returns a new Reader whose buffer has the default size. 34 | func NewReader(rd io.Reader) *Reader { 35 | return NewReaderSize(rd, defaultBufSize) 36 | } 37 | 38 | // NewReaderWithBuf 带有前置buf内容的Reader实例 39 | func NewReaderWithBuf(rd io.Reader, buf []byte) *Reader { 40 | r := new(Reader) 41 | r.reset(buf, rd) 42 | r.w = len(buf) 43 | return r 44 | } 45 | 46 | // NewReaderSize returns a new Reader whose buffer has at least the specified 47 | // size. If the argument io.Reader is already a Reader with large enough 48 | // size, it returns the underlying Reader. 49 | func NewReaderSize(rd io.Reader, size int) *Reader { 50 | // Is it already a Reader? 51 | b, ok := rd.(*Reader) 52 | if ok && len(b.buf) >= size { 53 | return b 54 | } 55 | if size < minReadBufferSize { 56 | size = minReadBufferSize 57 | } 58 | r := new(Reader) 59 | r.reset(make([]byte, size), rd) 60 | return r 61 | } 62 | 63 | func (b *Reader) reset(buf []byte, r io.Reader) { 64 | *b = Reader{ 65 | buf: buf, 66 | rd: r, 67 | lastByte: -1, 68 | } 69 | } 70 | 71 | func (b *Reader) readErr() error { 72 | err := b.err 73 | b.err = nil 74 | return err 75 | } 76 | 77 | // 一次性读一些数据,如果读出的数据大于要返回的数据则放入buf,否则不放buf 78 | func (b *Reader) Read(p []byte) (n int, err error) { 79 | n = len(p) 80 | if n == 0 { 81 | return 0, errors.New("read buf len is 0") 82 | } 83 | //当buf中已经没有未读内容 84 | if b.r == b.w { 85 | if b.err != nil { 86 | return 0, b.readErr() 87 | } 88 | if len(p) >= len(b.buf) { 89 | //要读取的内容大于buf长度,则直接从网络读取 90 | //因为没有先存于buf再copy到p的必要了 91 | n, b.err = b.rd.Read(p) 92 | if n < 0 { 93 | panic(errNegativeRead) 94 | } 95 | if n > 0 { 96 | b.lastByte = int(p[n-1]) 97 | } 98 | return n, b.readErr() 99 | } 100 | // 一次性读取b.buf长度内容,再分多次读到p中 101 | // 这里不使用b.fill方法,因为b.fill会循环读取 102 | b.r = 0 103 | b.w = 0 104 | n, b.err = b.rd.Read(b.buf) 105 | if n < 0 { 106 | panic(errNegativeRead) 107 | } 108 | if n == 0 { 109 | return 0, b.readErr() 110 | } 111 | b.w += n 112 | } 113 | // 从buf的未读内容中读取尽量多的内容到p 114 | n = copy(p, b.buf[b.r:b.w]) 115 | b.r += n 116 | b.lastByte = int(b.buf[b.r-1]) 117 | return n, nil 118 | } 119 | 120 | //ReadLine 在buf中查换换行符并截断返回, 找不到就返回buf 121 | func (b *Reader) ReadLine(dropBreak bool) (line []byte, isPrefix bool, err error) { 122 | line, err = b.ReadSlice('\n') 123 | if err == ErrBufferFull { 124 | // Handle the case where "\r\n" straddles the buffer. 125 | if len(line) > 0 && line[len(line)-1] == '\r' { 126 | // Put the '\r' back on buf and drop it from line. 127 | // Let the next call to ReadLine check for "\r\n". 128 | if b.r == 0 { 129 | // should be unreachable 130 | panic("bufio: tried to rewind past start of buffer") 131 | } 132 | //将读取位置前移1,把\r放回buf中,此时buf中只有1位数据 133 | b.r-- 134 | line = line[:len(line)-1] 135 | } 136 | return line, true, nil 137 | } 138 | 139 | if len(line) == 0 { 140 | if err != nil { 141 | line = nil 142 | } 143 | return 144 | } 145 | err = nil 146 | 147 | if dropBreak { 148 | if line[len(line)-1] == '\n' { 149 | drop := 1 150 | if len(line) > 1 && line[len(line)-2] == '\r' { 151 | drop = 2 152 | } 153 | line = line[:len(line)-drop] 154 | } 155 | } 156 | return 157 | } 158 | 159 | // Size returns the size of the underlying buffer in bytes. 160 | func (b *Reader) Size() int { return len(b.buf) } 161 | 162 | // Buffered returns the number of bytes that can be read from the current buffer. 163 | func (b *Reader) Buffered() int { return b.w - b.r } 164 | 165 | // ReadSlice reads until the first occurrence of delim in the input, 166 | // returning a slice pointing at the bytes in the buffer. 167 | // The bytes stop being valid at the next read. 168 | // If ReadSlice encounters an error before finding a delimiter, 169 | // it returns all the data in the buffer and the error itself (often io.EOF). 170 | // ReadSlice fails with error ErrBufferFull if the buffer fills without a delim. 171 | // Because the data returned from ReadSlice will be overwritten 172 | // by the next I/O operation, most clients should use 173 | // ReadBytes or ReadString instead. 174 | // ReadSlice returns err != nil if and only if line does not end in delim. 175 | func (b *Reader) ReadSlice(delim byte) (line []byte, err error) { 176 | s := 0 // search start index 177 | for { 178 | // Search buffer. 179 | if i := bytes.IndexByte(b.buf[b.r+s:b.w], delim); i >= 0 { 180 | i += s 181 | line = b.buf[b.r : b.r+i+1] 182 | b.r += i + 1 183 | break 184 | } 185 | 186 | // Pending error? 187 | if b.err != nil { 188 | line = b.buf[b.r:b.w] 189 | b.r = b.w 190 | err = b.readErr() 191 | break 192 | } 193 | 194 | // Buffer full? 195 | if b.Buffered() >= len(b.buf) { 196 | b.r = b.w 197 | line = b.buf 198 | err = ErrBufferFull 199 | break 200 | } 201 | 202 | s = b.w - b.r // do not rescan area we scanned before 203 | 204 | // buf未满,继续填充数据 205 | b.fill() // buffer is not full 206 | } 207 | 208 | // Handle last byte, if any. 209 | if i := len(line) - 1; i >= 0 { 210 | b.lastByte = int(line[i]) 211 | } 212 | 213 | return 214 | } 215 | 216 | // fill 填充buf数据. 217 | func (b *Reader) fill() { 218 | // Slide existing data to beginning. 219 | if b.r > 0 { 220 | copy(b.buf, b.buf[b.r:b.w]) 221 | b.w -= b.r 222 | b.r = 0 223 | } 224 | 225 | if b.w >= len(b.buf) { 226 | panic("bufio: tried to fill full buffer") 227 | } 228 | 229 | // Read new data: try a limited number of times. 230 | for i := maxConsecutiveEmptyReads; i > 0; i-- { 231 | n, err := b.rd.Read(b.buf[b.w:]) 232 | if n < 0 { 233 | panic(errNegativeRead) 234 | } 235 | b.w += n 236 | if err != nil { 237 | b.err = err 238 | return 239 | } 240 | if n > 0 { 241 | return 242 | } 243 | } 244 | b.err = io.ErrNoProgress 245 | } 246 | 247 | // Peek 返回当前读取位置后面的N个字节,如果不够会调用fill填充buf 248 | // 和Read不同的是Peek不会更新buf里读到的内容为已读 249 | func (b *Reader) Peek(n int) ([]byte, error) { 250 | if n < 0 { 251 | return nil, ErrNegativeCount 252 | } 253 | 254 | b.lastByte = -1 255 | 256 | for b.w-b.r < n && b.w-b.r < len(b.buf) && b.err == nil { 257 | b.fill() // b.w-b.r < len(b.buf) => buffer is not full 258 | } 259 | 260 | if n > len(b.buf) { 261 | return b.buf[b.r:b.w], ErrBufferFull 262 | } 263 | 264 | // 0 <= n <= len(b.buf) 265 | var err error 266 | if avail := b.w - b.r; avail < n { 267 | // not enough data in buffer 268 | n = avail 269 | err = b.readErr() 270 | if err == nil { 271 | err = ErrBufferFull 272 | } 273 | } 274 | return b.buf[b.r : b.r+n], err 275 | } 276 | 277 | // UnreadBuf 获取已读到buf但未从buf读走的内容 278 | func (b *Reader) UnreadBuf(max int) (data []byte) { 279 | if max == 0 { 280 | return 281 | } 282 | if max > 0 && b.Buffered() > max { 283 | data = b.buf[b.r:(b.r + max)] 284 | b.r = b.r + max 285 | } else { // max = -1 286 | data = b.buf[b.r:b.w] 287 | b.r = b.w 288 | } 289 | if i := len(data) - 1; i >= 0 { 290 | b.lastByte = int(data[i]) 291 | } 292 | return 293 | } 294 | -------------------------------------------------------------------------------- /proto/tcpcopy.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | 7 | "github.com/keminar/anyproxy/utils/conf" 8 | "github.com/keminar/anyproxy/utils/trace" 9 | ) 10 | 11 | type tcpCopy struct { 12 | req *Request 13 | } 14 | 15 | func newTCPCopy(req *Request) *tcpCopy { 16 | c := &tcpCopy{ 17 | req: req, 18 | } 19 | return c 20 | } 21 | 22 | func (that *tcpCopy) validHead() bool { 23 | return true 24 | } 25 | func (that *tcpCopy) readRequest(from string) (canProxy bool, err error) { 26 | return true, nil 27 | } 28 | 29 | func (that *tcpCopy) response() error { 30 | tunnel := newTunnel(that.req) 31 | if ip, ok := tunnel.isAllowed([]string{}); !ok { 32 | return errors.New(ip + " is not allowed") 33 | } 34 | var err error 35 | that.req.DstIP = conf.RouterConfig.TcpCopy.IP 36 | that.req.DstPort = conf.RouterConfig.TcpCopy.Port 37 | 38 | network, connAddr := tunnel.buildAddress("", that.req.DstIP, that.req.DstPort, true) 39 | if connAddr == "" { 40 | err = errors.New("target address is empty") 41 | return err 42 | } 43 | err = tunnel.dail(network, connAddr, 0) 44 | if err != nil { 45 | log.Println(trace.ID(that.req.ID), "dail err", err.Error()) 46 | return err 47 | } 48 | tunnel.curState = stateNew 49 | 50 | tunnel.transfer(-1) 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /proto/text/reader.go: -------------------------------------------------------------------------------- 1 | package text 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/keminar/anyproxy/proto/http" 7 | "github.com/keminar/anyproxy/proto/tcp" 8 | ) 9 | 10 | // A Reader implements convenience methods for reading requests 11 | // or responses from a text protocol network connection. 12 | type Reader struct { 13 | R *tcp.Reader 14 | buf []byte // a re-usable buffer for readContinuedLineSlice 15 | } 16 | 17 | // NewReader returns a new Reader reading from r. 18 | // 19 | // To avoid denial of service attacks, the provided bufio.Reader 20 | // should be reading from an io.LimitReader or similar Reader to bound 21 | // the size of responses. 22 | func NewReader(r *tcp.Reader) *Reader { 23 | return &Reader{R: r} 24 | } 25 | 26 | // ReadLine reads a single line from r, 27 | // eliding the final \n or \r\n from the returned string. 28 | func (r *Reader) ReadLine(dropBreak bool) (string, error) { 29 | line, err := r.readLineSlice(dropBreak) 30 | return string(line), err 31 | } 32 | 33 | // ReadLineBytes is like ReadLine but returns a []byte instead of a string. 34 | func (r *Reader) ReadLineBytes(dropBreak bool) ([]byte, error) { 35 | line, err := r.readLineSlice(dropBreak) 36 | if line != nil { 37 | buf := make([]byte, len(line)) 38 | copy(buf, line) 39 | line = buf 40 | } 41 | return line, err 42 | } 43 | 44 | func (r *Reader) readLineSlice(dropBreak bool) ([]byte, error) { 45 | var line []byte 46 | for { 47 | l, more, err := r.R.ReadLine(dropBreak) 48 | if err != nil { 49 | return nil, err 50 | } 51 | // Avoid the copy if the first call produced a full line. 52 | if line == nil && !more { 53 | return l, nil 54 | } 55 | line = append(line, l...) 56 | if !more { 57 | break 58 | } 59 | } 60 | return line, nil 61 | } 62 | 63 | // Trim 去掉头和尾的空格和\t 64 | // It does not assume Unicode or UTF-8. 65 | func Trim(s []byte) []byte { 66 | i := 0 67 | for i < len(s) && (s[i] == ' ' || s[i] == '\t') { 68 | i++ 69 | } 70 | n := len(s) 71 | for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { 72 | n-- 73 | } 74 | return s[i:n] 75 | } 76 | 77 | // DropBreak 去掉尾部换行 78 | func DropBreak(line []byte) []byte { 79 | if line[len(line)-1] == '\n' { 80 | drop := 1 81 | if len(line) > 1 && line[len(line)-2] == '\r' { 82 | drop = 2 83 | } 84 | line = line[:len(line)-drop] 85 | } 86 | return line 87 | } 88 | 89 | // upcomingHeaderNewlines returns an approximation of the number of newlines 90 | // that will be in this header. If it gets confused, it returns 0. 91 | func (r *Reader) upcomingHeaderNewlines() (n int) { 92 | // Try to determine the 'hint' size. 93 | r.R.Peek(1) // force a buffer load if empty 94 | s := r.R.Buffered() 95 | if s == 0 { 96 | return 97 | } 98 | peek, _ := r.R.Peek(s) 99 | for len(peek) > 0 { 100 | i := bytes.IndexByte(peek, '\n') 101 | if i < 3 { 102 | // Not present (-1) or found within the next few bytes, 103 | // implying we're at the end ("\r\n\r\n" or "\n\n") 104 | return 105 | } 106 | n++ 107 | peek = peek[i+1:] 108 | } 109 | return 110 | } 111 | 112 | // A ProtocolError describes a protocol violation such 113 | // as an invalid response or a hung-up connection. 114 | type ProtocolError string 115 | 116 | func (p ProtocolError) Error() string { 117 | return string(p) 118 | } 119 | 120 | // ReadHeader reads a MIME-style header from r. 121 | // The header is a sequence of possibly continued Key: Value lines 122 | // ending in a blank line. 123 | // The returned map m maps CanonicalMIMEHeaderKey(key) to a 124 | // sequence of values in the same order encountered in the input. 125 | // 126 | // For example, consider this input: 127 | // 128 | // My-Key: Value 1 129 | // Long-Key: Even 130 | // Longer Value 131 | // My-Key: Value 2 132 | // 133 | // Given that input, ReadMIMEHeader returns the map: 134 | // 135 | // map[string][]string{ 136 | // "My-Key": {"Value 1", "Value 2"}, 137 | // "Long-Key": {"Even Longer Value"}, 138 | // } 139 | // 140 | func (r *Reader) ReadHeader() (http.Header, error) { 141 | var strs []string 142 | hint := r.upcomingHeaderNewlines() 143 | if hint > 0 { 144 | strs = make([]string, hint) 145 | } 146 | 147 | m := make(http.Header, hint) 148 | 149 | // The first line cannot start with a leading space. 150 | if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { 151 | line, err := r.readLineSlice(false) 152 | if err != nil { 153 | return m, err 154 | } 155 | return m, ProtocolError("malformed MIME header initial line: " + string(line)) 156 | } 157 | var headerIsEnd bool 158 | lastEnd := make([]byte, 2) 159 | for { 160 | if headerIsEnd { 161 | return m, nil 162 | } 163 | kv, err := r.readLineSlice(false) 164 | if len(kv) == 0 { 165 | return m, err 166 | } 167 | // 发现头结束符,检查上一行是不是也是有换行符 168 | if len(kv) == 2 && kv[len(kv)-2] == '\r' && kv[len(kv)-1] == '\n' { 169 | if lastEnd[0] == '\r' && lastEnd[1] == '\n' { 170 | headerIsEnd = true 171 | continue 172 | } 173 | } 174 | // 记录当前行尾字符,为下一行检查提供帮助 175 | if len(kv) >= 2 { 176 | copy(lastEnd, kv[len(kv)-2:]) 177 | } 178 | kv = DropBreak(kv) 179 | if len(kv) == 0 { 180 | return m, nil 181 | } 182 | 183 | // Key ends at first colon; should not have trailing spaces 184 | // but they appear in the wild, violating specs, so we remove 185 | // them if present. 186 | i := bytes.IndexByte(kv, ':') 187 | if i < 0 { 188 | return m, ProtocolError("malformed MIME header line: " + string(kv)) 189 | } 190 | endKey := i 191 | //跳过:前的空格 192 | for endKey > 0 && kv[endKey-1] == ' ' { 193 | endKey-- 194 | } 195 | key := http.CanonicalMIMEHeaderKey(string(kv[:endKey])) 196 | 197 | // As per RFC 7230 field-name is a token, tokens consist of one or more chars. 198 | // We could return a ProtocolError here, but better to be liberal in what we 199 | // accept, so if we get an empty key, skip it. 200 | if key == "" { 201 | continue 202 | } 203 | 204 | // 跳过: 后的空格 205 | i++ // skip colon 206 | for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { 207 | i++ 208 | } 209 | value := string(kv[i:]) 210 | 211 | vv := m[key] 212 | if vv == nil && len(strs) > 0 { 213 | // More than likely this will be a single-element key. 214 | // Most headers aren't multi-valued. 215 | // Set the capacity on strs[0] to 1, so any future append 216 | // won't extend the slice into the other strings. 217 | vv, strs = strs[:1:1], strs[1:] 218 | vv[0] = value 219 | m[key] = vv 220 | } else { 221 | m[key] = append(vv, value) 222 | } 223 | 224 | if err != nil { 225 | return m, err 226 | } 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /proto/tunnel.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "bufio" 5 | "encoding/base64" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "github.com/keminar/anyproxy/proto/stats" 16 | 17 | "github.com/keminar/anyproxy/config" 18 | "github.com/keminar/anyproxy/crypto" 19 | "github.com/keminar/anyproxy/proto/tcp" 20 | "github.com/keminar/anyproxy/utils/cache" 21 | "github.com/keminar/anyproxy/utils/conf" 22 | "github.com/keminar/anyproxy/utils/tools" 23 | "github.com/keminar/anyproxy/utils/trace" 24 | "golang.org/x/net/proxy" 25 | ) 26 | 27 | const ( 28 | stateNew int = iota 29 | stateActive 30 | stateClosed 31 | stateIdle 32 | ) 33 | 34 | const protoTCP = "tcp" 35 | const protoHTTP = "http" 36 | const protoHTTPS = "https" 37 | 38 | // 上行统计 39 | var inbound *stats.Manager 40 | 41 | // 下行统计 42 | var outbound *stats.Manager 43 | 44 | func init() { 45 | inbound = stats.NewManager() 46 | outbound = stats.NewManager() 47 | go func() { 48 | ticker := time.NewTicker(1 * time.Minute) 49 | defer ticker.Stop() 50 | for range ticker.C { 51 | //log.Println("ticker...") 52 | inbound.UnregisterCounter() 53 | outbound.UnregisterCounter() 54 | } 55 | }() 56 | } 57 | 58 | // 转发实体 59 | type tunnel struct { 60 | req *Request 61 | conn *net.TCPConn // 后端服务 62 | curState int 63 | 64 | inboundIP string // 来源IP 65 | 66 | inbountCounter *stats.Counter 67 | outbountCounter *stats.Counter 68 | 69 | readSize int64 70 | writeSize int64 71 | 72 | clientUnRead int 73 | 74 | buf []byte 75 | } 76 | 77 | // newTunnel 实例 78 | func newTunnel(req *Request) *tunnel { 79 | s := &tunnel{ 80 | req: req, 81 | } 82 | 83 | s.inboundIP = tools.GetRemoteIp(req.conn.RemoteAddr().String()) 84 | return s 85 | } 86 | 87 | // copyBuffer 传输数据 88 | func (s *tunnel) copyBuffer(dst io.Writer, src *tcp.Reader, srcname string) (written int64, err error) { 89 | //如果设置过大会耗内存高,4k比较合理 90 | size := 4 * 1024 91 | buf := make([]byte, size) 92 | i := 0 93 | for { 94 | i++ 95 | if config.DebugLevel >= config.LevelDebug { 96 | log.Printf("%s receive from %s, n=%d\n", trace.ID(s.req.ID), srcname, i) 97 | } 98 | nr, er := src.Read(buf) 99 | if nr > 0 { 100 | // 如果为HTTP/1.1的Keep-alive情况下 101 | if srcname == "request" && s.clientUnRead >= 0 { 102 | // 之前已读完,说明要建新链接 或是 升级为长链接 103 | if s.clientUnRead == 0 { 104 | // 如果包是http协议则认为http复用 105 | if isKeepAliveHttp(s.req.ctx, s.req.conn, buf[0:nr]) { 106 | // 关闭与旧的服务器的连接的写 107 | s.conn.CloseWrite() 108 | // 状态变成已空闲,不能为关闭,会导致下面逻辑的Client也被关闭 109 | s.curState = stateIdle 110 | 111 | //todo 如果域名不同跳出交换数据, 因为这个逻辑会出现N次,应该在http.go实现 112 | //fmt.Println(string(buf[0:nr])) 113 | s.buf = make([]byte, nr) 114 | copy(s.buf, buf[0:nr]) 115 | break 116 | } else { 117 | //可能是http upgrade为websocket, 保持交换数据 118 | //比如经过nginx proxy -> 本程序 -> 旧版本的centrifugo 119 | s.clientUnRead = -1 120 | } 121 | } else { 122 | // 未读完 123 | s.clientUnRead -= nr 124 | } 125 | } 126 | if config.DebugLevel >= config.LevelDebugBody { 127 | log.Printf("%s receive from %s, n=%d, data len: %d\n", trace.ID(s.req.ID), srcname, i, nr) 128 | fmt.Println(trace.ID(s.req.ID), string(buf[0:nr])) 129 | } 130 | nw, ew := dst.Write(buf[0:nr]) 131 | if nw > 0 { 132 | written += int64(nw) 133 | if srcname == "request" { 134 | s.inbountCounter.Add(int64(nw)) 135 | } else { 136 | s.outbountCounter.Add(int64(nw)) 137 | } 138 | } 139 | if ew != nil { 140 | err = ew 141 | break 142 | } 143 | if nr != nw { 144 | err = io.ErrShortWrite 145 | break 146 | } 147 | } 148 | if er != nil { 149 | if er != io.EOF { 150 | err = er 151 | } else { 152 | s.logCopyErr(srcname+" read", er) 153 | if srcname == "server" { 154 | // 技巧:keep-alive 复用连接时写,后端收到CloseWrite后响应EOF,当收到EOF时说明body都收完了。 155 | if s.curState == stateIdle { 156 | //可以开始复用了, 带上之前读过的缓存 157 | KeepHandler(s.req.ctx, s.req.conn, s.buf) 158 | break 159 | } else if s.curState != stateClosed { 160 | // 如果非客户端导致的服务端关闭,则关闭客户端读 161 | // Notice: 如果只是CloseRead(),当在windows上执行时,且是做为订阅端从服务器收到请求再转到charles 162 | // 等服务时,当请求的地址返回足够长的内容时会触发卡住问题。 163 | // 流程如 curl -> anyproxy(server) -> ws -> anyproxy(windows) -> charles 164 | // 用Close()可以解决卡住,不过客户端会收到use of closed network connection的错误提醒 165 | dst.(*net.TCPConn).Close() 166 | } 167 | } 168 | } 169 | 170 | if srcname == "request" { 171 | // 当客户端断开或出错了,服务端也不用再读了,可以关闭,解决读Server卡住不能到EOF的问题 172 | s.conn.CloseWrite() 173 | s.curState = stateClosed 174 | } 175 | break 176 | } 177 | } 178 | return written, err 179 | } 180 | 181 | // transfer 交换数据 182 | func (s *tunnel) transfer(clientUnRead int) { 183 | if config.DebugLevel >= config.LevelLong { 184 | log.Println(trace.ID(s.req.ID), "transfer start") 185 | } 186 | s.curState = stateActive 187 | s.clientUnRead = clientUnRead 188 | done := make(chan struct{}) 189 | 190 | //发送请求 191 | go func() { 192 | defer func() { 193 | close(done) 194 | }() 195 | //不能和外层共用err 196 | var err error 197 | s.readSize, err = s.copyBuffer(s.conn, s.req.reader, "request") 198 | s.logCopyErr("request->server", err) 199 | if config.DebugLevel >= config.LevelLong { 200 | log.Println(trace.ID(s.req.ID), "request body size", s.readSize) 201 | } 202 | }() 203 | 204 | var err error 205 | //取返回结果 206 | s.writeSize, err = s.copyBuffer(s.req.conn, tcp.NewReader(s.conn), "server") 207 | s.logCopyErr("server->request", err) 208 | 209 | <-done 210 | // 不管是不是正常结束,只要server结束了,函数就会返回,然后底层会自动断开与client的连接 211 | if config.DebugLevel >= config.LevelLong { 212 | log.Println(trace.ID(s.req.ID), "transfer finished, response size", s.writeSize) 213 | } 214 | } 215 | 216 | // 上行写入 217 | func (s *tunnel) Write(p []byte) (n int, err error) { 218 | n, err = s.conn.Write(p) 219 | if s.inbountCounter != nil { 220 | s.inbountCounter.Add(int64(n)) 221 | } 222 | return 223 | } 224 | 225 | func (s *tunnel) logCopyErr(name string, err error) { 226 | if err == nil { 227 | return 228 | } 229 | if config.DebugLevel >= config.LevelLong { 230 | log.Println(trace.ID(s.req.ID), name, err.Error()) 231 | } else if err != io.EOF { 232 | log.Println(trace.ID(s.req.ID), name, err.Error()) 233 | } 234 | } 235 | 236 | // dail tcp连接 237 | func (s *tunnel) dail(network, connAddr string, second int64) error { 238 | if config.DebugLevel >= config.LevelLong { 239 | log.Printf("%s create new connection to server %s\n", trace.ID(s.req.ID), connAddr) 240 | } 241 | 242 | connTimeout := time.Duration(5) * time.Second 243 | if second > 0 { 244 | connTimeout = time.Duration(second) * time.Second 245 | } 246 | conn, err := net.DialTimeout(network, connAddr, connTimeout) 247 | if err != nil { 248 | return err 249 | } 250 | s.conn = conn.(*net.TCPConn) 251 | return nil 252 | } 253 | 254 | // 注册计数器, 日志地址优先使用域名 255 | func (s *tunnel) registerCounter(dstName, dstIP string, dstPort uint16) { 256 | // 日志地址优先使用域名 257 | var logAddr string 258 | if dstName != "" { 259 | logAddr = fmt.Sprintf("%s:%d", dstName, dstPort) 260 | } else { 261 | if strings.Contains(dstIP, ":") { 262 | logAddr = fmt.Sprintf("[%s]:%d", dstIP, dstPort) 263 | } else { 264 | logAddr = fmt.Sprintf("%s:%d", dstIP, dstPort) 265 | } 266 | } 267 | uplink := fmt.Sprintf("inbound>>>%s>>>%s>>>uplink", s.inboundIP, logAddr) 268 | downlink := fmt.Sprintf("inbound>>>%s>>>%s>>>downlink", s.inboundIP, logAddr) 269 | s.inbountCounter = inbound.RegisterCounter(uplink) 270 | s.outbountCounter = outbound.RegisterCounter(downlink) 271 | } 272 | 273 | // 连接地址优先使用IP 274 | func (s *tunnel) buildAddress(dstName, dstIP string, dstPort uint16, addCounter bool) (network string, connAddr string) { 275 | network = "tcp" 276 | if dstIP != "" { 277 | if strings.Contains(dstIP, ":") { 278 | network = "tcp6" 279 | connAddr = fmt.Sprintf("[%s]:%d", dstIP, dstPort) 280 | } else { 281 | connAddr = fmt.Sprintf("%s:%d", dstIP, dstPort) 282 | } 283 | } else if dstName != "" { 284 | connAddr = fmt.Sprintf("%s:%d", dstName, dstPort) 285 | } 286 | 287 | if addCounter && connAddr != "" { 288 | s.registerCounter(dstName, dstIP, dstPort) 289 | } 290 | return 291 | } 292 | 293 | // DNS解析 294 | func (s *tunnel) lookup(dstName, dstIP string) (string, cache.DialState) { 295 | state := cache.StateNone 296 | if dstName != "" { 297 | dstIP, state = cache.ResolveLookup.Lookup(s.req.ID, dstName) 298 | if dstIP == "" { 299 | s1 := time.Now() 300 | upIPs, _ := net.LookupIP(dstName) 301 | if time.Since(s1).Seconds() > 1 && config.DebugLevel >= config.LevelLong { 302 | log.Println(trace.ID(s.req.ID), "dns look up costtime", time.Since(s1).Seconds()) 303 | } 304 | if len(upIPs) > 0 { 305 | dstIP = upIPs[0].String() 306 | cache.ResolveLookup.Store(dstName, dstIP, cache.StateNew, time.Duration(10)*time.Minute) 307 | return dstIP, cache.StateNew 308 | } 309 | } 310 | } 311 | return dstIP, state 312 | } 313 | 314 | // 查询配置 315 | func findHost(dstName, dstIP string) conf.Host { 316 | for _, h := range conf.RouterConfig.Hosts { 317 | confMatch := getString(h.Match, conf.RouterConfig.Default.Match, "equal") 318 | switch confMatch { 319 | case "equal": 320 | if h.Name == dstName || h.Name == dstIP { 321 | return h 322 | } 323 | case "contain": 324 | if strings.Contains(dstName, h.Name) || strings.Contains(dstIP, h.Name) { 325 | return h 326 | } 327 | default: 328 | //todo 329 | } 330 | } 331 | return conf.Host{} 332 | } 333 | 334 | // 取值,如为空取默认 335 | func getString(val string, def string, def2 string) string { 336 | if val == "" { 337 | if def == "" { 338 | return def2 339 | } 340 | return def 341 | } 342 | return val 343 | } 344 | 345 | // handshake 和server握手 346 | func (s *tunnel) handshake(proto string, dstName, dstIP string, dstPort uint16) (err error) { 347 | var state cache.DialState 348 | // 先取下配置,再决定要不要走本地dns解析,否则未解析域名DNS解析再超时卡半天,又不会被缓存 349 | host := findHost(dstName, dstIP) 350 | if ip, ok := s.isAllowed(host.AllowIP); !ok { 351 | err = fmt.Errorf("%s is not allowed", ip) 352 | return err 353 | } 354 | var confTarget string 355 | if proto == protoTCP { 356 | confTarget = getString(host.Target, conf.RouterConfig.Default.TCPTarget, "auto") 357 | } else { 358 | confTarget = getString(host.Target, conf.RouterConfig.Default.Target, "auto") 359 | } 360 | confDNS := getString(host.DNS, conf.RouterConfig.Default.DNS, "local") 361 | 362 | // tcp 请求,如果是解析的IP被禁(代理端也无法telnet),不知道域名又无法使用远程dns解析,只能手动换ip 363 | // 如golang.org 解析为180.97.235.30 不通,配置改为 216.239.37.1就行 364 | if host.IP != "" { 365 | dstIP = host.IP 366 | } else if dstName != "" && confDNS != "remote" { 367 | // http请求的dns解析 368 | dstIP, state = s.lookup(dstName, dstIP) 369 | } 370 | 371 | // 检查是否要换端口 372 | for _, p := range host.Port { 373 | if p.From == dstPort { 374 | dstPort = p.To 375 | break 376 | } 377 | } 378 | 379 | if confTarget == "deny" { 380 | err = fmt.Errorf("deny visit %s (%s)", dstName, dstIP) 381 | return 382 | } 383 | proxyScheme := config.ProxyScheme 384 | proxyServer := config.ProxyServer 385 | proxyPort := config.ProxyPort 386 | if host.Proxy != "" { //如果有自定义代理,则走自定义 387 | suffixLen := 5 388 | // 如果单域名代理配置以" last"或" deny"结尾,忽略全局的代理,并做相应的动作 389 | opIdx := len(host.Proxy) - suffixLen 390 | opName := "" 391 | if len(host.Proxy) >= suffixLen && host.Proxy[opIdx:opIdx+1] == " " { 392 | opName = host.Proxy[opIdx+1:] 393 | host.Proxy = host.Proxy[:opIdx] 394 | } 395 | 396 | // 支持多代理以逗号分隔,依次找到能用的 397 | for _, hostProxy := range strings.Split(host.Proxy, ",") { 398 | hostProxy = strings.TrimSpace(hostProxy) 399 | proxyScheme2, proxyServer2, proxyPort2, err := getProxyServer(hostProxy) 400 | if err != nil { 401 | // 如果自定义代理不可用,confTarget走原来逻辑 402 | log.Println(trace.ID(s.req.ID), "host.proxy err", err) 403 | } else { 404 | proxyScheme = proxyScheme2 405 | proxyServer = proxyServer2 406 | proxyPort = proxyPort2 407 | if confTarget != "remote" { //如果有定制代理,就不能用local 和 auto 408 | confTarget = "remote" 409 | } 410 | opName = "" 411 | break 412 | } 413 | } 414 | if opName == "last" { //没通的代理,走本地 415 | proxyServer = "" 416 | } else if opName == "deny" { 417 | err = fmt.Errorf("all proxy dail fail %s", host.Proxy) 418 | return 419 | } 420 | } 421 | if proxyServer != "" && proxyPort > 0 && confTarget != "local" { 422 | if confTarget == "auto" { 423 | if state != cache.StateFail { 424 | //local dial成功则返回,走本地网络 425 | //auto 只能优化ip ping 不通的情况,能dail通访问不了的需要手动remote 426 | network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true) 427 | if connAddr != "" { 428 | err = s.dail(network, connAddr, 1) 429 | if err == nil { 430 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("auto to %s", connAddr)) 431 | s.curState = stateNew 432 | return 433 | } 434 | } 435 | if dstName != "" && dstIP != "" { 436 | cache.ResolveLookup.Store(dstName, dstIP, cache.StateFail, time.Duration(1)*time.Hour) 437 | } 438 | } 439 | //fail的auto 等于用remote访问,但ip在remote访问可能也是不通的,强制用远程dns 440 | //如果又想远程,又想用本地dns请配置中单独指定 441 | //有一种情况是ip能dail通,auto模式就是会用local,但是transfer时接不到数据包,这种也要配置中单独指定remote 442 | confDNS = "remote" 443 | } 444 | // remote 请求 445 | var targetAddr string 446 | var targetNet string 447 | if confDNS == "remote" { 448 | if dstName == "" { 449 | dstName = dstIP 450 | } 451 | targetNet, targetAddr = s.buildAddress(dstName, "", dstPort, false) 452 | } else { 453 | targetNet, targetAddr = s.buildAddress("", dstIP, dstPort, false) 454 | } 455 | if targetAddr == "" || targetAddr[0] == ':' { 456 | err = errors.New("target host is empty") 457 | return 458 | } 459 | 460 | network, connAddr := s.buildAddress(proxyServer, "", proxyPort, true) 461 | switch proxyScheme { 462 | case "socks5": 463 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr)) 464 | err = s.socks5(network, connAddr, targetNet, targetAddr) 465 | case "tunnel": 466 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr)) 467 | err = s.httpConnect(network, connAddr, targetAddr, true) 468 | case "http": 469 | if proto == protoHTTP { //可避免转发到charles显示2次域名,且部分电脑请求出错 470 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s", connAddr)) 471 | err = s.dail(network, connAddr, 0) 472 | } else { 473 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("PROXY %s for %s", connAddr, targetAddr)) 474 | err = s.httpConnect(network, connAddr, targetAddr, false) 475 | } 476 | default: 477 | err = fmt.Errorf("proxy scheme %s is error", proxyScheme) 478 | return 479 | } 480 | } else { 481 | network, connAddr := s.buildAddress(dstName, dstIP, dstPort, true) 482 | if connAddr != "" { 483 | if dstName == "" { 484 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s", connAddr)) 485 | } else { 486 | log.Println(trace.ID(s.req.ID), fmt.Sprintf("direct to %s for %s", connAddr, dstName)) 487 | } 488 | err = s.dail(network, connAddr, 0) 489 | } else { 490 | err = errors.New("dstName && dstIP is empty") 491 | } 492 | } 493 | if err != nil { 494 | return 495 | } 496 | s.curState = stateNew 497 | return 498 | } 499 | 500 | // getProxyServer 解析代理服务器 501 | func getProxyServer(proxySpec string) (string, string, uint16, error) { 502 | if proxySpec == "" { 503 | return "", "", 0, errors.New("proxy 长度为空") 504 | } 505 | proxyScheme := "tunnel" 506 | var proxyServer string 507 | var proxyPort uint16 508 | // 先检查协议 509 | tmp := strings.Split(proxySpec, "://") 510 | if len(tmp) == 2 { 511 | proxyScheme = tmp[0] 512 | proxySpec = tmp[1] 513 | } 514 | // 检查端口,和上面的顺序不能反 515 | tmp = strings.Split(proxySpec, ":") 516 | if len(tmp) == 2 { 517 | portInt, err := strconv.Atoi(tmp[1]) 518 | if err == nil { 519 | proxyServer = tmp[0] 520 | proxyPort = uint16(portInt) 521 | // 检查是否可连通, 内网不好时100毫秒不够,调整到300 522 | connTimeout := time.Duration(300) * time.Millisecond 523 | conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", proxyServer, proxyPort), connTimeout) 524 | if err != nil { 525 | return "", "", 0, err 526 | } 527 | conn.Close() 528 | return proxyScheme, proxyServer, proxyPort, nil 529 | } 530 | return "", "", 0, err 531 | } 532 | return "", "", 0, errors.New("proxy 格式不对") 533 | } 534 | 535 | // socket5代理 536 | func (s *tunnel) socks5(network, connAddr string, targetNet, targetAddr string) (err error) { 537 | var dialProxy proxy.Dialer 538 | dialProxy, err = proxy.SOCKS5(network, connAddr, nil, proxy.Direct) 539 | if err != nil { 540 | log.Println(trace.ID(s.req.ID), "socket5 err", err.Error()) 541 | return 542 | } 543 | 544 | var conn net.Conn 545 | conn, err = dialProxy.Dial(targetNet, targetAddr) 546 | if err != nil { 547 | log.Println(trace.ID(s.req.ID), "dail err", err.Error()) 548 | return 549 | } 550 | s.conn = conn.(*net.TCPConn) 551 | return 552 | } 553 | 554 | // http代理 555 | func (s *tunnel) httpConnect(network, connAddr string, target string, encrypt bool) (err error) { 556 | err = s.dail(network, connAddr, 0) 557 | if err != nil { 558 | log.Println(trace.ID(s.req.ID), "dail err", err.Error()) 559 | return 560 | } 561 | var connectString string 562 | if encrypt { 563 | key := []byte(getToken()) 564 | var x1 []byte 565 | x1, err = crypto.EncryptAES([]byte(target), key) 566 | if err != nil { 567 | log.Println(trace.ID(s.req.ID), "encrypt err", err.Error()) 568 | return 569 | } 570 | // CONNECT实现的加密 571 | connectString = fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", base64.StdEncoding.EncodeToString(x1)) 572 | } else { 573 | connectString = fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", target) 574 | } 575 | fmt.Fprintf(s.conn, connectString) 576 | var status string 577 | status, err = bufio.NewReader(s.conn).ReadString('\n') 578 | if err != nil { 579 | log.Printf("%s PROXY ERR: Could not find response to CONNECT: err=%v", trace.ID(s.req.ID), err) 580 | return 581 | } 582 | // 检查是不是200返回 583 | if strings.Contains(status, "200") == false { 584 | log.Printf("%s PROXY ERR: Proxy response to CONNECT was: %s.\n", trace.ID(s.req.ID), strconv.Quote(status)) 585 | err = fmt.Errorf("Proxy response was: %s", strconv.Quote(status)) 586 | } 587 | return 588 | } 589 | 590 | // IP限制 591 | func (s *tunnel) isAllowed(allows []string) (string, bool) { 592 | allows = append(allows, conf.RouterConfig.AllowIP...) 593 | if len(allows) == 0 { 594 | return "", true 595 | } 596 | 597 | userIP := net.ParseIP(s.inboundIP) 598 | for _, p := range allows { 599 | if iPInCIDR(userIP, p) { 600 | return "", true 601 | } 602 | } 603 | return s.inboundIP, false 604 | } 605 | 606 | // iPInCIDR 判断IP地址是否在指定的CIDR范围内,支持ipv4和ipv6 607 | // cidr 示例 "192.168.1.0/24" "2001:db8:1234:5678::/64" 608 | func iPInCIDR(ip net.IP, cidr string) bool { 609 | _, ipNet, err := net.ParseCIDR(cidr) 610 | if err != nil { 611 | // 可能cidr是一个单ip的情况 612 | return ip.String() == cidr 613 | } 614 | return ipNet.Contains(ip) 615 | } 616 | -------------------------------------------------------------------------------- /proto/websocket.go: -------------------------------------------------------------------------------- 1 | package proto 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "log" 7 | 8 | "github.com/keminar/anyproxy/config" 9 | "github.com/keminar/anyproxy/nat" 10 | "github.com/keminar/anyproxy/proto/http" 11 | "github.com/keminar/anyproxy/utils/conf" 12 | "github.com/keminar/anyproxy/utils/trace" 13 | ) 14 | 15 | // 转发实体 16 | type wsTunnel struct { 17 | req *Request 18 | header http.Header 19 | 20 | readSize int64 21 | writeSize int64 22 | 23 | buffer *bytes.Buffer 24 | } 25 | 26 | // newTunnel 实例 27 | func newWsTunnel(req *Request, header http.Header) *wsTunnel { 28 | s := &wsTunnel{ 29 | req: req, 30 | header: header, 31 | buffer: new(bytes.Buffer), 32 | } 33 | return s 34 | } 35 | 36 | // 检查ws转发是否允许 37 | func (s *wsTunnel) getTarget(dstName string) (ok bool) { 38 | if dstName == "" { 39 | return false 40 | } 41 | host := findHost(dstName, dstName) 42 | var confTarget string 43 | confTarget = getString(host.Target, conf.RouterConfig.Default.Target, "auto") 44 | 45 | if confTarget == "deny" { 46 | return false 47 | } 48 | return true 49 | } 50 | 51 | // transfer 交换数据 52 | func (s *wsTunnel) transfer() bool { 53 | if config.DebugLevel >= config.LevelLong { 54 | log.Println(trace.ID(s.req.ID), "websocket transfer start") 55 | } 56 | 57 | c := nat.ServerHub.GetClient(s.header) 58 | if c == nil { 59 | // 走旧转发 60 | log.Println(trace.ID(s.req.ID), "websocket subscribe not found") 61 | return false 62 | } 63 | b := nat.ServerBridge.Register(c, s.req.ID, s.req.conn) 64 | defer func() { 65 | b.Unregister() 66 | }() 67 | 68 | // 发送创建连接请求 69 | b.Open() 70 | var err error 71 | done := make(chan struct{}) 72 | 73 | //发送请求给websocket 74 | go func() { 75 | defer close(done) 76 | b.Write([]byte(s.buffer.String())) 77 | s.readSize, err = b.CopyBuffer(b, s.req.reader, "request") 78 | s.logCopyErr("request->websocket", err) 79 | if config.DebugLevel >= config.LevelDebug { 80 | log.Println(trace.ID(s.req.ID), "request body size", s.readSize) 81 | } 82 | b.CloseWrite() 83 | }() 84 | //取返回结果写入请求端 85 | s.writeSize, err = b.WritePump() 86 | s.logCopyErr("websocket->request", err) 87 | 88 | <-done 89 | // 不管是不是正常结束,只要server结束了,函数就会返回,然后底层会自动断开与client的连接 90 | if config.DebugLevel >= config.LevelDebug { 91 | log.Println(trace.ID(s.req.ID), "websocket transfer finished, response size", s.writeSize) 92 | } 93 | return true 94 | } 95 | 96 | func (s *wsTunnel) logCopyErr(name string, err error) { 97 | if err == nil { 98 | return 99 | } 100 | if config.DebugLevel >= config.LevelLong { 101 | log.Println(trace.ID(s.req.ID), name, err.Error()) 102 | } else if err != io.EOF { 103 | log.Println(trace.ID(s.req.ID), name, err.Error()) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SCRIPT=$(readlink -f $0) 3 | ROOT_DIR=$(dirname $SCRIPT)/../ 4 | cd $ROOT_DIR 5 | 6 | mkdir -p dist/ 7 | 8 | # 路径 9 | GOCMD="go" 10 | GITCMD="git" 11 | 12 | # 目标文件前缀 13 | BIN="anyproxy" 14 | 15 | # 版本号 16 | ARCH="amd64" 17 | 18 | #组装变量 19 | GOBUILD="${GOCMD} build" 20 | VER=`${GITCMD} describe --tags $(${GITCMD} rev-list --tags --max-count=1)` 21 | GOVER=`${GOCMD} version` 22 | COMMIT_SHA1=`${GITCMD} rev-parse HEAD` 23 | HELP_PRE="github.com/keminar/anyproxy/utils/help" 24 | LDFLAGS="-X '${HELP_PRE}.goVersion=${GOVER}'" 25 | LDFLAGS="${LDFLAGS} -X '${HELP_PRE}.gitHash=${COMMIT_SHA1}'" 26 | LDFLAGS="${LDFLAGS} -X '${HELP_PRE}.version=${VER}'" 27 | 28 | # 编译 29 | echo "build ..." 30 | if [ "$1" == "all" ] || [ "$1" == "linux" ] ;then 31 | echo " for linux" 32 | CGO_ENABLED=0 GOOS=linux GOARCH=${ARCH} ${GOBUILD} -trimpath -ldflags "$LDFLAGS" -o dist/${BIN}-${ARCH}-${VER} anyproxy.go 33 | fi 34 | 35 | if [ "$1" == "all" ] || [ "$1" == "mac" ] ;then 36 | echo " for mac" 37 | CGO_ENABLED=0 GOOS=darwin GOARCH=${ARCH} ${GOBUILD} -trimpath -ldflags "$LDFLAGS" -o dist/${BIN}-darwin-${ARCH}-${VER} anyproxy.go 38 | fi 39 | 40 | if [ "$1" == "all" ] || [ "$1" == "windows" ] ;then 41 | echo " for windows" 42 | CGO_ENABLED=0 GOOS=windows GOARCH=${ARCH} ${GOBUILD} -trimpath -ldflags "$LDFLAGS" -o dist/${BIN}-windows-${ARCH}-${VER}.exe anyproxy.go 43 | fi 44 | 45 | if [ "$1" == "all" ] || [ "$1" == "alpine" ] ;then 46 | echo " for alpine" 47 | CGO_ENABLED=0 GOOS=linux GOARCH=${ARCH} ${GOBUILD} -tags netgo -trimpath -ldflags "$LDFLAGS" -o dist/${BIN}-alpine-${ARCH}-${VER} anyproxy.go 48 | fi -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SCRIPT=$(readlink -f $0) 3 | ROOT_DIR=$(dirname $SCRIPT)/../ 4 | cd $ROOT_DIR 5 | ulimit -n 65536 6 | 7 | if ! (sudo cat /etc/passwd|grep ^anyproxy: > /dev/null); then 8 | echo "添加账号" 9 | sudo useradd -M -s /sbin/nologin anyproxy 10 | fi 11 | 12 | # 要启动的端口 13 | port=3000 14 | if ! (sudo iptables -t nat -L|grep "redir ports $port" > /dev/null); then 15 | echo "添加iptables" 16 | # anyproxy 账号不走代理直接请求 17 | sudo iptables -t nat -A OUTPUT -p tcp -m owner --uid-owner anyproxy -j RETURN 18 | 19 | # 注:如果有虚拟机,虚拟机的启动账号不要走代理端口,会导致虚拟机里的浏览器设置的代理无效果 20 | # 如果有本地账号要连mysql等服务,也不要走代理端口,会一直卡住(目前发现mysql协议不兼容) 21 | 22 | #指定root账号走代理 23 | sudo iptables -t nat -A OUTPUT -p tcp -d 192.168.0.0/16 -m owner --uid-owner 0 -j RETURN 24 | sudo iptables -t nat -A OUTPUT -p tcp -d 172.17.0.0/16 -m owner --uid-owner 0 -j RETURN 25 | sudo iptables -t nat -A OUTPUT -p tcp -m multiport --dport 80,443 -m owner --uid-owner 0 -j REDIRECT --to-port $port 26 | 27 | #sudo iptables -t nat -L -n --line-number 28 | #sudo iptables -t nat -D OUTPUT 3 29 | fi 30 | echo "启动anyproxy" 31 | sudo -u anyproxy ./anyproxy -daemon -l $port 32 | -------------------------------------------------------------------------------- /scripts/win-start.bat: -------------------------------------------------------------------------------- 1 | cmd /k "anyproxy-windows-amd64-v1.0.exe -c router.yaml" -------------------------------------------------------------------------------- /utils/cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | "time" 7 | 8 | "github.com/keminar/anyproxy/config" 9 | "github.com/keminar/anyproxy/utils/trace" 10 | ) 11 | 12 | // ResolveLookup 解析缓存 13 | var ResolveLookup *resolveLookupCache 14 | 15 | func init() { 16 | ResolveLookup = newResolveLookupCache() 17 | } 18 | 19 | // DialState 状态 20 | type DialState int 21 | 22 | const ( 23 | //StateNew 新值,未dial失败值 24 | StateNew DialState = iota 25 | //StateFail ipv4地址 dial失败 26 | StateFail 27 | //StateNone 不存在的地址 28 | StateNone 29 | ) 30 | 31 | type cacheEntry struct { 32 | ipv4 string //ip v4地址 33 | state DialState //是否可连通 34 | expires time.Time 35 | } 36 | type resolveLookupCache struct { 37 | ips map[string]*cacheEntry 38 | keys []string 39 | next int 40 | mu sync.Mutex 41 | } 42 | 43 | // newResolveLookupCache 初始化 44 | func newResolveLookupCache() *resolveLookupCache { 45 | return &resolveLookupCache{ 46 | ips: make(map[string]*cacheEntry), 47 | keys: make([]string, 65536), 48 | } 49 | } 50 | 51 | // Lookup 查找 52 | func (c *resolveLookupCache) Lookup(logID uint, host string) (string, DialState) { 53 | c.mu.Lock() 54 | defer c.mu.Unlock() 55 | hit := c.ips[host] 56 | if hit != nil { 57 | if hit.expires.After(time.Now()) { 58 | if config.DebugLevel >= config.LevelDebug { 59 | log.Println(trace.ID(logID), "lookup(): CACHE_HIT", hit.state) 60 | } 61 | return hit.ipv4, hit.state 62 | } 63 | if config.DebugLevel >= config.LevelDebug { 64 | log.Println(trace.ID(logID), "lookup(): CACHE_EXPIRED") 65 | } 66 | delete(c.ips, host) 67 | } else { 68 | if config.DebugLevel >= config.LevelDebug { 69 | log.Println(trace.ID(logID), "lookup(): CACHE_MISS") 70 | } 71 | } 72 | return "", StateNone 73 | } 74 | 75 | // Store 保存,只有65535个位置,删除之前的占用 76 | func (c *resolveLookupCache) Store(host, ipv4 string, state DialState, d time.Duration) { 77 | c.mu.Lock() 78 | defer c.mu.Unlock() 79 | hit := c.ips[host] 80 | if hit != nil { 81 | hit.ipv4 = ipv4 82 | hit.state = state 83 | hit.expires = time.Now().Add(d) 84 | return 85 | } 86 | // 删除原位置内的值 87 | delete(c.ips, c.keys[c.next]) 88 | c.keys[c.next] = host 89 | c.next = (c.next + 1) & 65535 90 | c.ips[host] = &cacheEntry{ipv4: ipv4, state: state, expires: time.Now().Add(d)} 91 | } 92 | -------------------------------------------------------------------------------- /utils/conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | 7 | "github.com/fsnotify/fsnotify" 8 | ) 9 | 10 | // RouterConfig 配置 11 | var RouterConfig *Router 12 | 13 | // LoadAllConfig 加载顺序要求,不写成init 14 | func LoadAllConfig(filePath string) { 15 | var err error 16 | if filePath == "" { 17 | filePath, err = GetPath("router.yaml") 18 | } else if !fileExists(filePath) { 19 | filePath, err = GetPath(filePath) 20 | } 21 | if err != nil { 22 | log.Println(fmt.Sprintf("config file %s path err:%s", "router", err.Error())) 23 | return 24 | } 25 | conf, err := LoadRouterConfig(filePath) 26 | if err != nil { 27 | log.Println(fmt.Sprintf("config file %s load err:%s", "router", err.Error())) 28 | return 29 | } 30 | RouterConfig = &conf 31 | if conf.Watcher { 32 | go notify(filePath) 33 | } 34 | } 35 | 36 | func notify(filePath string) { 37 | watcher, err := fsnotify.NewWatcher() 38 | if err != nil { 39 | log.Println("config new notify watcher err", err) 40 | return 41 | } 42 | defer watcher.Close() 43 | 44 | done := make(chan bool) 45 | go func() { 46 | defer close(done) 47 | for { 48 | select { 49 | case event, ok := <-watcher.Events: 50 | if !ok { 51 | return 52 | } 53 | if event.Op&fsnotify.Write == fsnotify.Write { 54 | conf, err := LoadRouterConfig(filePath) 55 | if err != nil { 56 | log.Println(fmt.Sprintf("config file %s load err:%s", "router", err.Error())) 57 | } else { 58 | RouterConfig = &conf 59 | log.Println("config file reloaded:", filePath) 60 | } 61 | } 62 | case err, ok := <-watcher.Errors: 63 | if !ok { 64 | return 65 | } 66 | log.Println("config notify watcher error:", err) 67 | } 68 | } 69 | }() 70 | 71 | err = watcher.Add(filePath) 72 | if err != nil { 73 | log.Println("config notify add file err", err) 74 | return 75 | } 76 | <-done 77 | } 78 | -------------------------------------------------------------------------------- /utils/conf/path.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "runtime" 7 | ) 8 | 9 | // AppSrcPath 源码根目录 10 | var AppSrcPath string 11 | 12 | // AppPath 二进制文件根目录 13 | var AppPath string 14 | 15 | func init() { 16 | _, file, _, _ := runtime.Caller(0) 17 | upDir := ".." + string(filepath.Separator) 18 | var err error 19 | if AppSrcPath, err = filepath.Abs(filepath.Dir(filepath.Join(file, upDir, upDir))); err != nil { 20 | panic(err) 21 | } 22 | 23 | if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { 24 | panic(err) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /utils/conf/router.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "errors" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | 10 | "gopkg.in/yaml.v2" 11 | ) 12 | 13 | type PortMap struct { 14 | From uint16 `yaml:"from"` //原目标地址 15 | To uint16 `yaml:"to"` //新目标地址 16 | } 17 | 18 | // Host 域名 19 | type Host struct { 20 | Name string `yaml:"name"` //域名关键字 21 | Match string `yaml:"match"` //contain 包含, equal 完全相等, preg 正则 22 | Target string `yaml:"target"` //local 当前环境, remote 远程, deny 禁止, auto根据dial选择 23 | DNS string `yaml:"dns"` //local 当前环境, remote 远程, 仅当target使用remote有效 24 | IP string `yaml:"ip"` //本地解析ip 25 | Port []PortMap `yaml:"port"` //目标端口转换 26 | Proxy string `yaml:"proxy"` //指定代理服务器 27 | AllowIP []string `yaml:"allowIP"` //可以访问的客户端IP 28 | } 29 | 30 | // Log 日志 31 | type Log struct { 32 | Dir string `yaml:"dir"` 33 | } 34 | 35 | // Subscribe 订阅标志 36 | type Subscribe struct { 37 | Key string `yaml:"key"` //Header的key 38 | Val string `yaml:"val"` //Header的val 39 | } 40 | 41 | // Websocket 与服务端websocket通信 42 | type Websocket struct { 43 | Listen string `yaml:"listen"` //websocket 监听 44 | Connect string `yaml:"connect"` //websocket 连接 45 | Host string `yaml:"host"` //connect的域名 46 | User string `yaml:"user"` //认证用户 47 | Pass string `yaml:"pass"` //密码 48 | Email string `yaml:"email"` //邮箱 49 | Subscribe []Subscribe `yaml:"subscribe"` //订阅信息 50 | } 51 | 52 | // Default 域名 53 | type Default struct { 54 | Match string `yaml:"match"` //默认域名比对 55 | Target string `yaml:"target"` //http默认访问策略 56 | DNS string `yaml:"dns"` //默认的DNS服务器 57 | Proxy string `yaml:"proxy"` //全局代理服务器 58 | TCPTarget string `yaml:"tcpTarget"` //tcp默认访问策略 59 | } 60 | 61 | type TcpCopy struct { 62 | Enable bool `yaml:"enable"` //是否开启 63 | IP string `yaml:"ip"` //ip 64 | Port uint16 `yaml:"port"` //新目标地址 65 | } 66 | 67 | // http首行请求格式,一般vue本地项目要把域名配置为off 68 | // 注意:custom配置域名和端口中间的冒号改为点,如localhost:5173配置为localhost.5173 69 | type FirstLine struct { 70 | Host string `yaml:"host"` //是否带Host, on带,off不带,默认带 71 | Custom map[string]string `yaml:"custom"` //按域名配带Host,on带,off不带,其他用默认 72 | } 73 | 74 | // Router 配置文件模型 75 | type Router struct { 76 | Listen string `yaml:"listen"` //监听端口 77 | Network string `yaml:"network"` //监听协议 78 | Log Log `yaml:"log"` //日志目录 79 | Watcher bool `yaml:"watcher"` //是否监听配置文件变化 80 | Token string `yaml:"token"` //加密值, 和tunnel通信密钥, 必须16位长度 81 | TcpCopy TcpCopy `yaml:"tcpcopy"` //进行tcp转发模式 82 | Default Default `yaml:"default"` //默认配置 83 | Hosts []Host `yaml:"hosts"` //域名列表 84 | AllowIP []string `yaml:"allowIP"` //可以访问的客户端IP 85 | FirstLine FirstLine `yaml:"firstLine"` //http请求首行域名和头部域名相同时删除首行域名 86 | Websocket Websocket `yaml:"websocket"` //会话订阅请求信息 87 | } 88 | 89 | // LoadRouterConfig 加载配置 90 | func LoadRouterConfig(configPath string) (cnf Router, err error) { 91 | data, err := ioutil.ReadFile(configPath) 92 | if err != nil { 93 | return 94 | } 95 | err = yaml.Unmarshal(data, &cnf) 96 | return 97 | } 98 | 99 | // 获取文件路径 100 | func GetPath(filename string) (string, error) { 101 | // 当前登录用户所在目录 102 | workPath, err := os.Getwd() 103 | if err != nil { 104 | panic(err) 105 | } 106 | configPath := filepath.Join(workPath, "conf", filename) 107 | if !fileExists(configPath) { 108 | configPath = filepath.Join(AppPath, "conf", filename) 109 | if !fileExists(configPath) { 110 | configPath = filepath.Join(AppSrcPath, "conf", filename) 111 | if !fileExists(configPath) { 112 | log.Println("workPath:", workPath) 113 | log.Println("appPath:", AppPath) 114 | return "", errors.New("conf/" + filename + " not found") 115 | } 116 | } 117 | } 118 | return configPath, nil 119 | } 120 | 121 | // fileExists reports whether the named file or directory exists. 122 | func fileExists(name string) bool { 123 | if _, err := os.Stat(name); err != nil { 124 | if os.IsNotExist(err) { 125 | return false 126 | } 127 | } 128 | return true 129 | } 130 | -------------------------------------------------------------------------------- /utils/daemon/daemon.go: -------------------------------------------------------------------------------- 1 | // https://github.com/immortal/immortal/blob/master/fork.go 2 | // https://github.com/icattlecoder/godaemon/blob/master/godaemon.go 3 | 4 | package daemon 5 | 6 | import ( 7 | "flag" 8 | "log" 9 | "os" 10 | ) 11 | 12 | var daemon = flag.Bool("daemon", false, "run app as a daemon") 13 | 14 | // Daemonize 后台化 15 | func Daemonize(envName string, fd *os.File) { 16 | if !flag.Parsed() { 17 | flag.Parse() 18 | } 19 | // 如果启用daemon模式,Fork的进行在主进程退出后PPID为1 20 | if *daemon && os.Getppid() > 1 { 21 | // 为了兼容平滑重启,和二重保证不死循环。主动替换daemon参数 22 | args := os.Args[1:] 23 | for i := 0; i < len(args); i++ { 24 | if args[i] == "-daemon" || args[i] == "-daemon=true" { 25 | args[i] = "-daemon=false" 26 | break 27 | } 28 | } 29 | if pid, err := Fork(envName, fd, args); err != nil { 30 | log.Fatalf("error while forking: %s", err) 31 | } else { 32 | if pid > 0 { 33 | os.Exit(0) 34 | } 35 | } 36 | } 37 | } 38 | 39 | // IsForeground 是否在前台执行 40 | // 因ppid的方案在graceful平滑重启时不准,又不想加太多的外部参数, 所以用环境变量 41 | func IsForeground(envName string) bool { 42 | if os.Getenv(envName) == "" { 43 | return true 44 | } 45 | return false 46 | } 47 | -------------------------------------------------------------------------------- /utils/daemon/daemon_fork.go: -------------------------------------------------------------------------------- 1 | // 条件编译 https://segmentfault.com/a/1190000017846997 2 | 3 | // +build !windows 4 | 5 | package daemon 6 | 7 | import ( 8 | "os" 9 | "os/exec" 10 | "syscall" 11 | ) 12 | 13 | // Fork crete a new process 14 | func Fork(envName string, fd *os.File, args []string) (int, error) { 15 | cmd := exec.Command(os.Args[0], args...) 16 | val := os.Getenv(envName) 17 | if val == "" { //若未设置则为空字符串 18 | //为子进程设置特殊的环境变量标识 19 | os.Setenv(envName, "daemon") 20 | } 21 | cmd.Env = os.Environ() 22 | cmd.Stdin = nil 23 | //为捕获执行程序的输出,非设置新进程的os.Stdout 不要理解错 24 | //新进程的os.Stdout.Name()值还是默认值,但输出到/dev/stdout的这边能获取到 25 | //这边必须设置,否则新进程内的错误可能捕获不到 26 | // 用 os.NewFile(uintptr(syscall.Stderr), "/dev/stderr").WriteString("test\n") 复现 27 | cmd.Stdout = fd 28 | cmd.Stderr = fd 29 | cmd.ExtraFiles = nil 30 | cmd.SysProcAttr = &syscall.SysProcAttr{ 31 | // Setsid is used to detach the process from the parent (normally a shell) 32 | // 33 | // The disowning of a child process is accomplished by executing the system call 34 | // setpgrp() or setsid(), (both of which have the same functionality) as soon as 35 | // the child is forked. These calls create a new process session group, make the 36 | // child process the session leader, and set the process group ID to the process 37 | // ID of the child. https://bsdmag.org/unix-kernel-system-calls/ 38 | Setsid: true, 39 | } 40 | if err := cmd.Start(); err != nil { 41 | return 0, err 42 | } 43 | return cmd.Process.Pid, nil 44 | } 45 | -------------------------------------------------------------------------------- /utils/daemon/daemon_windows.go: -------------------------------------------------------------------------------- 1 | package daemon 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | ) 7 | 8 | // Fork crete a new process 9 | func Fork(envName string, fd *os.File, args []string) (int, error) { 10 | cmd := exec.Command(os.Args[0], args...) 11 | val := os.Getenv(envName) 12 | if val == "" { //若未设置则为空字符串 13 | //为子进程设置特殊的环境变量标识 14 | os.Setenv(envName, "daemon") 15 | } 16 | cmd.Env = os.Environ() 17 | cmd.Stdin = nil 18 | //为捕获执行程序的输出,非设置新进程的os.Stdout 不要理解错 19 | //新进程的os.Stdout.Name()值还是默认值,但输出到/dev/stdout的这边能获取到 20 | //这边必须设置,否则新进程内的错误可能捕获不到 21 | // 用 os.NewFile(uintptr(syscall.Stderr), "/dev/stderr").WriteString("test\n") 复现 22 | cmd.Stdout = fd 23 | cmd.Stderr = fd 24 | cmd.ExtraFiles = nil 25 | if err := cmd.Start(); err != nil { 26 | return 0, err 27 | } 28 | return cmd.Process.Pid, nil 29 | } 30 | -------------------------------------------------------------------------------- /utils/help/help.go: -------------------------------------------------------------------------------- 1 | package help 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strconv" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | var ( 12 | version string 13 | gitHash string 14 | goVersion string 15 | ) 16 | 17 | // Usage 帮助 18 | func Usage() { 19 | fmt.Fprintf(os.Stdout, "%s\n\n", versionString("anyproxy")) 20 | fmt.Fprintf(os.Stdout, "Usage: %s -l listenaddress -p proxies \n", os.Args[0]) 21 | fmt.Fprintf(os.Stdout, " Proxies any tcp port transparently\n\n") 22 | fmt.Fprintf(os.Stdout, "Mandatory\n") 23 | fmt.Fprintf(os.Stdout, " -l=ADDRPORT Address and port to listen on (e.g., :3000 or 127.0.0.1:3000)\n") 24 | fmt.Fprintf(os.Stdout, " -p=PROXIES Address and ports of upstream proxy servers to use\n") 25 | fmt.Fprintf(os.Stdout, " (e.g., 10.1.1.1:80 will use http proxy, socks5://10.2.2.2:3128 use socks5 proxy,\n") 26 | fmt.Fprintf(os.Stdout, " tunnel://10.2.2.2:3001 use tunnel proxy)\n") 27 | fmt.Fprintf(os.Stdout, " -c=FILEPATH Config file path, default is router.yaml\n") 28 | fmt.Fprintf(os.Stdout, "Optional\n") 29 | fmt.Fprintf(os.Stdout, " -ws-listen Websocket address and port to listen on\n") 30 | fmt.Fprintf(os.Stdout, " -ws-connect Websocket Address and port to connect\n") 31 | fmt.Fprintf(os.Stdout, " -daemon Run as a Unix daemon\n") 32 | fmt.Fprintf(os.Stdout, " -mode Run mode(proxy, tunnel). proxy mode default\n") 33 | fmt.Fprintf(os.Stdout, " -debug Debug mode (0, 1, 2, 3)\n") 34 | fmt.Fprintf(os.Stdout, " -pprof Pprof port, disable if empty\n") 35 | fmt.Fprintf(os.Stdout, " -v Show build version\n\n") 36 | fmt.Fprintf(os.Stdout, " -h This usage message\n\n") 37 | 38 | fmt.Fprintf(os.Stdout, "Before starting anyproxy, be sure to change the number of available file handles to at least 65535\n") 39 | fmt.Fprintf(os.Stdout, "with \"ulimit -n 65535\"\n") //重要 40 | fmt.Fprintf(os.Stdout, "Some other tunables that enable higher performance:\n") 41 | fmt.Fprintf(os.Stdout, " net.core.netdev_max_backlog = 2048\n") 42 | fmt.Fprintf(os.Stdout, " net.core.somaxconn = 1024\n") 43 | fmt.Fprintf(os.Stdout, " net.core.rmem_default = 8388608\n") 44 | fmt.Fprintf(os.Stdout, " net.core.rmem_max = 16777216\n") 45 | fmt.Fprintf(os.Stdout, " net.core.wmem_max = 16777216\n") 46 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_tw_reuse = 1 \n") //重要 ,//sysctl -w net.ipv4.tcp_tw_reuse=1 47 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_fin_timeout = 30\n") //重要, //sysctl -w net.ipv4.tcp_fin_timeout=30 48 | fmt.Fprintf(os.Stdout, " net.ipv4.ip_local_port_range = 2000 65000\n") 49 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_window_scaling = 1\n") 50 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_max_syn_backlog = 3240000\n") 51 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_max_tw_buckets = 1440000\n") 52 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_mem = 50576 64768 98152\n") 53 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_rmem = 4096 87380 16777216\n") 54 | fmt.Fprintf(os.Stdout, " NOTE: if you see syn flood warnings in your logs, you need to adjust tcp_max_syn_backlog, tcp_synack_retries and tcp_abort_on_overflow\n") 55 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_syncookies = 1\n") 56 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_wmem = 4096 65536 16777216\n") 57 | fmt.Fprintf(os.Stdout, " net.ipv4.tcp_congestion_control = cubic\n\n") 58 | 59 | fmt.Fprintf(os.Stdout, "Report bugs to https://github.com/keminar/anyproxy or .\n") 60 | fmt.Fprintf(os.Stdout, "Thanks to https://github.com/ryanchapman/go-any-proxy.git\n") 61 | } 62 | 63 | // 版本 64 | func ShowVersion() { 65 | fmt.Fprintf(os.Stdout, "%s\n\n", versionString("anyproxy")) 66 | } 67 | 68 | func versionString(name string) (v string) { 69 | now := time.Now().Unix() 70 | buildNum := strings.ToUpper(strconv.FormatInt(now, 36)) 71 | buildDate := time.Unix(now, 0).Format(time.UnixDate) 72 | v = fmt.Sprintf("%s %s (build %v, %v)", name, version, buildNum, buildDate) 73 | v += fmt.Sprintf("\nGit Commit Hash: %s", gitHash) 74 | v += fmt.Sprintf("\nGoLang Version: %s", goVersion) 75 | return 76 | } 77 | -------------------------------------------------------------------------------- /utils/tools/string.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "crypto/md5" 5 | "encoding/hex" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // GetPort 从 127.0.0.1:3000 结构中取出3000 11 | func GetPort(addr string) string { 12 | for i := len(addr) - 1; i >= 0; i-- { 13 | if addr[i] == ':' { 14 | return addr[i+1:] 15 | } 16 | } 17 | return "" 18 | } 19 | 20 | // GetIp 从 127.0.0.1:3000 结构中取出127.0.0.1 21 | func GetRemoteIp(addr string) string { 22 | for i := len(addr) - 1; i >= 1; i-- { 23 | if addr[i] == ':' { 24 | return addr[0:i] 25 | } 26 | } 27 | return addr 28 | } 29 | 30 | func Md5Str(str string) (string, error) { 31 | h := md5.New() 32 | h.Write([]byte(str)) 33 | cipherStr := h.Sum(nil) 34 | return hex.EncodeToString(cipherStr), nil 35 | } 36 | 37 | // 支持只输入端口的形式 38 | func FillPort(port string) string { 39 | if !strings.Contains(port, ":") { 40 | d, err := strconv.Atoi(port) 41 | if err == nil && strconv.Itoa(d) == port { //说明输入为纯数字 42 | port = ":" + port 43 | } 44 | } 45 | return port 46 | } 47 | -------------------------------------------------------------------------------- /utils/trace/trace.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import "fmt" 4 | 5 | // ID 日志ID 6 | func ID(id uint) string { 7 | return fmt.Sprintf("ID #%d,", id) 8 | } 9 | --------------------------------------------------------------------------------