├── go.mod ├── Dockerfile ├── .gitignore ├── auth.go ├── cmd └── socks5 │ └── main.go ├── .github └── workflows │ └── go-cross-build.yml ├── LICENSE ├── udp_netip.go ├── README.md ├── simple_server.go ├── udp.go ├── common.go ├── all_test.go ├── client.go └── server.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/wzshiming/socks5 2 | 3 | go 1.15 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:alpine AS builder 2 | WORKDIR /go/src/github.com/wzshiming/socks5/ 3 | COPY . . 4 | ENV CGO_ENABLED=0 5 | RUN go install ./cmd/socks5 6 | 7 | FROM alpine 8 | EXPOSE 1080 9 | COPY --from=builder /go/bin/socks5 /usr/local/bin/ 10 | ENTRYPOINT [ "/usr/local/bin/socks5" ] 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | // AuthenticationFunc Authentication interface is implemented 4 | type AuthenticationFunc func(cmd Command, username, password string) bool 5 | 6 | // Auth authentication processing 7 | func (f AuthenticationFunc) Auth(cmd Command, username, password string) bool { 8 | return f(cmd, username, password) 9 | } 10 | 11 | // Authentication proxy authentication 12 | type Authentication interface { 13 | Auth(cmd Command, username, password string) bool 14 | } 15 | 16 | // UserAuth basic authentication 17 | func UserAuth(username, password string) Authentication { 18 | return AuthenticationFunc(func(c Command, u, p string) bool { 19 | return username == u && password == p 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /cmd/socks5/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | 8 | "github.com/wzshiming/socks5" 9 | ) 10 | 11 | var address string 12 | var username string 13 | var password string 14 | 15 | func init() { 16 | flag.StringVar(&address, "a", ":1080", "listen on the address") 17 | flag.StringVar(&username, "u", "", "username") 18 | flag.StringVar(&password, "p", "", "password") 19 | flag.Parse() 20 | } 21 | 22 | func main() { 23 | logger := log.New(os.Stderr, "[socks5] ", log.LstdFlags) 24 | svc := &socks5.Server{ 25 | Logger: logger, 26 | } 27 | if username != "" { 28 | svc.Authentication = socks5.UserAuth(username, password) 29 | } 30 | err := svc.ListenAndServe("tcp", address) 31 | if err != nil { 32 | logger.Println(err) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /.github/workflows/go-cross-build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Go 13 | uses: actions/setup-go@v2 14 | with: 15 | go-version: 1.17 16 | - name: Build Cross Platform 17 | uses: wzshiming/action-go-build-cross-plantform@v1 18 | - name: Upload Release Assets 19 | uses: wzshiming/action-upload-release-assets@v1 20 | env: 21 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | - name: Log into registry 23 | run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin 24 | - name: Upload Release Images 25 | uses: wzshiming/action-upload-release-images@v1 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 wzshiming 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /udp_netip.go: -------------------------------------------------------------------------------- 1 | //go:build go1.18 2 | // +build go1.18 3 | 4 | package socks5 5 | 6 | import ( 7 | "net" 8 | "net/netip" 9 | ) 10 | 11 | // ReadFromUDPAddrPort implements the net.UDPConn ReadFromUDPAddrPort method. 12 | func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { 13 | udpConn, ok := c.PacketConn.(*net.UDPConn) 14 | if !ok { 15 | return 0, addr, errUnsupportedMethod 16 | } 17 | return udpConn.ReadFromUDPAddrPort(b) 18 | } 19 | 20 | // ReadMsgUDPAddrPort implements the net.UDPConn ReadMsgUDPAddrPort method. 21 | func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { 22 | udpConn, ok := c.PacketConn.(*net.UDPConn) 23 | if !ok { 24 | return 0, 0, 0, addr, errUnsupportedMethod 25 | } 26 | return udpConn.ReadMsgUDPAddrPort(b, oob) 27 | } 28 | 29 | // WriteToUDPAddrPort implements the net.UDPConn WriteToUDPAddrPort method. 30 | func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { 31 | udpConn, ok := c.PacketConn.(*net.UDPConn) 32 | if !ok { 33 | return 0, errUnsupportedMethod 34 | } 35 | return udpConn.WriteToUDPAddrPort(b, addr) 36 | } 37 | 38 | // WriteMsgUDPAddrPort implements the net.UDPConn WriteMsgUDPAddrPort method. 39 | func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) { 40 | udpConn, ok := c.PacketConn.(*net.UDPConn) 41 | if !ok { 42 | return 0, 0, errUnsupportedMethod 43 | } 44 | return udpConn.WriteMsgUDPAddrPort(b, oob, addr) 45 | } 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # socks5 2 | 3 | Socks5/Socks5h server and client 4 | 5 | [![Build](https://github.com/wzshiming/socks5/actions/workflows/go-cross-build.yml/badge.svg)](https://github.com/wzshiming/socks5/actions/workflows/go-cross-build.yml) 6 | [![Go Report Card](https://goreportcard.com/badge/github.com/wzshiming/socks5)](https://goreportcard.com/report/github.com/wzshiming/socks5) 7 | [![GoDoc](https://pkg.go.dev/badge/github.com/wzshiming/socks5)](https://pkg.go.dev/github.com/wzshiming/socks5) 8 | [![GitHub license](https://img.shields.io/github/license/wzshiming/socks5.svg)](https://github.com/wzshiming/socks5/blob/master/LICENSE) 9 | 10 | This project is to add protocol support for the [Bridge](https://github.com/wzshiming/bridge), or it can be used alone 11 | 12 | The following is the implementation of other proxy protocols 13 | 14 | - [Socks4](https://github.com/wzshiming/socks4) 15 | - [HTTP Proxy](https://github.com/wzshiming/httpproxy) 16 | - [Shadow Socks](https://github.com/wzshiming/shadowsocks) 17 | - [SSH Proxy](https://github.com/wzshiming/sshproxy) 18 | - [Any Proxy](https://github.com/wzshiming/anyproxy) 19 | - [Emux](https://github.com/wzshiming/emux) 20 | 21 | ## Usage 22 | 23 | [API Documentation](https://godoc.org/github.com/wzshiming/socks5) 24 | 25 | [Example](https://github.com/wzshiming/socks5/blob/master/cmd/socks5/main.go) 26 | 27 | - [x] Support for the CONNECT command 28 | - [x] Support for the BIND command 29 | - [x] Support for the ASSOCIATE command 30 | 31 | ## License 32 | 33 | Licensed under the MIT License. See [LICENSE](https://github.com/wzshiming/socks5/blob/master/LICENSE) for the full license text. 34 | -------------------------------------------------------------------------------- /simple_server.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/url" 8 | "time" 9 | ) 10 | 11 | // SimpleServer is a simplified server, which can be configured as easily as client. 12 | type SimpleServer struct { 13 | Server 14 | Listener net.Listener 15 | Network string 16 | Address string 17 | Username string 18 | Password string 19 | } 20 | 21 | // NewServer creates a new NewSimpleServer 22 | func NewSimpleServer(addr string) (*SimpleServer, error) { 23 | s := &SimpleServer{} 24 | u, err := url.Parse(addr) 25 | if err != nil { 26 | return nil, err 27 | } 28 | switch u.Scheme { 29 | case "socks5", "socks5h": 30 | default: 31 | return nil, fmt.Errorf("unsupported protocol '%s'", u.Scheme) 32 | } 33 | host := u.Host 34 | port := u.Port() 35 | if port == "" { 36 | port = "1080" 37 | hostname := u.Hostname() 38 | host = net.JoinHostPort(hostname, port) 39 | } 40 | if u.User != nil { 41 | s.Username = u.User.Username() 42 | s.Password, _ = u.User.Password() 43 | s.Authentication = UserAuth(s.Username, s.Password) 44 | } 45 | 46 | s.Address = host 47 | s.Network = "tcp" 48 | s.ListenBindReuseTimeout = time.Second / 2 49 | return s, nil 50 | } 51 | 52 | // Run the server 53 | func (s *SimpleServer) Run(ctx context.Context) error { 54 | var listenConfig net.ListenConfig 55 | if s.Listener == nil { 56 | listener, err := listenConfig.Listen(ctx, s.Network, s.Address) 57 | if err != nil { 58 | return err 59 | } 60 | s.Listener = listener 61 | } 62 | s.Address = s.Listener.Addr().String() 63 | return s.Serve(s.Listener) 64 | } 65 | 66 | // Start the server 67 | func (s *SimpleServer) Start(ctx context.Context) error { 68 | var listenConfig net.ListenConfig 69 | if s.Listener == nil { 70 | listener, err := listenConfig.Listen(ctx, s.Network, s.Address) 71 | if err != nil { 72 | return err 73 | } 74 | s.Listener = listener 75 | } 76 | s.Address = s.Listener.Addr().String() 77 | go s.Serve(s.Listener) 78 | return nil 79 | } 80 | 81 | // Close closes the listener 82 | func (s *SimpleServer) Close() error { 83 | if s.Listener == nil { 84 | return nil 85 | } 86 | return s.Listener.Close() 87 | } 88 | 89 | // ProxyURL returns the URL of the proxy 90 | func (s *SimpleServer) ProxyURL() string { 91 | u := url.URL{ 92 | Scheme: "socks5", 93 | Host: s.Address, 94 | } 95 | if s.Username != "" { 96 | u.User = url.UserPassword(s.Username, s.Password) 97 | } 98 | return u.String() 99 | } 100 | -------------------------------------------------------------------------------- /udp.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net" 7 | "time" 8 | ) 9 | 10 | var ( 11 | errBadHeader = errors.New("bad header") 12 | errUnsupportedMethod = errors.New("unsupported method") 13 | ) 14 | 15 | type UDPConn struct { 16 | bufRead [maxUdpPacket]byte 17 | bufWrite [maxUdpPacket]byte 18 | proxyAddress net.Addr 19 | defaultTarget net.Addr 20 | prefix []byte 21 | net.PacketConn 22 | } 23 | 24 | func NewUDPConn(raw net.PacketConn, proxyAddress net.Addr, defaultTarget net.Addr) (*UDPConn, error) { 25 | conn := &UDPConn{ 26 | PacketConn: raw, 27 | proxyAddress: proxyAddress, 28 | defaultTarget: defaultTarget, 29 | prefix: []byte{0, 0, 0}, 30 | } 31 | return conn, nil 32 | } 33 | 34 | // ReadFrom implements the net.PacketConn ReadFrom method. 35 | func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 36 | n, addr, err = c.PacketConn.ReadFrom(c.bufRead[:]) 37 | if err != nil { 38 | return 0, nil, err 39 | } 40 | if n < len(c.prefix) || addr.String() != c.proxyAddress.String() { 41 | return 0, nil, errBadHeader 42 | } 43 | buf := bytes.NewBuffer(c.bufRead[len(c.prefix):n]) 44 | a, err := readAddr(buf) 45 | if err != nil { 46 | return 0, nil, err 47 | } 48 | n = copy(p, buf.Bytes()) 49 | return n, a, nil 50 | } 51 | 52 | // WriteTo implements the net.PacketConn WriteTo method. 53 | func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 54 | buf := bytes.NewBuffer(c.bufWrite[:0]) 55 | buf.Write(c.prefix) 56 | err = writeAddrWithStr(buf, addr.String()) 57 | if err != nil { 58 | return 0, err 59 | } 60 | n, err = buf.Write(p) 61 | if err != nil { 62 | return 0, err 63 | } 64 | 65 | data := buf.Bytes() 66 | _, err = c.PacketConn.WriteTo(data, c.proxyAddress) 67 | if err != nil { 68 | return 0, err 69 | } 70 | return n, nil 71 | } 72 | 73 | // Read implements the net.Conn Read method. 74 | func (c *UDPConn) Read(b []byte) (int, error) { 75 | n, addr, err := c.ReadFrom(b) 76 | if err != nil { 77 | return 0, err 78 | } 79 | if addr.String() != c.defaultTarget.String() { 80 | return c.Read(b) 81 | } 82 | return n, nil 83 | } 84 | 85 | // Write implements the net.Conn Write method. 86 | func (c *UDPConn) Write(b []byte) (int, error) { 87 | return c.WriteTo(b, c.defaultTarget) 88 | } 89 | 90 | // RemoteAddr implements the net.Conn RemoteAddr method. 91 | func (c *UDPConn) RemoteAddr() net.Addr { 92 | return c.defaultTarget 93 | } 94 | 95 | // SetReadBuffer implements the net.UDPConn SetReadBuffer method. 96 | func (c *UDPConn) SetReadBuffer(bytes int) error { 97 | udpConn, ok := c.PacketConn.(*net.UDPConn) 98 | if !ok { 99 | return errUnsupportedMethod 100 | } 101 | return udpConn.SetReadBuffer(bytes) 102 | } 103 | 104 | // SetWriteBuffer implements the net.UDPConn SetWriteBuffer method. 105 | func (c *UDPConn) SetWriteBuffer(bytes int) error { 106 | udpConn, ok := c.PacketConn.(*net.UDPConn) 107 | if !ok { 108 | return errUnsupportedMethod 109 | } 110 | return udpConn.SetWriteBuffer(bytes) 111 | } 112 | 113 | // SetDeadline implements the Conn SetDeadline method. 114 | func (c *UDPConn) SetDeadline(t time.Time) error { 115 | udpConn, ok := c.PacketConn.(*net.UDPConn) 116 | if !ok { 117 | return errUnsupportedMethod 118 | } 119 | return udpConn.SetDeadline(t) 120 | } 121 | 122 | // SetReadDeadline implements the Conn SetReadDeadline method. 123 | func (c *UDPConn) SetReadDeadline(t time.Time) error { 124 | udpConn, ok := c.PacketConn.(*net.UDPConn) 125 | if !ok { 126 | return errUnsupportedMethod 127 | } 128 | return udpConn.SetReadDeadline(t) 129 | } 130 | 131 | // SetWriteDeadline implements the Conn SetWriteDeadline method. 132 | func (c *UDPConn) SetWriteDeadline(t time.Time) error { 133 | udpConn, ok := c.PacketConn.(*net.UDPConn) 134 | if !ok { 135 | return errUnsupportedMethod 136 | } 137 | return udpConn.SetWriteDeadline(t) 138 | } 139 | 140 | // ReadFromUDP implements the net.UDPConn ReadFromUDP method. 141 | func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { 142 | udpConn, ok := c.PacketConn.(*net.UDPConn) 143 | if !ok { 144 | return 0, nil, errUnsupportedMethod 145 | } 146 | return udpConn.ReadFromUDP(b) 147 | } 148 | 149 | // ReadMsgUDP implements the net.UDPConn ReadMsgUDP method. 150 | func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { 151 | udpConn, ok := c.PacketConn.(*net.UDPConn) 152 | if !ok { 153 | return 0, 0, 0, nil, errUnsupportedMethod 154 | } 155 | return udpConn.ReadMsgUDP(b, oob) 156 | } 157 | 158 | // WriteToUDP implements the net.UDPConn WriteToUDP method. 159 | func (c *UDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { 160 | udpConn, ok := c.PacketConn.(*net.UDPConn) 161 | if !ok { 162 | return 0, errUnsupportedMethod 163 | } 164 | return udpConn.WriteToUDP(b, addr) 165 | } 166 | 167 | // WriteMsgUDP implements the net.UDPConn WriteMsgUDP method. 168 | func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { 169 | udpConn, ok := c.PacketConn.(*net.UDPConn) 170 | if !ok { 171 | return 0, 0, errUnsupportedMethod 172 | } 173 | return udpConn.WriteMsgUDP(b, oob, addr) 174 | } 175 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "errors" 7 | "io" 8 | "math" 9 | "net" 10 | "os" 11 | "reflect" 12 | "runtime" 13 | "strconv" 14 | "strings" 15 | ) 16 | 17 | var ( 18 | errStringTooLong = errors.New("string too long") 19 | errUserAuthFailed = errors.New("user authentication failed") 20 | errNoSupportedAuth = errors.New("no supported authentication mechanism") 21 | errUnrecognizedAddrType = errors.New("unrecognized address type") 22 | ) 23 | 24 | const ( 25 | maxUdpPacket = math.MaxUint16 - 28 26 | // maxHeaderSize is the maximum size of SOCKS5 header (3 bytes) plus address (up to 255 bytes for domain name) 27 | // 3 bytes (header: RSV, RSV, FRAG) + 1 byte (ATYP) + 1 byte (domain length) + 255 bytes (domain) + 2 bytes (port) = 262 bytes 28 | maxHeaderSize = 3 + 1 + 1 + 255 + 2 29 | ) 30 | 31 | const ( 32 | socks5Version = 0x05 33 | ) 34 | 35 | const ( 36 | ConnectCommand Command = 0x01 37 | BindCommand Command = 0x02 38 | AssociateCommand Command = 0x03 39 | ) 40 | 41 | // Command is a SOCKS Command. 42 | type Command byte 43 | 44 | func (cmd Command) String() string { 45 | switch cmd { 46 | case ConnectCommand: 47 | return "socks connect" 48 | case BindCommand: 49 | return "socks bind" 50 | case AssociateCommand: 51 | return "socks associate" 52 | default: 53 | return "socks " + strconv.Itoa(int(cmd)) 54 | } 55 | } 56 | 57 | const ( 58 | successReply reply = 0x00 59 | serverFailure reply = 0x01 60 | ruleFailure reply = 0x02 61 | networkUnreachable reply = 0x03 62 | hostUnreachable reply = 0x04 63 | connectionRefused reply = 0x05 64 | ttlExpired reply = 0x06 65 | commandNotSupported reply = 0x07 66 | addrTypeNotSupported reply = 0x08 67 | ) 68 | 69 | func errToReply(err error) reply { 70 | if err == nil { 71 | return successReply 72 | } 73 | msg := err.Error() 74 | resp := hostUnreachable 75 | if strings.Contains(msg, "refused") { 76 | resp = connectionRefused 77 | } else if strings.Contains(msg, "network is unreachable") { 78 | resp = networkUnreachable 79 | } 80 | return resp 81 | } 82 | 83 | // reply is a SOCKS Command reply code. 84 | type reply byte 85 | 86 | func (code reply) String() string { 87 | switch code { 88 | case successReply: 89 | return "succeeded" 90 | case serverFailure: 91 | return "general SOCKS server failure" 92 | case ruleFailure: 93 | return "connection not allowed by ruleset" 94 | case networkUnreachable: 95 | return "network unreachable" 96 | case hostUnreachable: 97 | return "host unreachable" 98 | case connectionRefused: 99 | return "connection refused" 100 | case ttlExpired: 101 | return "TTL expired" 102 | case commandNotSupported: 103 | return "Command not supported" 104 | case addrTypeNotSupported: 105 | return "address type not supported" 106 | default: 107 | return "unknown code: " + strconv.Itoa(int(code)) 108 | } 109 | } 110 | 111 | const ( 112 | ipv4Address = 0x01 113 | fqdnAddress = 0x03 114 | ipv6Address = 0x04 115 | ) 116 | 117 | // address is a SOCKS-specific address. 118 | // Either Name or IP is used exclusively. 119 | type address struct { 120 | Name string // fully-qualified domain name 121 | IP net.IP 122 | Port int 123 | } 124 | 125 | func (a *address) Network() string { return "socks5" } 126 | 127 | func (a *address) String() string { 128 | if a == nil { 129 | return "" 130 | } 131 | return a.Address() 132 | } 133 | 134 | // Address returns a string suitable to dial; prefer returning IP-based 135 | // address, fallback to Name 136 | func (a address) Address() string { 137 | port := strconv.Itoa(a.Port) 138 | if 0 != len(a.IP) { 139 | return net.JoinHostPort(a.IP.String(), port) 140 | } 141 | return net.JoinHostPort(a.Name, port) 142 | } 143 | 144 | // authMethod is a SOCKS authentication method. 145 | type authMethod byte 146 | 147 | const ( 148 | noAuth authMethod = 0x00 // no authentication required 149 | gssapiAuth authMethod = 0x01 // use GSSAPI 150 | userAuth authMethod = 0x02 // use username/password 151 | noAcceptable authMethod = 0xff // no acceptable authentication methods 152 | ) 153 | 154 | const ( 155 | userAuthVersion = 0x01 156 | authSuccess = 0x00 157 | authFailure = 0x01 158 | ) 159 | 160 | func readBytes(r io.Reader) ([]byte, error) { 161 | var buf [1]byte 162 | _, err := r.Read(buf[:]) 163 | if err != nil { 164 | return nil, err 165 | } 166 | bytes := make([]byte, buf[0]) 167 | _, err = io.ReadFull(r, bytes) 168 | if err != nil { 169 | return nil, err 170 | } 171 | return bytes, nil 172 | } 173 | 174 | func writeBytes(w io.Writer, b []byte) error { 175 | _, err := w.Write([]byte{byte(len(b))}) 176 | if err != nil { 177 | return err 178 | } 179 | _, err = w.Write(b) 180 | return err 181 | } 182 | 183 | func readByte(r io.Reader) (byte, error) { 184 | var buf [1]byte 185 | _, err := r.Read(buf[:]) 186 | if err != nil { 187 | return 0, err 188 | } 189 | return buf[0], nil 190 | } 191 | 192 | func readAddr(r io.Reader) (*address, error) { 193 | address := &address{} 194 | 195 | var addrType [1]byte 196 | if _, err := r.Read(addrType[:]); err != nil { 197 | return nil, err 198 | } 199 | 200 | switch addrType[0] { 201 | case ipv4Address: 202 | addr := make(net.IP, net.IPv4len) 203 | if _, err := io.ReadFull(r, addr); err != nil { 204 | return nil, err 205 | } 206 | address.IP = addr 207 | case ipv6Address: 208 | addr := make(net.IP, net.IPv6len) 209 | if _, err := io.ReadFull(r, addr); err != nil { 210 | return nil, err 211 | } 212 | address.IP = addr 213 | case fqdnAddress: 214 | if _, err := r.Read(addrType[:]); err != nil { 215 | return nil, err 216 | } 217 | addrLen := int(addrType[0]) 218 | fqdn := make([]byte, addrLen) 219 | if _, err := io.ReadFull(r, fqdn); err != nil { 220 | return nil, err 221 | } 222 | address.Name = string(fqdn) 223 | default: 224 | return nil, errUnrecognizedAddrType 225 | } 226 | var port [2]byte 227 | if _, err := io.ReadFull(r, port[:]); err != nil { 228 | return nil, err 229 | } 230 | address.Port = int(binary.BigEndian.Uint16(port[:])) 231 | return address, nil 232 | } 233 | 234 | func writeAddr(w io.Writer, addr *address) error { 235 | if addr == nil { 236 | _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0, 0, 0}) 237 | if err != nil { 238 | return err 239 | } 240 | return nil 241 | } 242 | if addr.IP != nil { 243 | if ip4 := addr.IP.To4(); ip4 != nil { 244 | _, err := w.Write([]byte{ipv4Address}) 245 | if err != nil { 246 | return err 247 | } 248 | _, err = w.Write(ip4) 249 | if err != nil { 250 | return err 251 | } 252 | } else if ip6 := addr.IP.To16(); ip6 != nil { 253 | _, err := w.Write([]byte{ipv6Address}) 254 | if err != nil { 255 | return err 256 | } 257 | _, err = w.Write(ip6) 258 | if err != nil { 259 | return err 260 | } 261 | } else { 262 | _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0}) 263 | if err != nil { 264 | return err 265 | } 266 | } 267 | } else if addr.Name != "" { 268 | if len(addr.Name) > 255 { 269 | return errStringTooLong 270 | } 271 | _, err := w.Write([]byte{fqdnAddress, byte(len(addr.Name))}) 272 | if err != nil { 273 | return err 274 | } 275 | _, err = w.Write([]byte(addr.Name)) 276 | if err != nil { 277 | return err 278 | } 279 | } else { 280 | _, err := w.Write([]byte{ipv4Address, 0, 0, 0, 0}) 281 | if err != nil { 282 | return err 283 | } 284 | } 285 | var p [2]byte 286 | binary.BigEndian.PutUint16(p[:], uint16(addr.Port)) 287 | _, err := w.Write(p[:]) 288 | return err 289 | } 290 | 291 | func writeAddrWithStr(w io.Writer, addr string) error { 292 | host, port, err := splitHostPort(addr) 293 | if err != nil { 294 | return err 295 | } 296 | if ip := net.ParseIP(host); ip != nil { 297 | return writeAddr(w, &address{IP: ip, Port: port}) 298 | } 299 | return writeAddr(w, &address{Name: host, Port: port}) 300 | } 301 | 302 | func splitHostPort(address string) (string, int, error) { 303 | host, port, err := net.SplitHostPort(address) 304 | if err != nil { 305 | return "", 0, err 306 | } 307 | portnum, err := strconv.Atoi(port) 308 | if err != nil { 309 | return "", 0, err 310 | } 311 | if 0 > portnum || portnum > 0xffff { 312 | return "", 0, errors.New("port number out of range " + port) 313 | } 314 | return host, portnum, nil 315 | } 316 | 317 | // isClosedConnError reports whether err is an error from use of a closed 318 | // network connection. 319 | func isClosedConnError(err error) bool { 320 | if err == nil { 321 | return false 322 | } 323 | 324 | str := err.Error() 325 | if strings.Contains(str, "use of closed network connection") { 326 | return true 327 | } 328 | 329 | if runtime.GOOS == "windows" { 330 | if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { 331 | if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { 332 | const WSAECONNABORTED = 10053 333 | const WSAECONNRESET = 10054 334 | if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { 335 | return true 336 | } 337 | } 338 | } 339 | } 340 | return false 341 | } 342 | 343 | func errno(v error) uintptr { 344 | if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { 345 | return uintptr(rv.Uint()) 346 | } 347 | return 0 348 | } 349 | 350 | // tunnel create tunnels for two io.ReadWriteCloser 351 | func tunnel(ctx context.Context, c1, c2 io.ReadWriteCloser, buf1, buf2 []byte) error { 352 | errCh := make(chan error, 2) 353 | go func() { 354 | _, err := io.CopyBuffer(c1, c2, buf1) 355 | errCh <- err 356 | }() 357 | go func() { 358 | _, err := io.CopyBuffer(c2, c1, buf2) 359 | errCh <- err 360 | }() 361 | defer func() { 362 | _ = c1.Close() 363 | _ = c2.Close() 364 | }() 365 | 366 | select { 367 | case err := <-errCh: 368 | return err 369 | case <-ctx.Done(): 370 | return ctx.Err() 371 | } 372 | } 373 | 374 | // BytesPool is an interface for getting and returning temporary 375 | // bytes for use by io.CopyBuffer. 376 | type BytesPool interface { 377 | Get() []byte 378 | Put([]byte) 379 | } 380 | -------------------------------------------------------------------------------- /all_test.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/rand" 7 | "net" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strings" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | var testServer = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 17 | rw.Write([]byte("ok")) 18 | })) 19 | 20 | func TestServerAndStdClient(t *testing.T) { 21 | listen, err := net.Listen("tcp", ":0") 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | defer listen.Close() 26 | 27 | proxy := NewServer() 28 | go proxy.Serve(listen) 29 | 30 | cli := testServer.Client() 31 | cli.Transport = &http.Transport{ 32 | Proxy: func(request *http.Request) (*url.URL, error) { 33 | return url.Parse("socks5://" + listen.Addr().String()) 34 | }, 35 | } 36 | resp, err := cli.Get(testServer.URL) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | resp.Body.Close() 41 | } 42 | 43 | func TestServerAndAuthStdClient(t *testing.T) { 44 | listen, err := net.Listen("tcp", ":0") 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | defer listen.Close() 49 | 50 | proxy := NewServer() 51 | proxy.Authentication = UserAuth("u", "p") 52 | go proxy.Serve(listen) 53 | 54 | cli := testServer.Client() 55 | cli.Transport = &http.Transport{ 56 | Proxy: func(request *http.Request) (*url.URL, error) { 57 | return url.Parse("socks5://u:p@" + listen.Addr().String()) 58 | }, 59 | } 60 | resp, err := cli.Get(testServer.URL) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | resp.Body.Close() 65 | } 66 | 67 | func TestServerAndAuthClient(t *testing.T) { 68 | listen, err := net.Listen("tcp", ":0") 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | defer listen.Close() 73 | 74 | proxy := NewServer() 75 | proxy.Authentication = UserAuth("u", "p") 76 | go proxy.Serve(listen) 77 | 78 | dial, err := NewDialer("socks5://u:p@" + listen.Addr().String()) 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | cli := testServer.Client() 83 | cli.Transport = &http.Transport{ 84 | DialContext: dial.DialContext, 85 | } 86 | 87 | resp, err := cli.Get(testServer.URL) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | resp.Body.Close() 92 | 93 | } 94 | 95 | func TestServerAndClient(t *testing.T) { 96 | listen, err := net.Listen("tcp", ":0") 97 | if err != nil { 98 | t.Fatal(err) 99 | } 100 | defer listen.Close() 101 | 102 | proxy := NewServer() 103 | go proxy.Serve(listen) 104 | 105 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | cli := testServer.Client() 110 | cli.Transport = &http.Transport{ 111 | DialContext: dial.DialContext, 112 | } 113 | 114 | resp, err := cli.Get(testServer.URL) 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | resp.Body.Close() 119 | 120 | } 121 | 122 | func TestServerAndClientWithDomain(t *testing.T) { 123 | listen, err := net.Listen("tcp", ":0") 124 | if err != nil { 125 | t.Fatal(err) 126 | } 127 | defer listen.Close() 128 | 129 | proxy := NewServer() 130 | go proxy.Serve(listen) 131 | 132 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 133 | if err != nil { 134 | t.Fatal(err) 135 | } 136 | cli := testServer.Client() 137 | cli.Transport = &http.Transport{ 138 | DialContext: dial.DialContext, 139 | } 140 | resp, err := cli.Get(strings.ReplaceAll(testServer.URL, "127.0.0.1", "localhost")) 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | resp.Body.Close() 145 | } 146 | 147 | func TestServerAndClientWithServerDomain(t *testing.T) { 148 | listen, err := net.Listen("tcp", ":0") 149 | if err != nil { 150 | t.Fatal(err) 151 | } 152 | defer listen.Close() 153 | 154 | proxy := NewServer() 155 | go proxy.Serve(listen) 156 | 157 | dial, err := NewDialer("socks5h://" + listen.Addr().String()) 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | cli := testServer.Client() 162 | cli.Transport = &http.Transport{ 163 | DialContext: dial.DialContext, 164 | } 165 | resp, err := cli.Get(strings.ReplaceAll(testServer.URL, "127.0.0.1", "localhost")) 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | resp.Body.Close() 170 | } 171 | 172 | func TestUDP(t *testing.T) { 173 | packet, err := net.ListenPacket("udp", "127.0.0.1:0") 174 | if err != nil { 175 | t.Fatal(err) 176 | } 177 | defer packet.Close() 178 | go func() { 179 | var buf [maxUdpPacket]byte 180 | for { 181 | n, addr, err := packet.ReadFrom(buf[:]) 182 | if err != nil { 183 | return 184 | } 185 | _, err = packet.WriteTo(buf[:n], addr) 186 | if err != nil { 187 | return 188 | } 189 | } 190 | }() 191 | 192 | listen, err := net.Listen("tcp", ":0") 193 | if err != nil { 194 | t.Fatal(err) 195 | } 196 | defer listen.Close() 197 | 198 | proxy := NewServer() 199 | go proxy.Serve(listen) 200 | 201 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 202 | if err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | conn, err := dial.Dial("udp", packet.LocalAddr().String()) 207 | if err != nil { 208 | t.Fatal(err) 209 | } 210 | 211 | want := make([]byte, 1024) 212 | rand.Read(want) 213 | _, err = conn.Write(want) 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | 218 | got := make([]byte, len(want)) 219 | _, err = conn.Read(got) 220 | if err != nil { 221 | t.Fatal(err) 222 | } 223 | if !bytes.Equal(want, got) { 224 | t.Fail() 225 | } 226 | } 227 | 228 | func TestUDPMultiple(t *testing.T) { 229 | // Create multiple UDP echo servers 230 | const numServers = 3 231 | echoServers := make([]net.PacketConn, numServers) 232 | for i := 0; i < numServers; i++ { 233 | packet, err := net.ListenPacket("udp", "127.0.0.1:0") 234 | if err != nil { 235 | t.Fatal(err) 236 | } 237 | defer packet.Close() 238 | echoServers[i] = packet 239 | 240 | // Start echo server 241 | go func(p net.PacketConn) { 242 | var buf [maxUdpPacket]byte 243 | for { 244 | n, addr, err := p.ReadFrom(buf[:]) 245 | if err != nil { 246 | return 247 | } 248 | _, err = p.WriteTo(buf[:n], addr) 249 | if err != nil { 250 | return 251 | } 252 | } 253 | }(packet) 254 | } 255 | 256 | // Start SOCKS5 proxy 257 | listen, err := net.Listen("tcp", ":0") 258 | if err != nil { 259 | t.Fatal(err) 260 | } 261 | defer listen.Close() 262 | 263 | proxy := NewServer() 264 | go proxy.Serve(listen) 265 | 266 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 267 | if err != nil { 268 | t.Fatal(err) 269 | } 270 | 271 | // Create UDP association to first server 272 | conn, err := dial.Dial("udp", echoServers[0].LocalAddr().String()) 273 | if err != nil { 274 | t.Fatal(err) 275 | } 276 | defer conn.Close() 277 | 278 | pc, ok := conn.(net.PacketConn) 279 | if !ok { 280 | t.Fatal("connection is not a PacketConn") 281 | } 282 | 283 | // Test sending to multiple different destinations 284 | for i := 0; i < numServers; i++ { 285 | echoAddr := echoServers[i].LocalAddr() 286 | want := []byte(strings.Repeat(string(rune('A'+i)), 100)) 287 | 288 | // Send to this echo server 289 | _, err = pc.WriteTo(want, echoAddr) 290 | if err != nil { 291 | t.Fatalf("WriteTo server %d failed: %v", i, err) 292 | } 293 | 294 | // Read response 295 | got := make([]byte, len(want)*2) 296 | n, addr, err := pc.ReadFrom(got) 297 | if err != nil { 298 | t.Fatalf("ReadFrom server %d failed: %v", i, err) 299 | } 300 | got = got[:n] 301 | 302 | // Verify response came from correct server 303 | if addr.String() != echoAddr.String() { 304 | t.Errorf("Response from wrong address: got %v, want %v", addr, echoAddr) 305 | } 306 | 307 | // Verify data 308 | if !bytes.Equal(want, got) { 309 | t.Errorf("Echo from server %d failed: got %x, want %x", i, got, want) 310 | } 311 | } 312 | } 313 | 314 | func TestBind(t *testing.T) { 315 | listen, err := net.Listen("tcp", ":0") 316 | if err != nil { 317 | t.Fatal(err) 318 | } 319 | defer listen.Close() 320 | 321 | proxy := NewServer() 322 | go proxy.Serve(listen) 323 | 324 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 325 | if err != nil { 326 | t.Fatal(err) 327 | } 328 | 329 | listener, err := dial.Listen(context.Background(), "tcp", ":10000") 330 | if err != nil { 331 | t.Fatal(err) 332 | } 333 | go http.Serve(listener, nil) 334 | time.Sleep(time.Second / 10) 335 | resp, err := http.Get("http://127.0.0.1:10000") 336 | if err != nil { 337 | t.Fatal(err) 338 | } 339 | resp.Body.Close() 340 | } 341 | 342 | func TestBindWithSerialAndParallel(t *testing.T) { 343 | listen, err := net.Listen("tcp", ":0") 344 | if err != nil { 345 | t.Fatal(err) 346 | } 347 | defer listen.Close() 348 | 349 | proxy := NewServer() 350 | go proxy.Serve(listen) 351 | 352 | dial, err := NewDialer("socks5://" + listen.Addr().String()) 353 | if err != nil { 354 | t.Fatal(err) 355 | } 356 | 357 | listener, err := dial.Listen(context.Background(), "tcp", ":10001") 358 | if err != nil { 359 | t.Fatal(err) 360 | } 361 | go http.Serve(listener, nil) 362 | time.Sleep(time.Second) 363 | 364 | for i := 0; i < 3; i++ { 365 | resp, err := http.Get("http://127.0.0.1:10001") 366 | if err != nil { 367 | t.Fatal(err) 368 | } 369 | resp.Body.Close() 370 | } 371 | 372 | const numRequests = 5 373 | errCh := make(chan error, numRequests) 374 | 375 | for i := 0; i < numRequests; i++ { 376 | go func() { 377 | resp, err := http.Get("http://127.0.0.1:10001") 378 | if err != nil { 379 | errCh <- err 380 | return 381 | } 382 | resp.Body.Close() 383 | errCh <- nil 384 | }() 385 | } 386 | 387 | for i := 0; i < numRequests; i++ { 388 | if err := <-errCh; err != nil { 389 | t.Fatal(err) 390 | } 391 | } 392 | } 393 | 394 | func TestSimpleServer(t *testing.T) { 395 | s, err := NewSimpleServer("socks5://u:p@:0") 396 | if err != nil { 397 | t.Fatal(err) 398 | } 399 | s.Start(context.Background()) 400 | defer s.Close() 401 | 402 | dial, err := NewDialer(s.ProxyURL()) 403 | if err != nil { 404 | t.Fatal(err) 405 | } 406 | cli := testServer.Client() 407 | cli.Transport = &http.Transport{ 408 | DialContext: dial.DialContext, 409 | } 410 | 411 | resp, err := cli.Get(testServer.URL) 412 | if err != nil { 413 | t.Fatal(err) 414 | } 415 | resp.Body.Close() 416 | } 417 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "net/url" 10 | "time" 11 | ) 12 | 13 | // Dialer is a SOCKS5 dialer. 14 | type Dialer struct { 15 | // ProxyNetwork network between a proxy server and a client 16 | ProxyNetwork string 17 | // ProxyAddress proxy server address 18 | ProxyAddress string 19 | // ProxyDial specifies the optional dial function for 20 | // establishing the transport connection. 21 | ProxyDial func(ctx context.Context, network string, address string) (net.Conn, error) 22 | // ProxyPacketDial specifies the optional proxyPacketDial function for 23 | // establishing the transport connection. 24 | ProxyPacketDial func(ctx context.Context, network string, address string) (net.PacketConn, error) 25 | // Username use username authentication if not empty 26 | Username string 27 | // Password use password authentication if not empty, 28 | // only valid if username is set 29 | Password string 30 | // IsResolve resolve domain name on locally 31 | IsResolve bool 32 | // Resolver optionally specifies an alternate resolver to use 33 | Resolver *net.Resolver 34 | // Timeout is the maximum amount of time a dial will wait for 35 | // a connect to complete. The default is no timeout 36 | Timeout time.Duration 37 | } 38 | 39 | // NewDialer returns a new Dialer that dials through the provided 40 | // proxy server's network and address. 41 | func NewDialer(addr string) (*Dialer, error) { 42 | d := &Dialer{ 43 | ProxyNetwork: "tcp", 44 | Timeout: time.Minute, 45 | } 46 | u, err := url.Parse(addr) 47 | if err != nil { 48 | return nil, err 49 | } 50 | switch u.Scheme { 51 | case "socks5": 52 | d.IsResolve = true 53 | case "socks5h": 54 | default: 55 | return nil, fmt.Errorf("unsupported protocol '%s'", u.Scheme) 56 | } 57 | host := u.Host 58 | port := u.Port() 59 | if port == "" { 60 | port = "1080" 61 | hostname := u.Hostname() 62 | host = net.JoinHostPort(hostname, port) 63 | } 64 | if u.User != nil { 65 | d.Username = u.User.Username() 66 | d.Password, _ = u.User.Password() 67 | } 68 | d.ProxyAddress = host 69 | return d, nil 70 | } 71 | 72 | // DialContext connects to the provided address on the provided network. 73 | func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 74 | switch network { 75 | default: 76 | return nil, fmt.Errorf("unsupported network %q", network) 77 | case "tcp", "tcp4", "tcp6": 78 | return d.do(ctx, ConnectCommand, address) 79 | case "udp", "udp4", "udp6": 80 | return d.do(ctx, AssociateCommand, address) 81 | } 82 | } 83 | 84 | // Dial connects to the provided address on the provided network. 85 | func (d *Dialer) Dial(network, address string) (net.Conn, error) { 86 | return d.DialContext(context.Background(), network, address) 87 | } 88 | 89 | func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Listener, error) { 90 | switch network { 91 | default: 92 | return nil, fmt.Errorf("unsupported network %q", network) 93 | case "tcp", "tcp4", "tcp6": 94 | } 95 | return &listener{ctx: ctx, d: d, address: address}, nil 96 | } 97 | 98 | func (d *Dialer) do(ctx context.Context, cmd Command, address string) (net.Conn, error) { 99 | if d.IsResolve { 100 | host, port, err := net.SplitHostPort(address) 101 | if err != nil { 102 | return nil, err 103 | } 104 | if host != "" { 105 | ip := net.ParseIP(host) 106 | if ip == nil { 107 | ipaddr, err := d.resolver().LookupIP(ctx, "ip4", host) 108 | if err != nil { 109 | ipaddr, err = d.resolver().LookupIP(ctx, "ip", host) 110 | if err != nil { 111 | return nil, err 112 | } 113 | } 114 | host := ipaddr[0].String() 115 | address = net.JoinHostPort(host, port) 116 | } 117 | } 118 | } 119 | 120 | conn, err := d.proxyDial(ctx, d.ProxyNetwork, d.ProxyAddress) 121 | if err != nil { 122 | return nil, err 123 | } 124 | 125 | return d.connect(ctx, conn, cmd, address) 126 | } 127 | 128 | func (d *Dialer) connect(ctx context.Context, conn net.Conn, cmd Command, address string) (net.Conn, error) { 129 | if d.Timeout != 0 { 130 | deadline := time.Now().Add(d.Timeout) 131 | if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { 132 | subCtx, cancel := context.WithDeadline(ctx, deadline) 133 | defer cancel() 134 | ctx = subCtx 135 | } 136 | } 137 | if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { 138 | conn.SetDeadline(deadline) 139 | defer conn.SetDeadline(time.Time{}) 140 | } 141 | 142 | err := d.connectAuth(conn) 143 | if err != nil { 144 | return nil, err 145 | } 146 | 147 | switch cmd { 148 | default: 149 | return nil, fmt.Errorf("unsupported Command %s", cmd) 150 | case ConnectCommand: 151 | _, err := d.connectCommand(conn, ConnectCommand, address) 152 | if err != nil { 153 | return nil, err 154 | } 155 | return conn, nil 156 | case BindCommand: 157 | _, err := d.connectCommand(conn, BindCommand, address) 158 | if err != nil { 159 | return nil, err 160 | } 161 | return conn, nil 162 | case AssociateCommand: 163 | targetIP, targetPort, err := splitHostPort(address) 164 | if err != nil { 165 | return nil, err 166 | } 167 | 168 | addr, err := d.connectCommand(conn, AssociateCommand, ":0") 169 | if err != nil { 170 | return nil, err 171 | } 172 | 173 | proxyIP, proxyPort, err := splitHostPort(addr.String()) 174 | if err != nil { 175 | return nil, err 176 | } 177 | 178 | udpConn, err := d.proxyPacketDial(ctx, "udp", ":0") 179 | if err != nil { 180 | return nil, err 181 | } 182 | 183 | targetAddr := &net.UDPAddr{ 184 | IP: net.ParseIP(targetIP), 185 | Port: targetPort, 186 | } 187 | proxyAddr := &net.UDPAddr{ 188 | IP: net.ParseIP(proxyIP), 189 | Port: proxyPort, 190 | } 191 | wrapConn, err := NewUDPConn(udpConn, proxyAddr, targetAddr) 192 | if err != nil { 193 | return nil, err 194 | } 195 | 196 | go func() { 197 | var buf [1]byte 198 | for { 199 | _, err := conn.Read(buf[:]) 200 | if err != nil { 201 | wrapConn.Close() 202 | break 203 | } 204 | } 205 | }() 206 | return wrapConn, nil 207 | } 208 | 209 | } 210 | 211 | func (d *Dialer) connectAuth(conn net.Conn) error { 212 | _, err := conn.Write([]byte{socks5Version}) 213 | if err != nil { 214 | return err 215 | } 216 | if d.Username == "" { 217 | err = writeBytes(conn, []byte{byte(noAuth)}) 218 | if err != nil { 219 | return err 220 | } 221 | } else { 222 | err = writeBytes(conn, []byte{byte(noAuth), byte(userAuth)}) 223 | if err != nil { 224 | return err 225 | } 226 | } 227 | 228 | var header [2]byte 229 | _, err = io.ReadFull(conn, header[:]) 230 | if err != nil { 231 | return err 232 | } 233 | if header[0] != socks5Version { 234 | return fmt.Errorf("unexpected protocol version %d", header[0]) 235 | } 236 | if authMethod(header[1]) == noAcceptable { 237 | return fmt.Errorf("no acceptable authentication methods %d", authMethod(header[1])) 238 | } 239 | switch authMethod(header[1]) { 240 | default: 241 | return fmt.Errorf("authentication method not supported %d", authMethod(header[1])) 242 | case noAuth: 243 | case userAuth: 244 | if d.Username == "" { 245 | return errors.New("need username/password") 246 | } 247 | 248 | if len(d.Username) == 0 || len(d.Username) > 255 || len(d.Password) == 0 || len(d.Password) > 255 { 249 | return errors.New("invalid username/password") 250 | } 251 | _, err = conn.Write([]byte{userAuthVersion}) 252 | if err != nil { 253 | return err 254 | } 255 | err = writeBytes(conn, []byte(d.Username)) 256 | if err != nil { 257 | return err 258 | } 259 | err = writeBytes(conn, []byte(d.Password)) 260 | if err != nil { 261 | return err 262 | } 263 | 264 | _, err := io.ReadFull(conn, header[:]) 265 | if err != nil { 266 | return err 267 | } 268 | if header[0] != userAuthVersion { 269 | return fmt.Errorf("invalid username/password version %d", header[0]) 270 | } 271 | if header[1] != authSuccess { 272 | return fmt.Errorf("username/password authentication failed %d", header[1]) 273 | } 274 | } 275 | return nil 276 | } 277 | 278 | func (d *Dialer) connectCommand(conn net.Conn, cmd Command, address string) (net.Addr, error) { 279 | _, err := conn.Write([]byte{socks5Version, byte(cmd), 0}) 280 | if err != nil { 281 | return nil, err 282 | } 283 | err = writeAddrWithStr(conn, address) 284 | if err != nil { 285 | return nil, err 286 | } 287 | 288 | return d.readReply(conn) 289 | } 290 | 291 | func (d *Dialer) readReply(conn net.Conn) (net.Addr, error) { 292 | var header [3]byte 293 | _, err := io.ReadFull(conn, header[:]) 294 | if err != nil { 295 | return nil, err 296 | } 297 | 298 | if header[0] != socks5Version { 299 | return nil, fmt.Errorf("unexpected protocol version %d", header[0]) 300 | } 301 | 302 | if reply(header[1]) != successReply { 303 | return nil, fmt.Errorf("unknown error %s", reply(header[1]).String()) 304 | } 305 | 306 | return readAddr(conn) 307 | } 308 | 309 | func (d *Dialer) resolver() *net.Resolver { 310 | if d.Resolver == nil { 311 | return net.DefaultResolver 312 | } 313 | return d.Resolver 314 | } 315 | 316 | func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Conn, error) { 317 | proxyDial := d.ProxyDial 318 | if proxyDial == nil { 319 | var dialer net.Dialer 320 | proxyDial = dialer.DialContext 321 | } 322 | return proxyDial(ctx, network, address) 323 | } 324 | 325 | func (d *Dialer) proxyPacketDial(ctx context.Context, network, address string) (net.PacketConn, error) { 326 | proxyPacketDial := d.ProxyPacketDial 327 | if proxyPacketDial == nil { 328 | var listener net.ListenConfig 329 | proxyPacketDial = listener.ListenPacket 330 | } 331 | return proxyPacketDial(ctx, network, address) 332 | } 333 | 334 | type listener struct { 335 | ctx context.Context 336 | d *Dialer 337 | address string 338 | } 339 | 340 | // Accept waits for and returns the next connection to the listener. 341 | func (l *listener) Accept() (net.Conn, error) { 342 | conn, err := l.d.do(l.ctx, BindCommand, l.address) 343 | if err != nil { 344 | return nil, err 345 | } 346 | addr, err := l.d.readReply(conn) 347 | if err != nil { 348 | return nil, err 349 | } 350 | return &connect{Conn: conn, remoteAddr: addr}, nil 351 | } 352 | 353 | // Close closes the listener. 354 | func (l *listener) Close() error { 355 | return nil 356 | } 357 | 358 | // address returns the listener's network address. 359 | func (l *listener) Addr() net.Addr { 360 | return nil 361 | } 362 | 363 | type connect struct { 364 | net.Conn 365 | remoteAddr net.Addr 366 | } 367 | 368 | func (c *connect) RemoteAddr() net.Addr { 369 | return c.remoteAddr 370 | } 371 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package socks5 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | // Server is accepting connections and handling the details of the SOCKS5 protocol 15 | type Server struct { 16 | // Authentication is proxy authentication 17 | Authentication Authentication 18 | // ProxyDial specifies the optional proxyDial function for 19 | // establishing the transport connection. 20 | ProxyDial func(ctx context.Context, network string, address string) (net.Conn, error) 21 | // ProxyListen specifies the optional proxyListen function for 22 | // establishing the transport connection. 23 | ProxyListen func(ctx context.Context, network string, address string) (net.Listener, error) 24 | // ProxyListenPacket specifies the optional proxyListenPacket function for 25 | // establishing the transport connection. 26 | ProxyListenPacket func(ctx context.Context, network string, address string) (net.PacketConn, error) 27 | // PacketForwardAddress specifies the packet forwarding address 28 | PacketForwardAddress func(ctx context.Context, destinationAddr string, packet net.PacketConn, conn net.Conn) (net.IP, int, error) 29 | // ProxyListenBind specifies the optional proxyListenBind function for 30 | // establishing the transport connection. 31 | ProxyListenBind func(ctx context.Context, network string, address string) (net.Listener, error) 32 | // ListenBindReuseTimeout is the timeout for reusing bind listener 33 | ListenBindReuseTimeout time.Duration 34 | // ListenBindAcceptTimeout is the timeout for accepting connections on bind listener 35 | ListenBindAcceptTimeout time.Duration 36 | // reserveListenBind is a pool for reusing bind listeners across requests. 37 | reserveListenBind reserveListen 38 | // Logger error log 39 | Logger Logger 40 | // Context is default context 41 | Context context.Context 42 | // BytesPool getting and returning temporary bytes for use by io.CopyBuffer 43 | BytesPool BytesPool 44 | } 45 | 46 | type Logger interface { 47 | Println(v ...interface{}) 48 | } 49 | 50 | // NewServer creates a new Server 51 | func NewServer() *Server { 52 | return &Server{ 53 | ListenBindReuseTimeout: time.Second / 2, 54 | } 55 | } 56 | 57 | // ListenAndServe is used to create a listener and serve on it 58 | func (s *Server) ListenAndServe(network, addr string) error { 59 | l, err := s.proxyListen(s.context(), network, addr) 60 | if err != nil { 61 | return err 62 | } 63 | return s.Serve(l) 64 | } 65 | 66 | func (s *Server) proxyListen(ctx context.Context, network, address string) (net.Listener, error) { 67 | proxyListen := s.ProxyListen 68 | if proxyListen == nil { 69 | var listenConfig net.ListenConfig 70 | proxyListen = listenConfig.Listen 71 | } 72 | return proxyListen(ctx, network, address) 73 | } 74 | 75 | func (s *Server) proxyListenBind(ctx context.Context, network, address string) (net.Listener, error) { 76 | proxyListenBind := s.ProxyListenBind 77 | if proxyListenBind == nil { 78 | var listenConfig net.ListenConfig 79 | proxyListenBind = listenConfig.Listen 80 | } 81 | return proxyListenBind(ctx, network, address) 82 | } 83 | 84 | // Serve is used to serve connections from a listener 85 | func (s *Server) Serve(l net.Listener) error { 86 | for { 87 | conn, err := l.Accept() 88 | if err != nil { 89 | return err 90 | } 91 | go s.ServeConn(conn) 92 | } 93 | } 94 | 95 | // ServeConn is used to serve a single connection. 96 | func (s *Server) ServeConn(conn net.Conn) { 97 | defer conn.Close() 98 | err := s.serveConn(conn) 99 | if err != nil && s.Logger != nil && !isClosedConnError(err) { 100 | s.Logger.Println(err) 101 | } 102 | } 103 | 104 | func (s *Server) serveConn(conn net.Conn) error { 105 | version, err := readByte(conn) 106 | if err != nil { 107 | return err 108 | } 109 | if version != socks5Version { 110 | return fmt.Errorf("unsupported SOCKS version: %d", version) 111 | } 112 | 113 | req := &request{ 114 | Version: socks5Version, 115 | Conn: conn, 116 | } 117 | 118 | methods, err := readBytes(conn) 119 | if err != nil { 120 | return err 121 | } 122 | 123 | if s.Authentication != nil && bytes.IndexByte(methods, byte(userAuth)) != -1 { 124 | _, err := conn.Write([]byte{socks5Version, byte(userAuth)}) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | header, err := readByte(conn) 130 | if err != nil { 131 | return err 132 | } 133 | if header != userAuthVersion { 134 | return fmt.Errorf("unsupported auth version: %d", header) 135 | } 136 | 137 | username, err := readBytes(conn) 138 | if err != nil { 139 | return err 140 | } 141 | req.Username = string(username) 142 | 143 | password, err := readBytes(conn) 144 | if err != nil { 145 | return err 146 | } 147 | req.Password = string(password) 148 | 149 | if !s.Authentication.Auth(req.Command, req.Username, req.Password) { 150 | _, err := conn.Write([]byte{userAuthVersion, authFailure}) 151 | if err != nil { 152 | return err 153 | } 154 | return errUserAuthFailed 155 | } 156 | _, err = conn.Write([]byte{userAuthVersion, authSuccess}) 157 | if err != nil { 158 | return err 159 | } 160 | } else if s.Authentication == nil && bytes.IndexByte(methods, byte(noAuth)) != -1 { 161 | _, err := conn.Write([]byte{socks5Version, byte(noAuth)}) 162 | if err != nil { 163 | return err 164 | } 165 | } else { 166 | _, err := conn.Write([]byte{socks5Version, byte(noAcceptable)}) 167 | if err != nil { 168 | return err 169 | } 170 | return errNoSupportedAuth 171 | } 172 | 173 | var header [3]byte 174 | _, err = io.ReadFull(conn, header[:]) 175 | if err != nil { 176 | return err 177 | } 178 | 179 | if header[0] != socks5Version { 180 | return fmt.Errorf("unsupported Command version: %d", header[0]) 181 | } 182 | 183 | req.Command = Command(header[1]) 184 | 185 | dest, err := readAddr(conn) 186 | if err != nil { 187 | if err == errUnrecognizedAddrType { 188 | err := sendReply(conn, addrTypeNotSupported, nil) 189 | if err != nil { 190 | return err 191 | } 192 | } 193 | return err 194 | } 195 | req.DestinationAddr = dest 196 | err = s.handle(req) 197 | if err != nil { 198 | return err 199 | } 200 | 201 | return nil 202 | } 203 | 204 | func (s *Server) handle(req *request) error { 205 | switch req.Command { 206 | case ConnectCommand: 207 | return s.handleConnect(req) 208 | case BindCommand: 209 | return s.handleBind(req) 210 | case AssociateCommand: 211 | return s.handleAssociate(req) 212 | default: 213 | if err := sendReply(req.Conn, commandNotSupported, nil); err != nil { 214 | return err 215 | } 216 | return fmt.Errorf("unsupported Command: %v", req.Command) 217 | } 218 | } 219 | 220 | func (s *Server) handleConnect(req *request) error { 221 | ctx := s.context() 222 | target, err := s.proxyDial(ctx, "tcp", req.DestinationAddr.Address()) 223 | if err != nil { 224 | if err := sendReply(req.Conn, errToReply(err), nil); err != nil { 225 | return fmt.Errorf("failed to send reply: %v", err) 226 | } 227 | return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) 228 | } 229 | defer target.Close() 230 | 231 | localAddr := target.LocalAddr() 232 | local, ok := localAddr.(*net.TCPAddr) 233 | if !ok { 234 | return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) 235 | } 236 | bind := address{IP: local.IP, Port: local.Port} 237 | if err := sendReply(req.Conn, successReply, &bind); err != nil { 238 | return fmt.Errorf("failed to send reply: %v", err) 239 | } 240 | 241 | var buf1, buf2 []byte 242 | if s.BytesPool != nil { 243 | buf1 = s.BytesPool.Get() 244 | buf2 = s.BytesPool.Get() 245 | defer func() { 246 | s.BytesPool.Put(buf1) 247 | s.BytesPool.Put(buf2) 248 | }() 249 | } else { 250 | buf1 = make([]byte, 32*1024) 251 | buf2 = make([]byte, 32*1024) 252 | } 253 | return tunnel(ctx, target, req.Conn, buf1, buf2) 254 | } 255 | 256 | func (s *Server) handleBind(req *request) error { 257 | ctx := s.context() 258 | addr := req.DestinationAddr.String() 259 | 260 | var listener net.Listener 261 | var err error 262 | if s.ListenBindReuseTimeout > 0 { 263 | listener, err = s.reserveListenBind.getOrNew(addr, func() (net.Listener, error) { 264 | return s.proxyListenBind(ctx, "tcp", addr) 265 | }, s.ListenBindReuseTimeout, s.ListenBindAcceptTimeout, s.Logger) 266 | } else { 267 | listener, err = s.proxyListenBind(ctx, "tcp", addr) 268 | } 269 | if err != nil { 270 | if err := sendReply(req.Conn, errToReply(err), nil); err != nil { 271 | return fmt.Errorf("failed to send reply: %v", err) 272 | } 273 | return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) 274 | } 275 | 276 | localAddr := listener.Addr() 277 | local, ok := localAddr.(*net.TCPAddr) 278 | if !ok { 279 | listener.Close() 280 | return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) 281 | } 282 | bind := address{IP: local.IP, Port: local.Port} 283 | if err := sendReply(req.Conn, successReply, &bind); err != nil { 284 | listener.Close() 285 | return fmt.Errorf("failed to send reply: %v", err) 286 | } 287 | 288 | conn, err := listener.Accept() 289 | if err != nil { 290 | listener.Close() 291 | if err := sendReply(req.Conn, errToReply(err), nil); err != nil { 292 | return fmt.Errorf("failed to send reply: %v", err) 293 | } 294 | return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) 295 | } 296 | listener.Close() 297 | 298 | remoteAddr := conn.RemoteAddr() 299 | local, ok = remoteAddr.(*net.TCPAddr) 300 | if !ok { 301 | return fmt.Errorf("connect to %v failed: remote address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) 302 | } 303 | bind = address{IP: local.IP, Port: local.Port} 304 | if err := sendReply(req.Conn, successReply, &bind); err != nil { 305 | return fmt.Errorf("failed to send reply: %v", err) 306 | } 307 | 308 | var buf1, buf2 []byte 309 | if s.BytesPool != nil { 310 | buf1 = s.BytesPool.Get() 311 | buf2 = s.BytesPool.Get() 312 | defer func() { 313 | s.BytesPool.Put(buf1) 314 | s.BytesPool.Put(buf2) 315 | }() 316 | } else { 317 | buf1 = make([]byte, 32*1024) 318 | buf2 = make([]byte, 32*1024) 319 | } 320 | return tunnel(ctx, conn, req.Conn, buf1, buf2) 321 | } 322 | 323 | func (s *Server) handleAssociate(req *request) error { 324 | ctx := s.context() 325 | destinationAddr := req.DestinationAddr.String() 326 | udpConn, err := s.proxyListenPacket(ctx, "udp", destinationAddr) 327 | if err != nil { 328 | if err := sendReply(req.Conn, errToReply(err), nil); err != nil { 329 | return fmt.Errorf("failed to send reply: %v", err) 330 | } 331 | return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) 332 | } 333 | defer udpConn.Close() 334 | 335 | replyPacketForwardAddress := defaultReplyPacketForwardAddress 336 | if s.PacketForwardAddress != nil { 337 | replyPacketForwardAddress = s.PacketForwardAddress 338 | } 339 | ip, port, err := replyPacketForwardAddress(ctx, destinationAddr, udpConn, req.Conn) 340 | if err != nil { 341 | return err 342 | } 343 | bind := address{IP: ip, Port: port} 344 | if err := sendReply(req.Conn, successReply, &bind); err != nil { 345 | return fmt.Errorf("failed to send reply: %v", err) 346 | } 347 | 348 | go func() { 349 | var buf [1]byte 350 | for { 351 | _, err := req.Conn.Read(buf[:]) 352 | if err != nil { 353 | udpConn.Close() 354 | break 355 | } 356 | } 357 | }() 358 | 359 | var ( 360 | sourceAddr net.Addr 361 | wantSource string 362 | buf [maxUdpPacket]byte 363 | replyBuf [maxHeaderSize]byte 364 | ) 365 | 366 | for { 367 | n, addr, err := udpConn.ReadFrom(buf[:]) 368 | if err != nil { 369 | return err 370 | } 371 | 372 | if sourceAddr == nil { 373 | sourceAddr = addr 374 | wantSource = sourceAddr.String() 375 | } 376 | 377 | gotAddr := addr.String() 378 | if wantSource == gotAddr { 379 | // Packet from client to target 380 | if n < 3 { 381 | continue 382 | } 383 | reader := bytes.NewBuffer(buf[3:n]) 384 | targetAddr, err := readAddr(reader) 385 | if err != nil { 386 | if s.Logger != nil { 387 | s.Logger.Println(err) 388 | } 389 | continue 390 | } 391 | target := &net.UDPAddr{ 392 | IP: targetAddr.IP, 393 | Port: targetAddr.Port, 394 | } 395 | _, err = udpConn.WriteTo(reader.Bytes(), target) 396 | if err != nil { 397 | return err 398 | } 399 | } else { 400 | headWriter := bytes.NewBuffer(replyBuf[:0]) 401 | headWriter.Write([]byte{0, 0, 0}) 402 | err = writeAddrWithStr(headWriter, gotAddr) 403 | if err != nil { 404 | if s.Logger != nil { 405 | s.Logger.Println(err) 406 | } 407 | continue 408 | } 409 | prefixLen := headWriter.Len() 410 | 411 | // Check if data length plus header exceeds maximum UDP packet limit 412 | if prefixLen+n > maxUdpPacket { 413 | if s.Logger != nil { 414 | s.Logger.Println(fmt.Errorf("dropping packet: data length (%d) + header length (%d) = %d exceeds max UDP packet size %d", n, prefixLen, prefixLen+n, maxUdpPacket)) 415 | } 416 | continue 417 | } 418 | 419 | copy(buf[prefixLen:prefixLen+n], buf[:n]) 420 | copy(buf[:prefixLen], headWriter.Bytes()) 421 | 422 | _, err = udpConn.WriteTo(buf[:prefixLen+n], sourceAddr) 423 | if err != nil { 424 | return err 425 | } 426 | } 427 | } 428 | } 429 | 430 | func (s *Server) proxyDial(ctx context.Context, network, address string) (net.Conn, error) { 431 | proxyDial := s.ProxyDial 432 | if proxyDial == nil { 433 | var dialer net.Dialer 434 | proxyDial = dialer.DialContext 435 | } 436 | return proxyDial(ctx, network, address) 437 | } 438 | 439 | func (s *Server) proxyListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { 440 | proxyListenPacket := s.ProxyListenPacket 441 | if proxyListenPacket == nil { 442 | var listener net.ListenConfig 443 | proxyListenPacket = listener.ListenPacket 444 | } 445 | return proxyListenPacket(ctx, network, address) 446 | } 447 | 448 | func (s *Server) context() context.Context { 449 | if s.Context == nil { 450 | return context.Background() 451 | } 452 | return s.Context 453 | } 454 | 455 | func sendReply(w io.Writer, resp reply, addr *address) error { 456 | _, err := w.Write([]byte{socks5Version, byte(resp), 0}) 457 | if err != nil { 458 | return err 459 | } 460 | err = writeAddr(w, addr) 461 | return err 462 | } 463 | 464 | type request struct { 465 | Version uint8 466 | Command Command 467 | DestinationAddr *address 468 | Username string 469 | Password string 470 | Conn net.Conn 471 | } 472 | 473 | func defaultReplyPacketForwardAddress(ctx context.Context, destinationAddr string, packet net.PacketConn, conn net.Conn) (net.IP, int, error) { 474 | udpLocal := packet.LocalAddr() 475 | udpLocalAddr, ok := udpLocal.(*net.UDPAddr) 476 | if !ok { 477 | return nil, 0, fmt.Errorf("connect to %v failed: local address is %s://%s", destinationAddr, udpLocal.Network(), udpLocal.String()) 478 | } 479 | 480 | tcpLocal := conn.LocalAddr() 481 | tcpLocalAddr, ok := tcpLocal.(*net.TCPAddr) 482 | if !ok { 483 | return nil, 0, fmt.Errorf("connect to %v failed: local address is %s://%s", destinationAddr, tcpLocal.Network(), tcpLocal.String()) 484 | } 485 | return tcpLocalAddr.IP, udpLocalAddr.Port, nil 486 | } 487 | 488 | type reserveListen struct { 489 | mut sync.Mutex 490 | reservedListeners map[string]*reserved 491 | } 492 | 493 | type reserved struct { 494 | key string 495 | base net.Listener 496 | conns chan net.Conn 497 | } 498 | 499 | type holdListener struct { 500 | r *reserved 501 | closed atomic.Bool 502 | } 503 | 504 | func (r *reserveListen) getOrNew(key string, newFunc func() (net.Listener, error), reuse, accept time.Duration, logger Logger) (net.Listener, error) { 505 | r.mut.Lock() 506 | defer r.mut.Unlock() 507 | 508 | reserve := r.reservedListeners[key] 509 | if reserve != nil { 510 | return &holdListener{r: reserve}, nil 511 | } 512 | 513 | listener, err := newFunc() 514 | if err != nil { 515 | return nil, err 516 | } 517 | reserve = &reserved{ 518 | key: key, 519 | base: listener, 520 | conns: make(chan net.Conn), 521 | } 522 | if r.reservedListeners == nil { 523 | r.reservedListeners = map[string]*reserved{} 524 | } 525 | r.reservedListeners[key] = reserve 526 | 527 | if accept > 0 { 528 | _, ok := listener.(setDeadline) 529 | if !ok { 530 | accept = 0 531 | if logger != nil { 532 | logger.Println("reserve bind listener does not support SetDeadline, disabling accept timeout") 533 | } 534 | } 535 | } 536 | go reserve.run(reuse, accept, logger) 537 | return &holdListener{r: reserve}, nil 538 | } 539 | 540 | type setDeadline interface { 541 | SetDeadline(t time.Time) error 542 | } 543 | 544 | func (r *reserved) run(reuse, accept time.Duration, logger Logger) { 545 | defer func() { 546 | r.base.Close() 547 | close(r.conns) 548 | }() 549 | 550 | for { 551 | if accept > 0 { 552 | r.base.(setDeadline).SetDeadline(time.Now().Add(accept)) 553 | } 554 | conn, err := r.base.Accept() 555 | if err != nil { 556 | if logger != nil { 557 | logger.Println("reserve bind listen accept error:", err) 558 | } 559 | return 560 | } 561 | 562 | select { 563 | case r.conns <- conn: 564 | case <-time.After(reuse): 565 | conn.Close() 566 | if logger != nil { 567 | logger.Println("reserve bind listen reuse timeout") 568 | } 569 | return 570 | } 571 | } 572 | } 573 | 574 | func (h *holdListener) Accept() (net.Conn, error) { 575 | if h.closed.Load() { 576 | return nil, net.ErrClosed 577 | } 578 | conn, ok := <-h.r.conns 579 | if !ok { 580 | h.closed.Store(true) 581 | return nil, net.ErrClosed 582 | } 583 | return conn, nil 584 | } 585 | 586 | func (h *holdListener) Close() error { 587 | if h.closed.Swap(true) { 588 | return net.ErrClosed 589 | } 590 | return nil 591 | } 592 | 593 | func (h *holdListener) Addr() net.Addr { 594 | return h.r.base.Addr() 595 | } 596 | --------------------------------------------------------------------------------