├── .github ├── renovate.json ├── update_dependencies.sh └── workflows │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── brutal.go ├── brutal_linux.go ├── brutal_stub.go ├── client.go ├── client_conn.go ├── client_conn_wait.go ├── error.go ├── go.mod ├── go.sum ├── h2mux.go ├── h2mux_conn.go ├── padding.go ├── protocol.go ├── protocol_conn.go ├── server.go ├── server_conn.go └── session.go /.github/renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "commitMessagePrefix": "[dependencies]", 4 | "extends": [ 5 | "config:base", 6 | ":disableRateLimiting" 7 | ], 8 | "golang": { 9 | "enabled": false 10 | }, 11 | "packageRules": [ 12 | { 13 | "matchManagers": [ 14 | "github-actions" 15 | ], 16 | "groupName": "github-actions" 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /.github/update_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PROJECTS=$(dirname "$0")/../.. 4 | go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD) 5 | go mod tidy 6 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - '.github/**' 11 | - '!.github/workflows/lint.yml' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | 17 | jobs: 18 | build: 19 | name: Build 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | - name: Setup Go 27 | uses: actions/setup-go@v5 28 | with: 29 | go-version: ^1.23 30 | - name: Cache go module 31 | uses: actions/cache@v4 32 | with: 33 | path: | 34 | ~/go/pkg/mod 35 | key: go-${{ hashFiles('**/go.sum') }} 36 | - name: golangci-lint 37 | uses: golangci/golangci-lint-action@v6 38 | with: 39 | version: latest 40 | args: . -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - '.github/**' 11 | - '!.github/workflows/debug.yml' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | 17 | jobs: 18 | build: 19 | name: Linux 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | - name: Setup Go 27 | uses: actions/setup-go@v5 28 | with: 29 | go-version: ^1.23 30 | - name: Build 31 | run: | 32 | make test 33 | build_go120: 34 | name: Linux (Go 1.20) 35 | runs-on: ubuntu-latest 36 | steps: 37 | - name: Checkout 38 | uses: actions/checkout@v4 39 | with: 40 | fetch-depth: 0 41 | - name: Setup Go 42 | uses: actions/setup-go@v5 43 | with: 44 | go-version: ~1.20 45 | continue-on-error: true 46 | - name: Build 47 | run: | 48 | make test 49 | build_go121: 50 | name: Linux (Go 1.21) 51 | runs-on: ubuntu-latest 52 | steps: 53 | - name: Checkout 54 | uses: actions/checkout@v4 55 | with: 56 | fetch-depth: 0 57 | - name: Setup Go 58 | uses: actions/setup-go@v5 59 | with: 60 | go-version: ~1.21 61 | continue-on-error: true 62 | - name: Build 63 | run: | 64 | make test 65 | build_go122: 66 | name: Linux (Go 1.22) 67 | runs-on: ubuntu-latest 68 | steps: 69 | - name: Checkout 70 | uses: actions/checkout@v4 71 | with: 72 | fetch-depth: 0 73 | - name: Setup Go 74 | uses: actions/setup-go@v5 75 | with: 76 | go-version: ~1.22 77 | continue-on-error: true 78 | - name: Build 79 | run: | 80 | make test 81 | build_windows: 82 | name: Windows 83 | runs-on: windows-latest 84 | steps: 85 | - name: Checkout 86 | uses: actions/checkout@v4 87 | with: 88 | fetch-depth: 0 89 | - name: Setup Go 90 | uses: actions/setup-go@v5 91 | with: 92 | go-version: ^1.23 93 | continue-on-error: true 94 | - name: Build 95 | run: | 96 | make test 97 | build_darwin: 98 | name: macOS 99 | runs-on: macos-latest 100 | steps: 101 | - name: Checkout 102 | uses: actions/checkout@v4 103 | with: 104 | fetch-depth: 0 105 | - name: Setup Go 106 | uses: actions/setup-go@v5 107 | with: 108 | go-version: ^1.23 109 | continue-on-error: true 110 | - name: Build 111 | run: | 112 | make test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /vendor/ 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - gofumpt 5 | - govet 6 | - gci 7 | - staticcheck 8 | - paralleltest 9 | - ineffassign 10 | 11 | linters-settings: 12 | gci: 13 | custom-order: true 14 | sections: 15 | - standard 16 | - prefix(github.com/sagernet/) 17 | - default 18 | 19 | run: 20 | go: "1.23" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 by nekohasekai 2 | 3 | This program is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | This program is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with this program. If not, see . -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | fmt: 2 | @gofumpt -l -w . 3 | @gofmt -s -w . 4 | @gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" . 5 | 6 | fmt_install: 7 | go install -v mvdan.cc/gofumpt@latest 8 | go install -v github.com/daixiang0/gci@latest 9 | 10 | lint: 11 | GOOS=linux golangci-lint run 12 | GOOS=android golangci-lint run 13 | GOOS=windows golangci-lint run 14 | GOOS=darwin golangci-lint run 15 | GOOS=freebsd golangci-lint run 16 | 17 | lint_install: 18 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 19 | 20 | test: 21 | go test ./... 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sing-mux 2 | 3 | Simple multiplex library. -------------------------------------------------------------------------------- /brutal.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | 7 | "github.com/sagernet/sing/common" 8 | "github.com/sagernet/sing/common/buf" 9 | E "github.com/sagernet/sing/common/exceptions" 10 | "github.com/sagernet/sing/common/varbin" 11 | ) 12 | 13 | const ( 14 | BrutalExchangeDomain = "_BrutalBwExchange" 15 | BrutalMinSpeedBPS = 65536 16 | ) 17 | 18 | func WriteBrutalRequest(writer io.Writer, receiveBPS uint64) error { 19 | return binary.Write(writer, binary.BigEndian, receiveBPS) 20 | } 21 | 22 | func ReadBrutalRequest(reader io.Reader) (uint64, error) { 23 | var receiveBPS uint64 24 | err := binary.Read(reader, binary.BigEndian, &receiveBPS) 25 | return receiveBPS, err 26 | } 27 | 28 | func WriteBrutalResponse(writer io.Writer, receiveBPS uint64, ok bool, message string) error { 29 | buffer := buf.New() 30 | defer buffer.Release() 31 | common.Must(binary.Write(buffer, binary.BigEndian, ok)) 32 | if ok { 33 | common.Must(binary.Write(buffer, binary.BigEndian, receiveBPS)) 34 | } else { 35 | err := varbin.Write(buffer, binary.BigEndian, message) 36 | if err != nil { 37 | return err 38 | } 39 | } 40 | return common.Error(writer.Write(buffer.Bytes())) 41 | } 42 | 43 | func ReadBrutalResponse(reader io.Reader) (uint64, error) { 44 | var ok bool 45 | err := binary.Read(reader, binary.BigEndian, &ok) 46 | if err != nil { 47 | return 0, err 48 | } 49 | if ok { 50 | var receiveBPS uint64 51 | err = binary.Read(reader, binary.BigEndian, &receiveBPS) 52 | return receiveBPS, err 53 | } else { 54 | var message string 55 | message, err = varbin.ReadValue[string](reader, binary.BigEndian) 56 | if err != nil { 57 | return 0, err 58 | } 59 | return 0, E.New("remote error: ", message) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /brutal_linux.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "reflect" 7 | "syscall" 8 | "unsafe" 9 | _ "unsafe" 10 | 11 | "github.com/sagernet/sing/common" 12 | "github.com/sagernet/sing/common/control" 13 | E "github.com/sagernet/sing/common/exceptions" 14 | 15 | "golang.org/x/sys/unix" 16 | ) 17 | 18 | const ( 19 | BrutalAvailable = true 20 | TCP_BRUTAL_PARAMS = 23301 21 | ) 22 | 23 | type TCPBrutalParams struct { 24 | Rate uint64 25 | CwndGain uint32 26 | } 27 | 28 | //go:linkname setsockopt syscall.setsockopt 29 | func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) 30 | 31 | func SetBrutalOptions(conn net.Conn, sendBPS uint64) error { 32 | syscallConn, loaded := common.Cast[syscall.Conn](conn) 33 | if !loaded { 34 | return E.New( 35 | "brutal: nested multiplexing is not supported: ", 36 | "cannot convert ", reflect.TypeOf(conn), " to syscall.Conn, final type: ", reflect.TypeOf(common.Top(conn)), 37 | ) 38 | } 39 | return control.Conn(syscallConn, func(fd uintptr) error { 40 | err := unix.SetsockoptString(int(fd), unix.IPPROTO_TCP, unix.TCP_CONGESTION, "brutal") 41 | if err != nil { 42 | return E.Extend( 43 | os.NewSyscallError("setsockopt IPPROTO_TCP TCP_CONGESTION brutal", err), 44 | "please make sure you have installed the tcp-brutal kernel module", 45 | ) 46 | } 47 | params := TCPBrutalParams{ 48 | Rate: sendBPS, 49 | CwndGain: 20, // hysteria2 default 50 | } 51 | err = setsockopt(int(fd), unix.IPPROTO_TCP, TCP_BRUTAL_PARAMS, unsafe.Pointer(¶ms), unsafe.Sizeof(params)) 52 | if err != nil { 53 | return os.NewSyscallError("setsockopt IPPROTO_TCP TCP_BRUTAL_PARAMS", err) 54 | } 55 | return nil 56 | }) 57 | } 58 | -------------------------------------------------------------------------------- /brutal_stub.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package mux 4 | 5 | import ( 6 | "net" 7 | 8 | E "github.com/sagernet/sing/common/exceptions" 9 | ) 10 | 11 | const BrutalAvailable = false 12 | 13 | func SetBrutalOptions(conn net.Conn, sendBPS uint64) error { 14 | return E.New("TCP Brutal is only supported on Linux") 15 | } 16 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "sync" 7 | 8 | "github.com/sagernet/sing/common" 9 | "github.com/sagernet/sing/common/bufio" 10 | E "github.com/sagernet/sing/common/exceptions" 11 | "github.com/sagernet/sing/common/logger" 12 | M "github.com/sagernet/sing/common/metadata" 13 | N "github.com/sagernet/sing/common/network" 14 | "github.com/sagernet/sing/common/x/list" 15 | ) 16 | 17 | type Client struct { 18 | dialer N.Dialer 19 | logger logger.Logger 20 | protocol byte 21 | maxConnections int 22 | minStreams int 23 | maxStreams int 24 | padding bool 25 | access sync.Mutex 26 | connections list.List[abstractSession] 27 | brutal BrutalOptions 28 | } 29 | 30 | type Options struct { 31 | Dialer N.Dialer 32 | Logger logger.Logger 33 | Protocol string 34 | MaxConnections int 35 | MinStreams int 36 | MaxStreams int 37 | Padding bool 38 | Brutal BrutalOptions 39 | } 40 | 41 | type BrutalOptions struct { 42 | Enabled bool 43 | SendBPS uint64 44 | ReceiveBPS uint64 45 | } 46 | 47 | func NewClient(options Options) (*Client, error) { 48 | client := &Client{ 49 | dialer: options.Dialer, 50 | logger: options.Logger, 51 | maxConnections: options.MaxConnections, 52 | minStreams: options.MinStreams, 53 | maxStreams: options.MaxStreams, 54 | padding: options.Padding, 55 | brutal: options.Brutal, 56 | } 57 | if client.dialer == nil { 58 | client.dialer = N.SystemDialer 59 | } 60 | if client.maxStreams == 0 && client.maxConnections == 0 { 61 | client.minStreams = 8 62 | } 63 | switch options.Protocol { 64 | case "", "h2mux": 65 | client.protocol = ProtocolH2Mux 66 | case "smux": 67 | client.protocol = ProtocolSmux 68 | case "yamux": 69 | client.protocol = ProtocolYAMux 70 | default: 71 | return nil, E.New("unknown protocol: " + options.Protocol) 72 | } 73 | return client, nil 74 | } 75 | 76 | func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 77 | switch N.NetworkName(network) { 78 | case N.NetworkTCP: 79 | stream, err := c.openStream(ctx) 80 | if err != nil { 81 | return nil, err 82 | } 83 | return &clientConn{Conn: stream, destination: destination}, nil 84 | case N.NetworkUDP: 85 | stream, err := c.openStream(ctx) 86 | if err != nil { 87 | return nil, err 88 | } 89 | extendedConn := bufio.NewExtendedConn(stream) 90 | return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil 91 | default: 92 | return nil, E.Extend(N.ErrUnknownNetwork, network) 93 | } 94 | } 95 | 96 | func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 97 | stream, err := c.openStream(ctx) 98 | if err != nil { 99 | return nil, err 100 | } 101 | extendedConn := bufio.NewExtendedConn(stream) 102 | return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil 103 | } 104 | 105 | func (c *Client) openStream(ctx context.Context) (net.Conn, error) { 106 | var ( 107 | session abstractSession 108 | stream net.Conn 109 | err error 110 | ) 111 | for attempts := 0; attempts < 2; attempts++ { 112 | session, err = c.offer(ctx) 113 | if err != nil { 114 | continue 115 | } 116 | stream, err = session.Open() 117 | if err != nil { 118 | continue 119 | } 120 | break 121 | } 122 | if err != nil { 123 | return nil, err 124 | } 125 | return &wrapStream{stream}, nil 126 | } 127 | 128 | func (c *Client) offer(ctx context.Context) (abstractSession, error) { 129 | c.access.Lock() 130 | defer c.access.Unlock() 131 | 132 | var sessions []abstractSession 133 | for element := c.connections.Front(); element != nil; { 134 | if element.Value.IsClosed() { 135 | element.Value.Close() 136 | nextElement := element.Next() 137 | c.connections.Remove(element) 138 | element = nextElement 139 | continue 140 | } 141 | sessions = append(sessions, element.Value) 142 | element = element.Next() 143 | } 144 | if c.brutal.Enabled { 145 | if len(sessions) > 0 { 146 | return sessions[0], nil 147 | } 148 | return c.offerNew(ctx) 149 | } 150 | session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) 151 | if session == nil { 152 | return c.offerNew(ctx) 153 | } 154 | numStreams := session.NumStreams() 155 | if numStreams == 0 { 156 | return session, nil 157 | } 158 | if c.maxConnections > 0 { 159 | if len(sessions) >= c.maxConnections || numStreams < c.minStreams { 160 | return session, nil 161 | } 162 | } else { 163 | if c.maxStreams > 0 && numStreams < c.maxStreams { 164 | return session, nil 165 | } 166 | } 167 | return c.offerNew(ctx) 168 | } 169 | 170 | func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { 171 | ctx, cancel := context.WithTimeout(ctx, TCPTimeout) 172 | defer cancel() 173 | conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) 174 | if err != nil { 175 | return nil, err 176 | } 177 | var version byte 178 | if c.padding { 179 | version = Version1 180 | } else { 181 | version = Version0 182 | } 183 | conn = newProtocolConn(conn, Request{ 184 | Version: version, 185 | Protocol: c.protocol, 186 | Padding: c.padding, 187 | }) 188 | if c.padding { 189 | conn = newPaddingConn(conn) 190 | } 191 | session, err := newClientSession(conn, c.protocol) 192 | if err != nil { 193 | conn.Close() 194 | return nil, err 195 | } 196 | if c.brutal.Enabled { 197 | err = c.brutalExchange(ctx, conn, session) 198 | if err != nil { 199 | conn.Close() 200 | session.Close() 201 | return nil, E.Cause(err, "brutal exchange") 202 | } 203 | } 204 | c.connections.PushBack(session) 205 | return session, nil 206 | } 207 | 208 | func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error { 209 | stream, err := session.Open() 210 | if err != nil { 211 | return err 212 | } 213 | conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}} 214 | err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS) 215 | if err != nil { 216 | return err 217 | } 218 | serverReceiveBPS, err := ReadBrutalResponse(conn) 219 | if err != nil { 220 | return err 221 | } 222 | conn.Close() 223 | sendBPS := c.brutal.SendBPS 224 | if serverReceiveBPS < sendBPS { 225 | sendBPS = serverReceiveBPS 226 | } 227 | clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS) 228 | if clientBrutalErr != nil { 229 | c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client")) 230 | } 231 | return nil 232 | } 233 | 234 | func (c *Client) Reset() { 235 | c.access.Lock() 236 | defer c.access.Unlock() 237 | for _, session := range c.connections.Array() { 238 | session.Close() 239 | } 240 | c.connections.Init() 241 | } 242 | 243 | func (c *Client) Close() error { 244 | c.Reset() 245 | return nil 246 | } 247 | -------------------------------------------------------------------------------- /client_conn.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "net" 7 | "sync" 8 | 9 | "github.com/sagernet/sing/common" 10 | "github.com/sagernet/sing/common/buf" 11 | E "github.com/sagernet/sing/common/exceptions" 12 | M "github.com/sagernet/sing/common/metadata" 13 | N "github.com/sagernet/sing/common/network" 14 | ) 15 | 16 | type clientConn struct { 17 | net.Conn 18 | destination M.Socksaddr 19 | requestWritten bool 20 | responseRead bool 21 | } 22 | 23 | func (c *clientConn) NeedHandshake() bool { 24 | return !c.requestWritten 25 | } 26 | 27 | func (c *clientConn) readResponse() error { 28 | response, err := ReadStreamResponse(c.Conn) 29 | if err != nil { 30 | return err 31 | } 32 | if response.Status == statusError { 33 | return E.New("remote error: ", response.Message) 34 | } 35 | return nil 36 | } 37 | 38 | func (c *clientConn) Read(b []byte) (n int, err error) { 39 | if !c.responseRead { 40 | err = c.readResponse() 41 | if err != nil { 42 | return 43 | } 44 | c.responseRead = true 45 | } 46 | return c.Conn.Read(b) 47 | } 48 | 49 | func (c *clientConn) Write(b []byte) (n int, err error) { 50 | if c.requestWritten { 51 | return c.Conn.Write(b) 52 | } 53 | request := StreamRequest{ 54 | Network: N.NetworkTCP, 55 | Destination: c.destination, 56 | } 57 | buffer := buf.NewSize(streamRequestLen(request) + len(b)) 58 | defer buffer.Release() 59 | err = EncodeStreamRequest(request, buffer) 60 | if err != nil { 61 | return 62 | } 63 | buffer.Write(b) 64 | _, err = c.Conn.Write(buffer.Bytes()) 65 | if err != nil { 66 | return 67 | } 68 | c.requestWritten = true 69 | return len(b), nil 70 | } 71 | 72 | func (c *clientConn) LocalAddr() net.Addr { 73 | return c.Conn.LocalAddr() 74 | } 75 | 76 | func (c *clientConn) RemoteAddr() net.Addr { 77 | return c.destination.TCPAddr() 78 | } 79 | 80 | func (c *clientConn) ReaderReplaceable() bool { 81 | return c.responseRead 82 | } 83 | 84 | func (c *clientConn) WriterReplaceable() bool { 85 | return c.requestWritten 86 | } 87 | 88 | func (c *clientConn) NeedAdditionalReadDeadline() bool { 89 | return true 90 | } 91 | 92 | func (c *clientConn) Upstream() any { 93 | return c.Conn 94 | } 95 | 96 | var _ N.NetPacketConn = (*clientPacketConn)(nil) 97 | 98 | type clientPacketConn struct { 99 | N.AbstractConn 100 | conn N.ExtendedConn 101 | access sync.Mutex 102 | destination M.Socksaddr 103 | requestWritten bool 104 | responseRead bool 105 | readWaitOptions N.ReadWaitOptions 106 | } 107 | 108 | func (c *clientPacketConn) NeedHandshake() bool { 109 | return !c.requestWritten 110 | } 111 | 112 | func (c *clientPacketConn) readResponse() error { 113 | response, err := ReadStreamResponse(c.conn) 114 | if err != nil { 115 | return err 116 | } 117 | if response.Status == statusError { 118 | return E.New("remote error: ", response.Message) 119 | } 120 | return nil 121 | } 122 | 123 | func (c *clientPacketConn) Read(b []byte) (n int, err error) { 124 | if !c.responseRead { 125 | err = c.readResponse() 126 | if err != nil { 127 | return 128 | } 129 | c.responseRead = true 130 | } 131 | var length uint16 132 | err = binary.Read(c.conn, binary.BigEndian, &length) 133 | if err != nil { 134 | return 135 | } 136 | if cap(b) < int(length) { 137 | return 0, io.ErrShortBuffer 138 | } 139 | return io.ReadFull(c.conn, b[:length]) 140 | } 141 | 142 | func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { 143 | request := StreamRequest{ 144 | Network: N.NetworkUDP, 145 | Destination: c.destination, 146 | } 147 | rLen := streamRequestLen(request) 148 | if len(payload) > 0 { 149 | rLen += 2 + len(payload) 150 | } 151 | buffer := buf.NewSize(rLen) 152 | defer buffer.Release() 153 | err = EncodeStreamRequest(request, buffer) 154 | if err != nil { 155 | return 156 | } 157 | if len(payload) > 0 { 158 | common.Must( 159 | binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 160 | common.Error(buffer.Write(payload)), 161 | ) 162 | } 163 | _, err = c.conn.Write(buffer.Bytes()) 164 | if err != nil { 165 | return 166 | } 167 | c.requestWritten = true 168 | return len(payload), nil 169 | } 170 | 171 | func (c *clientPacketConn) Write(b []byte) (n int, err error) { 172 | if !c.requestWritten { 173 | c.access.Lock() 174 | if c.requestWritten { 175 | c.access.Unlock() 176 | } else { 177 | defer c.access.Unlock() 178 | return c.writeRequest(b) 179 | } 180 | } 181 | err = binary.Write(c.conn, binary.BigEndian, uint16(len(b))) 182 | if err != nil { 183 | return 184 | } 185 | return c.conn.Write(b) 186 | } 187 | 188 | func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { 189 | if !c.responseRead { 190 | err = c.readResponse() 191 | if err != nil { 192 | return 193 | } 194 | c.responseRead = true 195 | } 196 | var length uint16 197 | err = binary.Read(c.conn, binary.BigEndian, &length) 198 | if err != nil { 199 | return 200 | } 201 | _, err = buffer.ReadFullFrom(c.conn, int(length)) 202 | return 203 | } 204 | 205 | func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { 206 | if !c.requestWritten { 207 | c.access.Lock() 208 | if c.requestWritten { 209 | c.access.Unlock() 210 | } else { 211 | defer c.access.Unlock() 212 | defer buffer.Release() 213 | return common.Error(c.writeRequest(buffer.Bytes())) 214 | } 215 | } 216 | bLen := buffer.Len() 217 | binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) 218 | return c.conn.WriteBuffer(buffer) 219 | } 220 | 221 | func (c *clientPacketConn) FrontHeadroom() int { 222 | return 2 223 | } 224 | 225 | func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 226 | if !c.responseRead { 227 | err = c.readResponse() 228 | if err != nil { 229 | return 230 | } 231 | c.responseRead = true 232 | } 233 | var length uint16 234 | err = binary.Read(c.conn, binary.BigEndian, &length) 235 | if err != nil { 236 | return 237 | } 238 | if cap(p) < int(length) { 239 | return 0, nil, io.ErrShortBuffer 240 | } 241 | n, err = io.ReadFull(c.conn, p[:length]) 242 | return 243 | } 244 | 245 | func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 246 | if !c.requestWritten { 247 | c.access.Lock() 248 | if c.requestWritten { 249 | c.access.Unlock() 250 | } else { 251 | defer c.access.Unlock() 252 | return c.writeRequest(p) 253 | } 254 | } 255 | err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) 256 | if err != nil { 257 | return 258 | } 259 | return c.conn.Write(p) 260 | } 261 | 262 | func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 263 | err = c.ReadBuffer(buffer) 264 | return 265 | } 266 | 267 | func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 268 | return c.WriteBuffer(buffer) 269 | } 270 | 271 | func (c *clientPacketConn) LocalAddr() net.Addr { 272 | return c.conn.LocalAddr() 273 | } 274 | 275 | func (c *clientPacketConn) RemoteAddr() net.Addr { 276 | return c.destination.UDPAddr() 277 | } 278 | 279 | func (c *clientPacketConn) NeedAdditionalReadDeadline() bool { 280 | return true 281 | } 282 | 283 | func (c *clientPacketConn) Upstream() any { 284 | return c.conn 285 | } 286 | 287 | var _ N.NetPacketConn = (*clientPacketAddrConn)(nil) 288 | 289 | type clientPacketAddrConn struct { 290 | N.AbstractConn 291 | conn N.ExtendedConn 292 | access sync.Mutex 293 | destination M.Socksaddr 294 | requestWritten bool 295 | responseRead bool 296 | readWaitOptions N.ReadWaitOptions 297 | } 298 | 299 | func (c *clientPacketAddrConn) NeedHandshake() bool { 300 | return !c.requestWritten 301 | } 302 | 303 | func (c *clientPacketAddrConn) readResponse() error { 304 | response, err := ReadStreamResponse(c.conn) 305 | if err != nil { 306 | return err 307 | } 308 | if response.Status == statusError { 309 | return E.New("remote error: ", response.Message) 310 | } 311 | return nil 312 | } 313 | 314 | func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 315 | if !c.responseRead { 316 | err = c.readResponse() 317 | if err != nil { 318 | return 319 | } 320 | c.responseRead = true 321 | } 322 | destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn) 323 | if err != nil { 324 | return 325 | } 326 | if destination.IsFqdn() { 327 | addr = destination 328 | } else { 329 | addr = destination.UDPAddr() 330 | } 331 | var length uint16 332 | err = binary.Read(c.conn, binary.BigEndian, &length) 333 | if err != nil { 334 | return 335 | } 336 | if cap(p) < int(length) { 337 | return 0, nil, io.ErrShortBuffer 338 | } 339 | n, err = io.ReadFull(c.conn, p[:length]) 340 | return 341 | } 342 | 343 | func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { 344 | request := StreamRequest{ 345 | Network: N.NetworkUDP, 346 | Destination: c.destination, 347 | PacketAddr: true, 348 | } 349 | rLen := streamRequestLen(request) 350 | if len(payload) > 0 { 351 | rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) 352 | } 353 | buffer := buf.NewSize(rLen) 354 | defer buffer.Release() 355 | err = EncodeStreamRequest(request, buffer) 356 | if err != nil { 357 | return 358 | } 359 | if len(payload) > 0 { 360 | err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 361 | if err != nil { 362 | return 363 | } 364 | common.Must( 365 | binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 366 | common.Error(buffer.Write(payload)), 367 | ) 368 | } 369 | _, err = c.conn.Write(buffer.Bytes()) 370 | if err != nil { 371 | return 372 | } 373 | c.requestWritten = true 374 | return len(payload), nil 375 | } 376 | 377 | func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 378 | if !c.requestWritten { 379 | c.access.Lock() 380 | if c.requestWritten { 381 | c.access.Unlock() 382 | } else { 383 | defer c.access.Unlock() 384 | return c.writeRequest(p, M.SocksaddrFromNet(addr)) 385 | } 386 | } 387 | err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr)) 388 | if err != nil { 389 | return 390 | } 391 | err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) 392 | if err != nil { 393 | return 394 | } 395 | return c.conn.Write(p) 396 | } 397 | 398 | func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 399 | if !c.responseRead { 400 | err = c.readResponse() 401 | if err != nil { 402 | return 403 | } 404 | c.responseRead = true 405 | } 406 | destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn) 407 | if err != nil { 408 | return 409 | } 410 | var length uint16 411 | err = binary.Read(c.conn, binary.BigEndian, &length) 412 | if err != nil { 413 | return 414 | } 415 | _, err = buffer.ReadFullFrom(c.conn, int(length)) 416 | return 417 | } 418 | 419 | func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 420 | if !c.requestWritten { 421 | c.access.Lock() 422 | if c.requestWritten { 423 | c.access.Unlock() 424 | } else { 425 | defer c.access.Unlock() 426 | defer buffer.Release() 427 | return common.Error(c.writeRequest(buffer.Bytes(), destination)) 428 | } 429 | } 430 | bLen := buffer.Len() 431 | header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) 432 | err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 433 | if err != nil { 434 | return err 435 | } 436 | common.Must(binary.Write(header, binary.BigEndian, uint16(bLen))) 437 | return c.conn.WriteBuffer(buffer) 438 | } 439 | 440 | func (c *clientPacketAddrConn) LocalAddr() net.Addr { 441 | return c.conn.LocalAddr() 442 | } 443 | 444 | func (c *clientPacketAddrConn) FrontHeadroom() int { 445 | return 2 + M.MaxSocksaddrLength 446 | } 447 | 448 | func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool { 449 | return true 450 | } 451 | 452 | func (c *clientPacketAddrConn) Upstream() any { 453 | return c.conn 454 | } 455 | -------------------------------------------------------------------------------- /client_conn_wait.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | 6 | "github.com/sagernet/sing/common/buf" 7 | M "github.com/sagernet/sing/common/metadata" 8 | N "github.com/sagernet/sing/common/network" 9 | ) 10 | 11 | var _ N.PacketReadWaiter = (*clientPacketConn)(nil) 12 | 13 | func (c *clientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 14 | c.readWaitOptions = options 15 | return false 16 | } 17 | 18 | func (c *clientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 19 | if !c.responseRead { 20 | err = c.readResponse() 21 | if err != nil { 22 | return 23 | } 24 | c.responseRead = true 25 | } 26 | var length uint16 27 | err = binary.Read(c.conn, binary.BigEndian, &length) 28 | if err != nil { 29 | return 30 | } 31 | buffer = c.readWaitOptions.NewPacketBuffer() 32 | _, err = buffer.ReadFullFrom(c.conn, int(length)) 33 | if err != nil { 34 | buffer.Release() 35 | return nil, M.Socksaddr{}, err 36 | } 37 | c.readWaitOptions.PostReturn(buffer) 38 | return 39 | } 40 | 41 | var _ N.PacketReadWaiter = (*clientPacketAddrConn)(nil) 42 | 43 | func (c *clientPacketAddrConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 44 | c.readWaitOptions = options 45 | return false 46 | } 47 | 48 | func (c *clientPacketAddrConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 49 | if !c.responseRead { 50 | err = c.readResponse() 51 | if err != nil { 52 | return 53 | } 54 | c.responseRead = true 55 | } 56 | destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn) 57 | if err != nil { 58 | return 59 | } 60 | var length uint16 61 | err = binary.Read(c.conn, binary.BigEndian, &length) 62 | if err != nil { 63 | return 64 | } 65 | buffer = c.readWaitOptions.NewPacketBuffer() 66 | _, err = buffer.ReadFullFrom(c.conn, int(length)) 67 | if err != nil { 68 | buffer.Release() 69 | return nil, M.Socksaddr{}, err 70 | } 71 | c.readWaitOptions.PostReturn(buffer) 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "io" 5 | "net" 6 | 7 | "github.com/hashicorp/yamux" 8 | ) 9 | 10 | type wrapStream struct { 11 | net.Conn 12 | } 13 | 14 | func (w *wrapStream) Read(p []byte) (n int, err error) { 15 | n, err = w.Conn.Read(p) 16 | err = wrapError(err) 17 | return 18 | } 19 | 20 | func (w *wrapStream) Write(p []byte) (n int, err error) { 21 | n, err = w.Conn.Write(p) 22 | err = wrapError(err) 23 | return 24 | } 25 | 26 | func (w *wrapStream) Upstream() any { 27 | return w.Conn 28 | } 29 | 30 | func wrapError(err error) error { 31 | switch err { 32 | case yamux.ErrStreamClosed: 33 | return io.EOF 34 | default: 35 | return err 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sagernet/sing-mux 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/hashicorp/yamux v0.1.2 7 | github.com/sagernet/sing v0.6.7 8 | github.com/sagernet/smux v1.5.34-mod.1 9 | golang.org/x/net v0.34.0 10 | golang.org/x/sys v0.30.0 11 | ) 12 | 13 | require golang.org/x/text v0.21.0 // indirect 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= 5 | github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= 9 | github.com/sagernet/sing v0.6.0 h1:jT55zAXrG7H3x+s/FlrC15xQy3LcmuZ2GGA9+8IJdt0= 10 | github.com/sagernet/sing v0.6.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= 11 | github.com/sagernet/sing v0.6.7 h1:NIWBLZ9AUWDXAQBKGleKwsitbQrI9M0nqoheXhUKnrI= 12 | github.com/sagernet/sing v0.6.7/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= 13 | github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= 14 | github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo= 15 | github.com/sagernet/smux v1.5.34-mod.1 h1:xZljMK3fVOX4HC+ND1N7eOiweqEa9bxRTKlliqe9DJE= 16 | github.com/sagernet/smux v1.5.34-mod.1/go.mod h1:qI3fpNiLZmwrh83DmbJHX7sAsc2R/gbqdWw0/WzciU0= 17 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 18 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 19 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 20 | github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= 21 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 22 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 23 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 24 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 25 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 26 | golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= 27 | golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= 28 | golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 29 | golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 30 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 31 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 32 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 33 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 34 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 35 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 37 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 38 | -------------------------------------------------------------------------------- /h2mux.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "io" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "sync" 12 | "time" 13 | 14 | "github.com/sagernet/sing/common/atomic" 15 | "github.com/sagernet/sing/common/buf" 16 | "github.com/sagernet/sing/common/bufio" 17 | E "github.com/sagernet/sing/common/exceptions" 18 | N "github.com/sagernet/sing/common/network" 19 | 20 | "golang.org/x/net/http2" 21 | ) 22 | 23 | const idleTimeout = 30 * time.Second 24 | 25 | var _ abstractSession = (*h2MuxServerSession)(nil) 26 | 27 | type h2MuxServerSession struct { 28 | server http2.Server 29 | active atomic.Int32 30 | conn net.Conn 31 | inbound chan net.Conn 32 | done chan struct{} 33 | } 34 | 35 | func newH2MuxServer(conn net.Conn) *h2MuxServerSession { 36 | session := &h2MuxServerSession{ 37 | conn: conn, 38 | inbound: make(chan net.Conn), 39 | done: make(chan struct{}), 40 | server: http2.Server{ 41 | IdleTimeout: idleTimeout, 42 | MaxReadFrameSize: buf.BufferSize, 43 | }, 44 | } 45 | go func() { 46 | session.server.ServeConn(conn, &http2.ServeConnOpts{ 47 | Handler: session, 48 | }) 49 | _ = session.Close() 50 | }() 51 | return session 52 | } 53 | 54 | func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 55 | s.active.Add(1) 56 | defer s.active.Add(-1) 57 | writer.WriteHeader(http.StatusOK) 58 | conn := newHTTP2Wrapper(newHTTPConn(request.Body, writer), writer.(http.Flusher)) 59 | s.inbound <- conn 60 | select { 61 | case <-conn.done: 62 | case <-s.done: 63 | _ = conn.Close() 64 | } 65 | } 66 | 67 | func (s *h2MuxServerSession) Open() (net.Conn, error) { 68 | return nil, os.ErrInvalid 69 | } 70 | 71 | func (s *h2MuxServerSession) Accept() (net.Conn, error) { 72 | select { 73 | case conn := <-s.inbound: 74 | return conn, nil 75 | case <-s.done: 76 | return nil, os.ErrClosed 77 | } 78 | } 79 | 80 | func (s *h2MuxServerSession) NumStreams() int { 81 | return int(s.active.Load()) 82 | } 83 | 84 | func (s *h2MuxServerSession) Close() error { 85 | select { 86 | case <-s.done: 87 | default: 88 | close(s.done) 89 | } 90 | return s.conn.Close() 91 | } 92 | 93 | func (s *h2MuxServerSession) IsClosed() bool { 94 | select { 95 | case <-s.done: 96 | return true 97 | default: 98 | return false 99 | } 100 | } 101 | 102 | func (s *h2MuxServerSession) CanTakeNewRequest() bool { 103 | return false 104 | } 105 | 106 | type h2MuxConnWrapper struct { 107 | N.ExtendedConn 108 | flusher http.Flusher 109 | access sync.Mutex 110 | closed bool 111 | done chan struct{} 112 | } 113 | 114 | func newHTTP2Wrapper(conn net.Conn, flusher http.Flusher) *h2MuxConnWrapper { 115 | return &h2MuxConnWrapper{ 116 | ExtendedConn: bufio.NewExtendedConn(conn), 117 | flusher: flusher, 118 | done: make(chan struct{}), 119 | } 120 | } 121 | 122 | func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) { 123 | w.access.Lock() 124 | defer w.access.Unlock() 125 | if w.closed { 126 | return 0, net.ErrClosed 127 | } 128 | n, err = w.ExtendedConn.Write(p) 129 | if err == nil { 130 | w.flusher.Flush() 131 | } 132 | return 133 | } 134 | 135 | func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error { 136 | w.access.Lock() 137 | defer w.access.Unlock() 138 | if w.closed { 139 | return net.ErrClosed 140 | } 141 | err := w.ExtendedConn.WriteBuffer(buffer) 142 | if err == nil { 143 | w.flusher.Flush() 144 | } 145 | return err 146 | } 147 | 148 | func (w *h2MuxConnWrapper) Close() error { 149 | w.access.Lock() 150 | select { 151 | case <-w.done: 152 | default: 153 | close(w.done) 154 | } 155 | w.closed = true 156 | w.access.Unlock() 157 | return w.ExtendedConn.Close() 158 | } 159 | 160 | func (w *h2MuxConnWrapper) Upstream() any { 161 | return w.ExtendedConn 162 | } 163 | 164 | var _ abstractSession = (*h2MuxClientSession)(nil) 165 | 166 | type h2MuxClientSession struct { 167 | transport *http2.Transport 168 | clientConn *http2.ClientConn 169 | access sync.RWMutex 170 | closed bool 171 | } 172 | 173 | func newH2MuxClient(conn net.Conn) (*h2MuxClientSession, error) { 174 | session := &h2MuxClientSession{ 175 | transport: &http2.Transport{ 176 | DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { 177 | return conn, nil 178 | }, 179 | ReadIdleTimeout: idleTimeout, 180 | MaxReadFrameSize: buf.BufferSize, 181 | }, 182 | } 183 | session.transport.ConnPool = session 184 | clientConn, err := session.transport.NewClientConn(conn) 185 | if err != nil { 186 | return nil, err 187 | } 188 | session.clientConn = clientConn 189 | return session, nil 190 | } 191 | 192 | func (s *h2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) { 193 | return s.clientConn, nil 194 | } 195 | 196 | func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) { 197 | s.Close() 198 | } 199 | 200 | func (s *h2MuxClientSession) Open() (net.Conn, error) { 201 | pipeInReader, pipeInWriter := io.Pipe() 202 | request := &http.Request{ 203 | Method: http.MethodConnect, 204 | Body: pipeInReader, 205 | URL: &url.URL{Scheme: "https", Host: "localhost"}, 206 | } 207 | connCtx, cancel := context.WithCancel(context.Background()) 208 | request = request.WithContext(connCtx) 209 | conn := newLateHTTPConn(pipeInWriter, cancel) 210 | requestDone := make(chan struct{}) 211 | go func() { 212 | select { 213 | case <-requestDone: 214 | return 215 | case <-time.After(TCPTimeout): 216 | cancel() 217 | } 218 | }() 219 | go func() { 220 | response, err := s.transport.RoundTrip(request) 221 | close(requestDone) 222 | if err != nil { 223 | conn.setup(nil, err) 224 | } else if response.StatusCode != 200 { 225 | response.Body.Close() 226 | conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status)) 227 | } else { 228 | conn.setup(response.Body, nil) 229 | } 230 | }() 231 | return conn, nil 232 | } 233 | 234 | func (s *h2MuxClientSession) Accept() (net.Conn, error) { 235 | return nil, os.ErrInvalid 236 | } 237 | 238 | func (s *h2MuxClientSession) NumStreams() int { 239 | return s.clientConn.State().StreamsActive 240 | } 241 | 242 | func (s *h2MuxClientSession) Close() error { 243 | s.access.Lock() 244 | defer s.access.Unlock() 245 | if s.closed { 246 | return os.ErrClosed 247 | } 248 | s.closed = true 249 | return s.clientConn.Close() 250 | } 251 | 252 | func (s *h2MuxClientSession) IsClosed() bool { 253 | s.access.RLock() 254 | defer s.access.RUnlock() 255 | return s.closed || s.clientConn.State().Closed 256 | } 257 | 258 | func (s *h2MuxClientSession) CanTakeNewRequest() bool { 259 | return s.clientConn.CanTakeNewRequest() 260 | } 261 | -------------------------------------------------------------------------------- /h2mux_conn.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "os" 8 | "time" 9 | 10 | "github.com/sagernet/sing/common" 11 | "github.com/sagernet/sing/common/baderror" 12 | M "github.com/sagernet/sing/common/metadata" 13 | ) 14 | 15 | type httpConn struct { 16 | reader io.Reader 17 | writer io.Writer 18 | create chan struct{} 19 | err error 20 | cancel context.CancelFunc 21 | } 22 | 23 | func newHTTPConn(reader io.Reader, writer io.Writer) *httpConn { 24 | return &httpConn{ 25 | reader: reader, 26 | writer: writer, 27 | } 28 | } 29 | 30 | func newLateHTTPConn(writer io.Writer, cancel context.CancelFunc) *httpConn { 31 | return &httpConn{ 32 | create: make(chan struct{}), 33 | writer: writer, 34 | cancel: cancel, 35 | } 36 | } 37 | 38 | func (c *httpConn) setup(reader io.Reader, err error) { 39 | c.reader = reader 40 | c.err = err 41 | close(c.create) 42 | } 43 | 44 | func (c *httpConn) Read(b []byte) (n int, err error) { 45 | if c.reader == nil { 46 | <-c.create 47 | if c.err != nil { 48 | return 0, c.err 49 | } 50 | } 51 | n, err = c.reader.Read(b) 52 | return n, baderror.WrapH2(err) 53 | } 54 | 55 | func (c *httpConn) Write(b []byte) (n int, err error) { 56 | n, err = c.writer.Write(b) 57 | return n, baderror.WrapH2(err) 58 | } 59 | 60 | func (c *httpConn) Close() error { 61 | if c.cancel != nil { 62 | c.cancel() 63 | } 64 | return common.Close(c.reader, c.writer) 65 | } 66 | 67 | func (c *httpConn) LocalAddr() net.Addr { 68 | return M.Socksaddr{} 69 | } 70 | 71 | func (c *httpConn) RemoteAddr() net.Addr { 72 | return M.Socksaddr{} 73 | } 74 | 75 | func (c *httpConn) SetDeadline(t time.Time) error { 76 | return os.ErrInvalid 77 | } 78 | 79 | func (c *httpConn) SetReadDeadline(t time.Time) error { 80 | return os.ErrInvalid 81 | } 82 | 83 | func (c *httpConn) SetWriteDeadline(t time.Time) error { 84 | return os.ErrInvalid 85 | } 86 | 87 | func (c *httpConn) NeedAdditionalReadDeadline() bool { 88 | return true 89 | } 90 | -------------------------------------------------------------------------------- /padding.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "math/rand" 7 | "net" 8 | 9 | "github.com/sagernet/sing/common" 10 | "github.com/sagernet/sing/common/buf" 11 | "github.com/sagernet/sing/common/bufio" 12 | N "github.com/sagernet/sing/common/network" 13 | "github.com/sagernet/sing/common/rw" 14 | ) 15 | 16 | const kFirstPaddings = 16 17 | 18 | type paddingConn struct { 19 | N.ExtendedConn 20 | writer N.VectorisedWriter 21 | readPadding int 22 | writePadding int 23 | readRemaining int 24 | paddingRemaining int 25 | } 26 | 27 | func newPaddingConn(conn net.Conn) net.Conn { 28 | writer, isVectorised := bufio.CreateVectorisedWriter(conn) 29 | if isVectorised { 30 | return &vectorisedPaddingConn{ 31 | paddingConn{ 32 | ExtendedConn: bufio.NewExtendedConn(conn), 33 | writer: bufio.NewVectorisedWriter(conn), 34 | }, 35 | writer, 36 | } 37 | } else { 38 | return &paddingConn{ 39 | ExtendedConn: bufio.NewExtendedConn(conn), 40 | writer: bufio.NewVectorisedWriter(conn), 41 | } 42 | } 43 | } 44 | 45 | func (c *paddingConn) Read(p []byte) (n int, err error) { 46 | if c.readRemaining > 0 { 47 | if len(p) > c.readRemaining { 48 | p = p[:c.readRemaining] 49 | } 50 | n, err = c.ExtendedConn.Read(p) 51 | if err != nil { 52 | return 53 | } 54 | c.readRemaining -= n 55 | return 56 | } 57 | if c.paddingRemaining > 0 { 58 | err = rw.SkipN(c.ExtendedConn, c.paddingRemaining) 59 | if err != nil { 60 | return 61 | } 62 | c.paddingRemaining = 0 63 | } 64 | if c.readPadding < kFirstPaddings { 65 | var paddingHdr []byte 66 | if len(p) >= 4 { 67 | paddingHdr = p[:4] 68 | } else { 69 | paddingHdr = make([]byte, 4) 70 | } 71 | _, err = io.ReadFull(c.ExtendedConn, paddingHdr) 72 | if err != nil { 73 | return 74 | } 75 | originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) 76 | paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) 77 | if len(p) > originalDataSize { 78 | p = p[:originalDataSize] 79 | } 80 | n, err = c.ExtendedConn.Read(p) 81 | if err != nil { 82 | return 83 | } 84 | c.readPadding++ 85 | c.readRemaining = originalDataSize - n 86 | c.paddingRemaining = paddingLen 87 | return 88 | } 89 | return c.ExtendedConn.Read(p) 90 | } 91 | 92 | func (c *paddingConn) Write(p []byte) (n int, err error) { 93 | for pLen := len(p); pLen > 0; { 94 | var data []byte 95 | if pLen > 65535 { 96 | data = p[:65535] 97 | p = p[65535:] 98 | pLen -= 65535 99 | } else { 100 | data = p 101 | pLen = 0 102 | } 103 | var writeN int 104 | writeN, err = c.write(data) 105 | n += writeN 106 | if err != nil { 107 | break 108 | } 109 | } 110 | return n, err 111 | } 112 | 113 | func (c *paddingConn) write(p []byte) (n int, err error) { 114 | if c.writePadding < kFirstPaddings { 115 | paddingLen := 256 + rand.Intn(512) 116 | buffer := buf.NewSize(4 + len(p) + paddingLen) 117 | defer buffer.Release() 118 | header := buffer.Extend(4) 119 | binary.BigEndian.PutUint16(header[:2], uint16(len(p))) 120 | binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) 121 | common.Must1(buffer.Write(p)) 122 | buffer.Extend(paddingLen) 123 | _, err = c.ExtendedConn.Write(buffer.Bytes()) 124 | if err == nil { 125 | n = len(p) 126 | } 127 | c.writePadding++ 128 | return 129 | } 130 | return c.ExtendedConn.Write(p) 131 | } 132 | 133 | func (c *paddingConn) ReadBuffer(buffer *buf.Buffer) error { 134 | p := buffer.FreeBytes() 135 | if c.readRemaining > 0 { 136 | if len(p) > c.readRemaining { 137 | p = p[:c.readRemaining] 138 | } 139 | n, err := c.ExtendedConn.Read(p) 140 | if err != nil { 141 | return err 142 | } 143 | c.readRemaining -= n 144 | buffer.Truncate(n) 145 | return nil 146 | } 147 | if c.paddingRemaining > 0 { 148 | err := rw.SkipN(c.ExtendedConn, c.paddingRemaining) 149 | if err != nil { 150 | return err 151 | } 152 | c.paddingRemaining = 0 153 | } 154 | if c.readPadding < kFirstPaddings { 155 | var paddingHdr []byte 156 | if len(p) >= 4 { 157 | paddingHdr = p[:4] 158 | } else { 159 | paddingHdr = make([]byte, 4) 160 | } 161 | _, err := io.ReadFull(c.ExtendedConn, paddingHdr) 162 | if err != nil { 163 | return err 164 | } 165 | originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) 166 | paddingLen := int(binary.BigEndian.Uint16(paddingHdr[2:])) 167 | 168 | if len(p) > originalDataSize { 169 | p = p[:originalDataSize] 170 | } 171 | n, err := c.ExtendedConn.Read(p) 172 | if err != nil { 173 | return err 174 | } 175 | c.readPadding++ 176 | c.readRemaining = originalDataSize - n 177 | c.paddingRemaining = paddingLen 178 | buffer.Truncate(n) 179 | return nil 180 | } 181 | return c.ExtendedConn.ReadBuffer(buffer) 182 | } 183 | 184 | func (c *paddingConn) WriteBuffer(buffer *buf.Buffer) error { 185 | if c.writePadding < kFirstPaddings { 186 | bufferLen := buffer.Len() 187 | if bufferLen > 65535 { 188 | return common.Error(c.Write(buffer.Bytes())) 189 | } 190 | paddingLen := 256 + rand.Intn(512) 191 | header := buffer.ExtendHeader(4) 192 | binary.BigEndian.PutUint16(header[:2], uint16(bufferLen)) 193 | binary.BigEndian.PutUint16(header[2:], uint16(paddingLen)) 194 | buffer.Extend(paddingLen) 195 | c.writePadding++ 196 | } 197 | return c.ExtendedConn.WriteBuffer(buffer) 198 | } 199 | 200 | func (c *paddingConn) FrontHeadroom() int { 201 | return 4 + 256 + 1024 202 | } 203 | 204 | func (c *paddingConn) Upstream() any { 205 | return c.ExtendedConn 206 | } 207 | 208 | type vectorisedPaddingConn struct { 209 | paddingConn 210 | writer N.VectorisedWriter 211 | } 212 | 213 | func (c *vectorisedPaddingConn) WriteVectorised(buffers []*buf.Buffer) error { 214 | if c.writePadding < kFirstPaddings { 215 | bufferLen := buf.LenMulti(buffers) 216 | if bufferLen > 65535 { 217 | defer buf.ReleaseMulti(buffers) 218 | for _, buffer := range buffers { 219 | _, err := c.Write(buffer.Bytes()) 220 | if err != nil { 221 | return err 222 | } 223 | } 224 | return nil 225 | } 226 | paddingLen := 256 + rand.Intn(512) 227 | header := buf.NewSize(4) 228 | common.Must( 229 | binary.Write(header, binary.BigEndian, uint16(bufferLen)), 230 | binary.Write(header, binary.BigEndian, uint16(paddingLen)), 231 | ) 232 | c.writePadding++ 233 | padding := buf.NewSize(paddingLen) 234 | padding.Extend(paddingLen) 235 | buffers = append(append([]*buf.Buffer{header}, buffers...), padding) 236 | } 237 | return c.writer.WriteVectorised(buffers) 238 | } 239 | -------------------------------------------------------------------------------- /protocol.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/sagernet/sing/common" 10 | "github.com/sagernet/sing/common/buf" 11 | E "github.com/sagernet/sing/common/exceptions" 12 | M "github.com/sagernet/sing/common/metadata" 13 | N "github.com/sagernet/sing/common/network" 14 | "github.com/sagernet/sing/common/rw" 15 | "github.com/sagernet/sing/common/varbin" 16 | ) 17 | 18 | const ( 19 | ProtocolSmux = iota 20 | ProtocolYAMux 21 | ProtocolH2Mux 22 | ) 23 | 24 | const ( 25 | Version0 = iota 26 | Version1 27 | ) 28 | 29 | const ( 30 | TCPTimeout = 5 * time.Second 31 | ) 32 | 33 | var Destination = M.Socksaddr{ 34 | Fqdn: "sp.mux.sing-box.arpa", 35 | Port: 444, 36 | } 37 | 38 | type Request struct { 39 | Version byte 40 | Protocol byte 41 | Padding bool 42 | } 43 | 44 | func ReadRequest(reader io.Reader) (*Request, error) { 45 | var ( 46 | version byte 47 | protocol byte 48 | ) 49 | err := binary.Read(reader, binary.BigEndian, &version) 50 | if err != nil { 51 | return nil, err 52 | } 53 | if version < Version0 || version > Version1 { 54 | return nil, E.New("unsupported version: ", version) 55 | } 56 | err = binary.Read(reader, binary.BigEndian, &protocol) 57 | if err != nil { 58 | return nil, err 59 | } 60 | var paddingEnabled bool 61 | if version == Version1 { 62 | err = binary.Read(reader, binary.BigEndian, &paddingEnabled) 63 | if err != nil { 64 | return nil, err 65 | } 66 | if paddingEnabled { 67 | var paddingLen uint16 68 | err = binary.Read(reader, binary.BigEndian, &paddingLen) 69 | if err != nil { 70 | return nil, err 71 | } 72 | err = rw.SkipN(reader, int(paddingLen)) 73 | if err != nil { 74 | return nil, err 75 | } 76 | } 77 | } 78 | return &Request{Version: version, Protocol: protocol, Padding: paddingEnabled}, nil 79 | } 80 | 81 | func EncodeRequest(request Request, payload []byte) *buf.Buffer { 82 | var requestLen int 83 | requestLen += 2 84 | var paddingLen uint16 85 | if request.Version == Version1 { 86 | requestLen += 1 87 | if request.Padding { 88 | requestLen += 2 89 | paddingLen = uint16(256 + rand.Intn(512)) 90 | requestLen += int(paddingLen) 91 | } 92 | } 93 | buffer := buf.NewSize(requestLen + len(payload)) 94 | common.Must( 95 | buffer.WriteByte(request.Version), 96 | buffer.WriteByte(request.Protocol), 97 | ) 98 | if request.Version == Version1 { 99 | common.Must(binary.Write(buffer, binary.BigEndian, request.Padding)) 100 | if request.Padding { 101 | common.Must(binary.Write(buffer, binary.BigEndian, paddingLen)) 102 | buffer.Extend(int(paddingLen)) 103 | } 104 | } 105 | common.Must1(buffer.Write(payload)) 106 | return buffer 107 | } 108 | 109 | const ( 110 | flagUDP = 1 111 | flagAddr = 2 112 | statusSuccess = 0 113 | statusError = 1 114 | ) 115 | 116 | type StreamRequest struct { 117 | Network string 118 | Destination M.Socksaddr 119 | PacketAddr bool 120 | } 121 | 122 | func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) { 123 | var flags uint16 124 | err := binary.Read(reader, binary.BigEndian, &flags) 125 | if err != nil { 126 | return nil, err 127 | } 128 | destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 129 | if err != nil { 130 | return nil, err 131 | } 132 | var network string 133 | var udpAddr bool 134 | if flags&flagUDP == 0 { 135 | network = N.NetworkTCP 136 | } else { 137 | network = N.NetworkUDP 138 | udpAddr = flags&flagAddr != 0 139 | } 140 | return &StreamRequest{network, destination, udpAddr}, nil 141 | } 142 | 143 | func streamRequestLen(request StreamRequest) int { 144 | var rLen int 145 | rLen += 1 // version 146 | rLen += 2 // flags 147 | rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) 148 | return rLen 149 | } 150 | 151 | func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) error { 152 | destination := request.Destination 153 | var flags uint16 154 | if request.Network == N.NetworkUDP { 155 | flags |= flagUDP 156 | } 157 | if request.PacketAddr { 158 | flags |= flagAddr 159 | if !destination.IsValid() { 160 | destination = Destination 161 | } 162 | } 163 | common.Must(binary.Write(buffer, binary.BigEndian, flags)) 164 | return M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 165 | } 166 | 167 | type StreamResponse struct { 168 | Status uint8 169 | Message string 170 | } 171 | 172 | func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { 173 | var response StreamResponse 174 | err := binary.Read(reader, binary.BigEndian, &response.Status) 175 | if err != nil { 176 | return nil, err 177 | } 178 | if response.Status == statusError { 179 | response.Message, err = varbin.ReadValue[string](reader, binary.BigEndian) 180 | if err != nil { 181 | return nil, err 182 | } 183 | } 184 | return &response, nil 185 | } 186 | -------------------------------------------------------------------------------- /protocol_conn.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/sagernet/sing/common/buf" 7 | "github.com/sagernet/sing/common/bufio" 8 | N "github.com/sagernet/sing/common/network" 9 | ) 10 | 11 | type protocolConn struct { 12 | net.Conn 13 | request Request 14 | requestWritten bool 15 | } 16 | 17 | func newProtocolConn(conn net.Conn, request Request) net.Conn { 18 | writer, isVectorised := bufio.CreateVectorisedWriter(conn) 19 | if isVectorised { 20 | return &vectorisedProtocolConn{ 21 | protocolConn{ 22 | Conn: conn, 23 | request: request, 24 | }, 25 | writer, 26 | } 27 | } else { 28 | return &protocolConn{ 29 | Conn: conn, 30 | request: request, 31 | } 32 | } 33 | } 34 | 35 | func (c *protocolConn) NeedHandshake() bool { 36 | return !c.requestWritten 37 | } 38 | 39 | func (c *protocolConn) Write(p []byte) (n int, err error) { 40 | if c.requestWritten { 41 | return c.Conn.Write(p) 42 | } 43 | buffer := EncodeRequest(c.request, p) 44 | n, err = c.Conn.Write(buffer.Bytes()) 45 | buffer.Release() 46 | if err == nil { 47 | n-- 48 | } 49 | c.requestWritten = true 50 | return n, err 51 | } 52 | 53 | func (c *protocolConn) Upstream() any { 54 | return c.Conn 55 | } 56 | 57 | type vectorisedProtocolConn struct { 58 | protocolConn 59 | writer N.VectorisedWriter 60 | } 61 | 62 | func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error { 63 | if c.requestWritten { 64 | return c.writer.WriteVectorised(buffers) 65 | } 66 | c.requestWritten = true 67 | buffer := EncodeRequest(c.request, nil) 68 | return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...)) 69 | } 70 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "net" 6 | 7 | "github.com/sagernet/sing/common/bufio" 8 | "github.com/sagernet/sing/common/debug" 9 | E "github.com/sagernet/sing/common/exceptions" 10 | "github.com/sagernet/sing/common/logger" 11 | M "github.com/sagernet/sing/common/metadata" 12 | N "github.com/sagernet/sing/common/network" 13 | "github.com/sagernet/sing/common/task" 14 | ) 15 | 16 | // Deprecated: Use ServiceHandlerEx instead. 17 | // 18 | //nolint:staticcheck 19 | type ServiceHandler interface { 20 | N.TCPConnectionHandler 21 | N.UDPConnectionHandler 22 | } 23 | 24 | type ServiceHandlerEx interface { 25 | N.TCPConnectionHandlerEx 26 | N.UDPConnectionHandlerEx 27 | } 28 | 29 | type Service struct { 30 | newStreamContext func(context.Context, net.Conn) context.Context 31 | logger logger.ContextLogger 32 | handler ServiceHandler 33 | handlerEx ServiceHandlerEx 34 | padding bool 35 | brutal BrutalOptions 36 | } 37 | 38 | type ServiceOptions struct { 39 | NewStreamContext func(context.Context, net.Conn) context.Context 40 | Logger logger.ContextLogger 41 | Handler ServiceHandler 42 | HandlerEx ServiceHandlerEx 43 | Padding bool 44 | Brutal BrutalOptions 45 | } 46 | 47 | func NewService(options ServiceOptions) (*Service, error) { 48 | if options.Brutal.Enabled && !BrutalAvailable && !debug.Enabled { 49 | return nil, E.New("TCP Brutal is only supported on Linux") 50 | } 51 | return &Service{ 52 | newStreamContext: options.NewStreamContext, 53 | logger: options.Logger, 54 | handler: options.Handler, 55 | handlerEx: options.HandlerEx, 56 | padding: options.Padding, 57 | brutal: options.Brutal, 58 | }, nil 59 | } 60 | 61 | // Deprecated: Use NewConnectionEx instead. 62 | func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 63 | return s.newConnection(ctx, conn, metadata.Source) 64 | } 65 | 66 | func (s *Service) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { 67 | err := s.newConnection(ctx, conn, source) 68 | N.CloseOnHandshakeFailure(conn, onClose, err) 69 | if err != nil { 70 | s.logger.ErrorContext(ctx, E.Cause(err, "process multiplex connection from ", source)) 71 | } 72 | } 73 | 74 | func (s *Service) newConnection(ctx context.Context, conn net.Conn, source M.Socksaddr) error { 75 | request, err := ReadRequest(conn) 76 | if err != nil { 77 | return err 78 | } 79 | if request.Padding { 80 | conn = newPaddingConn(conn) 81 | } else if s.padding { 82 | return E.New("non-padded connection rejected") 83 | } 84 | session, err := newServerSession(conn, request.Protocol) 85 | if err != nil { 86 | return err 87 | } 88 | var group task.Group 89 | group.Append0(func(_ context.Context) error { 90 | for { 91 | stream, aErr := session.Accept() 92 | if aErr != nil { 93 | return aErr 94 | } 95 | streamCtx := s.newStreamContext(ctx, stream) 96 | go func() { 97 | hErr := s.newSession(streamCtx, conn, stream, source) 98 | if hErr != nil { 99 | stream.Close() 100 | s.logger.ErrorContext(streamCtx, E.Cause(hErr, "process multiplex stream")) 101 | } 102 | }() 103 | } 104 | }) 105 | group.Cleanup(func() { 106 | session.Close() 107 | }) 108 | return group.Run(ctx) 109 | } 110 | 111 | func (s *Service) newSession(ctx context.Context, sessionConn net.Conn, stream net.Conn, source M.Socksaddr) error { 112 | stream = &wrapStream{stream} 113 | request, err := ReadStreamRequest(stream) 114 | if err != nil { 115 | return E.Cause(err, "read multiplex stream request") 116 | } 117 | destination := request.Destination 118 | if request.Network == N.NetworkTCP { 119 | conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)} 120 | if request.Destination.Fqdn == BrutalExchangeDomain { 121 | defer stream.Close() 122 | var clientReceiveBPS uint64 123 | clientReceiveBPS, err = ReadBrutalRequest(conn) 124 | if err != nil { 125 | return E.Cause(err, "read brutal request") 126 | } 127 | if !s.brutal.Enabled { 128 | err = WriteBrutalResponse(conn, 0, false, "brutal is not enabled by the server") 129 | if err != nil { 130 | return E.Cause(err, "write brutal response") 131 | } 132 | return nil 133 | } 134 | sendBPS := s.brutal.SendBPS 135 | if clientReceiveBPS < sendBPS { 136 | sendBPS = clientReceiveBPS 137 | } 138 | err = SetBrutalOptions(sessionConn, sendBPS) 139 | if err != nil { 140 | // ignore error in test 141 | if !debug.Enabled { 142 | err = WriteBrutalResponse(conn, 0, false, E.Cause(err, "enable TCP Brutal").Error()) 143 | if err != nil { 144 | return E.Cause(err, "write brutal response") 145 | } 146 | return nil 147 | } 148 | } 149 | err = WriteBrutalResponse(conn, s.brutal.ReceiveBPS, true, "") 150 | if err != nil { 151 | return E.Cause(err, "write brutal response") 152 | } 153 | return nil 154 | } 155 | s.logger.InfoContext(ctx, "inbound multiplex connection to ", destination) 156 | if s.handler != nil { 157 | //nolint:staticcheck 158 | s.handler.NewConnection(ctx, conn, M.Metadata{Source: source, Destination: destination}) 159 | } else { 160 | s.handlerEx.NewConnectionEx(ctx, conn, source, destination, nil) 161 | } 162 | } else { 163 | var packetConn N.PacketConn 164 | if !request.PacketAddr { 165 | s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", destination) 166 | packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} 167 | } else { 168 | s.logger.InfoContext(ctx, "inbound multiplex packet connection") 169 | packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} 170 | } 171 | if s.handler != nil { 172 | //nolint:staticcheck 173 | s.handler.NewPacketConnection(ctx, packetConn, M.Metadata{Source: source, Destination: destination}) 174 | } else { 175 | s.handlerEx.NewPacketConnectionEx(ctx, packetConn, source, destination, nil) 176 | } 177 | } 178 | return nil 179 | } 180 | -------------------------------------------------------------------------------- /server_conn.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "net" 7 | "sync" 8 | 9 | "github.com/sagernet/sing/common" 10 | "github.com/sagernet/sing/common/buf" 11 | M "github.com/sagernet/sing/common/metadata" 12 | N "github.com/sagernet/sing/common/network" 13 | "github.com/sagernet/sing/common/varbin" 14 | ) 15 | 16 | type serverConn struct { 17 | N.ExtendedConn 18 | responseWritten bool 19 | } 20 | 21 | func (c *serverConn) NeedHandshake() bool { 22 | return !c.responseWritten 23 | } 24 | 25 | func (c *serverConn) HandshakeFailure(err error) error { 26 | errMessage := err.Error() 27 | buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) 28 | defer buffer.Release() 29 | common.Must( 30 | buffer.WriteByte(statusError), 31 | varbin.Write(buffer, binary.BigEndian, errMessage), 32 | ) 33 | return common.Error(c.ExtendedConn.Write(buffer.Bytes())) 34 | } 35 | 36 | func (c *serverConn) Write(b []byte) (n int, err error) { 37 | if c.responseWritten { 38 | return c.ExtendedConn.Write(b) 39 | } 40 | buffer := buf.NewSize(1 + len(b)) 41 | defer buffer.Release() 42 | common.Must( 43 | buffer.WriteByte(statusSuccess), 44 | common.Error(buffer.Write(b)), 45 | ) 46 | _, err = c.ExtendedConn.Write(buffer.Bytes()) 47 | if err != nil { 48 | return 49 | } 50 | c.responseWritten = true 51 | return len(b), nil 52 | } 53 | 54 | func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error { 55 | if c.responseWritten { 56 | return c.ExtendedConn.WriteBuffer(buffer) 57 | } 58 | buffer.ExtendHeader(1)[0] = statusSuccess 59 | c.responseWritten = true 60 | return c.ExtendedConn.WriteBuffer(buffer) 61 | } 62 | 63 | func (c *serverConn) FrontHeadroom() int { 64 | if !c.responseWritten { 65 | return 1 66 | } 67 | return 0 68 | } 69 | 70 | func (c *serverConn) NeedAdditionalReadDeadline() bool { 71 | return true 72 | } 73 | 74 | func (c *serverConn) Upstream() any { 75 | return c.ExtendedConn 76 | } 77 | 78 | type serverPacketConn struct { 79 | N.ExtendedConn 80 | access sync.Mutex 81 | destination M.Socksaddr 82 | responseWritten bool 83 | } 84 | 85 | func (c *serverPacketConn) NeedHandshake() bool { 86 | return !c.responseWritten 87 | } 88 | 89 | func (c *serverPacketConn) HandshakeFailure(err error) error { 90 | errMessage := err.Error() 91 | buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) 92 | defer buffer.Release() 93 | common.Must( 94 | buffer.WriteByte(statusError), 95 | varbin.Write(buffer, binary.BigEndian, errMessage), 96 | ) 97 | return common.Error(c.ExtendedConn.Write(buffer.Bytes())) 98 | } 99 | 100 | func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 101 | var length uint16 102 | err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 103 | if err != nil { 104 | return 105 | } 106 | _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 107 | if err != nil { 108 | return 109 | } 110 | destination = c.destination 111 | return 112 | } 113 | 114 | func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 115 | pLen := buffer.Len() 116 | common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) 117 | if !c.responseWritten { 118 | c.access.Lock() 119 | if c.responseWritten { 120 | c.access.Unlock() 121 | } else { 122 | defer c.access.Unlock() 123 | } 124 | buffer.ExtendHeader(1)[0] = statusSuccess 125 | c.responseWritten = true 126 | } 127 | return c.ExtendedConn.WriteBuffer(buffer) 128 | } 129 | 130 | func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 131 | var length uint16 132 | err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 133 | if err != nil { 134 | return 135 | } 136 | if cap(p) < int(length) { 137 | return 0, nil, io.ErrShortBuffer 138 | } 139 | n, err = io.ReadFull(c.ExtendedConn, p[:length]) 140 | return 141 | } 142 | 143 | func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 144 | if !c.responseWritten { 145 | c.access.Lock() 146 | if c.responseWritten { 147 | c.access.Unlock() 148 | } else { 149 | defer c.access.Unlock() 150 | _, err = c.ExtendedConn.Write([]byte{statusSuccess}) 151 | if err != nil { 152 | return 153 | } 154 | c.responseWritten = true 155 | } 156 | } 157 | err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) 158 | if err != nil { 159 | return 160 | } 161 | return c.ExtendedConn.Write(p) 162 | } 163 | 164 | func (c *serverPacketConn) NeedAdditionalReadDeadline() bool { 165 | return true 166 | } 167 | 168 | func (c *serverPacketConn) Upstream() any { 169 | return c.ExtendedConn 170 | } 171 | 172 | func (c *serverPacketConn) FrontHeadroom() int { 173 | if !c.responseWritten { 174 | return 3 175 | } 176 | return 2 177 | } 178 | 179 | type serverPacketAddrConn struct { 180 | N.ExtendedConn 181 | access sync.Mutex 182 | responseWritten bool 183 | } 184 | 185 | func (c *serverPacketAddrConn) NeedHandshake() bool { 186 | return !c.responseWritten 187 | } 188 | 189 | func (c *serverPacketAddrConn) HandshakeFailure(err error) error { 190 | errMessage := err.Error() 191 | buffer := buf.NewSize(1 + varbin.UvarintLen(uint64(len(errMessage))) + len(errMessage)) 192 | defer buffer.Release() 193 | common.Must( 194 | buffer.WriteByte(statusError), 195 | varbin.Write(buffer, binary.BigEndian, errMessage), 196 | ) 197 | return common.Error(c.ExtendedConn.Write(buffer.Bytes())) 198 | } 199 | 200 | func (c *serverPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 201 | destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) 202 | if err != nil { 203 | return 204 | } 205 | if destination.IsFqdn() { 206 | addr = destination 207 | } else { 208 | addr = destination.UDPAddr() 209 | } 210 | var length uint16 211 | err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 212 | if err != nil { 213 | return 214 | } 215 | if cap(p) < int(length) { 216 | return 0, nil, io.ErrShortBuffer 217 | } 218 | n, err = io.ReadFull(c.ExtendedConn, p[:length]) 219 | return 220 | } 221 | 222 | func (c *serverPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 223 | if !c.responseWritten { 224 | c.access.Lock() 225 | if c.responseWritten { 226 | c.access.Unlock() 227 | } else { 228 | defer c.access.Unlock() 229 | _, err = c.ExtendedConn.Write([]byte{statusSuccess}) 230 | if err != nil { 231 | return 232 | } 233 | c.responseWritten = true 234 | } 235 | } 236 | err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) 237 | if err != nil { 238 | return 239 | } 240 | err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) 241 | if err != nil { 242 | return 243 | } 244 | return c.ExtendedConn.Write(p) 245 | } 246 | 247 | func (c *serverPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 248 | destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) 249 | if err != nil { 250 | return 251 | } 252 | var length uint16 253 | err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 254 | if err != nil { 255 | return 256 | } 257 | _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 258 | if err != nil { 259 | return 260 | } 261 | return 262 | } 263 | 264 | func (c *serverPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 265 | pLen := buffer.Len() 266 | common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) 267 | err := M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination) 268 | if err != nil { 269 | return err 270 | } 271 | if !c.responseWritten { 272 | c.access.Lock() 273 | if c.responseWritten { 274 | c.access.Unlock() 275 | } else { 276 | defer c.access.Unlock() 277 | buffer.ExtendHeader(1)[0] = statusSuccess 278 | c.responseWritten = true 279 | } 280 | } 281 | return c.ExtendedConn.WriteBuffer(buffer) 282 | } 283 | 284 | func (c *serverPacketAddrConn) NeedAdditionalReadDeadline() bool { 285 | return true 286 | } 287 | 288 | func (c *serverPacketAddrConn) Upstream() any { 289 | return c.ExtendedConn 290 | } 291 | 292 | func (c *serverPacketAddrConn) FrontHeadroom() int { 293 | if !c.responseWritten { 294 | return 3 + M.MaxSocksaddrLength 295 | } 296 | return 2 + M.MaxSocksaddrLength 297 | } 298 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "reflect" 7 | 8 | E "github.com/sagernet/sing/common/exceptions" 9 | "github.com/sagernet/smux" 10 | 11 | "github.com/hashicorp/yamux" 12 | ) 13 | 14 | type abstractSession interface { 15 | Open() (net.Conn, error) 16 | Accept() (net.Conn, error) 17 | NumStreams() int 18 | Close() error 19 | IsClosed() bool 20 | CanTakeNewRequest() bool 21 | } 22 | 23 | func newClientSession(conn net.Conn, protocol byte) (abstractSession, error) { 24 | switch protocol { 25 | case ProtocolH2Mux: 26 | session, err := newH2MuxClient(conn) 27 | if err != nil { 28 | return nil, err 29 | } 30 | return session, nil 31 | case ProtocolSmux: 32 | client, err := smux.Client(conn, smuxConfig()) 33 | if err != nil { 34 | return nil, err 35 | } 36 | return &smuxSession{client}, nil 37 | case ProtocolYAMux: 38 | checkYAMuxConn(conn) 39 | client, err := yamux.Client(conn, yaMuxConfig()) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return &yamuxSession{client}, nil 44 | default: 45 | return nil, E.New("unexpected protocol ", protocol) 46 | } 47 | } 48 | 49 | func newServerSession(conn net.Conn, protocol byte) (abstractSession, error) { 50 | switch protocol { 51 | case ProtocolH2Mux: 52 | return newH2MuxServer(conn), nil 53 | case ProtocolSmux: 54 | client, err := smux.Server(conn, smuxConfig()) 55 | if err != nil { 56 | return nil, err 57 | } 58 | return &smuxSession{client}, nil 59 | case ProtocolYAMux: 60 | checkYAMuxConn(conn) 61 | client, err := yamux.Server(conn, yaMuxConfig()) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return &yamuxSession{client}, nil 66 | default: 67 | return nil, E.New("unexpected protocol ", protocol) 68 | } 69 | } 70 | 71 | func checkYAMuxConn(conn net.Conn) { 72 | if conn.LocalAddr() == nil || conn.RemoteAddr() == nil { 73 | panic("found net.Conn with nil addr: " + reflect.TypeOf(conn).String()) 74 | } 75 | } 76 | 77 | var _ abstractSession = (*smuxSession)(nil) 78 | 79 | type smuxSession struct { 80 | *smux.Session 81 | } 82 | 83 | func (s *smuxSession) Open() (net.Conn, error) { 84 | return s.OpenStream() 85 | } 86 | 87 | func (s *smuxSession) Accept() (net.Conn, error) { 88 | return s.AcceptStream() 89 | } 90 | 91 | func (s *smuxSession) CanTakeNewRequest() bool { 92 | return true 93 | } 94 | 95 | type yamuxSession struct { 96 | *yamux.Session 97 | } 98 | 99 | func (y *yamuxSession) CanTakeNewRequest() bool { 100 | return true 101 | } 102 | 103 | func smuxConfig() *smux.Config { 104 | config := smux.DefaultConfig() 105 | config.KeepAliveDisabled = true 106 | return config 107 | } 108 | 109 | func yaMuxConfig() *yamux.Config { 110 | config := yamux.DefaultConfig() 111 | config.LogOutput = io.Discard 112 | config.StreamCloseTimeout = TCPTimeout 113 | config.StreamOpenTimeout = TCPTimeout 114 | return config 115 | } 116 | --------------------------------------------------------------------------------